From c6342763ab041fd9e3b6176895b131cf15f0f2e7 Mon Sep 17 00:00:00 2001 From: Yunnglin Date: Tue, 10 Mar 2026 16:29:42 +0800 Subject: [PATCH 01/24] update refact --- src/twinkle/server/__main__.py | 26 +- src/twinkle/server/common/__init__.py | 18 + .../server/{tinker => }/common/datum.py | 2 + src/twinkle/server/common/io_utils.py | 332 +++++++++ .../server/{tinker => }/common/router.py | 2 + .../server/{twinkle => }/common/serialize.py | 1 + src/twinkle/server/gateway/__init__.py | 3 + .../server/{tinker => gateway}/proxy.py | 73 +- src/twinkle/server/gateway/server.py | 126 ++++ src/twinkle/server/gateway/tinker_router.py | 289 ++++++++ src/twinkle/server/gateway/twinkle_router.py | 106 +++ src/twinkle/server/launcher.py | 123 +--- src/twinkle/server/model/__init__.py | 3 + src/twinkle/server/model/app.py | 159 +++++ .../common => model/backends}/__init__.py | 0 .../backends}/megatron_model.py | 91 +-- .../model/backends/transformers_model.py | 267 +++++++ src/twinkle/server/model/tinker_handlers.py | 298 ++++++++ src/twinkle/server/model/twinkle_handlers.py | 377 ++++++++++ src/twinkle/server/processor/__init__.py | 3 + .../processor.py => processor/app.py} | 93 +-- src/twinkle/server/sampler/__init__.py | 3 + src/twinkle/server/sampler/app.py | 159 +++++ src/twinkle/server/sampler/tinker_handlers.py | 120 ++++ .../server/sampler/twinkle_handlers.py | 139 ++++ src/twinkle/server/tinker/__init__.py | 18 - src/twinkle/server/tinker/common/__init__.py | 3 - .../server/tinker/common/compat_base.py | 151 ---- src/twinkle/server/tinker/common/io_utils.py | 181 ----- .../tinker/common/transformers_model.py | 148 ---- src/twinkle/server/tinker/model.py | 659 ------------------ src/twinkle/server/tinker/sampler.py | 251 ------- src/twinkle/server/tinker/server.py | 613 ---------------- src/twinkle/server/twinkle/__init__.py | 20 - src/twinkle/server/twinkle/common/io_utils.py | 235 ------- .../twinkle/common/transformers_model.py | 41 -- src/twinkle/server/twinkle/model.py | 584 ---------------- src/twinkle/server/twinkle/sampler.py | 308 -------- src/twinkle/server/twinkle/server.py | 270 ------- src/twinkle/server/utils/task_queue.py | 52 ++ src/twinkle_client/http/http_utils.py | 2 +- src/twinkle_client/manager.py | 4 +- src/twinkle_client/types/__init__.py | 1 + src/twinkle_client/types/model.py | 132 ++++ src/twinkle_client/types/processor.py | 30 + src/twinkle_client/types/sampler.py | 68 ++ src/twinkle_client/types/server.py | 16 + src/twinkle_client/types/training.py | 91 +++ 48 files changed, 2913 insertions(+), 3778 deletions(-) create mode 100644 src/twinkle/server/common/__init__.py rename src/twinkle/server/{tinker => }/common/datum.py (97%) create mode 100644 src/twinkle/server/common/io_utils.py rename src/twinkle/server/{tinker => }/common/router.py (96%) rename src/twinkle/server/{twinkle => }/common/serialize.py (97%) create mode 100644 src/twinkle/server/gateway/__init__.py rename src/twinkle/server/{tinker => gateway}/proxy.py (63%) create mode 100644 src/twinkle/server/gateway/server.py create mode 100644 src/twinkle/server/gateway/tinker_router.py create mode 100644 src/twinkle/server/gateway/twinkle_router.py create mode 100644 src/twinkle/server/model/__init__.py create mode 100644 src/twinkle/server/model/app.py rename src/twinkle/server/{twinkle/common => model/backends}/__init__.py (100%) rename src/twinkle/server/{tinker/common => model/backends}/megatron_model.py (58%) create mode 100644 src/twinkle/server/model/backends/transformers_model.py create mode 100644 src/twinkle/server/model/tinker_handlers.py create mode 100644 src/twinkle/server/model/twinkle_handlers.py create mode 100644 src/twinkle/server/processor/__init__.py rename src/twinkle/server/{twinkle/processor.py => processor/app.py} (74%) create mode 100644 src/twinkle/server/sampler/__init__.py create mode 100644 src/twinkle/server/sampler/app.py create mode 100644 src/twinkle/server/sampler/tinker_handlers.py create mode 100644 src/twinkle/server/sampler/twinkle_handlers.py delete mode 100644 src/twinkle/server/tinker/__init__.py delete mode 100644 src/twinkle/server/tinker/common/__init__.py delete mode 100644 src/twinkle/server/tinker/common/compat_base.py delete mode 100644 src/twinkle/server/tinker/common/io_utils.py delete mode 100644 src/twinkle/server/tinker/common/transformers_model.py delete mode 100644 src/twinkle/server/tinker/model.py delete mode 100644 src/twinkle/server/tinker/sampler.py delete mode 100644 src/twinkle/server/tinker/server.py delete mode 100644 src/twinkle/server/twinkle/__init__.py delete mode 100644 src/twinkle/server/twinkle/common/io_utils.py delete mode 100644 src/twinkle/server/twinkle/common/transformers_model.py delete mode 100644 src/twinkle/server/twinkle/model.py delete mode 100644 src/twinkle/server/twinkle/sampler.py delete mode 100644 src/twinkle/server/twinkle/server.py create mode 100644 src/twinkle_client/types/__init__.py create mode 100644 src/twinkle_client/types/model.py create mode 100644 src/twinkle_client/types/processor.py create mode 100644 src/twinkle_client/types/sampler.py create mode 100644 src/twinkle_client/types/server.py create mode 100644 src/twinkle_client/types/training.py diff --git a/src/twinkle/server/__main__.py b/src/twinkle/server/__main__.py index 17ea2e1f..e18283c3 100644 --- a/src/twinkle/server/__main__.py +++ b/src/twinkle/server/__main__.py @@ -5,12 +5,6 @@ Usage: # From config file python -m twinkle.server --config server_config.yaml - - # With server type override - python -m twinkle.server --config server_config.yaml --server-type tinker - - # Quick start with minimal args - python -m twinkle.server --server-type tinker --port 8000 --model-id "Qwen/Qwen3.5-4B" """ from __future__ import annotations @@ -27,15 +21,12 @@ def create_parser() -> argparse.ArgumentParser: """Create the argument parser.""" parser = argparse.ArgumentParser( prog='python -m twinkle.server', - description='Twinkle Server Launcher - Unified launcher for tinker and twinkle servers', + description='Twinkle Server Launcher - Unified launcher supporting both Tinker and Twinkle clients', formatter_class=argparse.RawDescriptionHelpFormatter, epilog=""" Examples: # Start server from YAML config file python -m twinkle.server --config server_config.yaml - - # Start tinker server with specific config - python -m twinkle.server -c config.yaml -t tinker """, ) @@ -49,23 +40,12 @@ def create_parser() -> argparse.ArgumentParser: help='Path to YAML configuration file (required)', ) - # Server type - parser.add_argument( - '-t', - '--server-type', - type=str, - default='twinkle', - choices=['tinker', 'twinkle'], - metavar='TYPE', - help="Server type: 'tinker' or 'twinkle' (default: twinkle)", - ) - # Ray options parser.add_argument( '--namespace', type=str, metavar='NS', - help="Ray namespace (default: 'twinkle_cluster' for tinker, None for twinkle)", + help="Ray namespace (default: 'twinkle_cluster')", ) # Runtime options @@ -97,7 +77,6 @@ def main(args: list[str] | None = None) -> int: try: from twinkle.server.launcher import launch_server - # Config file mode config_path = Path(parsed_args.config) if not config_path.exists(): logger.error(f'Config file not found: {config_path}') @@ -105,7 +84,6 @@ def main(args: list[str] | None = None) -> int: launch_server( config_path=config_path, - server_type=parsed_args.server_type, ray_namespace=parsed_args.namespace, ) diff --git a/src/twinkle/server/common/__init__.py b/src/twinkle/server/common/__init__.py new file mode 100644 index 00000000..495ae39b --- /dev/null +++ b/src/twinkle/server/common/__init__.py @@ -0,0 +1,18 @@ +# Copyright (c) ModelScope Contributors. All rights reserved. +from .datum import datum_to_input_feature, extract_rl_feature, input_feature_to_datum +from .io_utils import create_checkpoint_manager, create_training_run_manager, validate_ownership, validate_user_path +from .router import StickyLoraRequestRouter +from .serialize import deserialize_object, serialize_object + +__all__ = [ + 'datum_to_input_feature', + 'extract_rl_feature', + 'input_feature_to_datum', + 'create_checkpoint_manager', + 'create_training_run_manager', + 'validate_user_path', + 'validate_ownership', + 'StickyLoraRequestRouter', + 'deserialize_object', + 'serialize_object', +] diff --git a/src/twinkle/server/tinker/common/datum.py b/src/twinkle/server/common/datum.py similarity index 97% rename from src/twinkle/server/tinker/common/datum.py rename to src/twinkle/server/common/datum.py index 0eb74f82..7dd0ae1c 100644 --- a/src/twinkle/server/tinker/common/datum.py +++ b/src/twinkle/server/common/datum.py @@ -1,3 +1,5 @@ +# Copyright (c) ModelScope Contributors. All rights reserved. +# Moved from tinker/common/datum.py — logic unchanged. from __future__ import annotations import numpy as np diff --git a/src/twinkle/server/common/io_utils.py b/src/twinkle/server/common/io_utils.py new file mode 100644 index 00000000..089a4955 --- /dev/null +++ b/src/twinkle/server/common/io_utils.py @@ -0,0 +1,332 @@ +# Copyright (c) ModelScope Contributors. All rights reserved. +""" +Unified IO utilities for managing training runs and checkpoints. + +Merges tinker/common/io_utils.py and twinkle/common/io_utils.py. +Both client-type implementations share the same underlying base classes; +factory functions accept a ``client_type`` parameter ('tinker' or 'twinkle'). + +Pydantic models that need to be shared with the client live in +``twinkle_client.types.training``. +""" +from datetime import datetime +from tinker import types as tinker_types +from typing import Any, Dict, List, Optional + +from twinkle.server.utils.io_utils import (CHECKPOINT_INFO_FILENAME, TRAIN_RUN_INFO_FILENAME, TWINKLE_DEFAULT_SAVE_DIR, + BaseCheckpoint, BaseCheckpointManager, BaseCreateModelRequest, + BaseLoraConfig, BaseParsedCheckpointPath, BaseTrainingRun, + BaseTrainingRunManager, BaseWeightsInfoResponse, Cursor, ResolvedLoadPath, + validate_ownership, validate_user_path) +# Re-export twinkle-native pydantic models from twinkle_client.types +from twinkle_client.types.training import Checkpoint as TwinkleCheckpoint +from twinkle_client.types.training import (CheckpointsListResponse, CreateModelRequest, LoraConfig, + ParsedCheckpointTwinklePath) +from twinkle_client.types.training import TrainingRun as TwinkleTrainingRun +from twinkle_client.types.training import TrainingRunsResponse, WeightsInfoResponse + +__all__ = [ + 'create_checkpoint_manager', + 'create_training_run_manager', + 'validate_user_path', + 'validate_ownership', + 'ResolvedLoadPath', + 'Cursor', + # Twinkle-native models (re-exported for convenience) + 'TwinkleCheckpoint', + 'TwinkleTrainingRun', + 'TrainingRunsResponse', + 'CheckpointsListResponse', + 'WeightsInfoResponse', + 'LoraConfig', + 'CreateModelRequest', + 'ParsedCheckpointTwinklePath', +] + +# --------------------------------------------------------------------------- +# Tinker-specific managers (use tinker.types for model instances) +# --------------------------------------------------------------------------- + + +class TinkerTrainingRunManager(BaseTrainingRunManager): + """Tinker-specific training run manager using tinker.types models.""" + + @property + def train_run_info_filename(self) -> str: + return TRAIN_RUN_INFO_FILENAME + + def _create_training_run(self, model_id: str, run_config: tinker_types.CreateModelRequest) -> Dict[str, Any]: + lora_config = run_config.lora_config + train_run_data = tinker_types.TrainingRun( + training_run_id=model_id, + base_model=run_config.base_model, + model_owner=self.token, + is_lora=True if lora_config else False, + corrupted=False, + lora_rank=lora_config.rank if lora_config else None, + last_request_time=datetime.now(), + last_checkpoint=None, + last_sampler_checkpoint=None, + user_metadata=run_config.user_metadata) + + new_data = train_run_data.model_dump(mode='json') + if lora_config: + new_data['train_unembed'] = lora_config.train_unembed + new_data['train_mlp'] = lora_config.train_mlp + new_data['train_attn'] = lora_config.train_attn + return new_data + + def _parse_training_run(self, data: Dict[str, Any]) -> tinker_types.TrainingRun: + data = self._transform_checkpoint_fields(data) + return tinker_types.TrainingRun(**data) + + def _transform_checkpoint_fields(self, data: Dict[str, Any]) -> Dict[str, Any]: + data = data.copy() + for field in ['last_checkpoint', 'last_sampler_checkpoint']: + if field in data and data[field] is not None: + ckpt = data[field].copy() + if 'twinkle_path' in ckpt and 'tinker_path' not in ckpt: + ckpt['tinker_path'] = ckpt.pop('twinkle_path') + elif 'tinker_path' not in ckpt: + path = ckpt.get('path') or ckpt.get('twinkle_path') + if path: + ckpt['tinker_path'] = path + elif 'checkpoint_id' in ckpt and 'training_run_id' in data: + ckpt['tinker_path'] = f"twinkle://{data['training_run_id']}/{ckpt['checkpoint_id']}" + data[field] = ckpt + return data + + def _create_training_runs_response(self, runs: List[tinker_types.TrainingRun], limit: int, offset: int, + total: int) -> tinker_types.TrainingRunsResponse: + return tinker_types.TrainingRunsResponse( + training_runs=runs, cursor=tinker_types.Cursor(limit=limit, offset=offset, total_count=total)) + + +class TinkerCheckpointManager(BaseCheckpointManager): + """Tinker-specific checkpoint manager using tinker.types models.""" + + @property + def path_prefix(self) -> str: + return 'twinkle://' + + @property + def path_field_name(self) -> str: + return 'tinker_path' + + def _create_checkpoint(self, + checkpoint_id, + checkpoint_type, + path, + size_bytes, + public, + base_model=None, + is_lora=False, + lora_rank=None, + train_unembed=None, + train_mlp=None, + train_attn=None, + user_metadata=None) -> Dict[str, Any]: + checkpoint = tinker_types.Checkpoint( + checkpoint_id=checkpoint_id, + checkpoint_type=checkpoint_type, + time=datetime.now(), + tinker_path=path, + size_bytes=size_bytes, + public=public) + result = checkpoint.model_dump(mode='json') + result['base_model'] = base_model + result['is_lora'] = is_lora + result['lora_rank'] = lora_rank + result['train_unembed'] = train_unembed + result['train_mlp'] = train_mlp + result['train_attn'] = train_attn + result['user_metadata'] = user_metadata + return result + + def _parse_checkpoint(self, data: Dict[str, Any]) -> tinker_types.Checkpoint: + data = data.copy() + if 'twinkle_path' in data and 'tinker_path' not in data: + data['tinker_path'] = data.pop('twinkle_path') + elif 'tinker_path' not in data and 'path' in data: + data['tinker_path'] = data.pop('path') + return tinker_types.Checkpoint(**data) + + def _create_checkpoints_response( + self, checkpoints: List[tinker_types.Checkpoint]) -> tinker_types.CheckpointsListResponse: + return tinker_types.CheckpointsListResponse(checkpoints=checkpoints, cursor=None) + + def _create_parsed_path(self, path, training_run_id, checkpoint_type, + checkpoint_id) -> tinker_types.ParsedCheckpointTinkerPath: + return tinker_types.ParsedCheckpointTinkerPath( + tinker_path=path, + training_run_id=training_run_id, + checkpoint_type=checkpoint_type, + checkpoint_id=checkpoint_id, + ) + + def _create_weights_info(self, run_info: Dict[str, Any]) -> tinker_types.WeightsInfoResponse: + return tinker_types.WeightsInfoResponse(**run_info) + + def parse_tinker_path(self, tinker_path: str) -> Optional[tinker_types.ParsedCheckpointTinkerPath]: + return self.parse_path(tinker_path) + + +# --------------------------------------------------------------------------- +# Twinkle-specific managers (use twinkle_client.types.training models) +# --------------------------------------------------------------------------- + + +class TwinkleTrainingRunManager(BaseTrainingRunManager): + """Twinkle-specific training run manager.""" + + @property + def train_run_info_filename(self) -> str: + return TRAIN_RUN_INFO_FILENAME + + def _create_training_run(self, model_id: str, run_config: CreateModelRequest) -> Dict[str, Any]: + lora_config = run_config.lora_config + train_run_data = TwinkleTrainingRun( + training_run_id=model_id, + base_model=run_config.base_model, + model_owner=self.token, + is_lora=True if lora_config else False, + corrupted=False, + lora_rank=lora_config.rank if lora_config else None, + last_request_time=datetime.now(), + last_checkpoint=None, + last_sampler_checkpoint=None, + user_metadata=run_config.user_metadata) + + new_data = train_run_data.model_dump(mode='json') + if lora_config: + new_data['train_unembed'] = lora_config.train_unembed + new_data['train_mlp'] = lora_config.train_mlp + new_data['train_attn'] = lora_config.train_attn + return new_data + + def _parse_training_run(self, data: Dict[str, Any]) -> TwinkleTrainingRun: + return TwinkleTrainingRun(**data) + + def _create_training_runs_response(self, runs: List[TwinkleTrainingRun], limit: int, offset: int, + total: int) -> TrainingRunsResponse: + return TrainingRunsResponse(training_runs=runs, cursor=Cursor(limit=limit, offset=offset, total_count=total)) + + def get_with_permission(self, model_id: str) -> Optional[TwinkleTrainingRun]: + run = self.get(model_id) + if run and validate_ownership(self.token, run.model_owner): + return run + return None + + +class TwinkleCheckpointManager(BaseCheckpointManager): + """Twinkle-specific checkpoint manager.""" + + @property + def path_prefix(self) -> str: + return 'twinkle://' + + @property + def path_field_name(self) -> str: + return 'twinkle_path' + + def _create_checkpoint(self, + checkpoint_id, + checkpoint_type, + path, + size_bytes, + public, + base_model=None, + is_lora=False, + lora_rank=None, + train_unembed=None, + train_mlp=None, + train_attn=None, + user_metadata=None) -> Dict[str, Any]: + checkpoint = TwinkleCheckpoint( + checkpoint_id=checkpoint_id, + checkpoint_type=checkpoint_type, + time=datetime.now(), + twinkle_path=path, + size_bytes=size_bytes, + public=public, + base_model=base_model, + is_lora=is_lora, + lora_rank=lora_rank, + train_unembed=train_unembed, + train_mlp=train_mlp, + train_attn=train_attn, + user_metadata=user_metadata) + return checkpoint.model_dump(mode='json') + + def _parse_checkpoint(self, data: Dict[str, Any]) -> TwinkleCheckpoint: + data = data.copy() + if 'tinker_path' in data and 'twinkle_path' not in data: + data['twinkle_path'] = data.pop('tinker_path') + elif 'twinkle_path' not in data and 'path' in data: + data['twinkle_path'] = data.pop('path') + return TwinkleCheckpoint(**data) + + def get(self, model_id: str, checkpoint_id: str) -> Optional[TwinkleCheckpoint]: + data = self._read_ckpt_info(model_id, checkpoint_id) + if not data: + return None + if 'twinkle_path' not in data and 'tinker_path' not in data and 'path' not in data: + if 'checkpoint_id' in data: + data = data.copy() + data['twinkle_path'] = f"{self.path_prefix}{model_id}/{data['checkpoint_id']}" + return self._parse_checkpoint(data) + + def _create_checkpoints_response(self, checkpoints: List[TwinkleCheckpoint]) -> CheckpointsListResponse: + return CheckpointsListResponse(checkpoints=checkpoints, cursor=None) + + def _create_parsed_path(self, path, training_run_id, checkpoint_type, checkpoint_id) -> ParsedCheckpointTwinklePath: + return ParsedCheckpointTwinklePath( + path=path, + twinkle_path=path, + training_run_id=training_run_id, + checkpoint_type=checkpoint_type, + checkpoint_id=checkpoint_id, + ) + + def _create_weights_info(self, run_info: Dict[str, Any]) -> WeightsInfoResponse: + return WeightsInfoResponse( + training_run_id=run_info.get('training_run_id', ''), + base_model=run_info.get('base_model', ''), + model_owner=run_info.get('model_owner', ''), + is_lora=run_info.get('is_lora', False), + lora_rank=run_info.get('lora_rank'), + ) + + def parse_twinkle_path(self, twinkle_path: str) -> Optional[ParsedCheckpointTwinklePath]: + return self.parse_path(twinkle_path) + + +# --------------------------------------------------------------------------- +# Unified factory functions +# --------------------------------------------------------------------------- + + +def create_training_run_manager(token: str, client_type: str = 'twinkle'): + """Create a TrainingRunManager for the given token. + + Args: + token: User authentication token. + client_type: 'tinker' or 'twinkle' (default 'twinkle'). + """ + if client_type == 'tinker': + return TinkerTrainingRunManager(token) + return TwinkleTrainingRunManager(token) + + +def create_checkpoint_manager(token: str, client_type: str = 'twinkle'): + """Create a CheckpointManager for the given token. + + Args: + token: User authentication token. + client_type: 'tinker' or 'twinkle' (default 'twinkle'). + """ + if client_type == 'tinker': + run_mgr = TinkerTrainingRunManager(token) + return TinkerCheckpointManager(token, run_mgr) + run_mgr = TwinkleTrainingRunManager(token) + return TwinkleCheckpointManager(token, run_mgr) diff --git a/src/twinkle/server/tinker/common/router.py b/src/twinkle/server/common/router.py similarity index 96% rename from src/twinkle/server/tinker/common/router.py rename to src/twinkle/server/common/router.py index 19ec8650..27abbfbd 100644 --- a/src/twinkle/server/tinker/common/router.py +++ b/src/twinkle/server/common/router.py @@ -1,3 +1,5 @@ +# Copyright (c) ModelScope Contributors. All rights reserved. +# Moved from tinker/common/router.py — logic unchanged. from ray.serve.request_router import (FIFOMixin, MultiplexMixin, PendingRequest, ReplicaID, ReplicaResult, RequestRouter, RunningReplica) from typing import Dict, List, Optional diff --git a/src/twinkle/server/twinkle/common/serialize.py b/src/twinkle/server/common/serialize.py similarity index 97% rename from src/twinkle/server/twinkle/common/serialize.py rename to src/twinkle/server/common/serialize.py index de3ca4bb..f1b3f6dd 100644 --- a/src/twinkle/server/twinkle/common/serialize.py +++ b/src/twinkle/server/common/serialize.py @@ -1,4 +1,5 @@ # Copyright (c) ModelScope Contributors. All rights reserved. +# Moved from twinkle/common/serialize.py — logic unchanged. import json from numbers import Number from peft import LoraConfig diff --git a/src/twinkle/server/gateway/__init__.py b/src/twinkle/server/gateway/__init__.py new file mode 100644 index 00000000..1e6c2cbd --- /dev/null +++ b/src/twinkle/server/gateway/__init__.py @@ -0,0 +1,3 @@ +from .server import build_server_app + +__all__ = ['build_server_app'] diff --git a/src/twinkle/server/tinker/proxy.py b/src/twinkle/server/gateway/proxy.py similarity index 63% rename from src/twinkle/server/tinker/proxy.py rename to src/twinkle/server/gateway/proxy.py index bc429199..5517014e 100644 --- a/src/twinkle/server/tinker/proxy.py +++ b/src/twinkle/server/gateway/proxy.py @@ -1,15 +1,14 @@ # Copyright (c) ModelScope Contributors. All rights reserved. """ -Proxy utilities for forwarding requests to internal services. +Proxy utilities for forwarding requests to internal model/sampler services. -This module provides HTTP proxy functionality to route requests from the Tinker server -to appropriate model or sampler services based on base_model routing. +Moved from tinker/proxy.py. Updated proxy_to_model and proxy_to_sampler +to prepend the 'tinker/' prefix to endpoints so they route to /tinker/* paths +on the unified model/sampler deployments. """ - from __future__ import annotations import httpx -import os from fastapi import Request, Response from typing import Any @@ -21,11 +20,13 @@ class ServiceProxy: """HTTP proxy for routing requests to internal model and sampler services. - This proxy handles: + Handles: 1. URL construction using localhost to avoid external routing loops 2. Header forwarding with appropriate cleanup 3. Debug logging for troubleshooting 4. Error handling and response forwarding + + Tinker endpoints are routed to /tinker/ on the unified deployments. """ def __init__( @@ -33,28 +34,18 @@ def __init__( http_options: dict[str, Any] | None = None, route_prefix: str = '/api/v1', ): - """Initialize the service proxy. - - Args: - http_options: HTTP server options (host, port) for internal routing - route_prefix: URL prefix for routing (default: '/api/v1') - """ self.http_options = http_options or {} self.route_prefix = route_prefix - # Disable proxy for internal requests to avoid routing through external proxies + # Disable proxy env vars to avoid external routing self.client = httpx.AsyncClient(timeout=None, trust_env=False) def _build_target_url(self, service_type: str, base_model: str, endpoint: str) -> str: """Build the target URL for internal service routing. - Constructs URLs using localhost to avoid extra external hops. - When requests come from www.modelscope.com/twinkle, we proxy to - localhost:port directly instead of back to modelscope.com. - Args: service_type: Either 'model' or 'sampler' base_model: The base model name for routing - endpoint: The target endpoint name + endpoint: The target endpoint name (already includes tinker/ or twinkle/ prefix) Returns: Complete target URL for the internal service @@ -63,7 +54,6 @@ def _build_target_url(self, service_type: str, base_model: str, endpoint: str) - host = self.http_options.get('host', 'localhost') port = self.http_options.get('port', 8000) - # Use localhost for internal routing if host == '0.0.0.0': host = 'localhost' @@ -71,22 +61,11 @@ def _build_target_url(self, service_type: str, base_model: str, endpoint: str) - return f'{base_url}{prefix}/{service_type}/{base_model}/{endpoint}' def _prepare_headers(self, request_headers) -> dict[str, str]: - """Prepare headers for proxying by removing problematic headers. - - Args: - request_headers: Original request headers (case-insensitive from FastAPI) - - Returns: - Cleaned headers safe for proxying - """ + """Prepare headers for proxying by removing problematic headers.""" logger.debug('prepare_headers request_headers=%s', request_headers) - # Convert to dict while preserving case-insensitive lookups for special headers headers = dict(request_headers) - # Remove headers that should not be forwarded headers.pop('host', None) headers.pop('content-length', None) - # Add serve_multiplexed_model_id for sticky sessions if present - # Use case-insensitive lookup from original request_headers request_id = request_headers.get('X-Ray-Serve-Request-Id') if request_id is not None: headers['serve_multiplexed_model_id'] = request_id @@ -101,24 +80,20 @@ async def proxy_request( ) -> Response: """Generic proxy method to forward requests to model or sampler services. - This method consolidates the common proxy logic for both model and sampler endpoints. - Args: request: The incoming FastAPI request - endpoint: The target endpoint name (e.g., 'create_model', 'asample') + endpoint: The target endpoint path (e.g., 'tinker/create_model') base_model: The base model name for routing - service_type: Either 'model' or 'sampler' to determine the target service + service_type: Either 'model' or 'sampler' Returns: Proxied response from the target service """ body_bytes = await request.body() target_url = self._build_target_url(service_type, base_model, endpoint) - # Pass original request.headers (case-insensitive) instead of dict conversion headers = self._prepare_headers(request.headers) try: - # Debug logging for troubleshooting proxy issues logger.debug( 'proxy_request service=%s endpoint=%s target_url=%s request_id=%s', service_type, @@ -127,7 +102,6 @@ async def proxy_request( headers.get('serve_multiplexed_model_id'), ) - # Forward the request to the target service response = await self.client.request( method=request.method, url=target_url, @@ -136,7 +110,6 @@ async def proxy_request( params=request.query_params, ) - # Debug logging for response logger.debug( 'proxy_response status=%s body_preview=%s', response.status_code, @@ -154,31 +127,21 @@ async def proxy_request( return Response(content=f'Proxy Error: {str(e)}', status_code=502) async def proxy_to_model(self, request: Request, endpoint: str, base_model: str) -> Response: - """Proxy request to model endpoint. - - Routes the request to the appropriate model deployment based on base_model. + """Proxy request to model's tinker endpoint (/tinker/). Args: request: The incoming FastAPI request - endpoint: The target endpoint name (e.g., 'create_model', 'forward') + endpoint: The tinker endpoint name (e.g., 'create_model', 'forward') base_model: The base model name for routing - - Returns: - Proxied response from the model service """ - return await self.proxy_request(request, endpoint, base_model, 'model') + return await self.proxy_request(request, f'tinker/{endpoint}', base_model, 'model') async def proxy_to_sampler(self, request: Request, endpoint: str, base_model: str) -> Response: - """Proxy request to sampler endpoint. - - Routes the request to the appropriate sampler deployment based on base_model. + """Proxy request to sampler's tinker endpoint (/tinker/). Args: request: The incoming FastAPI request - endpoint: The target endpoint name (e.g., 'asample') + endpoint: The tinker endpoint name (e.g., 'asample') base_model: The base model name for routing - - Returns: - Proxied response from the sampler service """ - return await self.proxy_request(request, endpoint, base_model, 'sampler') + return await self.proxy_request(request, f'tinker/{endpoint}', base_model, 'sampler') diff --git a/src/twinkle/server/gateway/server.py b/src/twinkle/server/gateway/server.py new file mode 100644 index 00000000..9071a814 --- /dev/null +++ b/src/twinkle/server/gateway/server.py @@ -0,0 +1,126 @@ +# Copyright (c) ModelScope Contributors. All rights reserved. +""" +Unified Gateway Server. + +A single Ray Serve deployment that serves both Tinker (/tinker/*) and +Twinkle (/twinkle/*) management and proxy endpoints. +""" +from __future__ import annotations + +import asyncio +from fastapi import FastAPI, HTTPException, Request +from ray import serve +from tinker import types as tinker_types +from typing import Any + +from twinkle.server.utils.state import get_server_state +from twinkle.server.utils.validation import verify_request_token +from twinkle.utils.logger import get_logger +from .proxy import ServiceProxy +from .tinker_router import tinker_router +from .twinkle_router import twinkle_router + +logger = get_logger() + + +def build_server_app(deploy_options: dict[str, Any], + supported_models: list | None = None, + server_config: dict[str, Any] = {}, + http_options: dict[str, Any] | None = None, + **kwargs): + """Build and configure the unified gateway server application. + + Serves Tinker endpoints at /tinker/* and Twinkle endpoints at /twinkle/*. + + Args: + deploy_options: Ray Serve deployment configuration + supported_models: List of supported base models for tinker validation + server_config: Server configuration options + http_options: HTTP server options (host, port) for internal proxy routing + **kwargs: Additional keyword arguments (route_prefix, etc.) + + Returns: + Configured Ray Serve deployment bound with options + """ + app = FastAPI() + + @app.middleware('http') + async def verify_token(request: Request, call_next): + return await verify_request_token(request=request, call_next=call_next) + + @app.middleware('http') + async def inject_dependencies(request: Request, call_next): + """Middleware to inject GatewayServer dependencies into request.state. + + This must run after GatewayServer is instantiated. We use a marker + set by the first request to initialize the state reference. + """ + # The GatewayServer instance will set itself on the app state + server = getattr(app.state, 'gateway_server', None) + if server: + server._setup_request_state(request) + return await call_next(request) + + @serve.deployment(name='GatewayServer') + @serve.ingress(app) + class GatewayServer: + """Unified gateway server handling both Tinker and Twinkle API clients.""" + + def __init__(self, + supported_models: list | None = None, + server_config: dict[str, Any] = {}, + http_options: dict[str, Any] | None = None, + **kwargs) -> None: + self.state = get_server_state(**server_config) + self.route_prefix = kwargs.get('route_prefix', '/api/v1') + self.http_options = http_options or {} + self.proxy = ServiceProxy(http_options=http_options, route_prefix=self.route_prefix) + self.supported_models = self._normalize_models(supported_models) or [ + tinker_types.SupportedModel(model_name='Qwen/Qwen3-30B-A3B-Instruct-2507'), + ] + self._modelscope_config_lock = asyncio.Lock() + # Register self on app state so middleware can access dependencies + app.state.gateway_server = self + + def _normalize_models(self, supported_models): + if not supported_models: + return [] + normalized = [] + for item in supported_models: + if isinstance(item, tinker_types.SupportedModel): + normalized.append(item) + elif isinstance(item, dict): + normalized.append(tinker_types.SupportedModel(**item)) + elif isinstance(item, str): + normalized.append(tinker_types.SupportedModel(model_name=item)) + return normalized + + def _validate_base_model(self, base_model: str) -> None: + supported_model_names = [m.model_name for m in self.supported_models] + if base_model not in supported_model_names: + raise HTTPException( + status_code=400, + detail=f"Base model '{base_model}' is not supported. " + f"Supported models: {', '.join(supported_model_names)}") + + def _get_base_model(self, model_id: str) -> str: + metadata = self.state.get_model_metadata(model_id) + if metadata and metadata.get('base_model'): + return metadata['base_model'] + raise HTTPException(status_code=404, detail=f'Model {model_id} not found') + + def _setup_request_state(self, request: Request): + """Inject dependencies into request.state for router handlers.""" + request.state.server_state = self.state + request.state.proxy = self.proxy + request.state.supported_models = self.supported_models + request.state.modelscope_config_lock = self._modelscope_config_lock + request.state.validate_base_model = self._validate_base_model + request.state.get_base_model = self._get_base_model + + # Include routers for Tinker and Twinkle endpoints + app.include_router(tinker_router, prefix='/tinker') + app.include_router(twinkle_router, prefix='/twinkle') + + return GatewayServer.options(**deploy_options).bind( + supported_models=supported_models, server_config=server_config, http_options=http_options, **kwargs) diff --git a/src/twinkle/server/gateway/tinker_router.py b/src/twinkle/server/gateway/tinker_router.py new file mode 100644 index 00000000..f4290587 --- /dev/null +++ b/src/twinkle/server/gateway/tinker_router.py @@ -0,0 +1,289 @@ +# Copyright (c) ModelScope Contributors. All rights reserved. +""" +Tinker-compatible gateway router. + +Provides all tinker management and proxy endpoints under /tinker/* prefix. +Extracted from tinker/server.py — same endpoint logic, now on an APIRouter. +""" +from __future__ import annotations + +import asyncio +import os +from fastapi import APIRouter, HTTPException, Request, Response +from tinker import types +from typing import Any + +from twinkle.hub import HubOperation +from twinkle.server.common.io_utils import create_checkpoint_manager, create_training_run_manager +from twinkle.server.utils.task_queue import QueueState +from twinkle.server.utils.validation import get_token_from_request +from twinkle.utils.logger import get_logger + +logger = get_logger() + +tinker_router = APIRouter() + + +@tinker_router.get('/healthz') +async def healthz(request: Request) -> types.HealthResponse: + return types.HealthResponse(status='ok') + + +@tinker_router.get('/get_server_capabilities') +async def get_server_capabilities(request: Request) -> types.GetServerCapabilitiesResponse: + # GatewayServer injects self.supported_models via request.state in middleware + supported_models = getattr(request.state, 'supported_models', []) + return types.GetServerCapabilitiesResponse(supported_models=supported_models) + + +@tinker_router.post('/telemetry') +async def telemetry(request: Request, body: types.TelemetrySendRequest) -> types.TelemetryResponse: + return types.TelemetryResponse(status='accepted') + + +@tinker_router.post('/create_session') +async def create_session(request: Request, body: types.CreateSessionRequest) -> types.CreateSessionResponse: + state = request.state.server_state + session_id = state.create_session(body.model_dump()) + return types.CreateSessionResponse(session_id=session_id) + + +@tinker_router.post('/session_heartbeat') +async def session_heartbeat(request: Request, body: types.SessionHeartbeatRequest) -> types.SessionHeartbeatResponse: + state = request.state.server_state + alive = state.touch_session(body.session_id) + if not alive: + raise HTTPException(status_code=404, detail='Unknown session') + return types.SessionHeartbeatResponse() + + +@tinker_router.post('/create_sampling_session') +async def create_sampling_session(request: Request, + body: types.CreateSamplingSessionRequest) -> types.CreateSamplingSessionResponse: + state = request.state.server_state + sampling_session_id = state.create_sampling_session(body.model_dump()) + return types.CreateSamplingSessionResponse(sampling_session_id=sampling_session_id) + + +@tinker_router.post('/retrieve_future') +async def retrieve_future(request: Request, body: types.FutureRetrieveRequest) -> Any: + """Retrieve the result of an async task with long polling.""" + state = request.state.server_state + request_id = body.request_id + max_wait = float(os.environ.get('TWINKLE_LONG_POLL_TIMEOUT', '30')) + poll_interval = float(os.environ.get('TWINKLE_POLL_INTERVAL', '0.5')) + start = asyncio.get_event_loop().time() + + while True: + record = state.get_future(request_id) + + if record is None: + return {'type': 'try_again'} + + status = record.get('status') + if status not in ('pending', 'queued', 'running', 'rate_limited'): + break + + if asyncio.get_event_loop().time() - start >= max_wait: + response_data = {'type': 'try_again'} + if queue_state := record.get('queue_state'): + response_data['queue_state'] = queue_state + if queue_state_reason := record.get('queue_state_reason'): + response_data['queue_state_reason'] = queue_state_reason + return response_data + + await asyncio.sleep(poll_interval) + + record = state.get_future(request_id) + if not record: + return {'type': 'try_again'} + + status = record.get('status') + + if status == 'rate_limited': + return { + 'type': 'try_again', + 'queue_state': QueueState.PAUSED_RATE_LIMIT.value, + 'queue_state_reason': record.get('reason', 'Rate limit exceeded') + } + + if status == 'failed': + result = record.get('result', {}) + return {'error': result.get('error', 'Unknown error'), 'category': result.get('category', 'Server')} + + result = record.get('result') + if result is None: + raise HTTPException(status_code=500, detail='Task completed but no result found') + + if hasattr(result, 'model_dump'): + return result.model_dump() + return result + + +# --- Training Runs Endpoints --- + + +@tinker_router.get('/training_runs') +async def get_training_runs(request: Request, limit: int = 20, offset: int = 0) -> types.TrainingRunsResponse: + token = get_token_from_request(request) + training_run_manager = create_training_run_manager(token, client_type='tinker') + return training_run_manager.list_runs(limit=limit, offset=offset) + + +@tinker_router.get('/training_runs/{run_id}') +async def get_training_run(request: Request, run_id: str) -> types.TrainingRun: + token = get_token_from_request(request) + training_run_manager = create_training_run_manager(token, client_type='tinker') + run = training_run_manager.get(run_id) + if not run: + raise HTTPException(status_code=404, detail=f'Training run {run_id} not found') + return run + + +@tinker_router.get('/training_runs/{run_id}/checkpoints') +async def get_run_checkpoints(request: Request, run_id: str) -> types.CheckpointsListResponse: + token = get_token_from_request(request) + checkpoint_manager = create_checkpoint_manager(token, client_type='tinker') + response = checkpoint_manager.list_checkpoints(run_id) + if not response: + raise HTTPException(status_code=404, detail=f'Training run {run_id} not found') + return response + + +@tinker_router.delete('/training_runs/{run_id}/checkpoints/{checkpoint_id:path}') +async def delete_run_checkpoint(request: Request, run_id: str, checkpoint_id: str) -> Any: + token = get_token_from_request(request) + checkpoint_manager = create_checkpoint_manager(token, client_type='tinker') + success = checkpoint_manager.delete(run_id, checkpoint_id) + if not success: + raise HTTPException(status_code=404, detail=f'Checkpoint {checkpoint_id} not found for run {run_id}') + return None + + +@tinker_router.post('/weights_info') +async def weights_info(request: Request, body: dict[str, Any]) -> types.WeightsInfoResponse: + token = get_token_from_request(request) + checkpoint_manager = create_checkpoint_manager(token, client_type='tinker') + tinker_path = body.get('tinker_path') + response = checkpoint_manager.get_weights_info(tinker_path) + if not response: + raise HTTPException(status_code=404, detail=f'Weights at {tinker_path} not found') + return response + + +@tinker_router.post('/training_runs/{run_id}/checkpoints/{checkpoint_id:path}/publish') +async def publish_checkpoint(request: Request, run_id: str, checkpoint_id: str) -> Response: + token = get_token_from_request(request) + modelscope_config_lock = request.state.modelscope_config_lock + + training_run_manager = create_training_run_manager(token, client_type='tinker') + checkpoint_manager = create_checkpoint_manager(token, client_type='tinker') + + run = training_run_manager.get(run_id) + if not run: + raise HTTPException(status_code=404, detail=f'Training run {run_id} not found or access denied') + + checkpoint = checkpoint_manager.get(run_id, checkpoint_id) + if not checkpoint: + raise HTTPException(status_code=404, detail=f'Checkpoint {checkpoint_id} not found') + + checkpoint_dir = str(checkpoint_manager.get_ckpt_dir(run_id, checkpoint_id)) + + async with modelscope_config_lock: + try: + from modelscope.hub.api import HubApi, ModelScopeConfig + hub_api = HubApi(token=token) + hub_api.login() + username = ModelScopeConfig.get_user_info()[0] + except Exception as e: + logger.error(f'Failed to get username from ModelScope: {e}') + raise HTTPException( + status_code=401, detail='Failed to get username from ModelScope. Please ensure your token is valid.') + + checkpoint_name = checkpoint_id.split('/')[-1] + hub_model_id = f'{username}/{run_id}_{checkpoint_name}' + HubOperation.async_push_to_hub(repo_id=hub_model_id, folder_path=checkpoint_dir, token=token, private=True) + + return Response(status_code=204) + + +# --- Model Proxy Endpoints --- + + +@tinker_router.post('/create_model') +async def create_model(request: Request, body: types.CreateModelRequest) -> Any: + proxy = request.state.proxy + validate_base_model = request.state.validate_base_model + validate_base_model(body.base_model) + return await proxy.proxy_to_model(request, 'create_model', body.base_model) + + +@tinker_router.post('/get_info') +async def get_info(request: Request, body: types.GetInfoRequest) -> Any: + proxy = request.state.proxy + get_base_model = request.state.get_base_model + return await proxy.proxy_to_model(request, 'get_info', get_base_model(body.model_id)) + + +@tinker_router.post('/unload_model') +async def unload_model(request: Request, body: types.UnloadModelRequest) -> Any: + proxy = request.state.proxy + get_base_model = request.state.get_base_model + return await proxy.proxy_to_model(request, 'unload_model', get_base_model(body.model_id)) + + +@tinker_router.post('/forward') +async def forward(request: Request, body: types.ForwardRequest) -> Any: + proxy = request.state.proxy + get_base_model = request.state.get_base_model + return await proxy.proxy_to_model(request, 'forward', get_base_model(body.model_id)) + + +@tinker_router.post('/forward_backward') +async def forward_backward(request: Request, body: types.ForwardBackwardRequest) -> Any: + proxy = request.state.proxy + get_base_model = request.state.get_base_model + return await proxy.proxy_to_model(request, 'forward_backward', get_base_model(body.model_id)) + + +@tinker_router.post('/optim_step') +async def optim_step(request: Request, body: types.OptimStepRequest) -> Any: + proxy = request.state.proxy + get_base_model = request.state.get_base_model + return await proxy.proxy_to_model(request, 'optim_step', get_base_model(body.model_id)) + + +@tinker_router.post('/save_weights') +async def save_weights(request: Request, body: types.SaveWeightsRequest) -> Any: + proxy = request.state.proxy + get_base_model = request.state.get_base_model + return await proxy.proxy_to_model(request, 'save_weights', get_base_model(body.model_id)) + + +@tinker_router.post('/load_weights') +async def load_weights(request: Request, body: types.LoadWeightsRequest) -> Any: + proxy = request.state.proxy + get_base_model = request.state.get_base_model + return await proxy.proxy_to_model(request, 'load_weights', get_base_model(body.model_id)) + + +# --- Sampler Proxy Endpoints --- + + +@tinker_router.post('/asample') +async def asample(request: Request, body: types.SampleRequest) -> Any: + proxy = request.state.proxy + state = request.state.server_state + base_model = body.base_model + if not base_model and body.sampling_session_id: + session = state.get_sampling_session(body.sampling_session_id) + if session: + base_model = session.get('base_model') + return await proxy.proxy_to_sampler(request, 'asample', base_model) + + +@tinker_router.post('/save_weights_for_sampler') +async def save_weights_for_sampler(request: Request, body: types.SaveWeightsForSamplerRequest) -> Any: + proxy = request.state.proxy + get_base_model = request.state.get_base_model + return await proxy.proxy_to_model(request, 'save_weights_for_sampler', get_base_model(body.model_id)) diff --git a/src/twinkle/server/gateway/twinkle_router.py b/src/twinkle/server/gateway/twinkle_router.py new file mode 100644 index 00000000..9cd2af75 --- /dev/null +++ b/src/twinkle/server/gateway/twinkle_router.py @@ -0,0 +1,106 @@ +# Copyright (c) ModelScope Contributors. All rights reserved. +""" +Twinkle-native gateway router. + +Provides all twinkle management endpoints under /twinkle/* prefix. +Extracted from twinkle/server.py — same endpoint logic, now on an APIRouter. +""" +from __future__ import annotations + +from fastapi import APIRouter, HTTPException, Request +from pydantic import BaseModel +from typing import Any + +from twinkle.server.common.io_utils import create_checkpoint_manager, create_training_run_manager, validate_user_path +from twinkle.server.utils.validation import get_token_from_request +from twinkle.utils.logger import get_logger +from twinkle_client.types.server import DeleteCheckpointResponse, HealthResponse +from twinkle_client.types.training import (CheckpointsListResponse, TrainingRun, TrainingRunsResponse, + WeightsInfoResponse) + +logger = get_logger() + +twinkle_router = APIRouter() + + +class WeightsInfoRequest(BaseModel): + twinkle_path: str + + +@twinkle_router.get('/healthz', response_model=HealthResponse) +async def healthz(request: Request) -> HealthResponse: + return HealthResponse(status='ok') + + +@twinkle_router.get('/training_runs', response_model=TrainingRunsResponse) +async def get_training_runs(request: Request, limit: int = 20, offset: int = 0) -> TrainingRunsResponse: + token = get_token_from_request(request) + training_run_manager = create_training_run_manager(token, client_type='twinkle') + return training_run_manager.list_runs(limit=limit, offset=offset) + + +@twinkle_router.get('/training_runs/{run_id}', response_model=TrainingRun) +async def get_training_run(request: Request, run_id: str) -> TrainingRun: + token = get_token_from_request(request) + training_run_manager = create_training_run_manager(token, client_type='twinkle') + run = training_run_manager.get_with_permission(run_id) + if not run: + raise HTTPException(status_code=404, detail=f'Training run {run_id} not found or access denied') + return run + + +@twinkle_router.get('/training_runs/{run_id}/checkpoints', response_model=CheckpointsListResponse) +async def get_run_checkpoints(request: Request, run_id: str) -> CheckpointsListResponse: + token = get_token_from_request(request) + checkpoint_manager = create_checkpoint_manager(token, client_type='twinkle') + response = checkpoint_manager.list_checkpoints(run_id) + if response is None: + raise HTTPException(status_code=404, detail=f'Training run {run_id} not found or access denied') + return response + + +@twinkle_router.delete('/training_runs/{run_id}/checkpoints/{checkpoint_id:path}') +async def delete_run_checkpoint(request: Request, run_id: str, checkpoint_id: str) -> DeleteCheckpointResponse: + token = get_token_from_request(request) + + if not validate_user_path(token, checkpoint_id): + raise HTTPException(status_code=400, detail='Invalid checkpoint path: path traversal not allowed') + + checkpoint_manager = create_checkpoint_manager(token, client_type='twinkle') + success = checkpoint_manager.delete(run_id, checkpoint_id) + if not success: + raise HTTPException(status_code=404, detail=f'Checkpoint {checkpoint_id} not found or access denied') + + return DeleteCheckpointResponse(success=True, message=f'Checkpoint {checkpoint_id} deleted successfully') + + +@twinkle_router.post('/weights_info', response_model=WeightsInfoResponse) +async def weights_info(request: Request, body: WeightsInfoRequest) -> WeightsInfoResponse: + token = get_token_from_request(request) + checkpoint_manager = create_checkpoint_manager(token, client_type='twinkle') + response = checkpoint_manager.get_weights_info(body.twinkle_path) + if response is None: + raise HTTPException(status_code=404, detail=f'Weights at {body.twinkle_path} not found or access denied') + return response + + +@twinkle_router.get('/checkpoint_path/{run_id}/{checkpoint_id:path}') +async def get_checkpoint_path(request: Request, run_id: str, checkpoint_id: str) -> dict[str, str]: + token = get_token_from_request(request) + + if not validate_user_path(token, checkpoint_id): + raise HTTPException(status_code=400, detail='Invalid checkpoint path: path traversal not allowed') + + training_run_manager = create_training_run_manager(token, client_type='twinkle') + checkpoint_manager = create_checkpoint_manager(token, client_type='twinkle') + + run = training_run_manager.get(run_id) + if not run: + raise HTTPException(status_code=404, detail=f'Training run {run_id} not found or access denied') + + checkpoint = checkpoint_manager.get(run_id, checkpoint_id) + if not checkpoint: + raise HTTPException(status_code=404, detail=f'Checkpoint {checkpoint_id} not found') + + ckpt_dir = checkpoint_manager.get_ckpt_dir(run_id, checkpoint_id) + return {'path': str(ckpt_dir), 'twinkle_path': checkpoint.twinkle_path} diff --git a/src/twinkle/server/launcher.py b/src/twinkle/server/launcher.py index 843418c2..53b88350 100644 --- a/src/twinkle/server/launcher.py +++ b/src/twinkle/server/launcher.py @@ -2,8 +2,8 @@ """ Unified Server Launcher for Twinkle. -This module provides a unified way to launch both tinker and twinkle servers -with support for YAML config files, Python dict config, and CLI. +This module provides a unified way to launch the server with support for +YAML config files, Python dict config, and CLI. Usage: # From YAML config @@ -12,7 +12,6 @@ # From Python dict launch_server(config={ - "server_type": "tinker", "http_options": {"host": "0.0.0.0", "port": 8000}, "applications": [...] }) @@ -33,26 +32,17 @@ class ServerLauncher: """ - Unified server launcher for tinker and twinkle servers. + Unified server launcher. - This class handles Ray/Serve initialization and application deployment - for both tinker and twinkle server types. + This class handles Ray/Serve initialization and application deployment. Attributes: - server_type: The type of server ('tinker' or 'twinkle') config: The server configuration dictionary ray_namespace: The Ray namespace for the cluster """ - # Mapping of simplified import_path names to actual builder functions - # These will be populated lazily to avoid circular imports - _TINKER_BUILDERS: dict[str, str] = { - 'server': 'build_server_app', - 'model': 'build_model_app', - 'sampler': 'build_sampler_app', - } - - _TWINKLE_BUILDERS: dict[str, str] = { + # Mapping of simplified import_path names to builder function names + _BUILDERS: dict[str, str] = { 'server': 'build_server_app', 'model': 'build_model_app', 'sampler': 'build_sampler_app', @@ -61,7 +51,6 @@ class ServerLauncher: def __init__( self, - server_type: str = 'twinkle', config: dict[str, Any] | None = None, ray_namespace: str | None = None, ): @@ -69,14 +58,9 @@ def __init__( Initialize the server launcher. Args: - server_type: Server type ('tinker' or 'twinkle') config: Configuration dictionary - ray_namespace: Ray namespace (default: 'twinkle_cluster' for tinker, None for twinkle) + ray_namespace: Ray namespace (default: 'twinkle_cluster') """ - if server_type not in ('tinker', 'twinkle'): - raise ValueError(f"server_type must be 'tinker' or 'twinkle', got '{server_type}'") - - self.server_type = server_type self.config = config or {} self.ray_namespace = ray_namespace self._builders: dict[str, Callable] = {} @@ -84,30 +68,21 @@ def __init__( self._serve_started = False def _get_builders(self) -> dict[str, Callable]: - """ - Get the appropriate builder functions for the server type. - - Returns: - Dictionary mapping import_path names to builder functions - """ + """Get the builder functions for all app types.""" if self._builders: return self._builders - if self.server_type == 'tinker': - from twinkle.server.tinker import build_model_app, build_sampler_app, build_server_app - self._builders = { - 'build_server_app': build_server_app, - 'build_model_app': build_model_app, - 'build_sampler_app': build_sampler_app, - } - else: # twinkle - from twinkle.server.twinkle import build_model_app, build_processor_app, build_sampler_app, build_server_app - self._builders = { - 'build_server_app': build_server_app, - 'build_model_app': build_model_app, - 'build_sampler_app': build_sampler_app, - 'build_processor_app': build_processor_app, - } + from twinkle.server.gateway import build_server_app + from twinkle.server.model import build_model_app + from twinkle.server.processor import build_processor_app + from twinkle.server.sampler import build_sampler_app + + self._builders = { + 'build_server_app': build_server_app, + 'build_model_app': build_model_app, + 'build_sampler_app': build_sampler_app, + 'build_processor_app': build_processor_app, + } return self._builders @@ -116,7 +91,7 @@ def _resolve_builder(self, import_path: str) -> Callable: Resolve an import_path to a builder function. Args: - import_path: The import path from config (e.g., 'server', 'main:build_server_app') + import_path: The import path from config (e.g., 'server', 'model') Returns: The builder function @@ -125,11 +100,10 @@ def _resolve_builder(self, import_path: str) -> Callable: ValueError: If the import_path cannot be resolved """ builders = self._get_builders() - builder_map = self._TINKER_BUILDERS if self.server_type == 'tinker' else self._TWINKLE_BUILDERS - # Try to resolve through the mapping - if import_path in builder_map: - builder_name = builder_map[import_path] + # Try to resolve through the simplified name mapping + if import_path in self._BUILDERS: + builder_name = self._BUILDERS[import_path] if builder_name in builders: return builders[builder_name] @@ -137,8 +111,8 @@ def _resolve_builder(self, import_path: str) -> Callable: if import_path in builders: return builders[import_path] - raise ValueError(f"Unknown import_path '{import_path}' for server_type '{self.server_type}'. " - f'Available: {list(builder_map.keys())}') + raise ValueError(f"Unknown import_path '{import_path}'. " + f'Available: {list(self._BUILDERS.keys())}') def _init_ray(self) -> None: """Initialize Ray if not already initialized.""" @@ -147,14 +121,10 @@ def _init_ray(self) -> None: import ray - # Determine namespace namespace = self.ray_namespace or self.config.get('ray_namespace') or 'twinkle_cluster' - init_kwargs = {} - init_kwargs['namespace'] = namespace - if not ray.is_initialized(): - ray.init(**init_kwargs) + ray.init(namespace=namespace) logger.info(f'Ray initialized with namespace={namespace}') self._ray_initialized = True @@ -166,19 +136,16 @@ def _start_serve(self) -> None: from ray import serve - # Shutdown any existing serve instance try: serve.shutdown() - time.sleep(2) # Wait for cleanup + time.sleep(2) except Exception: pass - # Get http_options from config http_options = self.config.get('http_options', {}) if isinstance(http_options, dict): http_options = dict(http_options) else: - # Handle OmegaConf or other config objects http_options = dict(http_options) if http_options else {} serve.start(http_options=http_options) @@ -187,8 +154,7 @@ def _start_serve(self) -> None: self._serve_started = True def _deploy_application(self, app_config: dict[str, Any]) -> None: - """ - Deploy a single application. + """Deploy a single application. Args: app_config: Application configuration dictionary @@ -203,15 +169,12 @@ def _deploy_application(self, app_config: dict[str, Any]) -> None: logger.info(f'Starting {name} at {route_prefix}...') - # Resolve builder function builder = self._resolve_builder(import_path) - # Build deploy_options from deployments config deploy_options = {} if deployments: deploy_config = deployments[0] if isinstance(deploy_config, dict): - # Copy all deployment options from the config, except 'name'. deploy_options = {k: v for k, v in deploy_config.items() if k != 'name'} # Pass http_options to server apps for internal proxy routing @@ -219,16 +182,13 @@ def _deploy_application(self, app_config: dict[str, Any]) -> None: if import_path == 'server' and http_options: args['http_options'] = http_options - # Build and deploy the application app = builder(deploy_options=deploy_options, **{k: v for k, v in args.items()}) serve.run(app, name=name, route_prefix=route_prefix) logger.info(f'Deployed {name} at {route_prefix}') def launch(self) -> None: - """ - Launch the server with all configured applications. - """ + """Launch the server with all configured applications.""" self._init_ray() self._start_serve() @@ -237,15 +197,12 @@ def launch(self) -> None: logger.warning('No applications configured') return - # Deploy each application for app_config in applications: if isinstance(app_config, dict): self._deploy_application(app_config) else: - # Handle OmegaConf or other config objects self._deploy_application(dict(app_config)) - # Print endpoints http_options = self.config.get('http_options', {}) host = http_options.get('host', 'localhost') port = http_options.get('port', 8000) @@ -264,7 +221,6 @@ def launch(self) -> None: def from_yaml( cls, config_path: str | Path, - server_type: str = 'twinkle', ray_namespace: str | None = None, ) -> ServerLauncher: """ @@ -272,7 +228,6 @@ def from_yaml( Args: config_path: Path to the YAML config file - server_type: Server type ('tinker' or 'twinkle'), default is 'twinkle' ray_namespace: Override Ray namespace from config Returns: @@ -287,12 +242,7 @@ def from_yaml( config = OmegaConf.load(config_path) config_dict = OmegaConf.to_container(config, resolve=True) - # Override server_type from config if specified - if 'server_type' in config_dict: - server_type = config_dict['server_type'] - return cls( - server_type=server_type, config=config_dict, ray_namespace=ray_namespace or config_dict.get('ray_namespace'), ) @@ -301,7 +251,6 @@ def from_yaml( def launch_server( config: dict[str, Any] | None = None, config_path: str | Path | None = None, - server_type: str = 'twinkle', ray_namespace: str | None = None, ) -> ServerLauncher: """ @@ -312,7 +261,6 @@ def launch_server( Args: config: Configuration dictionary (takes precedence over config_path) config_path: Path to YAML config file - server_type: Server type ('tinker' or 'twinkle'), default is 'twinkle' ray_namespace: Ray namespace Returns: @@ -322,15 +270,11 @@ def launch_server( ValueError: If neither config nor config_path is provided Examples: - # From YAML config (twinkle mode) + # From YAML config launch_server(config_path="server_config.yaml") - # From YAML config (tinker mode) - launch_server(config_path="server_config.yaml", server_type="tinker") - # From Python dict launch_server(config={ - "server_type": "tinker", "http_options": {"host": "0.0.0.0", "port": 8000}, "applications": [...] }) @@ -338,21 +282,14 @@ def launch_server( if config is None and config_path is None: raise ValueError("Either 'config' or 'config_path' must be provided") - launcher: ServerLauncher - if config is not None: - # From Python dict config - override with config's server_type if specified - final_server_type = config.get('server_type', server_type) launcher = ServerLauncher( - server_type=final_server_type, config=config, ray_namespace=ray_namespace or config.get('ray_namespace'), ) else: - # From YAML config file launcher = ServerLauncher.from_yaml( config_path=config_path, - server_type=server_type, ray_namespace=ray_namespace, ) diff --git a/src/twinkle/server/model/__init__.py b/src/twinkle/server/model/__init__.py new file mode 100644 index 00000000..1a203083 --- /dev/null +++ b/src/twinkle/server/model/__init__.py @@ -0,0 +1,3 @@ +from .app import build_model_app + +__all__ = ['build_model_app'] diff --git a/src/twinkle/server/model/app.py b/src/twinkle/server/model/app.py new file mode 100644 index 00000000..e5926f99 --- /dev/null +++ b/src/twinkle/server/model/app.py @@ -0,0 +1,159 @@ +# Copyright (c) ModelScope Contributors. All rights reserved. +""" +Unified model management application. + +Builds a single Ray Serve deployment (ModelManagement) that simultaneously handles +both Tinker (/tinker/*) and Twinkle (/twinkle/*) model endpoints. +""" +from fastapi import FastAPI, Request +from ray import serve +from ray.serve.config import RequestRouterConfig +from typing import Any, Dict, Optional + +import twinkle +from twinkle import DeviceGroup, DeviceMesh +from twinkle.server.utils.adapter_manager import AdapterManagerMixin +from twinkle.server.utils.state import ServerStateProxy, get_server_state +from twinkle.server.utils.task_queue import TaskQueueConfig, TaskQueueMixin +from twinkle.server.utils.validation import get_token_from_request, verify_request_token +from twinkle.utils.logger import get_logger +from ..common.router import StickyLoraRequestRouter +from ..utils import wrap_builder_with_device_group_env +from .tinker_handlers import TinkerModelHandlers +from .twinkle_handlers import TwinkleModelHandlers + +logger = get_logger() + + +def build_model_app(model_id: str, + nproc_per_node: int, + device_group: Dict[str, Any], + device_mesh: Dict[str, Any], + deploy_options: Dict[str, Any], + use_megatron: bool = False, + adapter_config: Dict[str, Any] = {}, + queue_config: Optional[Dict[str, Any]] = None, + **kwargs): + """Build a unified model management application for distributed training. + + Supports both Tinker (polling-style) and Twinkle (synchronous) clients. + + Args: + model_id: Base model identifier (e.g., "Qwen/Qwen2.5-0.5B-Instruct") + nproc_per_node: Number of processes per node for distributed training + device_group: Device group configuration dict + device_mesh: Device mesh configuration dict for tensor parallelism + deploy_options: Ray Serve deployment options + use_megatron: Whether to use Megatron backend (vs Transformers) + adapter_config: Adapter lifecycle config (timeout, per-token limits) + queue_config: Task queue configuration (rate limiting, etc.) + **kwargs: Additional model initialization arguments + + Returns: + Configured Ray Serve deployment bound with parameters + """ + app = FastAPI() + # Mutable list so inner route functions can capture the model_id + model_id_ref = [model_id] + + @app.middleware('http') + async def verify_token(request: Request, call_next): + return await verify_request_token(request=request, call_next=call_next) + + @serve.deployment( + name='ModelManagement', + request_router_config=RequestRouterConfig(request_router_class=StickyLoraRequestRouter), + ) + @serve.ingress(app) + class ModelManagement(TaskQueueMixin, AdapterManagerMixin): + """Unified model management service. + + Handles: + - Base model and multiple LoRA adapters (multi-user) + - Tinker training operations via /tinker/* endpoints (async/polling) + - Twinkle training operations via /twinkle/* endpoints (synchronous) + - Adapter lifecycle via AdapterManagerMixin + - Per-user rate limiting via TaskQueueMixin + """ + + def __init__(self, + nproc_per_node: int, + device_group: Dict[str, Any], + device_mesh: Dict[str, Any], + use_megatron: bool = False, + queue_config: Optional[Dict[str, Any]] = None, + **kwargs): + self.device_group = DeviceGroup(**device_group) + twinkle.initialize( + mode='ray', nproc_per_node=nproc_per_node, groups=[self.device_group], lazy_collect=False) + if 'mesh_dim_names' in device_mesh: + self.device_mesh = DeviceMesh(**device_mesh) + else: + self.device_mesh = DeviceMesh.from_sizes(**device_mesh) + self.use_megatron = use_megatron + self.replica_id = serve.get_replica_context().replica_id.unique_id + self.max_loras = kwargs.get('max_loras', 5) + + # Choose model backend + if use_megatron: + from ..model.backends.megatron_model import TwinkleCompatMegatronModel + self.model = TwinkleCompatMegatronModel( + model_id=model_id, + device_mesh=self.device_mesh, + remote_group=self.device_group.name, + instance_id=self.replica_id, + **kwargs) + else: + from ..model.backends.transformers_model import TwinkleCompatTransformersModel + self.model = TwinkleCompatTransformersModel( + model_id=model_id, + device_mesh=self.device_mesh, + remote_group=self.device_group.name, + instance_id=self.replica_id, + **kwargs) + + self.base_model = model_id + self.state: ServerStateProxy = get_server_state() + self.state.register_replica(self.replica_id, self.max_loras) + + # Initialize mixins + self._init_task_queue(TaskQueueConfig.from_dict(queue_config)) + self._init_adapter_manager(**adapter_config) + self.start_adapter_countdown() + + @serve.multiplexed(max_num_models_per_replica=kwargs.get('max_loras', 5)) + async def _sticky_entry(self, sticky_key: str): + return sticky_key + + async def _ensure_sticky(self): + sticky_key = serve.get_multiplexed_model_id() + await self._sticky_entry(sticky_key) + + async def _on_request_start(self, request: Request) -> str: + await self._ensure_sticky() + token = get_token_from_request(request) + return token + + def __del__(self): + self.state.unregister_replica(self.replica_id) + + def _cleanup_adapter(self, adapter_name: str) -> None: + if self.get_adapter_info(adapter_name): + self.clear_adapter_state(adapter_name) + self.model.remove_adapter(adapter_name) + self.unregister_adapter(adapter_name) + self.state.unload_model(adapter_name) + + def _on_adapter_expired(self, adapter_name: str) -> None: + self.fail_pending_tasks_for_model(adapter_name, reason='Adapter expired') + self._cleanup_adapter(adapter_name) + + # Register routes from both handler mixins + TinkerModelHandlers._register_tinker_routes(app, model_id_ref) + TwinkleModelHandlers._register_twinkle_routes(app, model_id_ref) + + return ModelManagement.options(**deploy_options).bind(nproc_per_node, device_group, device_mesh, use_megatron, + queue_config, **kwargs) + + +build_model_app = wrap_builder_with_device_group_env(build_model_app) diff --git a/src/twinkle/server/twinkle/common/__init__.py b/src/twinkle/server/model/backends/__init__.py similarity index 100% rename from src/twinkle/server/twinkle/common/__init__.py rename to src/twinkle/server/model/backends/__init__.py diff --git a/src/twinkle/server/tinker/common/megatron_model.py b/src/twinkle/server/model/backends/megatron_model.py similarity index 58% rename from src/twinkle/server/tinker/common/megatron_model.py rename to src/twinkle/server/model/backends/megatron_model.py index ebd4df76..61868247 100644 --- a/src/twinkle/server/tinker/common/megatron_model.py +++ b/src/twinkle/server/model/backends/megatron_model.py @@ -1,19 +1,21 @@ # Copyright (c) ModelScope Contributors. All rights reserved. - +""" +Megatron backend model for the unified model deployment. +Moved from tinker/common/megatron_model.py — imports updated. +""" import torch from tinker import types from typing import TYPE_CHECKING, Any, List, Optional, Tuple from twinkle import remote_class, remote_function +from twinkle.server.common.datum import datum_to_input_feature, extract_rl_feature +from twinkle.server.model.backends.transformers_model import (TwinkleCompatModelBase, clean_metrics, + collect_forward_backward_results) from twinkle.utils import exists, requires -from .compat_base import TwinkleCompatModelBase, clean_metrics, collect_forward_backward_results -from .datum import datum_to_input_feature, extract_rl_feature -from .io_utils import create_checkpoint_manager if TYPE_CHECKING: from twinkle.model.megatron import MultiLoraMegatronModel as _MegatronBase elif exists('megatron_core'): - # Use module-level import to trigger LazyModule's __getattr__ correctly import twinkle.model.megatron as megatron_module _MegatronBase = megatron_module.MultiLoraMegatronModel else: @@ -26,74 +28,37 @@ def __init__(self, *args, **kwargs): @remote_class(execute='all') class TwinkleCompatMegatronModel(_MegatronBase, TwinkleCompatModelBase): - """ - Compatibility wrapper around :class:`MultiLoraMegatronModel` for Twinkle/Tinker. - - This class adapts the core `MultiLoraMegatronModel` API to the data types and - remote-call semantics used by Twinkle: - - * Inputs to :meth:`forward_backward` and :meth:`forward_only` are provided as - ``List[types.Datum]`` and are converted to the underlying model's - ``InputFeature`` format via :func:`datum_to_input_feature`. - * The outputs are a list of dictionaries, one per input example, containing: + """Compatibility wrapper around MultiLoraMegatronModel for Twinkle/Tinker. - - ``"logprobs"``: token-level log-probabilities as ``types.TensorData``. - - ``"elementwise_loss"``: per-token (masked) NLL loss as ``types.TensorData``. - - These are derived from the underlying logits by applying ``log_softmax`` - and slicing to the label sequence length. - * :meth:`forward_backward` returns a tuple of (outputs, loss) where loss is a - Python scalar for the aggregated loss. - * :meth:`step` accepts optimizer hyperparameters as :class:`types.AdamParams`, - and updates the optimizer configuration before calling the base ``step``. - - Note: Megatron uses combined forward_backward instead of separate forward/backward. - This wrapper provides a direct forward_backward interface. + Moved from tinker/common/megatron_model.py — logic unchanged. """ @remote_function(dispatch='slice_dp', collect=collect_forward_backward_results, sync=True) def forward_backward(self, *, inputs: List[types.Datum], adapter_name: str, loss_fn: str, **kwargs): - """Combined forward and backward pass. - - Returns: - Tuple of (outputs, loss) where outputs is a list of dicts with - 'logprobs' and 'elementwise_loss', and loss is a scalar. - """ + """Combined forward and backward pass.""" if loss_fn == 'importance_sampling': - super().set_loss( - 'GRPOLoss', - adapter_name=adapter_name, - epsilon=0.2, # Default GRPO epsilon - beta=0.0) # No KL penalty by default - # Get template for input processing + super().set_loss('GRPOLoss', adapter_name=adapter_name, epsilon=0.2, beta=0.0) template = self.get_template(adapter_name=adapter_name) - # Convert Datum to InputFeature input_features = datum_to_input_feature(inputs, template) - # Extract old_logps and advantages using common utility loss_values = extract_rl_feature(inputs) loss_kwargs = kwargs.copy() loss_kwargs.update(loss_values) - # Megatron forward_backward returns loss directly outputs = super().forward_backward(inputs=input_features, adapter_name=adapter_name, **loss_kwargs) loss = outputs.get('loss', None) logits_list = outputs.get('logits', []) logps = outputs.get('logps', []) - # When PP enabled, only logits from last stage are available if logits_list is None and logps is None: return [None, None] logits = None if logits_list is not None: - # Process logits to match transformers output format if isinstance(logits_list, torch.Tensor): logits = logits_list.detach() else: - # Concatenate logits from multiple microbatches logits = torch.cat([logit.detach() for logit in logits_list], dim=0) logps = logps.detach().cpu() results = self._get_forward_output(inputs, logits, logps) - # Convert loss to scalar if isinstance(loss, torch.Tensor): loss = loss.item() else: @@ -104,14 +69,9 @@ def forward_backward(self, *, inputs: List[types.Datum], adapter_name: str, loss @remote_function(dispatch='slice_dp', collect='flatten') def forward_only(self, *, inputs: List[types.Datum], **kwargs): """Forward pass without gradient computation.""" - # Get template for input processing template = self.get_template(**kwargs) - # Convert Datum to InputFeature input_features = datum_to_input_feature(inputs, template) - outputs = super().forward_only(inputs=input_features, **kwargs) - - # Get logits logits = outputs.get('logits', None) logps = outputs.get('logps', None) @@ -122,23 +82,17 @@ def forward_only(self, *, inputs: List[types.Datum], **kwargs): logits = torch.cat([logit.detach().cpu() for logit in logits], dim=0) results = self._get_forward_output(inputs, logits, logps) else: - # If no logits available (non-last PP stage), return empty results results = [{'logprobs': None, 'elementwise_loss': None} for _ in inputs] return results @remote_function(dispatch='all') def step(self, *, adam_params: types.AdamParams, **kwargs): - """Optimizer step with AdamParams configuration. - - Updates the optimizer configuration and performs the step. - """ + """Optimizer step with AdamParams configuration.""" adapter_name = kwargs.get('adapter_name') optimizer_config = self.optimizer_group.get(adapter_name) if optimizer_config and optimizer_config.optimizer: - # Update optimizer config with adam_params - # Megatron optimizer handles gradient clipping internally opt = optimizer_config.optimizer if hasattr(opt, 'chained_optimizers'): for chained_opt in opt.chained_optimizers: @@ -151,9 +105,7 @@ def step(self, *, adam_params: types.AdamParams, **kwargs): if adam_params.grad_clip_norm > 0: chained_opt.config.clip_grad = adam_params.grad_clip_norm - # Perform optimizer step super().step(**kwargs) - # Zero gradients super().zero_grad(**kwargs) @remote_function(collect='first', lazy_collect=False) @@ -163,27 +115,14 @@ def calculate_metric(self, is_training, **kwargs): @remote_function(dispatch='all', sync=True) def load(self, checkpoint_dir: str, **kwargs): - """ - Load checkpoint with token-based isolation support. - - Args: - checkpoint_dir: The twinkle:// path to the checkpoint or hub model ID - **kwargs: Additional keyword arguments including optional 'token' - """ - # Extract token from kwargs if provided (for user isolation) + """Load checkpoint with token-based isolation support.""" token = kwargs.pop('token', None) if not token: raise ValueError('Token is required for loading checkpoints') - - # Create checkpoint manager with the token - checkpoint_manager = create_checkpoint_manager(token) - - # Use resolve_load_path to handle path resolution + from twinkle.server.common.io_utils import create_checkpoint_manager + checkpoint_manager = create_checkpoint_manager(token, client_type='tinker') resolved = checkpoint_manager.resolve_load_path(checkpoint_dir) - if resolved.is_twinkle_path: - # Load from twinkle checkpoint return super().load(name=resolved.checkpoint_name, output_dir=resolved.checkpoint_dir, **kwargs) else: - # Load from hub return super().load(name=resolved.checkpoint_name, **kwargs) diff --git a/src/twinkle/server/model/backends/transformers_model.py b/src/twinkle/server/model/backends/transformers_model.py new file mode 100644 index 00000000..2802741c --- /dev/null +++ b/src/twinkle/server/model/backends/transformers_model.py @@ -0,0 +1,267 @@ +# Copyright (c) ModelScope Contributors. All rights reserved. +""" +Backend model implementations for the unified model deployment. + +Contains two classes: +- TwinkleCompatTransformersModel: tinker-compat wrapper (Datum-based I/O), + moved from tinker/common/transformers_model.py. +- TwinkleCompatTransformersModelNative: twinkle-native wrapper + (InputFeature/Trajectory-based I/O), moved from twinkle/common/transformers_model.py. +""" +import numpy as np +import torch +from collections.abc import Mapping +from tinker import types +from typing import Any, List, Union + +# --------------------------------------------------------------------------- +# Shared helpers (moved from tinker/common/compat_base.py) +# --------------------------------------------------------------------------- +from twinkle import DeviceMesh, remote_class, remote_function +from twinkle.data_format import InputFeature, Trajectory +from twinkle.model import MultiLoraTransformersModel +from twinkle.server.common.datum import datum_to_input_feature, extract_rl_feature +from twinkle.template import Template + + +def collect_forward_backward_results(results, device_mesh: DeviceMesh): + """Custom collect function for forward_backward that handles list [outputs, loss].""" + if not results: + return results + + pp_last_ranks = None + if device_mesh.pp_world_size > 1: + pp_last_ranks = set(device_mesh.get_pp_last_ranks()) + + tp_last_ranks = None + if device_mesh.tp_world_size > 1: + tp_last_ranks = set(device_mesh.get_tp_last_ranks()) + + mesh_flat = device_mesh.mesh.flatten() + + all_outputs = [] + all_losses = [] + for i, result in enumerate(results): + rank = mesh_flat[i] if i < len(mesh_flat) else -1 + + if pp_last_ranks is not None: + if rank not in pp_last_ranks: + continue + + if tp_last_ranks is not None: + if rank not in tp_last_ranks: + continue + + if result is None: + continue + + outputs, loss = result + if outputs is None or loss is None: + continue + all_outputs.extend(outputs) + all_losses.append(loss) + + if all_losses: + avg_loss = float(np.mean(all_losses)) + else: + avg_loss = 0.0 + + return [all_outputs, avg_loss] + + +def clean_metrics(metrics: dict) -> dict: + import re + from numbers import Number + + def _to_float(v): + if isinstance(v, (float, int, Number, np.generic, str)): + try: + return float(v) + except Exception: + return None + if isinstance(v, torch.Tensor) and v.numel() == 1: + try: + return float(v.item()) + except Exception: + return None + return None + + cleaned = {} + for key, value in metrics.items(): + fv = _to_float(value) + if fv is not None: + cleaned[key] = fv + continue + + if isinstance(value, str): + s = value.strip() + if s: + try: + head, unit = s.split() + cleaned[f'{key}/{unit}'] = float(head) + except Exception: + m = re.match(r'^([+-]?(?:\d+(?:\.\d*)?|\.\d+)(?:[eE][+-]?\d+)?)', s) + if m: + cleaned[key] = float(m.group(1)) + + return cleaned + + +class TwinkleCompatModelBase: + """Base class containing common logic for Twinkle compatibility wrappers.""" + + def get_template(self, adapter_name: str) -> Template: + return self.optimizer_group[adapter_name].template + + @staticmethod + def _get_forward_output(inputs: List[types.Datum], logits: torch.Tensor, logps: torch.Tensor) -> List[dict]: + """Convert raw logits to the expected output format with logprobs and elementwise_loss.""" + from twinkle.utils.torch_utils import selective_log_softmax + device = logits.device if logits is not None else logps.device + results = [] + if logits is None: + logits = [None] * len(inputs) + for idx, (feature, logit) in enumerate(zip(inputs, logits)): + labels = feature.loss_fn_inputs['target_tokens'].to_torch().long().view(-1).to(device) + weights = feature.loss_fn_inputs['weights'].to_torch().view(-1).to(device) + + seq_len = labels.numel() + + if logps is None: + assert logits is not None + feature_logits = logit[:seq_len, :] + token_log_probs = selective_log_softmax(feature_logits, labels) + else: + token_log_probs = logps[idx, :seq_len] + + elementwise_loss = -token_log_probs * weights + + results.append({ + 'logprobs': types.TensorData.from_torch(token_log_probs.cpu()), + 'elementwise_loss': types.TensorData.from_torch(elementwise_loss.cpu()) + }) + return results + + +# --------------------------------------------------------------------------- +# Tinker-compat Transformers model (Datum-based I/O) +# --------------------------------------------------------------------------- + + +@remote_class() +class TwinkleCompatTransformersModel(MultiLoraTransformersModel, TwinkleCompatModelBase): + """Tinker-compatible wrapper around MultiLoraTransformersModel. + + Input/output is in tinker Datum / TensorData format. + Moved from tinker/common/transformers_model.py. + """ + + @remote_function(dispatch='slice_dp', collect='flatten') + def forward_only(self, *, inputs: List[types.Datum], **kwargs): + template = self.get_template(**kwargs) + input_features = datum_to_input_feature(inputs, template) + outputs = super().forward_only(inputs=input_features, **kwargs) + logits = outputs['logits'].detach().cpu() + logps = outputs.get('logps', None) + if logps is not None: + logps = logps.detach().cpu() + results = self._get_forward_output(inputs, logits, logps) + return results + + @remote_function(dispatch='slice_dp', collect=collect_forward_backward_results) + def forward_backward(self, *, inputs: List[types.Datum], adapter_name: str, loss_fn: str, **kwargs): + if loss_fn == 'cross_entropy': + super().set_loss('CrossEntropyLoss', adapter_name=adapter_name) + elif loss_fn == 'importance_sampling': + super().set_loss('GRPOLoss', adapter_name=adapter_name, epsilon=0.2, beta=0.0) + else: + super().set_loss('CrossEntropyLoss', adapter_name=adapter_name) + template = self.get_template(adapter_name) + input_features = datum_to_input_feature(inputs, template) + outputs = super().forward(inputs=input_features, adapter_name=adapter_name, **kwargs) + loss_values = extract_rl_feature(inputs) + loss_kwargs = kwargs.copy() + loss_kwargs.update(loss_values) + loss = super().calculate_loss(adapter_name=adapter_name, **loss_kwargs) + super().backward(adapter_name=adapter_name, **kwargs) + logits = outputs['logits'].detach() + logps = outputs.get('logps', None) + if logps is not None: + logps = logps.detach().cpu() + results = self._get_forward_output(inputs, logits, logps) + return [results, loss] + + @remote_function() + def step(self, *, adam_params: types.AdamParams, **kwargs): + grad_clip_norm = adam_params.grad_clip_norm + if grad_clip_norm > 0.0: + self.clip_grad_norm(max_grad_norm=grad_clip_norm, norm_type=2, **kwargs) + optim_params = { + 'lr': adam_params.learning_rate, + 'eps': adam_params.eps, + 'betas': (adam_params.beta1, adam_params.beta2), + 'weight_decay': adam_params.weight_decay, + } + super().step(optim_params=optim_params, **kwargs) + super().zero_grad(**kwargs) + + @remote_function(collect='first', lazy_collect=False) + def calculate_metric(self, is_training, **kwargs): + metric = super().calculate_metric(is_training, **kwargs) + return clean_metrics(metric) + + @remote_function() + def load(self, checkpoint_dir: str, **kwargs): + """Load checkpoint with token-based isolation support.""" + token = kwargs.pop('token', None) + if not token: + raise ValueError('Token is required for loading checkpoints') + from twinkle.server.common.io_utils import create_checkpoint_manager + checkpoint_manager = create_checkpoint_manager(token, client_type='tinker') + resolved = checkpoint_manager.resolve_load_path(checkpoint_dir) + if resolved.is_twinkle_path: + return super().load(name=resolved.checkpoint_name, output_dir=resolved.checkpoint_dir, **kwargs) + else: + return super().load(name=resolved.checkpoint_name, **kwargs) + + +# --------------------------------------------------------------------------- +# Twinkle-native Transformers model (InputFeature/Trajectory-based I/O) +# --------------------------------------------------------------------------- + + +@remote_class() +class TwinkleCompatTransformersModelNative(MultiLoraTransformersModel): + """Twinkle-native wrapper around MultiLoraTransformersModel. + + Input/output is in native InputFeature / Trajectory format. + Moved from twinkle/common/transformers_model.py. + """ + + @staticmethod + def _to_cpu_safe_output(obj: Any) -> Any: + """Convert nested outputs into CPU-safe Python objects for HTTP transport.""" + from twinkle.utils import torch_util + + if isinstance(obj, torch.Tensor): + tensor = torch_util.to_local_tensor(obj).detach().cpu() + if tensor.numel() == 1: + return tensor.item() + return tensor.tolist() + if isinstance(obj, np.ndarray): + if obj.size == 1: + return obj.item() + return obj.tolist() + if isinstance(obj, np.generic): + return obj.item() + if isinstance(obj, Mapping): + return {key: TwinkleCompatTransformersModelNative._to_cpu_safe_output(value) for key, value in obj.items()} + if isinstance(obj, (list, tuple)): + return [TwinkleCompatTransformersModelNative._to_cpu_safe_output(value) for value in obj] + return obj + + @remote_function(dispatch='slice_dp', collect='mean') + def forward_backward(self, *, inputs: Union[InputFeature, List[InputFeature], Trajectory, List[Trajectory]], + **kwargs): + output = super().forward_backward(inputs=inputs, **kwargs) + return self._to_cpu_safe_output(output) diff --git a/src/twinkle/server/model/tinker_handlers.py b/src/twinkle/server/model/tinker_handlers.py new file mode 100644 index 00000000..66280676 --- /dev/null +++ b/src/twinkle/server/model/tinker_handlers.py @@ -0,0 +1,298 @@ +# Copyright (c) ModelScope Contributors. All rights reserved. +""" +Tinker-compatible model handler mixin. + +All endpoints are prefixed /tinker/... and use schedule_task() returning UntypedAPIFuture. +""" +import traceback +from fastapi import FastAPI, Request +from peft import LoraConfig +from tinker import types +from typing import Any + +from twinkle.server.common.io_utils import create_checkpoint_manager, create_training_run_manager +from twinkle.server.utils.validation import get_token_from_request +from twinkle.utils.logger import get_logger + +logger = get_logger() + + +class TinkerModelHandlers: + """ + Mixin providing Tinker-compatible model management endpoints. + + Expects the combined class to also inherit TaskQueueMixin and AdapterManagerMixin, + and to have: + self.model, self.state, self.device_mesh, self.base_model, self.replica_id + """ + + @staticmethod + def _register_tinker_routes(app: FastAPI, model_id_ref: list): + """Register all tinker routes on the given FastAPI app. + + This is called once during build_model_app to wire routes. + model_id_ref is a mutable list so we can capture the closure variable. + """ + + @app.post('/tinker/create_model') + async def create_model(self, request: Request, body: types.CreateModelRequest) -> types.UntypedAPIFuture: + token = await self._on_request_start(request) + model_id = model_id_ref[0] + + async def _create_adapter(): + _model_id = None + try: + _model_id = self.state.register_model(body.model_dump(), token=token, replica_id=self.replica_id) + if body.lora_config: + lora_cfg = LoraConfig(r=body.lora_config.rank, target_modules='all-linear') + adapter_name = self.get_adapter_name(adapter_name=_model_id) + self.register_adapter(adapter_name, token, session_id=body.session_id) + self.model.add_adapter_to_model(adapter_name=adapter_name, config_or_dir=lora_cfg) + self.model.set_template('Template', adapter_name=adapter_name, model_id=model_id) + self.model.set_processor('InputProcessor', adapter_name=adapter_name) + self.model.set_optimizer('Adam', adapter_name=adapter_name) + self.set_adapter_state(adapter_name, 'grad_ready', False) + training_run_manager = create_training_run_manager(token, client_type='tinker') + training_run_manager.save(_model_id, body) + return types.CreateModelResponse(model_id=_model_id) + except Exception: + if _model_id: + adapter_name = self.get_adapter_name(adapter_name=_model_id) + self._cleanup_adapter(adapter_name) + logger.error(traceback.format_exc()) + return types.RequestFailedResponse( + error=traceback.format_exc(), + category=types.RequestErrorCategory.Server, + ) + + return await self.schedule_task(_create_adapter, token=token, task_type='create_model') + + @app.post('/tinker/get_info') + async def get_info(self, request: Request, body: types.GetInfoRequest) -> types.GetInfoResponse: + token = await self._on_request_start(request) + model_id = model_id_ref[0] + training_run_manager = create_training_run_manager(token, client_type='tinker') + metadata = training_run_manager.get(str(body.model_id)) + model_name = metadata.base_model if metadata else model_id + lora_rank = None + is_lora = False + if metadata and hasattr(metadata, 'lora_rank') and metadata.lora_rank: + lora_rank = metadata.lora_rank + is_lora = metadata.is_lora + return types.GetInfoResponse( + model_data=types.ModelData(model_name=model_name), + model_id=body.model_id, + is_lora=is_lora, + lora_rank=lora_rank, + model_name=model_name, + ) + + @app.post('/tinker/unload_model') + async def unload_model(self, request: Request, body: types.UnloadModelRequest) -> types.UntypedAPIFuture: + token = await self._on_request_start(request) + + async def _do_unload(): + adapter_name = self.get_adapter_name(adapter_name=body.model_id) + self._cleanup_adapter(adapter_name) + return types.UnloadModelResponse(model_id=body.model_id) + + return await self.schedule_task(_do_unload, model_id=body.model_id, token=token, task_type='unload_model') + + @app.post('/tinker/forward') + async def forward(self, request: Request, body: types.ForwardRequest) -> types.UntypedAPIFuture: + token = await self._on_request_start(request) + + async def _do_forward(): + try: + adapter_name = self.get_adapter_name(adapter_name=body.model_id) + self.assert_adapter_exists(adapter_name=adapter_name) + self.touch_adapter(adapter_name) + datum_list = body.forward_input.data + loss_fn_config = body.forward_input.loss_fn_config or {} + output = self.model.forward_only(inputs=datum_list, adapter_name=adapter_name) + loss = self.model.calculate_loss(adapter_name=adapter_name, **loss_fn_config) + return types.ForwardBackwardOutput( + loss_fn_output_type='CrossEntropyLossReturn', + loss_fn_outputs=output, + metrics={'loss:sum': loss}, + ) + except Exception: + logger.error(traceback.format_exc()) + return types.RequestFailedResponse( + error=traceback.format_exc(), + category=types.RequestErrorCategory.Server, + ) + + datum_list = body.forward_input.data + input_tokens = sum(len(d.model_input.to_ints()) for d in datum_list) + batch_size = len(datum_list) + return await self.schedule_task( + _do_forward, + model_id=body.model_id, + token=token, + input_tokens=input_tokens, + batch_size=batch_size, + data_world_size=self.device_mesh.data_world_size, + task_type='forward', + ) + + @app.post('/tinker/forward_backward') + async def forward_backward(self, request: Request, + body: types.ForwardBackwardRequest) -> types.UntypedAPIFuture: + token = await self._on_request_start(request) + + async def _do_forward_backward(): + try: + adapter_name = self.get_adapter_name(adapter_name=body.model_id) + self.assert_adapter_exists(adapter_name=adapter_name) + self.touch_adapter(adapter_name) + datum_list = body.forward_backward_input.data + loss_fn = body.forward_backward_input.loss_fn + loss_fn_config = body.forward_backward_input.loss_fn_config or {} + output, loss = self.model.forward_backward( + inputs=datum_list, adapter_name=adapter_name, loss_fn=loss_fn, **loss_fn_config) + if loss_fn == 'importance_sampling': + output_type = 'ImportanceSamplingLossReturn' + else: + output_type = 'CrossEntropyLossReturn' + self.set_adapter_state(adapter_name, 'grad_ready', True) + return types.ForwardBackwardOutput( + loss_fn_output_type=output_type, + loss_fn_outputs=output, + metrics={'loss:avg': loss}, + ) + except Exception: + logger.error(traceback.format_exc()) + return types.RequestFailedResponse( + error=traceback.format_exc(), + category=types.RequestErrorCategory.Server, + ) + + datum_list = body.forward_backward_input.data + input_tokens = sum(len(d.model_input.to_ints()) for d in datum_list) + batch_size = len(datum_list) + return await self.schedule_task( + _do_forward_backward, + model_id=body.model_id, + token=token, + input_tokens=input_tokens, + batch_size=batch_size, + data_world_size=self.device_mesh.data_world_size, + task_type='forward_backward', + ) + + @app.post('/tinker/optim_step') + async def optim_step(self, request: Request, body: types.OptimStepRequest) -> types.UntypedAPIFuture: + token = await self._on_request_start(request) + + async def _do_optim(): + try: + adapter_name = self.get_adapter_name(adapter_name=body.model_id) + self.assert_adapter_exists(adapter_name=adapter_name) + if not self.get_adapter_state(adapter_name, 'grad_ready', False): + raise RuntimeError( + f'No accumulated gradients for adapter={adapter_name}; call forward_backward before optim_step' # noqa: E501 + ) + self.touch_adapter(adapter_name) + self.model.step(adam_params=body.adam_params, adapter_name=adapter_name) + self.set_adapter_state(adapter_name, 'grad_ready', False) + metrics = self.model.calculate_metric(is_training=True, adapter_name=adapter_name) + return types.OptimStepResponse(metrics=metrics) + except Exception: + logger.error(traceback.format_exc()) + return types.RequestFailedResponse( + error=traceback.format_exc(), + category=types.RequestErrorCategory.Server, + ) + + return await self.schedule_task(_do_optim, model_id=body.model_id, token=token, task_type='optim_step') + + @app.post('/tinker/save_weights') + async def save_weights(self, request: Request, body: types.SaveWeightsRequest) -> types.UntypedAPIFuture: + token = await self._on_request_start(request) + + async def _do_save(): + try: + adapter_name = self.get_adapter_name(adapter_name=body.model_id) + self.assert_adapter_exists(adapter_name=adapter_name) + self.touch_adapter(adapter_name) + checkpoint_manager = create_checkpoint_manager(token, client_type='tinker') + checkpoint_name = checkpoint_manager.get_ckpt_name(body.path) + save_dir = checkpoint_manager.get_save_dir(model_id=body.model_id, is_sampler=False) + self.model.save( + name=checkpoint_name, output_dir=save_dir, adapter_name=adapter_name, save_optimizer=True) + tinker_path = checkpoint_manager.save(body.model_id, name=checkpoint_name, is_sampler=False) + return types.SaveWeightsResponse(path=tinker_path, type='save_weights') + except Exception: + logger.error(traceback.format_exc()) + return types.RequestFailedResponse( + error=traceback.format_exc(), + category=types.RequestErrorCategory.Server, + ) + + return await self.schedule_task(_do_save, model_id=body.model_id, token=token, task_type='save_weights') + + @app.post('/tinker/save_weights_for_sampler') + async def save_weights_for_sampler(self, request: Request, + body: types.SaveWeightsForSamplerRequest) -> types.UntypedAPIFuture: + token = await self._on_request_start(request) + + async def _do_save_for_sampler(): + try: + adapter_name = self.get_adapter_name(adapter_name=body.model_id) + self.assert_adapter_exists(adapter_name=adapter_name) + self.touch_adapter(adapter_name) + checkpoint_manager = create_checkpoint_manager(token, client_type='tinker') + checkpoint_name = checkpoint_manager.get_ckpt_name(body.path) + save_dir = checkpoint_manager.get_save_dir(model_id=body.model_id, is_sampler=True) + tinker_path = checkpoint_manager.save(body.model_id, name=checkpoint_name, is_sampler=True) + logger.info(f'Saving weights to {save_dir}') + self.model.save( + name=checkpoint_name, output_dir=save_dir, adapter_name=adapter_name, save_optimizer=False) + payload = body.model_dump() + payload['model_path'] = tinker_path + metadata = self.state.get_model_metadata(body.model_id) or {} + if metadata.get('base_model'): + payload['base_model'] = metadata['base_model'] + sampling_session_id = self.state.create_sampling_session(payload) + return types.SaveWeightsForSamplerResponseInternal( + path=None, sampling_session_id=sampling_session_id) + except Exception: + logger.error(traceback.format_exc()) + return types.RequestFailedResponse( + error=traceback.format_exc(), + category=types.RequestErrorCategory.Server, + ) + + return await self.schedule_task( + _do_save_for_sampler, model_id=body.model_id, token=token, task_type='save_weights_for_sampler') + + @app.post('/tinker/load_weights') + async def load_weights(self, request: Request, body: types.LoadWeightsRequest) -> types.UntypedAPIFuture: + token = await self._on_request_start(request) + + async def _do_load(): + try: + assert self.model is not None, 'Model not loaded, please load model first' + adapter_name = self.get_adapter_name(adapter_name=body.model_id) + self.assert_adapter_exists(adapter_name=adapter_name) + self.touch_adapter(adapter_name) + self.model.load( + checkpoint_dir=body.path, load_optimizer=body.optimizer, adapter_name=adapter_name, token=token) + self.set_adapter_state(adapter_name, 'grad_ready', False) + return types.LoadWeightsResponse(path=body.path, type='load_weights') + except Exception: + logger.error(traceback.format_exc()) + return types.RequestFailedResponse( + error=traceback.format_exc(), + category=types.RequestErrorCategory.Server, + ) + + return await self.schedule_task(_do_load, model_id=body.model_id, token=token, task_type='load_weights') + + # Tinker uses {request_id}-{adapter_name} prefix via self.get_adapter_name() + # which is inherited from AdapterManagerMixin (no-op here; method kept for clarity). + @staticmethod + def get_adapter_name(adapter_name: Any) -> Any: + """Returns adapter_name as-is; overridden by AdapterManagerMixin in the combined class.""" + return adapter_name diff --git a/src/twinkle/server/model/twinkle_handlers.py b/src/twinkle/server/model/twinkle_handlers.py new file mode 100644 index 00000000..e6b2243b --- /dev/null +++ b/src/twinkle/server/model/twinkle_handlers.py @@ -0,0 +1,377 @@ +# Copyright (c) ModelScope Contributors. All rights reserved. +""" +Twinkle-native model handler mixin. + +All endpoints are prefixed /twinkle/... and use schedule_task_and_wait() returning +results directly (synchronous from the client's perspective). +""" +import traceback +from fastapi import FastAPI, Request +from peft import LoraConfig +from typing import Any, Optional + +from twinkle.data_format import InputFeature, Trajectory +from twinkle.server.common.io_utils import create_checkpoint_manager, create_training_run_manager +from twinkle.server.common.serialize import deserialize_object +from twinkle.utils.logger import get_logger +from twinkle_client.types.model import (AdapterRequest, AddAdapterRequest, CalculateMetricRequest, CreateRequest, + ForwardOnlyRequest, ForwardRequest, GetStateDictRequest, HeartbeatRequest, + LoadRequest, SaveRequest, SetLossRequest, SetLrSchedulerRequest, + SetOptimizerRequest, SetProcessorRequest, SetTemplateRequest, + UploadToHubRequest) + +logger = get_logger() + + +def _parse_inputs(inputs: Any): + """Convert raw dict/list inputs to InputFeature or Trajectory objects.""" + if isinstance(inputs, list) and inputs: + first = inputs[0] + if isinstance(first, dict) and 'input_ids' in first: + return [InputFeature(**item) for item in inputs] + else: + return [Trajectory(**item) for item in inputs] + elif isinstance(inputs, dict): + if 'input_ids' in inputs: + return [InputFeature(**inputs)] + else: + return [Trajectory(**inputs)] + return inputs + + +class TwinkleModelHandlers: + """ + Mixin providing Twinkle-native model management endpoints. + + Expects the combined class to also inherit TaskQueueMixin and AdapterManagerMixin, + and to have: self.model, self.state, self.base_model + The get_adapter_name static method uses request.state.request_id prefix. + """ + + @staticmethod + def _register_twinkle_routes(app: FastAPI, model_id_ref: list): + """Register all twinkle routes on the given FastAPI app. + + model_id_ref is a mutable list containing [model_id] for closure capture. + """ + + @app.post('/twinkle/create') + async def create(self, request: Request, body: CreateRequest): + return {'status': 'ok'} + + @staticmethod + def _get_twinkle_adapter_name(request: Request, adapter_name: Optional[str]) -> Optional[str]: + if adapter_name is None or adapter_name == '': + return None + return request.state.request_id + '-' + adapter_name + + @app.post('/twinkle/forward') + async def forward(self, request: Request, body: ForwardRequest): + adapter_name = self._get_twinkle_adapter_name(request, body.adapter_name) + + async def _task(): + self.assert_adapter_exists(adapter_name=adapter_name) + extra_kwargs = body.model_extra or {} + inputs = _parse_inputs(body.inputs) + ret = self.model.forward(inputs=inputs, adapter_name=adapter_name, **extra_kwargs) + return {'result': ret} + + return await self.schedule_task_and_wait(_task, task_type='forward') + + @app.post('/twinkle/forward_only') + async def forward_only(self, request: Request, body: ForwardOnlyRequest): + adapter_name = self._get_twinkle_adapter_name(request, body.adapter_name) + + async def _task(): + self.assert_adapter_exists(adapter_name=adapter_name) + extra_kwargs = body.model_extra or {} + inputs = _parse_inputs(body.inputs) + ret = self.model.forward_only(inputs=inputs, adapter_name=adapter_name, **extra_kwargs) + return {'result': ret} + + return await self.schedule_task_and_wait(_task, task_type='forward_only') + + @app.post('/twinkle/calculate_loss') + async def calculate_loss(self, request: Request, body: AdapterRequest): + adapter_name = self._get_twinkle_adapter_name(request, body.adapter_name) + + async def _task(): + self.assert_adapter_exists(adapter_name=adapter_name) + extra_kwargs = body.model_extra or {} + ret = self.model.calculate_loss(adapter_name=adapter_name, **extra_kwargs) + return {'result': ret} + + return await self.schedule_task_and_wait(_task, task_type='calculate_loss') + + @app.post('/twinkle/backward') + async def backward(self, request: Request, body: AdapterRequest): + adapter_name = self._get_twinkle_adapter_name(request, body.adapter_name) + + async def _task(): + self.assert_adapter_exists(adapter_name=adapter_name) + extra_kwargs = body.model_extra or {} + ret = self.model.backward(adapter_name=adapter_name, **extra_kwargs) + return {'result': ret} + + return await self.schedule_task_and_wait(_task, task_type='backward') + + @app.post('/twinkle/forward_backward') + async def forward_backward(self, request: Request, body: ForwardRequest): + adapter_name = self._get_twinkle_adapter_name(request, body.adapter_name) + + async def _task(): + self.assert_adapter_exists(adapter_name=adapter_name) + extra_kwargs = body.model_extra or {} + inputs = _parse_inputs(body.inputs) + ret = self.model.forward_backward(inputs=inputs, adapter_name=adapter_name, **extra_kwargs) + return {'result': ret} + + return await self.schedule_task_and_wait(_task, task_type='forward_backward') + + @app.post('/twinkle/clip_grad_norm') + async def clip_grad_norm(self, request: Request, body: AdapterRequest): + adapter_name = self._get_twinkle_adapter_name(request, body.adapter_name) + + async def _task(): + self.assert_adapter_exists(adapter_name=adapter_name) + extra_kwargs = body.model_extra or {} + ret = self.model.clip_grad_norm(adapter_name=adapter_name, **extra_kwargs) + return {'result': str(ret)} + + return await self.schedule_task_and_wait(_task, task_type='clip_grad_norm') + + @app.post('/twinkle/step') + async def step(self, request: Request, body: AdapterRequest): + adapter_name = self._get_twinkle_adapter_name(request, body.adapter_name) + + async def _task(): + self.assert_adapter_exists(adapter_name=adapter_name) + extra_kwargs = body.model_extra or {} + ret = self.model.step(adapter_name=adapter_name, **extra_kwargs) + return {'result': ret} + + return await self.schedule_task_and_wait(_task, task_type='step') + + @app.post('/twinkle/zero_grad') + async def zero_grad(self, request: Request, body: AdapterRequest): + adapter_name = self._get_twinkle_adapter_name(request, body.adapter_name) + + async def _task(): + self.assert_adapter_exists(adapter_name=adapter_name) + extra_kwargs = body.model_extra or {} + ret = self.model.zero_grad(adapter_name=adapter_name, **extra_kwargs) + return {'result': ret} + + return await self.schedule_task_and_wait(_task, task_type='zero_grad') + + @app.post('/twinkle/lr_step') + async def lr_step(self, request: Request, body: AdapterRequest): + adapter_name = self._get_twinkle_adapter_name(request, body.adapter_name) + + async def _task(): + self.assert_adapter_exists(adapter_name=adapter_name) + extra_kwargs = body.model_extra or {} + ret = self.model.lr_step(adapter_name=adapter_name, **extra_kwargs) + return {'result': ret} + + return await self.schedule_task_and_wait(_task, task_type='lr_step') + + @app.post('/twinkle/get_train_configs') + async def get_train_configs(self, request: Request, body: AdapterRequest): + adapter_name = self._get_twinkle_adapter_name(request, body.adapter_name) + + async def _task(): + self.assert_adapter_exists(adapter_name=adapter_name) + extra_kwargs = body.model_extra or {} + ret = self.model.get_train_configs(adapter_name=adapter_name, **extra_kwargs) + return {'result': ret} + + return await self.schedule_task_and_wait(_task, task_type='get_train_configs') + + @app.post('/twinkle/set_loss') + async def set_loss(self, request: Request, body: SetLossRequest): + adapter_name = self._get_twinkle_adapter_name(request, body.adapter_name) + + async def _task(): + self.assert_adapter_exists(adapter_name=adapter_name) + extra_kwargs = body.model_extra or {} + ret = self.model.set_loss(body.loss_cls, adapter_name=adapter_name, **extra_kwargs) + return {'result': ret} + + return await self.schedule_task_and_wait(_task, task_type='set_loss') + + @app.post('/twinkle/set_optimizer') + async def set_optimizer(self, request: Request, body: SetOptimizerRequest): + adapter_name = self._get_twinkle_adapter_name(request, body.adapter_name) + + async def _task(): + self.assert_adapter_exists(adapter_name=adapter_name) + extra_kwargs = body.model_extra or {} + ret = self.model.set_optimizer(body.optimizer_cls, adapter_name=adapter_name, **extra_kwargs) + return {'result': ret} + + return await self.schedule_task_and_wait(_task, task_type='set_optimizer') + + @app.post('/twinkle/set_lr_scheduler') + async def set_lr_scheduler(self, request: Request, body: SetLrSchedulerRequest): + adapter_name = self._get_twinkle_adapter_name(request, body.adapter_name) + + async def _task(): + self.assert_adapter_exists(adapter_name=adapter_name) + extra_kwargs = body.model_extra or {} + ret = self.model.set_lr_scheduler(body.scheduler_cls, adapter_name=adapter_name, **extra_kwargs) + return {'result': ret} + + return await self.schedule_task_and_wait(_task, task_type='set_lr_scheduler') + + @app.post('/twinkle/save') + async def save(self, request: Request, body: SaveRequest): + adapter_name = self._get_twinkle_adapter_name(request, body.adapter_name) + + async def _task(): + self.assert_adapter_exists(adapter_name=adapter_name) + extra_kwargs = body.model_extra or {} + token = request.state.token + checkpoint_manager = create_checkpoint_manager(token, client_type='twinkle') + checkpoint_name = checkpoint_manager.get_ckpt_name(body.name) + save_dir = checkpoint_manager.get_save_dir(model_id=adapter_name, is_sampler=False) + checkpoint_dir = self.model.save( + name=checkpoint_name, + output_dir=save_dir, + adapter_name=adapter_name, + save_optimizer=body.save_optimizer, + **extra_kwargs) + twinkle_path = checkpoint_manager.save(model_id=adapter_name, name=checkpoint_name, is_sampler=False) + return {'result': twinkle_path, 'checkpoint_dir': checkpoint_dir} + + return await self.schedule_task_and_wait(_task, task_type='save') + + @app.post('/twinkle/load') + async def load(self, request: Request, body: LoadRequest): + adapter_name = self._get_twinkle_adapter_name(request, body.adapter_name) + + async def _task(): + self.assert_adapter_exists(adapter_name=adapter_name) + extra_kwargs = body.model_extra or {} + token = request.state.token + checkpoint_manager = create_checkpoint_manager(token, client_type='twinkle') + resolved = checkpoint_manager.resolve_load_path(body.name) + ret = self.model.load( + name=resolved.checkpoint_name, + output_dir=resolved.checkpoint_dir, + adapter_name=adapter_name, + load_optimizer=body.load_optimizer, + token=token, + **extra_kwargs) + return {'result': ret} + + return await self.schedule_task_and_wait(_task, task_type='load') + + @app.post('/twinkle/upload_to_hub') + async def upload_to_hub(self, request: Request, body: UploadToHubRequest): + + async def _task(): + token = request.state.token + if body.checkpoint_dir.startswith('twinkle://'): + checkpoint_manager = create_checkpoint_manager(token, client_type='twinkle') + parsed = checkpoint_manager.parse_twinkle_path(body.checkpoint_dir) + if not parsed: + raise ValueError(f'Invalid twinkle path format: {body.checkpoint_dir}') + checkpoint_id = parsed.checkpoint_id + model_id_to_load = parsed.training_run_id + checkpoint = checkpoint_manager.get(model_id_to_load, checkpoint_id) + if not checkpoint: + raise ValueError(f'Checkpoint not found or access denied: {body.checkpoint_dir}') + checkpoint_dir = str( + checkpoint_manager.get_ckpt_dir(model_id=model_id_to_load, checkpoint_id=checkpoint_id)) + else: + checkpoint_dir = body.checkpoint_dir + self.model.upload_to_hub( + checkpoint_dir=checkpoint_dir, + hub_model_id=body.hub_model_id, + hub_token=body.hub_token or token, + async_upload=body.async_upload) + return {'result': body.hub_model_id} + + return await self.schedule_task_and_wait(_task, task_type='upload_to_hub') + + @app.post('/twinkle/add_adapter_to_model') + async def add_adapter_to_model(self, request: Request, body: AddAdapterRequest): + assert body.adapter_name, 'You need to specify a valid `adapter_name`' + adapter_name = self._get_twinkle_adapter_name(request, body.adapter_name) + model_id = model_id_ref[0] + + async def _task(): + config = deserialize_object(body.config) + extra_kwargs = body.model_extra or {} + token = request.state.token + training_run_manager = create_training_run_manager(token, client_type='twinkle') + self.register_adapter(adapter_name, token) + self.model.add_adapter_to_model(adapter_name, config, **extra_kwargs) + from twinkle.server.common.io_utils import CreateModelRequest + from twinkle.server.common.io_utils import LoraConfig as IoLoraConfig + lora_config = None + if isinstance(config, LoraConfig): + lora_config = IoLoraConfig(rank=config.r, train_unembed=False, train_mlp=True, train_attn=True) + run_config = CreateModelRequest( + base_model=model_id, lora_config=lora_config, user_metadata={'adapter_name': body.adapter_name}) + training_run_manager.save(adapter_name, run_config) + return {'status': 'ok', 'adapter_name': adapter_name} + + return await self.schedule_task_and_wait(_task, task_type='add_adapter_to_model') + + @app.post('/twinkle/set_template') + async def set_template(self, request: Request, body: SetTemplateRequest): + adapter_name = self._get_twinkle_adapter_name(request, body.adapter_name) + + async def _task(): + self.assert_adapter_exists(adapter_name=adapter_name) + extra_kwargs = body.model_extra or {} + ret = self.model.set_template(body.template_cls, adapter_name=adapter_name, **extra_kwargs) + return {'result': ret} + + return await self.schedule_task_and_wait(_task, task_type='set_template') + + @app.post('/twinkle/set_processor') + async def set_processor(self, request: Request, body: SetProcessorRequest): + adapter_name = self._get_twinkle_adapter_name(request, body.adapter_name) + + async def _task(): + self.assert_adapter_exists(adapter_name=adapter_name) + extra_kwargs = body.model_extra or {} + ret = self.model.set_processor(body.processor_cls, adapter_name=adapter_name, **extra_kwargs) + return {'result': ret} + + return await self.schedule_task_and_wait(_task, task_type='set_processor') + + @app.post('/twinkle/heartbeat') + async def heartbeat(self, request: Request, body: HeartbeatRequest): + adapter_name = self._get_twinkle_adapter_name(request, body.adapter_name) + self.assert_adapter_exists(adapter_name=adapter_name) + self.touch_adapter(adapter_name) + return {'status': 'ok'} + + @app.post('/twinkle/calculate_metric') + async def calculate_metric(self, request: Request, body: CalculateMetricRequest): + adapter_name = self._get_twinkle_adapter_name(request, body.adapter_name) + + async def _task(): + self.assert_adapter_exists(adapter_name=adapter_name) + extra_kwargs = body.model_extra or {} + ret = self.model.calculate_metric( + is_training=body.is_training, adapter_name=adapter_name, **extra_kwargs) + return {'result': ret} + + return await self.schedule_task_and_wait(_task, task_type='calculate_metric') + + @app.post('/twinkle/get_state_dict') + async def get_state_dict(self, request: Request, body: GetStateDictRequest): + adapter_name = self._get_twinkle_adapter_name(request, body.adapter_name) + + async def _task(): + self.assert_adapter_exists(adapter_name=adapter_name) + extra_kwargs = body.model_extra or {} + ret = self.model.get_state_dict(adapter_name=adapter_name, **extra_kwargs) + return {'result': ret} + + return await self.schedule_task_and_wait(_task, task_type='get_state_dict') diff --git a/src/twinkle/server/processor/__init__.py b/src/twinkle/server/processor/__init__.py new file mode 100644 index 00000000..4032f5bf --- /dev/null +++ b/src/twinkle/server/processor/__init__.py @@ -0,0 +1,3 @@ +from .app import build_processor_app + +__all__ = ['build_processor_app'] diff --git a/src/twinkle/server/twinkle/processor.py b/src/twinkle/server/processor/app.py similarity index 74% rename from src/twinkle/server/twinkle/processor.py rename to src/twinkle/server/processor/app.py index cbead9b7..68cddc20 100644 --- a/src/twinkle/server/twinkle/processor.py +++ b/src/twinkle/server/processor/app.py @@ -1,44 +1,43 @@ # Copyright (c) ModelScope Contributors. All rights reserved. +""" +Processor management application (moved from twinkle/processor.py). + +Provides a Ray Serve deployment for managing distributed processors +(datasets, dataloaders, preprocessors, rewards, templates, weight loaders, etc.). +""" import importlib import os import threading import uuid from fastapi import FastAPI, HTTPException, Request -from pydantic import BaseModel from ray import serve from typing import Any, Dict import twinkle from twinkle import DeviceGroup, DeviceMesh, get_logger +from twinkle.server.common.serialize import deserialize_object from twinkle.server.utils.state import ServerStateProxy, get_server_state from twinkle.server.utils.validation import verify_request_token -from .common.serialize import deserialize_object +from twinkle_client.types.processor import ProcessorCallRequest, ProcessorCreateRequest, ProcessorHeartbeatRequest logger = get_logger() -class CreateRequest(BaseModel): - processor_type: str - class_type: str - - class Config: - extra = 'allow' - - -class HeartbeatRequest(BaseModel): - processor_id: str - - -class CallRequest(BaseModel): - processor_id: str - function: str - - class Config: - extra = 'allow' - - def build_processor_app(nproc_per_node: int, ncpu_proc_per_node: int, device_group: Dict[str, Any], device_mesh: Dict[str, Any], deploy_options: Dict[str, Any], **kwargs): + """Build the processor management application. + + Args: + nproc_per_node: Number of GPU processes per node + ncpu_proc_per_node: Number of CPU processes per node + device_group: Device group configuration dict + device_mesh: Device mesh configuration dict + deploy_options: Ray Serve deployment options + **kwargs: Additional arguments + + Returns: + Ray Serve deployment bound with configuration + """ app = FastAPI() @app.middleware('http') @@ -50,6 +49,11 @@ async def verify_token(request: Request, call_next): @serve.deployment(name='ProcessorManagement') @serve.ingress(app) class ProcessorManagement: + """Processor management service. + + Manages lifecycle and invocation of distributed processor objects + (datasets, dataloaders, rewards, templates, etc.). + """ COUNT_DOWN = 60 * 30 @@ -105,11 +109,10 @@ def handle_processor_count(self, token: str, add: bool): self.state.pop_config(user_key) @app.post('/create') - def create(self, request: Request, body: CreateRequest): - + def create(self, request: Request, body: ProcessorCreateRequest): processor_type_name = body.processor_type class_type = body.class_type - kwargs = body.model_extra or {} + _kwargs = body.model_extra or {} assert processor_type_name in processors, f'Invalid processor type: {processor_type_name}' processor_module = importlib.import_module(f'twinkle.{processor_type_name}') @@ -118,26 +121,29 @@ def create(self, request: Request, body: CreateRequest): processor_id = str(uuid.uuid4().hex) self.key_token_dict[processor_id] = request.state.token - kwargs.pop('remote_group', None) - kwargs.pop('device_mesh', None) + _kwargs.pop('remote_group', None) + _kwargs.pop('device_mesh', None) - _kwargs = {} - for key, value in kwargs.items(): + resolved_kwargs = {} + for key, value in _kwargs.items(): if isinstance(value, str) and value.startswith('pid:'): ref_id = value[4:] - _kwargs[key] = self.resource_dict[ref_id] + resolved_kwargs[key] = self.resource_dict[ref_id] else: value = deserialize_object(value) - _kwargs[key] = value + resolved_kwargs[key] = value processor = getattr(processor_module, class_type)( - remote_group=self.device_group.name, device_mesh=self.device_mesh, instance_id=processor_id, **_kwargs) + remote_group=self.device_group.name, + device_mesh=self.device_mesh, + instance_id=processor_id, + **resolved_kwargs) self.resource_dict[processor_id] = processor self.resource_records[processor_id] = 0 return {'processor_id': 'pid:' + processor_id} @app.post('/heartbeat') - def heartbeat(self, body: HeartbeatRequest): + def heartbeat(self, body: ProcessorHeartbeatRequest): processor_ids = body.processor_id.split(',') for _id in processor_ids: if _id and _id in self.resource_dict: @@ -145,10 +151,10 @@ def heartbeat(self, body: HeartbeatRequest): return {'status': 'ok'} @app.post('/call') - def call(self, body: CallRequest): + def call(self, body: ProcessorCallRequest): processor_id = body.processor_id function_name = body.function - kwargs = body.model_extra or {} + _kwargs = body.model_extra or {} processor_id = processor_id[4:] self.assert_processor_exists(processor_id=processor_id) processor = self.resource_dict.get(processor_id) @@ -157,28 +163,25 @@ def call(self, body: CallRequest): assert function is not None, f'`{function_name}` not found in {processor.__class__}' assert hasattr(function, '_execute'), f'Cannot call inner method of {processor.__class__}' - _kwargs = {} - for key, value in kwargs.items(): + resolved_kwargs = {} + for key, value in _kwargs.items(): if isinstance(value, str) and value.startswith('pid:'): ref_id = value[4:] - _kwargs[key] = self.resource_dict[ref_id] + resolved_kwargs[key] = self.resource_dict[ref_id] else: value = deserialize_object(value) - _kwargs[key] = value + resolved_kwargs[key] = value # Special handling for __next__ to catch StopIteration - # We convert StopIteration to HTTP 410 (Gone) which semantically means - # "the resource (next item) is no longer available" if function_name == '__next__': try: - result = function(**_kwargs) + result = function(**resolved_kwargs) return {'result': result} except StopIteration: - # Use HTTP 410 Gone to indicate iterator exhausted - # This is a clean signal that won't be confused with errors + # HTTP 410 Gone signals iterator exhausted raise HTTPException(status_code=410, detail='Iterator exhausted') - result = function(**_kwargs) + result = function(**resolved_kwargs) if function_name == '__iter__': return {'result': 'ok'} else: diff --git a/src/twinkle/server/sampler/__init__.py b/src/twinkle/server/sampler/__init__.py new file mode 100644 index 00000000..58db9098 --- /dev/null +++ b/src/twinkle/server/sampler/__init__.py @@ -0,0 +1,3 @@ +from .app import build_sampler_app + +__all__ = ['build_sampler_app'] diff --git a/src/twinkle/server/sampler/app.py b/src/twinkle/server/sampler/app.py new file mode 100644 index 00000000..265ccd45 --- /dev/null +++ b/src/twinkle/server/sampler/app.py @@ -0,0 +1,159 @@ +# Copyright (c) ModelScope Contributors. All rights reserved. +""" +Unified sampler management application. + +Builds a single Ray Serve deployment (SamplerManagement) that simultaneously handles +both Tinker (/tinker/asample) and Twinkle (/twinkle/*) sampler endpoints. +""" +from fastapi import FastAPI, Request +from ray import serve +from typing import Any, Dict, Optional + +import twinkle +from twinkle import DeviceGroup, DeviceMesh +from twinkle.server.utils.adapter_manager import AdapterManagerMixin +from twinkle.server.utils.state import ServerStateProxy, get_server_state +from twinkle.server.utils.task_queue import TaskQueueConfig, TaskQueueMixin +from twinkle.server.utils.validation import get_token_from_request, verify_request_token +from twinkle.utils.logger import get_logger +from ..utils import wrap_builder_with_device_group_env +from .tinker_handlers import TinkerSamplerHandlers +from .twinkle_handlers import TwinkleSamplerHandlers + +logger = get_logger() + + +def build_sampler_app(model_id: str, + nproc_per_node: int, + device_group: Dict[str, Any], + device_mesh: Dict[str, Any], + deploy_options: Dict[str, Any], + sampler_type: str = 'vllm', + engine_args: Optional[Dict[str, Any]] = None, + adapter_config: Optional[Dict[str, Any]] = None, + queue_config: Optional[Dict[str, Any]] = None, + **kwargs): + """Build a unified sampler application for text generation inference. + + Supports both Tinker (polling-style /tinker/asample) and + Twinkle (synchronous /twinkle/*) sampler clients. + + Args: + model_id: Model identifier (e.g., "Qwen/Qwen2.5-0.5B-Instruct") + nproc_per_node: Number of processes per node + device_group: Device group configuration dict + device_mesh: Device mesh configuration dict for parallelism + deploy_options: Ray Serve deployment options + sampler_type: Type of sampler to use ('vllm' or 'torch') + engine_args: Additional engine arguments for the sampler + adapter_config: Adapter lifecycle config (timeout, per-token limits) + queue_config: Task queue configuration dict (rps_limit, tps_limit, etc.) + **kwargs: Additional arguments passed to the sampler + + Returns: + Ray Serve deployment bound with configuration + """ + app = FastAPI( + title='Unified Sampler', + description='REST API for distributed text generation inference (Tinker + Twinkle)', + version='1.0.0') + + @app.middleware('http') + async def verify_token(request: Request, call_next): + return await verify_request_token(request=request, call_next=call_next) + + @serve.deployment(name='SamplerManagement') + @serve.ingress(app) + class SamplerManagement(TaskQueueMixin, AdapterManagerMixin): + """Unified sampler management service. + + Manages: + - vLLM or Torch sampler initialization and lifecycle + - Tinker inference requests (/tinker/asample) with rate limiting via TaskQueueMixin + - Twinkle inference requests (/twinkle/*) calling sampler directly + - Adapter lifecycle via AdapterManagerMixin + - Template configuration for trajectory encoding + """ + + def __init__(self, + nproc_per_node: int, + device_group: Dict[str, Any], + device_mesh: Dict[str, Any], + sampler_type: str = 'vllm', + engine_args: Optional[Dict[str, Any]] = None, + adapter_config: Optional[Dict[str, Any]] = None, + queue_config: Optional[Dict[str, Any]] = None, + **kwargs): + self.device_group = DeviceGroup(**device_group) + twinkle.initialize( + mode='ray', nproc_per_node=nproc_per_node, groups=[self.device_group], lazy_collect=False) + if 'mesh_dim_names' in device_mesh: + self.device_mesh = DeviceMesh(**device_mesh) + else: + self.device_mesh = DeviceMesh.from_sizes(**device_mesh) + self.sampler_type = sampler_type + replica_context = serve.get_replica_context() + replica_id = replica_context.replica_id.unique_id + + # Initialize sampler based on type + if sampler_type == 'vllm': + from twinkle.sampler import vLLMSampler + sampler_kwargs = engine_args or {} + self.sampler = vLLMSampler( + model_id=model_id, + engine_args=sampler_kwargs, + device_mesh=self.device_mesh, + remote_group=self.device_group.name, + instance_id=replica_id, + **{ + k: v + for k, v in kwargs.items() if k not in ['engine_args'] + }) + else: + from twinkle.sampler import TorchSampler + self.sampler = TorchSampler( + model_id=model_id, + device_mesh=self.device_mesh, + instance_id=replica_id, + remote_group=self.device_group.name, + **kwargs) + + self.sampler.set_template('Template', model_id=model_id) + self.state: ServerStateProxy = get_server_state() + + # Initialize both mixins + self._init_task_queue(TaskQueueConfig.from_dict(queue_config)) + _adapter_config = adapter_config or {} + self._init_adapter_manager(**_adapter_config) + self.start_adapter_countdown() + + @serve.multiplexed(max_num_models_per_replica=5) + async def _sticky_entry(self, sticky_key: str): + return sticky_key + + async def _ensure_sticky(self): + sticky_key = serve.get_multiplexed_model_id() + await self._sticky_entry(sticky_key) + + async def _on_request_start(self, request: Request) -> str: + await self._ensure_sticky() + token = get_token_from_request(request) + return token + + def _on_adapter_expired(self, adapter_name: str, token: str = None) -> None: + """Handle expired adapters by removing them from the sampler.""" + try: + self.sampler.remove_adapter(adapter_name) + logger.info(f'Removed expired adapter {adapter_name}') + except Exception as e: + logger.warning(f'Failed to remove expired adapter {adapter_name}: {e}') + + # Register routes from both handler mixins + TinkerSamplerHandlers._register_tinker_sampler_routes(app) + TwinkleSamplerHandlers._register_twinkle_sampler_routes(app) + + return SamplerManagement.options(**deploy_options).bind(nproc_per_node, device_group, device_mesh, sampler_type, + engine_args, adapter_config, queue_config, **kwargs) + + +build_sampler_app = wrap_builder_with_device_group_env(build_sampler_app) diff --git a/src/twinkle/server/sampler/tinker_handlers.py b/src/twinkle/server/sampler/tinker_handlers.py new file mode 100644 index 00000000..acb2bb0e --- /dev/null +++ b/src/twinkle/server/sampler/tinker_handlers.py @@ -0,0 +1,120 @@ +# Copyright (c) ModelScope Contributors. All rights reserved. +""" +Tinker-compatible sampler handler mixin. + +Provides POST /tinker/asample using schedule_task() returning UntypedAPIFuture. +""" +import os +import traceback +from fastapi import FastAPI, Request +from tinker import types + +from twinkle.data_format import SamplingParams +from twinkle.server.common.io_utils import create_checkpoint_manager +from twinkle.utils.logger import get_logger + +logger = get_logger() + + +class TinkerSamplerHandlers: + """ + Mixin providing Tinker-compatible sampler endpoint. + + Expects the combined class to also inherit TaskQueueMixin and to have: + self.sampler, self.state + """ + + @staticmethod + def _register_tinker_sampler_routes(app: FastAPI): + """Register the tinker sampler route on the given FastAPI app.""" + + @app.post('/tinker/asample') + async def asample(self, request: Request, body: types.SampleRequest) -> types.UntypedAPIFuture: + """Execute text generation (inference) for Tinker clients. + + Args: + request: FastAPI request with auth token + body: SampleRequest with prompt, sampling params, and adapter info + + Returns: + UntypedAPIFuture wrapping SampleResponse with generated sequences + """ + from twinkle.server.utils.validation import get_token_from_request + token = await self._on_request_start(request) + + async def _do_sample(): + try: + # Extract prompt token IDs from ModelInput + prompt_inputs = {'input_ids': body.prompt.to_ints()} + + # Get model_path from body or sampling session + model_path = body.model_path + if not model_path and body.sampling_session_id: + session = self.state.get_sampling_session(body.sampling_session_id) + if session: + model_path = session.get('model_path') + + # Parse and resolve adapter URI from model_path + adapter_uri = None + if model_path: + checkpoint_manager = create_checkpoint_manager(token, client_type='tinker') + adapter_name, adapter_uri = checkpoint_manager.parse_adapter_uri(model_path) + + # Validate adapter URI + if not adapter_uri or not os.path.exists(adapter_uri): + return types.RequestFailedResponse( + error=f'Adapter URI {model_path} does not exist. Please check the model_path.', + category=types.RequestErrorCategory.User, + ) + + # Convert tinker SamplingParams to twinkle SamplingParams if needed + sampling_params = None + if body.sampling_params: + sampling_params = SamplingParams( + max_tokens=body.sampling_params.max_tokens or 256, + temperature=body.sampling_params.temperature or 1.0, + top_p=body.sampling_params.top_p, + top_k=body.sampling_params.top_k, + stop=body.sampling_params.stop, + ) + + response = self.sampler.sample( + inputs=[prompt_inputs] * body.num_samples, + sampling_params=sampling_params, + adapter_path=adapter_uri, + ) + + # Convert twinkle SampleResponse to tinker types + tinker_sequences = [] + for seq in response.sequences: + logprobs = None + if seq.logprobs is not None: + if any(lp is None for lp in seq.logprobs): + logprobs = None + else: + logprobs = list(seq.logprobs) + tinker_sequences.append( + types.SampledSequence( + stop_reason=seq.stop_reason, + tokens=list(seq.tokens), + logprobs=logprobs, + )) + return types.SampleResponse( + sequences=tinker_sequences, + prompt_logprobs=response.prompt_logprobs, + topk_prompt_logprobs=response.topk_prompt_logprobs, + ) + except Exception: + logger.error(traceback.format_exc()) + return types.RequestFailedResponse( + error=traceback.format_exc(), + category=types.RequestErrorCategory.Server, + ) + + input_tokens = len(body.prompt.to_ints()) + return await self.schedule_task( + _do_sample, + token=token, + input_tokens=input_tokens, + task_type='sample', + ) diff --git a/src/twinkle/server/sampler/twinkle_handlers.py b/src/twinkle/server/sampler/twinkle_handlers.py new file mode 100644 index 00000000..b35ac404 --- /dev/null +++ b/src/twinkle/server/sampler/twinkle_handlers.py @@ -0,0 +1,139 @@ +# Copyright (c) ModelScope Contributors. All rights reserved. +""" +Twinkle-native sampler handler mixin. + +Provides /twinkle/* sampler endpoints that call the sampler directly (no queue needed). +""" +import traceback +from fastapi import FastAPI, Request +from typing import Optional + +from twinkle.data_format import InputFeature, SamplingParams, Trajectory +from twinkle.utils.logger import get_logger +from twinkle_client.types.sampler import (AddAdapterRequest, AddAdapterResponse, CreateResponse, HeartbeatRequest, + HeartbeatResponse, SampleRequest, SampleResponseModel, SetTemplateRequest, + SetTemplateResponse) + +logger = get_logger() + + +class TwinkleSamplerHandlers: + """ + Mixin providing Twinkle-native sampler endpoints. + + Expects the combined class to also have: + self.sampler, self.state + The class should also inherit AdapterManagerMixin for adapter lifecycle. + """ + + @staticmethod + def _register_twinkle_sampler_routes(app: FastAPI): + """Register all twinkle sampler routes on the given FastAPI app.""" + + @staticmethod + def _get_twinkle_sampler_adapter_name(request: Request, adapter_name: Optional[str]) -> Optional[str]: + if adapter_name is None or adapter_name == '': + return None + return request.state.request_id + '-' + adapter_name + + @app.post('/twinkle/create', response_model=CreateResponse) + def create(self, request: Request) -> CreateResponse: + """Health check / session creation endpoint.""" + return CreateResponse() + + @app.post('/twinkle/sample', response_model=SampleResponseModel) + def sample(self, request: Request, body: SampleRequest) -> SampleResponseModel: + """Sample completions from the model. + + Supports Trajectory or InputFeature inputs, with optional LoRA adapter. + """ + try: + # Resolve adapter + adapter_path = None + adapter_name = body.adapter_name or '' + full_adapter_name = _get_twinkle_sampler_adapter_name(request, adapter_name) or '' + + if body.adapter_uri: + from twinkle.server.common.io_utils import create_checkpoint_manager + from twinkle.server.utils.validation import get_token_from_request + token = get_token_from_request(request) + checkpoint_manager = create_checkpoint_manager(token, client_type='twinkle') + _, adapter_path = checkpoint_manager.parse_adapter_uri(body.adapter_uri) + + # Parse inputs + inputs = body.inputs + if isinstance(inputs, list) and inputs: + first = inputs[0] + if isinstance(first, dict) and 'input_ids' in first: + inputs = [InputFeature(**item) for item in inputs] + else: + inputs = [Trajectory(**item) for item in inputs] + elif isinstance(inputs, dict): + if 'input_ids' in inputs: + inputs = [InputFeature(**inputs)] + else: + inputs = [Trajectory(**inputs)] + + # Build sampling params + params = None + if body.sampling_params: + params = SamplingParams.from_dict(body.sampling_params) + + # Call sampler + response = self.sampler.sample( + inputs, + params, + adapter_name=full_adapter_name, + adapter_path=adapter_path, + num_samples=body.num_samples, + ) + if callable(response): + response = response() + + sequences = [] + for seq in response.sequences: + sequences.append({ + 'stop_reason': seq.stop_reason, + 'tokens': list(seq.tokens), + 'logprobs': list(seq.logprobs) if seq.logprobs is not None else None, + }) + + return SampleResponseModel( + sequences=sequences, + prompt_logprobs=response.prompt_logprobs, + topk_prompt_logprobs=response.topk_prompt_logprobs, + ) + except Exception: + logger.error(traceback.format_exc()) + raise + + @app.post('/twinkle/set_template', response_model=SetTemplateResponse) + def set_template(self, request: Request, body: SetTemplateRequest) -> SetTemplateResponse: + """Set the chat template for encoding Trajectory inputs.""" + extra_kwargs = body.model_extra or {} + self.sampler.set_template(body.template_cls, **extra_kwargs) + return SetTemplateResponse() + + @app.post('/twinkle/add_adapter_to_sampler', response_model=AddAdapterResponse) + def add_adapter_to_sampler(self, request: Request, body: AddAdapterRequest) -> AddAdapterResponse: + """Add a LoRA adapter to the sampler.""" + assert body.adapter_name, 'You need to specify a valid `adapter_name`' + full_adapter_name = _get_twinkle_sampler_adapter_name(request, body.adapter_name) + from twinkle.server.utils.validation import get_token_from_request + token = get_token_from_request(request) + + from peft import LoraConfig + config = LoraConfig(**body.config) if isinstance(body.config, dict) else body.config + + self.register_adapter(full_adapter_name, token) + self.sampler.add_adapter_to_sampler(full_adapter_name, config) + + return AddAdapterResponse(adapter_name=full_adapter_name) + + @app.post('/twinkle/heartbeat', response_model=HeartbeatResponse) + def heartbeat(self, request: Request, body: HeartbeatRequest) -> HeartbeatResponse: + """Keep an adapter alive by resetting its inactivity timer.""" + full_adapter_name = _get_twinkle_sampler_adapter_name(request, body.adapter_name) + self.assert_adapter_exists(adapter_name=full_adapter_name) + self.touch_adapter(full_adapter_name) + return HeartbeatResponse() diff --git a/src/twinkle/server/tinker/__init__.py b/src/twinkle/server/tinker/__init__.py deleted file mode 100644 index 40688d64..00000000 --- a/src/twinkle/server/tinker/__init__.py +++ /dev/null @@ -1,18 +0,0 @@ -# Copyright (c) ModelScope Contributors. All rights reserved. -import sys -from typing import TYPE_CHECKING - -from twinkle.utils.import_utils import _LazyModule - -_import_structure = { - 'model': ['build_model_app'], - 'sampler': ['build_sampler_app'], - 'server': ['build_server_app'], -} - -if TYPE_CHECKING: - from .model import build_model_app - from .sampler import build_sampler_app - from .server import build_server_app -else: - sys.modules[__name__] = _LazyModule(__name__, __file__, _import_structure, module_spec=__spec__) diff --git a/src/twinkle/server/tinker/common/__init__.py b/src/twinkle/server/tinker/common/__init__.py deleted file mode 100644 index ae59d58f..00000000 --- a/src/twinkle/server/tinker/common/__init__.py +++ /dev/null @@ -1,3 +0,0 @@ -# Copyright (c) ModelScope Contributors. All rights reserved. -from twinkle.utils import exists, requires -from .datum import datum_to_input_feature, input_feature_to_datum diff --git a/src/twinkle/server/tinker/common/compat_base.py b/src/twinkle/server/tinker/common/compat_base.py deleted file mode 100644 index 62e22ff6..00000000 --- a/src/twinkle/server/tinker/common/compat_base.py +++ /dev/null @@ -1,151 +0,0 @@ -import numpy as np -import torch -from tinker import types -from typing import List - -from twinkle import DeviceMesh -from twinkle.template import Template - - -def collect_forward_backward_results(results, device_mesh: DeviceMesh): - """Custom collect function for forward_backward that handles list [outputs, loss]. - - Args: - results: List of lists from each worker, where each list is [outputs_list, loss_float] - - Returns: - List of [flattened_outputs, averaged_loss] - """ - if not results: - return results - - # Filter for last pipeline stage if PP is enabled - pp_last_ranks = None - if device_mesh.pp_world_size > 1: - pp_last_ranks = set(device_mesh.get_pp_last_ranks()) - - # Filter for last tp rank if TP is enabled - tp_last_ranks = None - if device_mesh.tp_world_size > 1: - tp_last_ranks = set(device_mesh.get_tp_last_ranks()) - - mesh_flat = device_mesh.mesh.flatten() - - # results is a list of lists: [[outputs1, loss1], [outputs2, loss2], ...] - # Flatten outputs (first element of each list) - all_outputs = [] - all_losses = [] - for i, result in enumerate(results): - rank = mesh_flat[i] if i < len(mesh_flat) else -1 - - # Only collect from the last PP rank to avoid duplicates - if pp_last_ranks is not None: - if rank not in pp_last_ranks: - continue - - # Only collect from the last TP rank to avoid duplicates - if tp_last_ranks is not None: - if rank not in tp_last_ranks: - continue - - if result is None: - continue - - outputs, loss = result - if outputs is None or loss is None: - continue - all_outputs.extend(outputs) - all_losses.append(loss) - - # Average the losses - if all_losses: - avg_loss = float(np.mean(all_losses)) - else: - avg_loss = 0.0 - - return [all_outputs, avg_loss] - - -def clean_metrics(metrics: dict) -> dict: - import re - from numbers import Number - - def _to_float(v): - # python numeric / numpy scalar - if isinstance(v, (float, int, Number, np.generic, str)): - try: - return float(v) - except Exception: - return None - # 0-d torch tensor - if isinstance(v, torch.Tensor) and v.numel() == 1: - try: - return float(v.item()) - except Exception: - return None - return None - - cleaned = {} - for key, value in metrics.items(): - fv = _to_float(value) - if fv is not None: - cleaned[key] = fv - continue - - # handle common metric strings: "123 seconds", "1.23 iters/s" - if isinstance(value, str): - s = value.strip() - if s: - try: - head, unit = s.split() # ignore unit/tail - cleaned[f'{key}/{unit}'] = float(head) - except Exception: - m = re.match(r'^([+-]?(?:\d+(?:\.\d*)?|\.\d+)(?:[eE][+-]?\d+)?)', s) - if m: - cleaned[key] = float(m.group(1)) - - return cleaned - - -class TwinkleCompatModelBase: - """Base class containing common logic for Twinkle compatibility wrappers.""" - - def get_template(self, adapter_name: str) -> Template: - return self.optimizer_group[adapter_name].template - - @staticmethod - def _get_forward_output(inputs: List[types.Datum], logits: torch.Tensor, logps: torch.Tensor) -> List[dict]: - """Convert raw logits to the expected output format with logprobs and elementwise_loss.""" - from twinkle.utils.torch_utils import selective_log_softmax - device = logits.device if logits is not None else logps.device - results = [] - if logits is None: - logits = [None] * len(inputs) - for idx, (feature, logit) in enumerate(zip(inputs, logits)): - # Ensure 1D shape and correct device to avoid dimension mismatch and device errors - labels = feature.loss_fn_inputs['target_tokens'].to_torch().long().view(-1).to(device) # shape (seq_len,) - weights = feature.loss_fn_inputs['weights'].to_torch().view(-1).to(device) # shape (seq_len,) - - # Slice logits to match the sequence length of labels - # Labels are assumed to be already shifted/aligned with logits - seq_len = labels.numel() - - if logps is None: - assert logits is not None - # Check if index is within logits bounds - # Right padding - feature_logits = logit[:seq_len, :] - - # Calculate log probs for all labels - token_log_probs = selective_log_softmax(feature_logits, labels) - else: - token_log_probs = logps[idx, :seq_len] - - # elementwise_loss: positive NLL loss (0.0 where masked) - elementwise_loss = -token_log_probs * weights - - results.append({ - 'logprobs': types.TensorData.from_torch(token_log_probs.cpu()), - 'elementwise_loss': types.TensorData.from_torch(elementwise_loss.cpu()) - }) - return results diff --git a/src/twinkle/server/tinker/common/io_utils.py b/src/twinkle/server/tinker/common/io_utils.py deleted file mode 100644 index f3128e99..00000000 --- a/src/twinkle/server/tinker/common/io_utils.py +++ /dev/null @@ -1,181 +0,0 @@ -# Copyright (c) ModelScope Contributors. All rights reserved. -"""Tinker-specific IO utilities for managing training runs and checkpoints. - -This module extends the base IO utilities with Tinker-specific implementations. -It uses types from the tinker package for compatibility with the Tinker API. -""" -from datetime import datetime -from tinker import types -from typing import Any, Dict, List, Optional - -from twinkle.server.utils.io_utils import (CHECKPOINT_INFO_FILENAME, TRAIN_RUN_INFO_FILENAME, TWINKLE_DEFAULT_SAVE_DIR, - BaseCheckpointManager, BaseTrainingRunManager, ResolvedLoadPath, - validate_ownership, validate_user_path) - -# ----- Tinker Training Run Manager ----- - - -class TrainingRunManager(BaseTrainingRunManager): - """Tinker-specific training run manager using tinker.types models.""" - - @property - def train_run_info_filename(self) -> str: - return TRAIN_RUN_INFO_FILENAME - - def _create_training_run(self, model_id: str, run_config: types.CreateModelRequest) -> Dict[str, Any]: - """Create training run data from model_id and run_config.""" - lora_config = run_config.lora_config - train_run_data = types.TrainingRun( - training_run_id=model_id, - base_model=run_config.base_model, - model_owner=self.token, - is_lora=True if lora_config else False, - corrupted=False, - lora_rank=lora_config.rank if lora_config else None, - last_request_time=datetime.now(), - last_checkpoint=None, - last_sampler_checkpoint=None, - user_metadata=run_config.user_metadata) - - new_data = train_run_data.model_dump(mode='json') - # Store lora config details separately if needed - if lora_config: - new_data['train_unembed'] = lora_config.train_unembed - new_data['train_mlp'] = lora_config.train_mlp - new_data['train_attn'] = lora_config.train_attn - - return new_data - - def _parse_training_run(self, data: Dict[str, Any]) -> types.TrainingRun: - """Parse training run data into TrainingRun model.""" - # Transform checkpoint data to ensure tinker_path field exists - data = self._transform_checkpoint_fields(data) - return types.TrainingRun(**data) - - def _transform_checkpoint_fields(self, data: Dict[str, Any]) -> Dict[str, Any]: - """Transform checkpoint data to ensure compatibility with tinker types. - - Handles cases where: - - last_checkpoint/last_sampler_checkpoint might have twinkle_path instead of tinker_path - - Missing path field that needs to be constructed from other data - """ - data = data.copy() - for field in ['last_checkpoint', 'last_sampler_checkpoint']: - if field in data and data[field] is not None: - ckpt = data[field].copy() - # If twinkle_path exists but tinker_path doesn't, use twinkle_path - if 'twinkle_path' in ckpt and 'tinker_path' not in ckpt: - ckpt['tinker_path'] = ckpt.pop('twinkle_path') - # If neither exists, try to construct from checkpoint_id - elif 'tinker_path' not in ckpt: - # Try to get path from any available path field - path = ckpt.get('path') or ckpt.get('twinkle_path') - if path: - ckpt['tinker_path'] = path - elif 'checkpoint_id' in ckpt and 'training_run_id' in data: - # Construct path from components - ckpt['tinker_path'] = f"twinkle://{data['training_run_id']}/{ckpt['checkpoint_id']}" - data[field] = ckpt - return data - - def _create_training_runs_response(self, runs: List[types.TrainingRun], limit: int, offset: int, - total: int) -> types.TrainingRunsResponse: - """Create a training runs response.""" - return types.TrainingRunsResponse( - training_runs=runs, cursor=types.Cursor(limit=limit, offset=offset, total_count=total)) - - -# ----- Tinker Checkpoint Manager ----- - - -class CheckpointManager(BaseCheckpointManager): - """Tinker-specific checkpoint manager using tinker.types models.""" - - @property - def path_prefix(self) -> str: - return 'twinkle://' - - @property - def path_field_name(self) -> str: - return 'tinker_path' - - def _create_checkpoint(self, - checkpoint_id: str, - checkpoint_type: str, - path: str, - size_bytes: int, - public: bool, - base_model: Optional[str] = None, - is_lora: bool = False, - lora_rank: Optional[int] = None, - train_unembed: Optional[bool] = None, - train_mlp: Optional[bool] = None, - train_attn: Optional[bool] = None, - user_metadata: Optional[Dict[str, Any]] = None) -> Dict[str, Any]: - """Create checkpoint data.""" - # Create base checkpoint using tinker types - checkpoint = types.Checkpoint( - checkpoint_id=checkpoint_id, - checkpoint_type=checkpoint_type, - time=datetime.now(), - tinker_path=path, - size_bytes=size_bytes, - public=public) - result = checkpoint.model_dump(mode='json') - - # Add training run info fields (may not be supported by external types.Checkpoint) - result['base_model'] = base_model - result['is_lora'] = is_lora - result['lora_rank'] = lora_rank - result['train_unembed'] = train_unembed - result['train_mlp'] = train_mlp - result['train_attn'] = train_attn - result['user_metadata'] = user_metadata - - return result - - def _parse_checkpoint(self, data: Dict[str, Any]) -> types.Checkpoint: - """Parse checkpoint data into Checkpoint model.""" - data = data.copy() - # Transform twinkle_path to tinker_path if needed - if 'twinkle_path' in data and 'tinker_path' not in data: - data['tinker_path'] = data.pop('twinkle_path') - elif 'tinker_path' not in data and 'path' in data: - data['tinker_path'] = data.pop('path') - return types.Checkpoint(**data) - - def _create_checkpoints_response(self, checkpoints: List[types.Checkpoint]) -> types.CheckpointsListResponse: - """Create a checkpoints list response.""" - return types.CheckpointsListResponse(checkpoints=checkpoints, cursor=None) - - def _create_parsed_path(self, path: str, training_run_id: str, checkpoint_type: str, - checkpoint_id: str) -> types.ParsedCheckpointTinkerPath: - """Create a parsed path model.""" - return types.ParsedCheckpointTinkerPath( - tinker_path=path, - training_run_id=training_run_id, - checkpoint_type=checkpoint_type, - checkpoint_id=checkpoint_id, - ) - - def _create_weights_info(self, run_info: Dict[str, Any]) -> types.WeightsInfoResponse: - """Create weights info from run info.""" - return types.WeightsInfoResponse(**run_info) - - def parse_tinker_path(self, tinker_path: str) -> Optional[types.ParsedCheckpointTinkerPath]: - """Parse a twinkle:// path into its components (alias for parse_path).""" - return self.parse_path(tinker_path) - - -# ----- Factory Functions ----- - - -def create_training_run_manager(token: str) -> TrainingRunManager: - """Create a TrainingRunManager for the given token.""" - return TrainingRunManager(token) - - -def create_checkpoint_manager(token: str) -> CheckpointManager: - """Create a CheckpointManager for the given token.""" - training_run_manager = TrainingRunManager(token) - return CheckpointManager(token, training_run_manager) diff --git a/src/twinkle/server/tinker/common/transformers_model.py b/src/twinkle/server/tinker/common/transformers_model.py deleted file mode 100644 index 98ae0134..00000000 --- a/src/twinkle/server/tinker/common/transformers_model.py +++ /dev/null @@ -1,148 +0,0 @@ -from tinker import types -from typing import List - -from twinkle import remote_class, remote_function -from twinkle.model import MultiLoraTransformersModel -from .compat_base import TwinkleCompatModelBase, clean_metrics, collect_forward_backward_results -from .datum import datum_to_input_feature, extract_rl_feature -from .io_utils import create_checkpoint_manager - - -@remote_class() -class TwinkleCompatTransformersModel(MultiLoraTransformersModel, TwinkleCompatModelBase): - """ - Compatibility wrapper around :class:`MultiLoraTransformersModel` for Twinkle/Tinker. - - This class adapts the core `MultiLoraTransformersModel` API to the data types and - remote-call semantics used by Twinkle: - - * Inputs to :meth:`forward` and :meth:`forward_only` are provided as - ``List[types.Datum]`` and are converted to the underlying model's - ``InputFeature`` format via :func:`datum_to_input_feature`. - * The outputs of :meth:`forward` and :meth:`forward_only` are not the raw - transformer outputs; instead they are a list of dictionaries, one per - input example, containing: - - - ``"logprobs"``: token-level log-probabilities as ``types.TensorData``. - - ``"elementwise_loss"``: per-token (masked) NLL loss as ``types.TensorData``. - - These are derived from the underlying logits by applying ``log_softmax`` - and slicing to the label sequence length. - * :meth:`calculate_loss` returns a Python scalar (via ``tensor.item()``) - and is exposed as a remote function with ``collect='sum'``, so the - distributed caller receives an aggregated scalar loss instead of a - tensor object. - * :meth:`step` accepts optimizer hyperparameters as :class:`types.AdamParams`, - performs optional gradient clipping, translates them into the optimizer - configuration expected by the base class, invokes the base ``step`` - implementation, and finally zeros gradients. - - Overall, this wrapper ensures that callers using Twinkle's higher-level - ``Datum``/``TensorData`` abstractions and remote functions can interact - with a ``MultiLoraTransformersModel`` instance without needing to know its - internal input feature schema, output structure, or optimizer API. - """ - - @remote_function(dispatch='slice_dp', collect='flatten') - def forward_only(self, *, inputs: List[types.Datum], **kwargs): - # Get template for input processing - template = self.get_template(**kwargs) - # Convert Datum to InputFeature - input_features = datum_to_input_feature(inputs, template) - outputs = super().forward_only(inputs=input_features, **kwargs) - # shape (batch_size, seq_len, vocab_size) - logits = outputs['logits'].detach().cpu() - logps = outputs.get('logps', None) - if logps is not None: - logps = logps.detach().cpu() - results = self._get_forward_output(inputs, logits, logps) - return results - - @remote_function(dispatch='slice_dp', collect=collect_forward_backward_results) - def forward_backward(self, *, inputs: List[types.Datum], adapter_name: str, loss_fn: str, **kwargs): - # Set loss first based on loss_fn - if loss_fn == 'cross_entropy': - super().set_loss('CrossEntropyLoss', adapter_name=adapter_name) - elif loss_fn == 'importance_sampling': - super().set_loss( - 'GRPOLoss', - adapter_name=adapter_name, - epsilon=0.2, # Default GRPO epsilon - beta=0.0) # No KL penalty by default - else: - super().set_loss('CrossEntropyLoss', adapter_name=adapter_name) - # Get template for input processing - template = self.get_template(adapter_name) - - # Convert Datum to InputFeature - input_features = datum_to_input_feature(inputs, template) - - # Forward pass - outputs = super().forward(inputs=input_features, adapter_name=adapter_name, **kwargs) - - # Calculate loss with extra parameters - # Extract old_logps and advantages using common utility - loss_values = extract_rl_feature(inputs) - loss_kwargs = kwargs.copy() - loss_kwargs.update(loss_values) - loss = super().calculate_loss(adapter_name=adapter_name, **loss_kwargs) - - # Backward pass - super().backward(adapter_name=adapter_name, **kwargs) - - # shape (batch_size, seq_len, vocab_size) - logits = outputs['logits'].detach() - logps = outputs.get('logps', None) - if logps is not None: - logps = logps.detach().cpu() - results = self._get_forward_output(inputs, logits, logps) - return [results, loss] - - @remote_function() - def step(self, *, adam_params: types.AdamParams, **kwargs): - # Gradient clipping - grad_clip_norm = adam_params.grad_clip_norm - if grad_clip_norm > 0.0: - self.clip_grad_norm(max_grad_norm=grad_clip_norm, norm_type=2, **kwargs) - # Optimizer step - optim_params = { - 'lr': adam_params.learning_rate, - 'eps': adam_params.eps, - 'betas': (adam_params.beta1, adam_params.beta2), - 'weight_decay': adam_params.weight_decay, - } - super().step(optim_params=optim_params, **kwargs) - # Zero gradients - super().zero_grad(**kwargs) - - @remote_function(collect='first', lazy_collect=False) - def calculate_metric(self, is_training, **kwargs): - metric = super().calculate_metric(is_training, **kwargs) - return clean_metrics(metric) - - @remote_function() - def load(self, checkpoint_dir: str, **kwargs): - """ - Load checkpoint with token-based isolation support. - - Args: - checkpoint_dir: The twinkle:// path to the checkpoint or hub model ID - **kwargs: Additional keyword arguments including optional 'token' - """ - # Extract token from kwargs if provided (for user isolation) - token = kwargs.pop('token', None) - if not token: - raise ValueError('Token is required for loading checkpoints') - - # Create checkpoint manager with the token - checkpoint_manager = create_checkpoint_manager(token) - - # Use resolve_load_path to handle path resolution - resolved = checkpoint_manager.resolve_load_path(checkpoint_dir) - - if resolved.is_twinkle_path: - # Load from twinkle checkpoint - return super().load(name=resolved.checkpoint_name, output_dir=resolved.checkpoint_dir, **kwargs) - else: - # Load from hub - return super().load(name=resolved.checkpoint_name, **kwargs) diff --git a/src/twinkle/server/tinker/model.py b/src/twinkle/server/tinker/model.py deleted file mode 100644 index 80778c36..00000000 --- a/src/twinkle/server/tinker/model.py +++ /dev/null @@ -1,659 +0,0 @@ -# Copyright (c) ModelScope Contributors. All rights reserved. -""" -Tinker-compatible model management server. - -This module provides a Ray Serve deployment that manages distributed training models. -It handles: -1. Model and adapter lifecycle (create, load, unload) -2. Training operations (forward, backward, optimizer steps) -3. Checkpoint management (save/load weights) -4. Multi-user support with token-based isolation -""" -import traceback -from fastapi import FastAPI, Request -from peft import LoraConfig -from ray import serve -from ray.serve.config import RequestRouterConfig -from tinker import types -from typing import Any, Dict, Optional - -import twinkle -from twinkle import DeviceGroup, DeviceMesh -from twinkle.server.utils.adapter_manager import AdapterManagerMixin -from twinkle.server.utils.state import ServerStateProxy, get_server_state -from twinkle.server.utils.task_queue import TaskQueueConfig, TaskQueueMixin -from twinkle.server.utils.validation import get_token_from_request, verify_request_token -from twinkle.utils.logger import get_logger -from ..utils import wrap_builder_with_device_group_env -from .common.io_utils import create_checkpoint_manager, create_training_run_manager -from .common.router import StickyLoraRequestRouter - -logger = get_logger() - - -def build_model_app(model_id: str, - nproc_per_node: int, - device_group: Dict[str, Any], - device_mesh: Dict[str, Any], - deploy_options: Dict[str, Any], - use_megatron: bool = False, - adapter_config: Dict[str, Any] = {}, - queue_config: Optional[Dict[str, Any]] = {}, - **kwargs): - """Build a model management application for distributed training. - - This factory function creates a Ray Serve deployment that manages a training model - with support for multiple adapters (LoRA) and multi-user isolation. - - Args: - model_id: Base model identifier (e.g., "Qwen/Qwen2.5-0.5B-Instruct") - nproc_per_node: Number of processes per node for distributed training - device_group: Device group configuration dict - device_mesh: Device mesh configuration dict for tensor parallelism - deploy_options: Ray Serve deployment options - use_megatron: Whether to use Megatron backend (vs Transformers) - queue_config: Task queue configuration (rate limiting, etc.) - **kwargs: Additional model initialization arguments - - Returns: - Configured Ray Serve deployment bound with parameters - """ - app = FastAPI() - - @app.middleware('http') - async def verify_token(request: Request, call_next): - """Middleware to verify authentication token for all requests.""" - return await verify_request_token(request=request, call_next=call_next) - - @serve.deployment( - name='ModelManagement', - request_router_config=RequestRouterConfig(request_router_class=StickyLoraRequestRouter, ), - ) - @serve.ingress(app) - class ModelManagement(TaskQueueMixin, AdapterManagerMixin): - """Model management service handling training operations. - - This class manages: - - Base model and multiple adapter instances (multi-user LoRA) - - Training operations (forward, backward, optimizer steps) - - Adapter lifecycle with automatic cleanup via AdapterManagerMixin - - Per-user adapter limits and tracking - """ - - def __init__(self, - nproc_per_node: int, - device_group: Dict[str, Any], - device_mesh: Dict[str, Any], - use_megatron: bool = False, - queue_config: Optional[Dict[str, Any]] = None, - **kwargs): - """Initialize the model management service. - - Args: - nproc_per_node: Number of processes per node - device_group: Device group configuration - device_mesh: Device mesh configuration for parallelism - use_megatron: Whether to use Megatron backend - queue_config: Task queue configuration dict - **kwargs: Additional model initialization arguments - """ - self.device_group = DeviceGroup(**device_group) - twinkle.initialize( - mode='ray', nproc_per_node=nproc_per_node, groups=[self.device_group], lazy_collect=False) - if 'mesh_dim_names' in device_mesh: - self.device_mesh = DeviceMesh(**device_mesh) - else: - self.device_mesh = DeviceMesh.from_sizes(**device_mesh) - self.use_megatron = use_megatron - self.replica_id = serve.get_replica_context().replica_id.unique_id - self.max_loras = kwargs.get('max_loras', 5) - # Initialize model immediately - choose backend based on use_megatron - if use_megatron: - from .common.megatron_model import TwinkleCompatMegatronModel - self.model = TwinkleCompatMegatronModel( - model_id=model_id, - device_mesh=self.device_mesh, - remote_group=self.device_group.name, - instance_id=self.replica_id, - **kwargs) - else: - from .common.transformers_model import TwinkleCompatTransformersModel - self.model = TwinkleCompatTransformersModel( - model_id=model_id, - device_mesh=self.device_mesh, - remote_group=self.device_group.name, - instance_id=self.replica_id, - **kwargs) - self.base_model = model_id - self.state: ServerStateProxy = get_server_state() - - # Register this replica so the router can track capacity - self.state.register_replica(self.replica_id, self.max_loras) - - # Initialize task queue - self._init_task_queue(TaskQueueConfig.from_dict(queue_config)) - - self._init_adapter_manager(**adapter_config) - self.start_adapter_countdown() - - """ - This is a cache system, we must change to sticky routing - Reference docs: - 1. [Now]https://docs.ray.io/en/latest/serve/model-multiplexing.html - 2. https://docs.ray.io/en/latest/serve/llm/architecture/routing-policies.html - 3. https://github.com/ray-project/ray/pull/56855/changes - 4. Direct call actor instead of http or handler in server.py - """ - - @serve.multiplexed(max_num_models_per_replica=kwargs.get('max_loras', 5)) - async def _sticky_entry(self, sticky_key: str): - return sticky_key - - async def _ensure_sticky(self): - sticky_key = serve.get_multiplexed_model_id() - await self._sticky_entry(sticky_key) - - async def _on_request_start(self, request: Request) -> str: - await self._ensure_sticky() - token = get_token_from_request(request) - return token - - def __del__(self): - self.state.unregister_replica(self.replica_id) - - def _cleanup_adapter(self, adapter_name: str) -> None: - """Common adapter cleanup logic used by both manual unload and automatic expiration. - - This method handles: - 1. Clearing adapter state - 2. Removing adapter from model - 3. Unregistering from adapter manager - 4. Removing from server state - - Args: - adapter_name: Name of the adapter to clean up - """ - # Remove from model if it exists - if self.get_adapter_info(adapter_name): - # Clear adapter state - self.clear_adapter_state(adapter_name) - - self.model.remove_adapter(adapter_name) - # Unregister from adapter manager - self.unregister_adapter(adapter_name) - - # Remove from server state - self.state.unload_model(adapter_name) - - def _on_adapter_expired(self, adapter_name: str) -> None: - # Called from AdapterManagerMixin's countdown thread. - # Fail any pending tasks for this adapter/model. - self.fail_pending_tasks_for_model(adapter_name, reason='Adapter expired') - # Perform common cleanup (without token since it's automatic) - self._cleanup_adapter(adapter_name) - - @app.post('/create_model') - async def create_model(self, request: Request, body: types.CreateModelRequest) -> types.UntypedAPIFuture: - """Create a new model adapter for training. - - This endpoint: - 1. Registers the model in server state - 2. Creates a LoRA adapter with specified config - 3. Sets up processor, loss, and optimizer for the adapter - 4. Saves metadata to training run manager - - Args: - request: FastAPI request with auth token - body: CreateModelRequest with base_model and lora_config - - Returns: - UntypedAPIFuture wrapping CreateModelResponse with model_id - """ - token = await self._on_request_start(request) - - async def _create_adapter(): - model_id = None - try: - # Register a new model_id for each create_model call - model_id = self.state.register_model(body.model_dump(), token=token, replica_id=self.replica_id) - - # Create a new LoRA adapter for the model - if body.lora_config: - # TODO: support more lora config parameters, train_unembed, etc. - lora_cfg = LoraConfig(r=body.lora_config.rank, target_modules='all-linear') - - adapter_name = self.get_adapter_name(adapter_name=model_id) - - # Register adapter FIRST - self.register_adapter(adapter_name, token, session_id=body.session_id) - - # Create adapter AFTER successful registration - self.model.add_adapter_to_model(adapter_name=adapter_name, config_or_dir=lora_cfg) - - self.model.set_template('Template', adapter_name=adapter_name, model_id=self.base_model) - self.model.set_processor('InputProcessor', adapter_name=adapter_name) - self.model.set_optimizer('Adam', adapter_name=adapter_name) - - # Fresh adapter has no accumulated gradients. - self.set_adapter_state(adapter_name, 'grad_ready', False) - - training_run_manager = create_training_run_manager(token) - training_run_manager.save(model_id, body) - - return types.CreateModelResponse(model_id=model_id) - except Exception: - # Ensure we don't leave stale grad state. - if model_id: - adapter_name = self.get_adapter_name(adapter_name=model_id) - self._cleanup_adapter(adapter_name) - - logger.error(traceback.format_exc()) - return types.RequestFailedResponse( - error=traceback.format_exc(), - category=types.RequestErrorCategory.Server, - ) - - return await self.schedule_task( - _create_adapter, - token=token, - task_type='create_model', - ) - - @app.post('/get_info') - async def get_info(self, request: Request, body: types.GetInfoRequest) -> types.GetInfoResponse: - """Get information about a model. - - Args: - request: FastAPI request with auth token - body: GetInfoRequest with model_id - - Returns: - GetInfoResponse with model metadata (name, lora_rank, etc.) - """ - token = await self._on_request_start(request) - # Note: get_info doesn't require token for reading metadata in tinker - # Using a default token or None since this is read-only - training_run_manager = create_training_run_manager(token) - metadata = training_run_manager.get(str(body.model_id)) - model_name = metadata.base_model if metadata else model_id - lora_rank = None - is_lora = False - if metadata and hasattr(metadata, 'lora_rank') and metadata.lora_rank: - lora_rank = metadata.lora_rank - is_lora = metadata.is_lora - return types.GetInfoResponse( - model_data=types.ModelData(model_name=model_name), - model_id=body.model_id, - is_lora=is_lora, - lora_rank=lora_rank, - model_name=model_name, - ) - - @app.post('/unload_model') - async def unload_model(self, request: Request, body: types.UnloadModelRequest) -> types.UntypedAPIFuture: - """Unload a model adapter from memory. - - Removes the adapter and updates user adapter counts. - - Args: - request: FastAPI request with auth token - body: UnloadModelRequest with model_id - - Returns: - UntypedAPIFuture wrapping UnloadModelResponse - """ - token = await self._on_request_start(request) - - async def _do_unload(): - # Only remove adapter, not the base model - adapter_name = self.get_adapter_name(adapter_name=body.model_id) - # Use common cleanup logic - self._cleanup_adapter(adapter_name) - return types.UnloadModelResponse(model_id=body.model_id) - - return await self.schedule_task( - _do_unload, - model_id=body.model_id, - token=token, - task_type='unload_model', - ) - - @app.post('/forward') - async def forward(self, request: Request, body: types.ForwardRequest) -> types.UntypedAPIFuture: - """Execute forward pass without backward pass. - - Used for inference or evaluation without gradient computation. - - Args: - request: FastAPI request with auth token - body: ForwardRequest with input data - - Returns: - UntypedAPIFuture wrapping ForwardBackwardOutput with loss - """ - token = await self._on_request_start(request) - - async def _do_forward(): - try: - adapter_name = self.get_adapter_name(adapter_name=body.model_id) - self.assert_adapter_exists(adapter_name=adapter_name) - - # Touch adapter to reset inactivity counter - self.touch_adapter(adapter_name) - - datum_list = body.forward_input.data - loss_fn_config = body.forward_input.loss_fn_config or {} - - output = self.model.forward_only(inputs=datum_list, adapter_name=adapter_name) - loss = self.model.calculate_loss(adapter_name=adapter_name, **loss_fn_config) - return types.ForwardBackwardOutput( - loss_fn_output_type='CrossEntropyLossReturn', - loss_fn_outputs=output, - metrics={'loss:sum': loss}, - ) - except Exception: - logger.error(traceback.format_exc()) - return types.RequestFailedResponse( - error=traceback.format_exc(), - category=types.RequestErrorCategory.Server, - ) - - # Calculate input tokens and batch size for validation - datum_list = body.forward_input.data - input_tokens = sum(len(d.model_input.to_ints()) for d in datum_list) - batch_size = len(datum_list) - return await self.schedule_task( - _do_forward, - model_id=body.model_id, - token=token, - input_tokens=input_tokens, - batch_size=batch_size, - data_world_size=self.device_mesh.data_world_size, - task_type='forward', - ) - - @app.post('/forward_backward') - async def forward_backward(self, request: Request, - body: types.ForwardBackwardRequest) -> types.UntypedAPIFuture: - """Execute forward and backward pass for training. - - This combines forward pass and gradient computation. The implementation - differs based on backend: - - Megatron: Uses combined forward_backward method - - Transformers: Separate forward, calculate_loss, backward calls - - Args: - request: FastAPI request with auth token - body: ForwardBackwardRequest with training data - - Returns: - UntypedAPIFuture wrapping ForwardBackwardOutput with loss and metrics - """ - token = await self._on_request_start(request) - - async def _do_forward_backward(): - try: - adapter_name = self.get_adapter_name(adapter_name=body.model_id) - self.assert_adapter_exists(adapter_name=adapter_name) - - # Touch adapter to reset inactivity counter - self.touch_adapter(adapter_name) - - datum_list = body.forward_backward_input.data - loss_fn = body.forward_backward_input.loss_fn - loss_fn_config = body.forward_backward_input.loss_fn_config or {} - - # Unified forward_backward for both Megatron and Transformers - output, loss = self.model.forward_backward( - inputs=datum_list, adapter_name=adapter_name, loss_fn=loss_fn, **loss_fn_config) - if loss_fn == 'importance_sampling': - output_type = 'ImportanceSamplingLossReturn' - else: - output_type = 'CrossEntropyLossReturn' - # Mark gradients as ready after a successful forward_backward. - self.set_adapter_state(adapter_name, 'grad_ready', True) - return types.ForwardBackwardOutput( - loss_fn_output_type=output_type, - loss_fn_outputs=output, - metrics={'loss:avg': loss}, - ) - except Exception: - logger.error(traceback.format_exc()) - return types.RequestFailedResponse( - error=traceback.format_exc(), - category=types.RequestErrorCategory.Server, - ) - - # Calculate input tokens and batch size for validation - datum_list = body.forward_backward_input.data - input_tokens = sum(len(d.model_input.to_ints()) for d in datum_list) - batch_size = len(datum_list) - return await self.schedule_task( - _do_forward_backward, - model_id=body.model_id, - token=token, - input_tokens=input_tokens, - batch_size=batch_size, - data_world_size=self.device_mesh.data_world_size, - task_type='forward_backward', - ) - - @app.post('/optim_step') - async def optim_step(self, request: Request, body: types.OptimStepRequest) -> types.UntypedAPIFuture: - """Execute optimizer step to update model weights. - - Applies accumulated gradients to update adapter parameters. - - Args: - request: FastAPI request with auth token - body: OptimStepRequest with optimizer parameters - - Returns: - UntypedAPIFuture wrapping OptimStepResponse - """ - token = await self._on_request_start(request) - - async def _do_optim(): - try: - adapter_name = self.get_adapter_name(adapter_name=body.model_id) - self.assert_adapter_exists(adapter_name=adapter_name) - - # Disallow empty step (must have at least one forward_backward since last step) - if not self.get_adapter_state(adapter_name, 'grad_ready', False): - raise RuntimeError( - f'No accumulated gradients for adapter={adapter_name}; call forward_backward before optim_step' # noqa: E501 - ) - - # Touch adapter to reset inactivity counter - self.touch_adapter(adapter_name) - - self.model.step(adam_params=body.adam_params, adapter_name=adapter_name) - # Clear grad-ready after a successful step. - self.set_adapter_state(adapter_name, 'grad_ready', False) - metrics = self.model.calculate_metric(is_training=True, adapter_name=adapter_name) - return types.OptimStepResponse(metrics=metrics) - except Exception: - logger.error(traceback.format_exc()) - return types.RequestFailedResponse( - error=traceback.format_exc(), - category=types.RequestErrorCategory.Server, - ) - - return await self.schedule_task( - _do_optim, - model_id=body.model_id, - token=token, - task_type='optim_step', - ) - - @app.post('/save_weights') - async def save_weights(self, request: Request, body: types.SaveWeightsRequest) -> types.UntypedAPIFuture: - """Save model adapter weights to storage. - - Saves both model weights and optimizer state for training resumption. - Uses token-based isolation for user-specific storage. - - Args: - request: FastAPI request with auth token - body: SaveWeightsRequest with path and model_id - - Returns: - UntypedAPIFuture wrapping SaveWeightsResponse with saved path - """ - token = await self._on_request_start(request) - - async def _do_save(): - try: - adapter_name = self.get_adapter_name(adapter_name=body.model_id) - self.assert_adapter_exists(adapter_name=adapter_name) - - # Touch adapter to reset inactivity counter - self.touch_adapter(adapter_name) - - checkpoint_manager = create_checkpoint_manager(token) - - # get save dir with token-based isolation - checkpoint_name = checkpoint_manager.get_ckpt_name(body.path) - save_dir = checkpoint_manager.get_save_dir(model_id=body.model_id, is_sampler=False) - - self.model.save( - name=checkpoint_name, output_dir=save_dir, adapter_name=adapter_name, save_optimizer=True) - - tinker_path = checkpoint_manager.save(body.model_id, name=checkpoint_name, is_sampler=False) - - return types.SaveWeightsResponse(path=tinker_path, type='save_weights') - except Exception: - logger.error(traceback.format_exc()) - return types.RequestFailedResponse( - error=traceback.format_exc(), - category=types.RequestErrorCategory.Server, - ) - - return await self.schedule_task( - _do_save, - model_id=body.model_id, - token=token, - task_type='save_weights', - ) - - @app.post('/save_weights_for_sampler') - async def save_weights_for_sampler(self, request: Request, - body: types.SaveWeightsForSamplerRequest) -> types.UntypedAPIFuture: - """Save/convert weights for inference use. - - Saves adapter weights without optimizer state for use with sampler. - Creates a sampling session for tracking. - - Args: - request: FastAPI request with auth token - body: SaveWeightsForSamplerRequest with model_id and path - - Returns: - UntypedAPIFuture wrapping SaveWeightsForSamplerResponseInternal - """ - token = await self._on_request_start(request) - - async def _do_save_for_sampler(): - try: - - adapter_name = self.get_adapter_name(adapter_name=body.model_id) - self.assert_adapter_exists(adapter_name=adapter_name) - - # Touch adapter to reset inactivity counter - self.touch_adapter(adapter_name) - - checkpoint_manager = create_checkpoint_manager(token) - - # get save dir with token-based isolation - checkpoint_name = checkpoint_manager.get_ckpt_name(body.path) - save_dir = checkpoint_manager.get_save_dir(model_id=body.model_id, is_sampler=True) - # NOTE: Need to save meta first to ensure only one sample weight exists - tinker_path = checkpoint_manager.save(body.model_id, name=checkpoint_name, is_sampler=True) - - logger.info(f'Saving weights to {save_dir}') - # Save weights with save_optimizer=False for sampler use - self.model.save( - name=checkpoint_name, output_dir=save_dir, adapter_name=adapter_name, save_optimizer=False) - - # Create sampling session with resolved model_path/base_model. - payload = body.model_dump() - payload['model_path'] = tinker_path - metadata = self.state.get_model_metadata(body.model_id) or {} - if metadata.get('base_model'): - payload['base_model'] = metadata['base_model'] - sampling_session_id = self.state.create_sampling_session(payload) - - return types.SaveWeightsForSamplerResponseInternal( - path=None, # Disable path return for internal use - sampling_session_id=sampling_session_id) - except Exception: - logger.error(traceback.format_exc()) - return types.RequestFailedResponse( - error=traceback.format_exc(), - category=types.RequestErrorCategory.Server, - ) - - return await self.schedule_task( - _do_save_for_sampler, - model_id=body.model_id, - token=token, - task_type='save_weights_for_sampler', - ) - - @app.post('/load_weights') - async def load_weights(self, request: Request, body: types.LoadWeightsRequest) -> types.UntypedAPIFuture: - """Load model adapter weights from storage. - - Loads weights and optionally optimizer state for training resumption. - Uses token-based isolation for user-specific storage access. - - Args: - request: FastAPI request with auth token - body: LoadWeightsRequest with path and optimizer flag - - Returns: - UntypedAPIFuture wrapping LoadWeightsResponse - """ - token = await self._on_request_start(request) - - async def _do_load(): - try: - assert self.model is not None, 'Model not loaded, please load model first' - - adapter_name = self.get_adapter_name(adapter_name=body.model_id) - self.assert_adapter_exists(adapter_name=adapter_name) - - # Touch adapter to reset inactivity counter - self.touch_adapter(adapter_name) - - weight_path = body.path - load_optimizer = body.optimizer - - self.model.load( - checkpoint_dir=weight_path, - load_optimizer=load_optimizer, - adapter_name=adapter_name, - token=token) - - # Loading a checkpoint should reset step readiness. - self.set_adapter_state(adapter_name, 'grad_ready', False) - return types.LoadWeightsResponse(path=body.path, type='load_weights') - except Exception: - logger.error(traceback.format_exc()) - return types.RequestFailedResponse( - error=traceback.format_exc(), - category=types.RequestErrorCategory.Server, - ) - - return await self.schedule_task( - _do_load, - model_id=body.model_id, - token=token, - task_type='load_weights', - ) - - return ModelManagement.options(**deploy_options).bind(nproc_per_node, device_group, device_mesh, use_megatron, - queue_config, **kwargs) - - -build_model_app = wrap_builder_with_device_group_env(build_model_app) diff --git a/src/twinkle/server/tinker/sampler.py b/src/twinkle/server/tinker/sampler.py deleted file mode 100644 index 406524f3..00000000 --- a/src/twinkle/server/tinker/sampler.py +++ /dev/null @@ -1,251 +0,0 @@ -# Copyright (c) ModelScope Contributors. All rights reserved. -""" -Tinker-compatible sampler (inference) server. - -This module provides a Ray Serve deployment for distributed text generation/inference. -It supports: -1. vLLM and Torch sampler backends -2. LoRA adapter loading via adapter URIs -3. Multi-user inference with rate limiting -4. Flexible sampling parameters -""" -import os -import traceback -from fastapi import FastAPI, Request -from ray import serve -from tinker import types -from typing import Any, Dict, Optional - -import twinkle -from twinkle import DeviceGroup, DeviceMesh -from twinkle.data_format import SamplingParams -from twinkle.server.utils.state import ServerStateProxy, get_server_state -from twinkle.server.utils.task_queue import TaskQueueConfig, TaskQueueMixin -from twinkle.server.utils.validation import get_token_from_request, verify_request_token -from twinkle.utils.logger import get_logger -from ..utils import wrap_builder_with_device_group_env -from .common.io_utils import create_checkpoint_manager - -logger = get_logger() - - -def build_sampler_app(model_id: str, - nproc_per_node: int, - device_group: Dict[str, Any], - device_mesh: Dict[str, Any], - deploy_options: Dict[str, Any], - sampler_type: str = 'vllm', - engine_args: Optional[Dict[str, Any]] = None, - queue_config: Optional[Dict[str, Any]] = None, - **kwargs): - """Build a sampler application for tinker-compatible inference. - - This factory function creates a Ray Serve deployment that manages a sampler - (inference engine) with support for LoRA adapters and rate limiting. - - Args: - model_id: Model identifier (e.g., "ms://Qwen/Qwen2.5-0.5B-Instruct") - nproc_per_node: Number of processes per node - device_group: Device group configuration dict - device_mesh: Device mesh configuration dict for parallelism - deploy_options: Ray Serve deployment options - sampler_type: Type of sampler to use ('vllm' or 'torch') - engine_args: Additional engine arguments for the sampler - queue_config: Task queue configuration dict (rps_limit, tps_limit, etc.) - **kwargs: Additional arguments passed to the sampler - - Returns: - Ray Serve deployment bound with configuration - """ - app = FastAPI() - - @app.middleware('http') - async def verify_token(request: Request, call_next): - """Middleware to verify authentication token for all requests.""" - return await verify_request_token(request=request, call_next=call_next) - - @serve.deployment(name='SamplerManagement') - @serve.ingress(app) - class SamplerManagement(TaskQueueMixin): - """Sampler management service for text generation inference. - - This class manages: - - vLLM or Torch sampler initialization and lifecycle - - Inference requests with LoRA adapter support - - Rate limiting via task queue - - Sampling parameter conversion between Tinker and Twinkle formats - """ - - def __init__(self, - nproc_per_node: int, - device_group: Dict[str, Any], - device_mesh: Dict[str, Any], - sampler_type: str = 'vllm', - engine_args: Optional[Dict[str, Any]] = None, - queue_config: Optional[Dict[str, Any]] = None, - **kwargs): - """Initialize the sampler management service. - - Args: - nproc_per_node: Number of processes per node - device_group: Device group configuration - device_mesh: Device mesh configuration for parallelism - sampler_type: Type of sampler ('vllm' or 'torch') - engine_args: Additional engine arguments for sampler - queue_config: Task queue configuration dict - **kwargs: Additional sampler initialization arguments - """ - self.device_group = DeviceGroup(**device_group) - twinkle.initialize( - mode='ray', nproc_per_node=nproc_per_node, groups=[self.device_group], lazy_collect=False) - if 'mesh_dim_names' in device_mesh: - self.device_mesh = DeviceMesh(**device_mesh) - else: - self.device_mesh = DeviceMesh.from_sizes(**device_mesh) - self.sampler_type = sampler_type - replica_context = serve.get_replica_context() - replica_id = replica_context.replica_id.unique_id - - # Initialize sampler based on type - if sampler_type == 'vllm': - from twinkle.sampler import vLLMSampler - sampler_kwargs = engine_args or {} - self.sampler = vLLMSampler( - model_id=model_id, - engine_args=sampler_kwargs, - device_mesh=self.device_mesh, - remote_group=self.device_group.name, - instance_id=replica_id, - **{ - k: v - for k, v in kwargs.items() if k not in ['engine_args'] - }) - else: # torch sampler - from twinkle.sampler import TorchSampler - self.sampler = TorchSampler(model_id=model_id, device_mesh=self.device_mesh, **kwargs) - self.sampler.set_template('Template', model_id=model_id) - self.state: ServerStateProxy = get_server_state() - self._init_task_queue(TaskQueueConfig.from_dict(queue_config)) - - @serve.multiplexed(max_num_models_per_replica=5) - async def _sticky_entry(self, sticky_key: str): - return sticky_key - - async def _ensure_sticky(self): - sticky_key = serve.get_multiplexed_model_id() - await self._sticky_entry(sticky_key) - - async def _on_request_start(self, request: Request) -> str: - await self._ensure_sticky() - token = get_token_from_request(request) - return token - - @app.post('/asample') - async def asample(self, request: Request, body: types.SampleRequest) -> types.UntypedAPIFuture: - """Execute text generation (inference). - - This endpoint: - 1. Extracts prompt token IDs from the request - 2. Determines adapter URI from model_path if provided - 3. Converts Tinker sampling params to Twinkle format - 4. Calls the sampler engine to generate text - 5. Converts results back to Tinker format - - Args: - request: FastAPI request with auth token - body: SampleRequest with prompt, sampling params, and adapter info - - Returns: - UntypedAPIFuture wrapping SampleResponse with generated sequences - """ - token = await self._on_request_start(request) - - async def _do_sample(): - try: - # Extract prompt token IDs from ModelInput - prompt_inputs = {'input_ids': body.prompt.to_ints()} - - # Get model_path: use body.model_path or look up from sampling session - model_path = body.model_path - if not model_path and body.sampling_session_id: - session = self.state.get_sampling_session(body.sampling_session_id) - if session: - model_path = session.get('model_path') - - # Parse and resolve adapter URI from model_path - adapter_uri = None - if model_path: - checkpoint_manager = create_checkpoint_manager(token) - adapter_name, adapter_uri = checkpoint_manager.parse_adapter_uri(model_path) - - # Validate adapter URI existence if provided - if not adapter_uri or not os.path.exists(adapter_uri): - return types.RequestFailedResponse( - error=f'Adapter URI {model_path} does not exist. Please check the model_path.', - category=types.RequestErrorCategory.User, - ) - - # Convert tinker SamplingParams to twinkle SamplingParams if needed - sampling_params = None - if body.sampling_params: - sampling_params = SamplingParams( - max_tokens=body.sampling_params.max_tokens or 256, - temperature=body.sampling_params.temperature or 1.0, - top_p=body.sampling_params.top_p, - top_k=body.sampling_params.top_k, - stop=body.sampling_params.stop, - ) - - # Only request logprobs when the client asks for them. Some backends may - # return None entries in logprobs, which breaks pydantic validation. - response = self.sampler.sample( - inputs=[prompt_inputs] * body.num_samples, # For speed up - sampling_params=sampling_params, - adapter_path=adapter_uri, - # adapter_name=adapter_name, - ) - - # Convert twinkle SampleResponse to tinker types.SampleResponse - tinker_sequences = [] - for seq in response.sequences: - logprobs = None - if seq.logprobs is not None: - if any(lp is None for lp in seq.logprobs): - # Fix: backend can emit None logprobs for some tokens, which triggers - # pydantic "Input should be a valid number" errors in SampleResponse. - # We drop the field to keep the response valid. - logprobs = None - else: - logprobs = list(seq.logprobs) - tinker_sequences.append( - types.SampledSequence( - stop_reason=seq.stop_reason, - tokens=list(seq.tokens), - logprobs=logprobs, - )) - return types.SampleResponse( - sequences=tinker_sequences, - prompt_logprobs=response.prompt_logprobs, - topk_prompt_logprobs=response.topk_prompt_logprobs, - ) - except Exception: - logger.error(traceback.format_exc()) - return types.RequestFailedResponse( - error=traceback.format_exc(), - category=types.RequestErrorCategory.Server, - ) - - # Calculate input tokens for rate limiting - input_tokens = len(body.prompt.to_ints()) - return await self.schedule_task( - _do_sample, - token=token, - input_tokens=input_tokens, - task_type='sample', - ) - - return SamplerManagement.options(**deploy_options).bind(nproc_per_node, device_group, device_mesh, sampler_type, - engine_args, queue_config, **kwargs) - - -build_sampler_app = wrap_builder_with_device_group_env(build_sampler_app) diff --git a/src/twinkle/server/tinker/server.py b/src/twinkle/server/tinker/server.py deleted file mode 100644 index 81543c58..00000000 --- a/src/twinkle/server/tinker/server.py +++ /dev/null @@ -1,613 +0,0 @@ -# Copyright (c) ModelScope Contributors. All rights reserved. -""" -Tinker-compatible server implementation. - -This module provides a Ray Serve-based server that implements the Tinker API for distributed -training and inference. It acts as a routing layer that: -1. Handles client requests and validates tokens -2. Manages training runs and checkpoints with user isolation -3. Proxies requests to appropriate model or sampler deployments based on base_model -""" - -from __future__ import annotations - -import asyncio -import os -from fastapi import FastAPI, HTTPException, Request, Response -from ray import serve -from tinker import types -from typing import Any, Dict, List, Optional - -from twinkle.hub import HubOperation -from twinkle.server.utils.state import get_server_state -from twinkle.server.utils.task_queue import QueueState -from twinkle.server.utils.validation import get_token_from_request, verify_request_token -from twinkle.utils.logger import get_logger -from .common.io_utils import create_checkpoint_manager, create_training_run_manager -from .proxy import ServiceProxy - -logger = get_logger() - - -def build_server_app(deploy_options: dict[str, Any], - supported_models: list[types.SupportedModel] | None = None, - server_config: dict[str, Any] = {}, - http_options: dict[str, Any] | None = None, - **kwargs): - """Build and configure the Tinker-compatible server application. - - This factory function creates a FastAPI application with Ray Serve deployment - that handles routing, authentication, and proxying for training and inference. - - Args: - deploy_options: Ray Serve deployment configuration (num_replicas, etc.) - supported_models: List of supported base models for validation - server_config: Server configuration options (per_token_adapter_limit, etc.) - **kwargs: Additional keyword arguments (route_prefix, etc.) - - Returns: - Configured Ray Serve deployment bound with options - """ - app = FastAPI() - - @app.middleware('http') - async def verify_token(request: Request, call_next): - """Middleware to verify authentication token for all requests.""" - return await verify_request_token(request=request, call_next=call_next) - - @serve.deployment(name='TinkerCompatServer') - @serve.ingress(app) - class TinkerCompatServer: - """Main server class handling Tinker API endpoints and request routing. - - This class manages: - - Server state and session management - - Request validation and authentication - - Proxying to model/sampler deployments - - Training run and checkpoint CRUD operations - """ - - def __init__(self, - supported_models: list[types.SupportedModel] | None = None, - server_config: dict[str, Any] = {}, - http_options: dict[str, Any] | None = None, - **kwargs) -> None: - """Initialize the Tinker-compatible server. - - Args: - supported_models: List of supported base models for validation - server_config: Server configuration options - http_options: HTTP server options (host, port) for internal proxy routing - **kwargs: Additional configuration (route_prefix, etc.) - """ - self.state = get_server_state(**server_config) - self.route_prefix = kwargs.get('route_prefix', '/api/v1') - self.http_options = http_options or {} - - # Initialize service proxy for routing requests to model/sampler services - self.proxy = ServiceProxy(http_options=http_options, route_prefix=self.route_prefix) - - self.supported_models = self.normalize_models(supported_models) or [ - types.SupportedModel(model_name='Qwen/Qwen3-30B-A3B-Instruct-2507'), - ] - # Lock for ModelScope config file operations (login writes, get_user_info reads) - self._modelscope_config_lock = asyncio.Lock() - - def normalize_models(self, supported_models): - # Normalize supported_models to objects; passing raw dicts can trigger internal errors - # when creating LoRA training clients via the tinker API. - if not supported_models: - return [] - normalized = [] - for item in supported_models: - if isinstance(item, types.SupportedModel): - normalized.append(item) - elif isinstance(item, dict): - normalized.append(types.SupportedModel(**item)) - elif isinstance(item, str): - normalized.append(types.SupportedModel(model_name=item)) - return normalized - - def _validate_base_model(self, base_model: str) -> None: - """Validate that base_model is in supported_models list. - - Args: - base_model: The base model name to validate - - Raises: - HTTPException: If base_model is not supported - """ - supported_model_names = [m.model_name for m in self.supported_models] - if base_model not in supported_model_names: - raise HTTPException( - status_code=400, - detail=f"Base model '{base_model}' is not supported. " - f"Supported models: {', '.join(supported_model_names)}") - - def _get_base_model(self, model_id: str) -> str: - """Get base_model for a model_id from state metadata. - - Args: - model_id: The model identifier to lookup - - Returns: - The base model name - - Raises: - HTTPException: If model_id not found in state - """ - metadata = self.state.get_model_metadata(model_id) - if metadata and metadata.get('base_model'): - return metadata['base_model'] - raise HTTPException(status_code=404, detail=f'Model {model_id} not found') - - # --- Endpoints --------------------------------------------------------- - - @app.get('/healthz') - async def healthz(self, request: Request) -> types.HealthResponse: - """Health check endpoint. - - Returns: - HealthResponse indicating server is operational - """ - return types.HealthResponse(status='ok') - - @app.get('/get_server_capabilities') - async def get_server_capabilities(self, request: Request) -> types.GetServerCapabilitiesResponse: - """Get server capabilities including supported models. - - Returns: - GetServerCapabilitiesResponse with list of supported models - """ - return types.GetServerCapabilitiesResponse(supported_models=self.supported_models) - - @app.post('/telemetry') - async def telemetry(self, request: Request, body: types.TelemetrySendRequest) -> types.TelemetryResponse: - """Accept telemetry data from clients. - - Note: Telemetry is accepted but not persisted; this endpoint is intentionally lightweight. - - Returns: - TelemetryResponse indicating data was accepted - """ - return types.TelemetryResponse(status='accepted') - - @app.post('/create_session') - async def create_session(self, request: Request, - body: types.CreateSessionRequest) -> types.CreateSessionResponse: - """Create a new training session. - - Args: - body: Session creation parameters - - Returns: - CreateSessionResponse with new session_id - """ - session_id = self.state.create_session(body.model_dump()) - return types.CreateSessionResponse(session_id=session_id) - - @app.post('/session_heartbeat') - async def session_heartbeat(self, request: Request, - body: types.SessionHeartbeatRequest) -> types.SessionHeartbeatResponse: - """Keep a session alive via heartbeat. - - Args: - body: Heartbeat request with session_id - - Returns: - SessionHeartbeatResponse if session is alive - - Raises: - HTTPException: If session not found - """ - alive = self.state.touch_session(body.session_id) - if not alive: - raise HTTPException(status_code=404, detail='Unknown session') - return types.SessionHeartbeatResponse() - - @app.post('/create_sampling_session') - async def create_sampling_session( - self, request: Request, - body: types.CreateSamplingSessionRequest) -> types.CreateSamplingSessionResponse: - """Create a new sampling (inference) session. - - Args: - body: Sampling session creation parameters - - Returns: - CreateSamplingSessionResponse with new sampling_session_id - """ - sampling_session_id = self.state.create_sampling_session(body.model_dump()) - return types.CreateSamplingSessionResponse(sampling_session_id=sampling_session_id) - - @app.post('/retrieve_future') - async def retrieve_future(self, request: Request, body: types.FutureRetrieveRequest) -> Any: - """Retrieve the result of an async task with long polling. - - Server waits up to 30s for task completion instead of immediately returning try_again. - This reduces client polling frequency from ~100 req/s to ~1 req/30s. - """ - request_id = body.request_id - max_wait = float(os.environ.get('TWINKLE_LONG_POLL_TIMEOUT', '30')) - poll_interval = float(os.environ.get('TWINKLE_POLL_INTERVAL', '0.5')) - start = asyncio.get_event_loop().time() - - # Long poll: wait for task completion or timeout - while True: - record = self.state.get_future(request_id) - - if record is None: - return {'type': 'try_again'} - - status = record.get('status') - - # Task finished, return immediately - if status not in ('pending', 'queued', 'running', 'rate_limited'): - break - - # Timeout, let client retry - if asyncio.get_event_loop().time() - start >= max_wait: - response_data = {'type': 'try_again'} - if queue_state := record.get('queue_state'): - response_data['queue_state'] = queue_state - if queue_state_reason := record.get('queue_state_reason'): - response_data['queue_state_reason'] = queue_state_reason - return response_data - - await asyncio.sleep(poll_interval) - - # Handle final result - record = self.state.get_future(request_id) - if not record: - return {'type': 'try_again'} - - status = record.get('status') - - if status == 'rate_limited': - return { - 'type': 'try_again', - 'queue_state': QueueState.PAUSED_RATE_LIMIT.value, - 'queue_state_reason': record.get('reason', 'Rate limit exceeded') - } - - if status == 'failed': - result = record.get('result', {}) - return {'error': result.get('error', 'Unknown error'), 'category': result.get('category', 'Server')} - - result = record.get('result') - if result is None: - raise HTTPException(status_code=500, detail='Task completed but no result found') - - if hasattr(result, 'model_dump'): - return result.model_dump() - return result - - # --- Restful Endpoints ------------------------------------------ - - @app.get('/training_runs') - async def get_training_runs(self, - request: Request, - limit: int = 20, - offset: int = 0) -> types.TrainingRunsResponse: - """ - List training runs for the current user. - - Uses token-based isolation to only show runs owned by the requesting user. - - Args: - request: FastAPI request with token in state - limit: Maximum number of results - offset: Pagination offset - - Returns: - TrainingRunsResponse with user's training runs - """ - token = get_token_from_request(request) - training_run_manager = create_training_run_manager(token) - return training_run_manager.list_runs(limit=limit, offset=offset) - - @app.get('/training_runs/{run_id}') - async def get_training_run(self, request: Request, run_id: str) -> types.TrainingRun: - """ - Get a specific training run. - - Uses token-based isolation to verify user owns the run. - - Args: - request: FastAPI request with token in state - run_id: The training run identifier - - Returns: - TrainingRun details - - Raises: - HTTPException 404 if run not found in user's token directory - """ - token = get_token_from_request(request) - training_run_manager = create_training_run_manager(token) - run = training_run_manager.get(run_id) - if not run: - raise HTTPException(status_code=404, detail=f'Training run {run_id} not found') - return run - - @app.get('/training_runs/{run_id}/checkpoints') - async def get_run_checkpoints(self, request: Request, run_id: str) -> types.CheckpointsListResponse: - """ - List checkpoints for a training run. - - Uses token-based isolation to verify user owns the run. - - Args: - request: FastAPI request with token in state - run_id: The training run identifier - - Returns: - CheckpointsListResponse with list of checkpoints - - Raises: - HTTPException 404 if run not found in user's token directory - """ - token = get_token_from_request(request) - checkpoint_manager = create_checkpoint_manager(token) - response = checkpoint_manager.list_checkpoints(run_id) - if not response: - raise HTTPException(status_code=404, detail=f'Training run {run_id} not found') - return response - - @app.delete('/training_runs/{run_id}/checkpoints/{checkpoint_id:path}') - async def delete_run_checkpoint(self, request: Request, run_id: str, checkpoint_id: str) -> Any: - """ - Delete a checkpoint from a training run. - - Uses token-based isolation to verify user owns the checkpoint. - - Args: - request: FastAPI request with token in state - run_id: The training run identifier - checkpoint_id: The checkpoint identifier (path) - - Returns: - None (200 OK) if successful - - Raises: - HTTPException 404 if checkpoint not found in user's token directory - """ - token = get_token_from_request(request) - checkpoint_manager = create_checkpoint_manager(token) - success = checkpoint_manager.delete(run_id, checkpoint_id) - if not success: - raise HTTPException(status_code=404, detail=f'Checkpoint {checkpoint_id} not found for run {run_id}') - return None - - @app.post('/weights_info') - async def weights_info(self, request: Request, body: dict[str, Any]) -> types.WeightsInfoResponse: - """ - Get weights information from a tinker path. - - Uses token-based isolation to verify user owns the weights. - - Args: - request: FastAPI request with token in state - body: Dict with 'tinker_path' key - - Returns: - WeightsInfoResponse with weight details - - Raises: - HTTPException 404 if weights not found in user's token directory - """ - token = get_token_from_request(request) - checkpoint_manager = create_checkpoint_manager(token) - tinker_path = body.get('tinker_path') - response = checkpoint_manager.get_weights_info(tinker_path) - if not response: - raise HTTPException(status_code=404, detail=f'Weights at {tinker_path} not found') - return response - - @app.post('/training_runs/{run_id}/checkpoints/{checkpoint_id:path}/publish') - async def publish_checkpoint(self, request: Request, run_id: str, checkpoint_id: str) -> Response: - """ - Publish a checkpoint to the hub. - - This endpoint uploads a checkpoint to a hub repository. The hub_model_id - is automatically generated from the checkpoint content and user token. - The upload is performed asynchronously by default. - - Args: - request: FastAPI request object (contains token in state) - run_id: The training run identifier - checkpoint_id: The checkpoint identifier (can include path like weights/checkpoint_name) - - Returns: - Response with 204 No Content status - - Raises: - HTTPException 404 if checkpoint not found or access denied - """ - token = get_token_from_request(request) - - training_run_manager = create_training_run_manager(token) - checkpoint_manager = create_checkpoint_manager(token) - - # Check ownership and get training run info - run = training_run_manager.get(run_id) - if not run: - raise HTTPException(status_code=404, detail=f'Training run {run_id} not found or access denied') - - # Get checkpoint with token-based path - checkpoint = checkpoint_manager.get(run_id, checkpoint_id) - if not checkpoint: - raise HTTPException(status_code=404, detail=f'Checkpoint {checkpoint_id} not found') - - # Get the filesystem path for the checkpoint - checkpoint_dir = str(checkpoint_manager.get_ckpt_dir(run_id, checkpoint_id)) - - # Generate hub_model_id from checkpoint content and user token - # Format: {username}/{run_id}_{checkpoint_name} - # Use lock to prevent race conditions when multiple requests access ModelScope config file - async with self._modelscope_config_lock: - try: - from modelscope.hub.api import HubApi, ModelScopeConfig - hub_api = HubApi(token=token) - hub_api.login() # Save user info to local - username = ModelScopeConfig.get_user_info()[0] - except Exception as e: - logger.error(f'Failed to get username from ModelScope: {e}') - raise HTTPException( - status_code=401, - detail='Failed to get username from ModelScope. Please ensure your token is valid.') - - # Extract checkpoint name from checkpoint_id (e.g., "weights/step-8" -> "step-8") - checkpoint_name = checkpoint_id.split('/')[-1] - hub_model_id = f'{username}/{run_id}_{checkpoint_name}' - - # Upload to hub asynchronously with default async_upload=True - HubOperation.async_push_to_hub(repo_id=hub_model_id, folder_path=checkpoint_dir, token=token, private=True) - - # Return 204 No Content (successful with no response body) - return Response(status_code=204) - - # --- Proxy Endpoints --------------------------------------------------------- - - # --- Model Proxy Endpoints ---------------------------------------- - - @app.post('/create_model') - async def create_model(self, request: Request, body: types.CreateModelRequest) -> Any: - """Create a new model (adapter) for training. - - Args: - body: Model creation request with base_model and config - - Returns: - Proxied response from model service - """ - self._validate_base_model(body.base_model) - return await self.proxy.proxy_to_model(request, 'create_model', body.base_model) - - @app.post('/get_info') - async def get_info(self, request: Request, body: types.GetInfoRequest) -> Any: - """Get information about a model. - - Args: - body: Info request with model_id - - Returns: - Proxied response from model service - """ - return await self.proxy.proxy_to_model(request, 'get_info', self._get_base_model(body.model_id)) - - @app.post('/unload_model') - async def unload_model(self, request: Request, body: types.UnloadModelRequest) -> Any: - """Unload a model adapter from memory. - - Args: - body: Unload request with model_id - - Returns: - Proxied response from model service - """ - return await self.proxy.proxy_to_model(request, 'unload_model', self._get_base_model(body.model_id)) - - @app.post('/forward') - async def forward(self, request: Request, body: types.ForwardRequest) -> Any: - """Execute forward pass without backward. - - Args: - body: Forward request with inputs - - Returns: - Proxied response from model service - """ - return await self.proxy.proxy_to_model(request, 'forward', self._get_base_model(body.model_id)) - - @app.post('/forward_backward') - async def forward_backward(self, request: Request, body: types.ForwardBackwardRequest) -> Any: - """Execute forward and backward pass for training. - - Args: - body: Forward-backward request with inputs - - Returns: - Proxied response from model service - """ - return await self.proxy.proxy_to_model(request, 'forward_backward', self._get_base_model(body.model_id)) - - @app.post('/optim_step') - async def optim_step(self, request: Request, body: types.OptimStepRequest) -> Any: - """Execute optimizer step to update model weights. - - Args: - body: Optimizer step request with parameters - - Returns: - Proxied response from model service - """ - return await self.proxy.proxy_to_model(request, 'optim_step', self._get_base_model(body.model_id)) - - @app.post('/save_weights') - async def save_weights(self, request: Request, body: types.SaveWeightsRequest) -> Any: - """Save model weights to storage. - - Args: - body: Save weights request with path - - Returns: - Proxied response from model service - """ - return await self.proxy.proxy_to_model(request, 'save_weights', self._get_base_model(body.model_id)) - - @app.post('/load_weights') - async def load_weights(self, request: Request, body: types.LoadWeightsRequest) -> Any: - """Load model weights from storage. - - Args: - body: Load weights request with path - - Returns: - Proxied response from model service - """ - return await self.proxy.proxy_to_model(request, 'load_weights', self._get_base_model(body.model_id)) - - # --- Sampler Proxy Endpoints ---------------------------------------- - - @app.post('/asample') - async def asample(self, request: Request, body: types.SampleRequest) -> Any: - """Execute text generation (inference). - - Proxies the request to the sampler service based on base_model. - The sampler handles model_path resolution from sampling session. - - Args: - body: Sample request with prompt and sampling parameters - - Returns: - Proxied response from sampler service - """ - base_model = body.base_model - - # If base_model not provided, look up from sampling session - if not base_model and body.sampling_session_id: - session = self.state.get_sampling_session(body.sampling_session_id) - if session: - base_model = session.get('base_model') - - return await self.proxy.proxy_to_sampler(request, 'asample', base_model) - - @app.post('/save_weights_for_sampler') - async def save_weights_for_sampler(self, request: Request, body: types.SaveWeightsForSamplerRequest) -> Any: - """Save/convert weights for inference use. - - This endpoint proxies to the model service to save weights for sampler. - - Args: - body: Save weights request with model_id - - Returns: - Proxied response from model service - """ - # Proxy to model service for save_weights_for_sampler - base_model = self._get_base_model(body.model_id) - return await self.proxy.proxy_to_model(request, 'save_weights_for_sampler', base_model) - - return TinkerCompatServer.options(**deploy_options).bind( - supported_models=supported_models, server_config=server_config, http_options=http_options, **kwargs) diff --git a/src/twinkle/server/twinkle/__init__.py b/src/twinkle/server/twinkle/__init__.py deleted file mode 100644 index 7371b1d7..00000000 --- a/src/twinkle/server/twinkle/__init__.py +++ /dev/null @@ -1,20 +0,0 @@ -# Copyright (c) ModelScope Contributors. All rights reserved. -import sys -from typing import TYPE_CHECKING - -from twinkle.utils.import_utils import _LazyModule - -_import_structure = { - 'model': ['build_model_app'], - 'processor': ['build_processor_app'], - 'sampler': ['build_sampler_app'], - 'server': ['build_server_app'], -} - -if TYPE_CHECKING: - from .model import build_model_app - from .processor import build_processor_app - from .sampler import build_sampler_app - from .server import build_server_app -else: - sys.modules[__name__] = _LazyModule(__name__, __file__, _import_structure, module_spec=__spec__) diff --git a/src/twinkle/server/twinkle/common/io_utils.py b/src/twinkle/server/twinkle/common/io_utils.py deleted file mode 100644 index 4693c381..00000000 --- a/src/twinkle/server/twinkle/common/io_utils.py +++ /dev/null @@ -1,235 +0,0 @@ -# Copyright (c) ModelScope Contributors. All rights reserved. -""" -Twinkle-specific IO utilities for managing training runs and checkpoints. - -This module extends the base IO utilities with Twinkle-specific implementations. -""" -from datetime import datetime -from pydantic import BaseModel -from typing import Any, Dict, List, Optional - -from twinkle.server.utils.io_utils import (CHECKPOINT_INFO_FILENAME, TRAIN_RUN_INFO_FILENAME, TWINKLE_DEFAULT_SAVE_DIR, - BaseCheckpoint, BaseCheckpointManager, BaseCreateModelRequest, - BaseLoraConfig, BaseParsedCheckpointPath, BaseTrainingRun, - BaseTrainingRunManager, BaseWeightsInfoResponse, Cursor, ResolvedLoadPath, - validate_ownership, validate_user_path) - -# ----- Twinkle-specific Pydantic Models ----- - - -class Checkpoint(BaseCheckpoint): - """Twinkle checkpoint model.""" - twinkle_path: str - - -class TrainingRun(BaseTrainingRun): - """Twinkle training run model.""" - pass - - -class TrainingRunsResponse(BaseModel): - training_runs: List[TrainingRun] - cursor: Cursor - - -class CheckpointsListResponse(BaseModel): - checkpoints: List[Checkpoint] - cursor: Optional[Cursor] = None - - -class ParsedCheckpointTwinklePath(BaseParsedCheckpointPath): - """Twinkle-specific parsed path model.""" - twinkle_path: str - - -class WeightsInfoResponse(BaseWeightsInfoResponse): - """Twinkle weights info response.""" - pass - - -class LoraConfig(BaseLoraConfig): - """Twinkle LoRA configuration.""" - pass - - -class CreateModelRequest(BaseCreateModelRequest): - """Twinkle create model request.""" - lora_config: Optional[LoraConfig] = None - - -# ----- Twinkle Training Run Manager ----- - - -class TrainingRunManager(BaseTrainingRunManager): - """Twinkle-specific training run manager.""" - - @property - def train_run_info_filename(self) -> str: - return TRAIN_RUN_INFO_FILENAME - - def _create_training_run(self, model_id: str, run_config: CreateModelRequest) -> Dict[str, Any]: - """Create training run data from model_id and run_config.""" - lora_config = run_config.lora_config - train_run_data = TrainingRun( - training_run_id=model_id, - base_model=run_config.base_model, - model_owner=self.token, - is_lora=True if lora_config else False, - corrupted=False, - lora_rank=lora_config.rank if lora_config else None, - last_request_time=datetime.now(), - last_checkpoint=None, - last_sampler_checkpoint=None, - user_metadata=run_config.user_metadata) - - new_data = train_run_data.model_dump(mode='json') - # Store lora config details separately if needed - if lora_config: - new_data['train_unembed'] = lora_config.train_unembed - new_data['train_mlp'] = lora_config.train_mlp - new_data['train_attn'] = lora_config.train_attn - - return new_data - - def _parse_training_run(self, data: Dict[str, Any]) -> TrainingRun: - """Parse training run data into TrainingRun model.""" - return TrainingRun(**data) - - def _create_training_runs_response(self, runs: List[TrainingRun], limit: int, offset: int, - total: int) -> TrainingRunsResponse: - """Create a training runs response.""" - return TrainingRunsResponse(training_runs=runs, cursor=Cursor(limit=limit, offset=offset, total_count=total)) - - def get_with_permission(self, model_id: str) -> Optional[TrainingRun]: - """ - Get training run with ownership validation. - - Args: - model_id: The model identifier - - Returns: - TrainingRun if found and owned by user, None otherwise - """ - run = self.get(model_id) - if run and validate_ownership(self.token, run.model_owner): - return run - return None - - -# ----- Twinkle Checkpoint Manager ----- - - -class CheckpointManager(BaseCheckpointManager): - """Twinkle-specific checkpoint manager.""" - - @property - def path_prefix(self) -> str: - return 'twinkle://' - - @property - def path_field_name(self) -> str: - return 'twinkle_path' - - def _create_checkpoint(self, - checkpoint_id: str, - checkpoint_type: str, - path: str, - size_bytes: int, - public: bool, - base_model: Optional[str] = None, - is_lora: bool = False, - lora_rank: Optional[int] = None, - train_unembed: Optional[bool] = None, - train_mlp: Optional[bool] = None, - train_attn: Optional[bool] = None, - user_metadata: Optional[Dict[str, Any]] = None) -> Dict[str, Any]: - """Create checkpoint data.""" - checkpoint = Checkpoint( - checkpoint_id=checkpoint_id, - checkpoint_type=checkpoint_type, - time=datetime.now(), - twinkle_path=path, - size_bytes=size_bytes, - public=public, - base_model=base_model, - is_lora=is_lora, - lora_rank=lora_rank, - train_unembed=train_unembed, - train_mlp=train_mlp, - train_attn=train_attn, - user_metadata=user_metadata) - return checkpoint.model_dump(mode='json') - - def _parse_checkpoint(self, data: Dict[str, Any]) -> Checkpoint: - """Parse checkpoint data into Checkpoint model.""" - data = data.copy() - # Transform tinker_path to twinkle_path if needed - if 'tinker_path' in data and 'twinkle_path' not in data: - data['twinkle_path'] = data.pop('tinker_path') - elif 'twinkle_path' not in data and 'path' in data: - data['twinkle_path'] = data.pop('path') - return Checkpoint(**data) - - def get(self, model_id: str, checkpoint_id: str) -> Optional[Checkpoint]: - """ - Get checkpoint metadata with backwards compatibility. - - Args: - model_id: The model identifier - checkpoint_id: The checkpoint identifier - - Returns: - Checkpoint object or None if not found - """ - data = self._read_ckpt_info(model_id, checkpoint_id) - if not data: - return None - # Handle backwards compatibility: construct twinkle_path if missing - if 'twinkle_path' not in data and 'tinker_path' not in data and 'path' not in data: - if 'checkpoint_id' in data: - data = data.copy() - data['twinkle_path'] = f"{self.path_prefix}{model_id}/{data['checkpoint_id']}" - return self._parse_checkpoint(data) - - def _create_checkpoints_response(self, checkpoints: List[Checkpoint]) -> CheckpointsListResponse: - """Create a checkpoints list response.""" - return CheckpointsListResponse(checkpoints=checkpoints, cursor=None) - - def _create_parsed_path(self, path: str, training_run_id: str, checkpoint_type: str, - checkpoint_id: str) -> ParsedCheckpointTwinklePath: - """Create a parsed path model.""" - return ParsedCheckpointTwinklePath( - path=path, - twinkle_path=path, - training_run_id=training_run_id, - checkpoint_type=checkpoint_type, - checkpoint_id=checkpoint_id, - ) - - def _create_weights_info(self, run_info: Dict[str, Any]) -> WeightsInfoResponse: - """Create weights info from run info.""" - return WeightsInfoResponse( - training_run_id=run_info.get('training_run_id', ''), - base_model=run_info.get('base_model', ''), - model_owner=run_info.get('model_owner', ''), - is_lora=run_info.get('is_lora', False), - lora_rank=run_info.get('lora_rank'), - ) - - def parse_twinkle_path(self, twinkle_path: str) -> Optional[ParsedCheckpointTwinklePath]: - """Parse a twinkle:// path into its components (alias for parse_path).""" - return self.parse_path(twinkle_path) - - -# ----- Factory Functions ----- - - -def create_training_run_manager(token: str) -> TrainingRunManager: - """Create a TrainingRunManager for the given token.""" - return TrainingRunManager(token) - - -def create_checkpoint_manager(token: str) -> CheckpointManager: - """Create a CheckpointManager for the given token.""" - training_run_manager = TrainingRunManager(token) - return CheckpointManager(token, training_run_manager) diff --git a/src/twinkle/server/twinkle/common/transformers_model.py b/src/twinkle/server/twinkle/common/transformers_model.py deleted file mode 100644 index c67a0a28..00000000 --- a/src/twinkle/server/twinkle/common/transformers_model.py +++ /dev/null @@ -1,41 +0,0 @@ -# Copyright (c) ModelScope Contributors. All rights reserved. -import numpy as np -import torch -from collections.abc import Mapping -from typing import Any, List, Union - -from twinkle import remote_class, remote_function -from twinkle.data_format import InputFeature, Trajectory -from twinkle.model import MultiLoraTransformersModel - - -@remote_class() -class TwinkleCompatTransformersModel(MultiLoraTransformersModel): - - @staticmethod - def _to_cpu_safe_output(obj: Any) -> Any: - """Convert nested outputs into CPU-safe Python objects for HTTP transport.""" - from twinkle.utils import torch_util - - if isinstance(obj, torch.Tensor): - tensor = torch_util.to_local_tensor(obj).detach().cpu() - if tensor.numel() == 1: - return tensor.item() - return tensor.tolist() - if isinstance(obj, np.ndarray): - if obj.size == 1: - return obj.item() - return obj.tolist() - if isinstance(obj, np.generic): - return obj.item() - if isinstance(obj, Mapping): - return {key: TwinkleCompatTransformersModel._to_cpu_safe_output(value) for key, value in obj.items()} - if isinstance(obj, (list, tuple)): - return [TwinkleCompatTransformersModel._to_cpu_safe_output(value) for value in obj] - return obj - - @remote_function(dispatch='slice_dp', collect='mean') - def forward_backward(self, *, inputs: Union[InputFeature, List[InputFeature], Trajectory, List[Trajectory]], - **kwargs): - output = super().forward_backward(inputs=inputs, **kwargs) - return self._to_cpu_safe_output(output) diff --git a/src/twinkle/server/twinkle/model.py b/src/twinkle/server/twinkle/model.py deleted file mode 100644 index 4bf4bf4b..00000000 --- a/src/twinkle/server/twinkle/model.py +++ /dev/null @@ -1,584 +0,0 @@ -# Copyright (c) ModelScope Contributors. All rights reserved. -import os -from fastapi import FastAPI, Request -from peft import LoraConfig -from pydantic import BaseModel -from ray import serve -from typing import Any, Dict, Optional - -import twinkle -from twinkle import DeviceGroup, DeviceMesh -from twinkle.data_format import InputFeature, Trajectory -from twinkle.server.utils.adapter_manager import AdapterManagerMixin -from twinkle.server.utils.state import ServerStateProxy, get_server_state -from twinkle.server.utils.validation import verify_request_token -from twinkle.utils.logger import get_logger -from .common.io_utils import CreateModelRequest -from .common.io_utils import LoraConfig as IoLoraConfig -from .common.io_utils import create_checkpoint_manager, create_training_run_manager -from .common.serialize import deserialize_object - -logger = get_logger() - - -class CreateRequest(BaseModel): - - class Config: - extra = 'allow' - - -class ForwardRequest(BaseModel): - inputs: Any - adapter_name: str - - class Config: - extra = 'allow' - - -class ForwardOnlyRequest(BaseModel): - inputs: Any - adapter_name: Optional[str] = None - - class Config: - extra = 'allow' - - -class AdapterRequest(BaseModel): - adapter_name: str - - class Config: - extra = 'allow' - - -class SetLossRequest(BaseModel): - loss_cls: str - adapter_name: str - - class Config: - extra = 'allow' - - -class SetOptimizerRequest(BaseModel): - optimizer_cls: str - adapter_name: str - - class Config: - extra = 'allow' - - -class SetLrSchedulerRequest(BaseModel): - scheduler_cls: str - adapter_name: str - - class Config: - extra = 'allow' - - -class SaveRequest(BaseModel): - adapter_name: str - save_optimizer: bool = False - name: Optional[str] = None - - class Config: - extra = 'allow' - - -class UploadToHubRequest(BaseModel): - checkpoint_dir: str - hub_model_id: str - hub_token: Optional[str] = None - async_upload: bool = True - - class Config: - extra = 'allow' - - -class LoadRequest(BaseModel): - adapter_name: str - load_optimizer: bool = False - name: str - - class Config: - extra = 'allow' - - -class AddAdapterRequest(BaseModel): - adapter_name: str - config: str - - class Config: - extra = 'allow' - - -class SetTemplateRequest(BaseModel): - template_cls: str - adapter_name: str - - class Config: - extra = 'allow' - - -class SetProcessorRequest(BaseModel): - processor_cls: str - adapter_name: str - - class Config: - extra = 'allow' - - -class HeartbeatRequest(BaseModel): - adapter_name: str - - -class CalculateMetricRequest(BaseModel): - adapter_name: str - is_training: bool = True - - class Config: - extra = 'allow' - - -class GetStateDictRequest(BaseModel): - adapter_name: str - - class Config: - extra = 'allow' - - -def build_model_app(model_id: str, - nproc_per_node: int, - device_group: Dict[str, Any], - device_mesh: Dict[str, Any], - deploy_options: Dict[str, Any], - use_megatron: bool = False, - adapter_config: Dict[str, Any] = {}, - **kwargs): - app = FastAPI() - - @app.middleware('http') - async def verify_token(request: Request, call_next): - return await verify_request_token(request=request, call_next=call_next) - - @serve.deployment(name='ModelManagement') - @serve.ingress(app) - class ModelManagement(AdapterManagerMixin): - - def __init__(self, nproc_per_node: int, device_group: Dict[str, Any], device_mesh: Dict[str, Any]): - self.device_group = DeviceGroup(**device_group) - twinkle.initialize( - mode='ray', nproc_per_node=nproc_per_node, groups=[self.device_group], lazy_collect=False) - if 'mesh_dim_names' in device_mesh: - self.device_mesh = DeviceMesh(**device_mesh) - else: - self.device_mesh = DeviceMesh.from_sizes(**device_mesh) - replica_context = serve.get_replica_context() - replica_id = replica_context.replica_id.unique_id - if use_megatron: - from twinkle.model import MultiLoraMegatronModel - self.model = MultiLoraMegatronModel( - model_id=model_id, - device_mesh=self.device_mesh, - remote_group=self.device_group.name, - instance_id=replica_id, - **kwargs) - else: - from .common.transformers_model import TwinkleCompatTransformersModel - self.model = TwinkleCompatTransformersModel( - model_id=model_id, - device_mesh=self.device_mesh, - remote_group=self.device_group.name, - instance_id=replica_id, - **kwargs) - - # Initialize state before adapter manager (mixin needs self.state) - self.state: ServerStateProxy = get_server_state() - - # Initialize adapter manager from mixin - self._init_adapter_manager(**adapter_config) - self.start_adapter_countdown() - - def _on_adapter_expired(self, adapter_name: str) -> None: - """Handle adapter expiration by removing it from the model. - - This method is called automatically by AdapterManagerMixin when - an adapter exceeds its timeout or TTL. - - Args: - adapter_name: Name of the expired adapter to remove. - """ - # Remove from model if it exists - if self.get_adapter_info(adapter_name): - # Clear adapter state - self.clear_adapter_state(adapter_name) - # Unregister from adapter manager - self.unregister_adapter(adapter_name) - - # Remove from server state - self.state.unload_model(adapter_name) - # Remove adapter from model - self.model.remove_adapter(adapter_name) - - @app.post('/create') - def create(self, request: Request, body: CreateRequest): - return {'status': 'ok'} - - @staticmethod - def get_adapter_name(request: Request, adapter_name: Optional[str]) -> Optional[str]: - if adapter_name is None or adapter_name == '': - return None - return request.state.request_id + '-' + adapter_name - - @app.post('/forward') - def forward(self, request: Request, body: ForwardRequest): - adapter_name = self.get_adapter_name(request, adapter_name=body.adapter_name) - self.assert_adapter_exists(adapter_name=adapter_name) - extra_kwargs = body.model_extra or {} - inputs = body.inputs - if isinstance(inputs, list): - _input = inputs[0] - if 'input_ids' in _input: - inputs = [InputFeature(**_input) for _input in inputs] - else: - inputs = [Trajectory(**_input) for _input in inputs] - else: - assert isinstance(inputs, dict) - inputs = InputFeature(**inputs) if 'input_ids' in inputs else Trajectory(**inputs) - ret = self.model.forward(inputs=inputs, adapter_name=adapter_name, **extra_kwargs) - return {'result': ret} - - @app.post('/forward_only') - def forward_only(self, request: Request, body: ForwardOnlyRequest): - adapter_name = self.get_adapter_name(request, adapter_name=body.adapter_name) - self.assert_adapter_exists(adapter_name=adapter_name) - extra_kwargs = body.model_extra or {} - inputs = body.inputs - if isinstance(inputs, list): - _input = inputs[0] - if 'input_ids' in _input: - inputs = [InputFeature(**_input) for _input in inputs] - else: - inputs = [Trajectory(**_input) for _input in inputs] - else: - assert isinstance(inputs, dict) - inputs = InputFeature(**inputs) if 'input_ids' in inputs else Trajectory(**inputs) - ret = self.model.forward_only(inputs=inputs, adapter_name=adapter_name, **extra_kwargs) - return {'result': ret} - - @app.post('/calculate_loss') - def calculate_loss(self, request: Request, body: AdapterRequest): - adapter_name = self.get_adapter_name(request, adapter_name=body.adapter_name) - self.assert_adapter_exists(adapter_name=adapter_name) - extra_kwargs = body.model_extra or {} - ret = self.model.calculate_loss(adapter_name=adapter_name, **extra_kwargs) - return {'result': ret} - - @app.post('/backward') - def backward(self, request: Request, body: AdapterRequest): - adapter_name = self.get_adapter_name(request, adapter_name=body.adapter_name) - self.assert_adapter_exists(adapter_name=adapter_name) - extra_kwargs = body.model_extra or {} - ret = self.model.backward(adapter_name=adapter_name, **extra_kwargs) - return {'result': ret} - - @app.post('/forward_backward') - def forward_backward(self, request: Request, body: ForwardRequest): - adapter_name = self.get_adapter_name(request, adapter_name=body.adapter_name) - self.assert_adapter_exists(adapter_name=adapter_name) - extra_kwargs = body.model_extra or {} - inputs = body.inputs - if isinstance(inputs, list): - _input = inputs[0] - if 'input_ids' in _input: - inputs = [InputFeature(**_input) for _input in inputs] - else: - inputs = [Trajectory(**_input) for _input in inputs] - else: - assert isinstance(inputs, dict) - inputs = InputFeature(**inputs) if 'input_ids' in inputs else Trajectory(**inputs) - ret = self.model.forward_backward(inputs=inputs, adapter_name=adapter_name, **extra_kwargs) - return {'result': ret} - - @app.post('/get_train_configs') - def get_train_configs(self, request: Request, body: AdapterRequest): - adapter_name = self.get_adapter_name(request, adapter_name=body.adapter_name) - self.assert_adapter_exists(adapter_name=adapter_name) - extra_kwargs = body.model_extra or {} - ret = self.model.get_train_configs(adapter_name=adapter_name, **extra_kwargs) - return {'result': ret} - - @app.post('/clip_grad_norm') - def clip_grad_norm(self, request: Request, body: AdapterRequest): - adapter_name = self.get_adapter_name(request, adapter_name=body.adapter_name) - self.assert_adapter_exists(adapter_name=adapter_name) - extra_kwargs = body.model_extra or {} - ret = self.model.clip_grad_norm(adapter_name=adapter_name, **extra_kwargs) - return {'result': str(ret)} - - @app.post('/step') - def step(self, request: Request, body: AdapterRequest): - adapter_name = self.get_adapter_name(request, adapter_name=body.adapter_name) - self.assert_adapter_exists(adapter_name=adapter_name) - extra_kwargs = body.model_extra or {} - ret = self.model.step(adapter_name=adapter_name, **extra_kwargs) - return {'result': ret} - - @app.post('/zero_grad') - def zero_grad(self, request: Request, body: AdapterRequest): - adapter_name = self.get_adapter_name(request, adapter_name=body.adapter_name) - self.assert_adapter_exists(adapter_name=adapter_name) - extra_kwargs = body.model_extra or {} - ret = self.model.zero_grad(adapter_name=adapter_name, **extra_kwargs) - return {'result': ret} - - @app.post('/lr_step') - def lr_step(self, request: Request, body: AdapterRequest): - adapter_name = self.get_adapter_name(request, adapter_name=body.adapter_name) - self.assert_adapter_exists(adapter_name=adapter_name) - extra_kwargs = body.model_extra or {} - ret = self.model.lr_step(adapter_name=adapter_name, **extra_kwargs) - return {'result': ret} - - @app.post('/set_loss') - def set_loss(self, request: Request, body: SetLossRequest): - adapter_name = self.get_adapter_name(request, adapter_name=body.adapter_name) - self.assert_adapter_exists(adapter_name=adapter_name) - extra_kwargs = body.model_extra or {} - ret = self.model.set_loss(body.loss_cls, adapter_name=adapter_name, **extra_kwargs) - return {'result': ret} - - @app.post('/set_optimizer') - def set_optimizer(self, request: Request, body: SetOptimizerRequest): - adapter_name = self.get_adapter_name(request, adapter_name=body.adapter_name) - self.assert_adapter_exists(adapter_name=adapter_name) - extra_kwargs = body.model_extra or {} - ret = self.model.set_optimizer(body.optimizer_cls, adapter_name=adapter_name, **extra_kwargs) - return {'result': ret} - - @app.post('/set_lr_scheduler') - def set_lr_scheduler(self, request: Request, body: SetLrSchedulerRequest): - adapter_name = self.get_adapter_name(request, adapter_name=body.adapter_name) - self.assert_adapter_exists(adapter_name=adapter_name) - extra_kwargs = body.model_extra or {} - ret = self.model.set_lr_scheduler(body.scheduler_cls, adapter_name=adapter_name, **extra_kwargs) - return {'result': ret} - - @app.post('/save') - def save(self, request: Request, body: SaveRequest): - """ - Save adapter weights with token-based isolation. - - This endpoint: - 1. Saves adapter weights to token-specific directory - 2. Saves checkpoint metadata with ownership tracking - - Args: - request: FastAPI request object (contains token in state) - body: SaveRequest with adapter_name, name, and save_optimizer flag - - Returns: - Dict with result containing the twinkle:// path to saved checkpoint - """ - adapter_name = self.get_adapter_name(request, adapter_name=body.adapter_name) - self.assert_adapter_exists(adapter_name=adapter_name) - extra_kwargs = body.model_extra or {} - - # Extract token for directory isolation - token = request.state.token - checkpoint_manager = create_checkpoint_manager(token) - - # Get checkpoint name and save directory with token-based path - checkpoint_name = checkpoint_manager.get_ckpt_name(body.name) - save_dir = checkpoint_manager.get_save_dir(model_id=adapter_name, is_sampler=False) - - # Save the model weights - checkpoint_dir = self.model.save( - name=checkpoint_name, - output_dir=save_dir, - adapter_name=adapter_name, - save_optimizer=body.save_optimizer, - **extra_kwargs) - - # Save checkpoint metadata - twinkle_path = checkpoint_manager.save(model_id=adapter_name, name=checkpoint_name, is_sampler=False) - - return {'result': twinkle_path, 'checkpoint_dir': checkpoint_dir} - - @app.post('/load') - def load(self, request: Request, body: LoadRequest): - """ - Load adapter weights with token-based access validation. - - This endpoint: - 1. Validates user has access to the checkpoint - 2. Loads weights from token-specific directory - - Args: - request: FastAPI request object (contains token in state) - body: LoadRequest with adapter_name, name, and load_optimizer flag - - Returns: - Dict with result indicating load status - """ - adapter_name = self.get_adapter_name(request, adapter_name=body.adapter_name) - self.assert_adapter_exists(adapter_name=adapter_name) - extra_kwargs = body.model_extra or {} - - # Extract token for directory isolation - token = request.state.token - checkpoint_manager = create_checkpoint_manager(token) - - # Use resolve_load_path to handle path resolution - resolved = checkpoint_manager.resolve_load_path(body.name) - - # Load from twinkle checkpoint directory - ret = self.model.load( - name=resolved.checkpoint_name, - output_dir=resolved.checkpoint_dir, - adapter_name=adapter_name, - load_optimizer=body.load_optimizer, - token=token, - **extra_kwargs) - - return {'result': ret} - - @app.post('/upload_to_hub') - def upload_to_hub(self, request: Request, body: UploadToHubRequest): - """ - Upload model checkpoint to hub. - - This endpoint uploads a previously saved checkpoint to a hub repository. - - Args: - request: FastAPI request object (contains token in state) - body: UploadToHubRequest with checkpoint_dir, hub_model_id, hub_token, and async_upload - - Returns: - Dict with success status and message - """ - token = request.state.token - - # Check if body.name is a twinkle:// path or a simple checkpoint name - if body.checkpoint_dir.startswith('twinkle://'): - # Parse twinkle:// path - checkpoint_manager = create_checkpoint_manager(token) - parsed = checkpoint_manager.parse_twinkle_path(body.checkpoint_dir) - if not parsed: - raise ValueError(f'Invalid twinkle path format: {body.checkpoint_dir}') - # parsed.checkpoint_id is like "weights/step-8" - checkpoint_id = parsed.checkpoint_id - - # Use the training_run_id from the path as the model_id - model_id_to_load = parsed.training_run_id - - # Verify checkpoint exists and user has access - checkpoint = checkpoint_manager.get(model_id_to_load, checkpoint_id) - if not checkpoint: - raise ValueError(f'Checkpoint not found or access denied: {body.checkpoint_dir}') - - # Get the actual directory path for the specific checkpoint - checkpoint_dir = str( - checkpoint_manager.get_ckpt_dir(model_id=model_id_to_load, checkpoint_id=checkpoint_id)) - else: - checkpoint_dir = body.checkpoint_dir - - # Call the model's upload_to_hub method - self.model.upload_to_hub( - checkpoint_dir=checkpoint_dir, - hub_model_id=body.hub_model_id, - hub_token=body.hub_token or token, - async_upload=body.async_upload) - - return {'result': body.hub_model_id} - - @app.post('/add_adapter_to_model') - def add_adapter_to_model(self, request: Request, body: AddAdapterRequest): - """ - Add a new adapter to the model. - - This endpoint: - 1. Creates a new adapter with the specified configuration - 2. Registers it in the adapter tracking system - 3. Saves training run metadata with token-based isolation - - Args: - request: FastAPI request object (contains token in state) - body: AddAdapterRequest with adapter_name and config - - Returns: - Dict with status and adapter_name - """ - assert body.adapter_name, 'You need to specify a valid `adapter_name`' - adapter_name = self.get_adapter_name(request, adapter_name=body.adapter_name) - config = deserialize_object(body.config) - extra_kwargs = body.model_extra or {} - - # Extract token for metadata storage - token = request.state.token - training_run_manager = create_training_run_manager(token) - - # Register adapter FIRST - self.register_adapter(adapter_name, token) - - # Create adapter AFTER successful registration - self.model.add_adapter_to_model(adapter_name, config, **extra_kwargs) - - # Save training run metadata (similar to tinker's create_model) - # Create a training run config from the adapter configuration - lora_config = None - if isinstance(config, LoraConfig): - lora_config = IoLoraConfig( - rank=config.r, - train_unembed=False, # Default values - train_mlp=True, - train_attn=True) - - run_config = CreateModelRequest( - base_model=model_id, # Use the model_id from build_model_app - lora_config=lora_config, - user_metadata={'adapter_name': body.adapter_name}) - - # Save training run metadata with token-based isolation - training_run_manager.save(adapter_name, run_config) - - return {'status': 'ok', 'adapter_name': adapter_name} - - @app.post('/set_template') - def set_template(self, request: Request, body: SetTemplateRequest): - adapter_name = self.get_adapter_name(request, adapter_name=body.adapter_name) - self.assert_adapter_exists(adapter_name=adapter_name) - extra_kwargs = body.model_extra or {} - ret = self.model.set_template(body.template_cls, adapter_name=adapter_name, **extra_kwargs) - return {'result': ret} - - @app.post('/set_processor') - def set_processor(self, request: Request, body: SetProcessorRequest): - adapter_name = self.get_adapter_name(request, adapter_name=body.adapter_name) - self.assert_adapter_exists(adapter_name=adapter_name) - extra_kwargs = body.model_extra or {} - ret = self.model.set_processor(body.processor_cls, adapter_name=adapter_name, **extra_kwargs) - return {'result': ret} - - @app.post('/heartbeat') - def heartbeat(self, request: Request, body: HeartbeatRequest): - adapter_name = self.get_adapter_name(request, adapter_name=body.adapter_name) - self.assert_adapter_exists(adapter_name=adapter_name) - self.touch_adapter(adapter_name) - return {'status': 'ok'} - - @app.post('/calculate_metric') - def calculate_metric(self, request: Request, body: CalculateMetricRequest): - adapter_name = self.get_adapter_name(request, adapter_name=body.adapter_name) - self.assert_adapter_exists(adapter_name=adapter_name) - extra_kwargs = body.model_extra or {} - ret = self.model.calculate_metric(is_training=body.is_training, adapter_name=adapter_name, **extra_kwargs) - return {'result': ret} - - @app.post('/get_state_dict') - def get_state_dict(self, request: Request, body: GetStateDictRequest): - adapter_name = self.get_adapter_name(request, adapter_name=body.adapter_name) - self.assert_adapter_exists(adapter_name=adapter_name) - extra_kwargs = body.model_extra or {} - ret = self.model.get_state_dict(adapter_name=adapter_name, **extra_kwargs) - return {'result': ret} - - return ModelManagement.options(**deploy_options).bind(nproc_per_node, device_group, device_mesh) diff --git a/src/twinkle/server/twinkle/sampler.py b/src/twinkle/server/twinkle/sampler.py deleted file mode 100644 index 27ffd694..00000000 --- a/src/twinkle/server/twinkle/sampler.py +++ /dev/null @@ -1,308 +0,0 @@ -# Copyright (c) ModelScope Contributors. All rights reserved. -""" -Twinkle sampler (inference) server. - -This module provides a Ray Serve deployment for distributed text generation/inference. -It supports: -1. vLLM and Torch sampler backends -2. LoRA adapter loading via adapter URIs (twinkle:// paths or local paths) -3. Multi-user inference with adapter lifecycle management -4. Flexible sampling parameters -""" -import traceback -from fastapi import FastAPI, Request -from pydantic import BaseModel, Field -from ray import serve -from typing import Any, Dict, List, Optional, Union - -import twinkle -from twinkle import DeviceGroup, DeviceMesh -from twinkle.data_format import InputFeature, SamplingParams, Trajectory -from twinkle.server.utils.adapter_manager import AdapterManagerMixin -from twinkle.server.utils.state import ServerStateProxy, get_server_state -from twinkle.server.utils.validation import get_token_from_request, verify_request_token -from twinkle.utils.logger import get_logger - -logger = get_logger() - -# ----- Request/Response Models ----- - - -class SampleRequest(BaseModel): - """Request body for the /sample endpoint.""" - inputs: Any = Field(..., description='List of Trajectory or InputFeature dicts') - sampling_params: Optional[Dict[str, Any]] = Field( - None, description='Sampling parameters (max_tokens, temperature, etc.)') - adapter_name: str = Field('', description='Adapter name for LoRA inference') - adapter_uri: Optional[str] = Field( - None, description='Adapter URI (twinkle:// path or local path) for LoRA inference') - num_samples: int = Field(1, description='Number of completions to generate per prompt') - - -class SampleResponseModel(BaseModel): - """Response body for the /sample endpoint.""" - sequences: List[Dict[str, - Any]] = Field(..., - description='List of sampled sequences, each with tokens, logprobs, stop_reason') - prompt_logprobs: Optional[List[Optional[float]]] = None - topk_prompt_logprobs: Optional[List[Optional[List]]] = None - - -class SetTemplateRequest(BaseModel): - """Request body for the /set_template endpoint.""" - template_cls: str = Field(..., description="Template class name (e.g. 'Template')") - adapter_name: str = Field('', description='Adapter name to associate the template with') - - class Config: - extra = 'allow' - - -class SetTemplateResponse(BaseModel): - """Response body for the /set_template endpoint.""" - status: str = 'ok' - - -class AddAdapterRequest(BaseModel): - """Request body for the /add_adapter_to_sampler endpoint.""" - adapter_name: str = Field(..., description='Name of the adapter to add') - config: Any = Field(..., description='LoRA configuration dict') - - -class AddAdapterResponse(BaseModel): - """Response body for the /add_adapter_to_sampler endpoint.""" - status: str = 'ok' - adapter_name: str - - -class HeartbeatRequest(BaseModel): - """Request body for the /heartbeat endpoint.""" - adapter_name: str = Field(..., description='Adapter name to keep alive') - - -class HeartbeatResponse(BaseModel): - """Response body for the /heartbeat endpoint.""" - status: str = 'ok' - - -class CreateResponse(BaseModel): - """Response body for the /create endpoint.""" - status: str = 'ok' - - -# ----- Application Builder ----- - - -def build_sampler_app(model_id: str, - nproc_per_node: int = 1, - device_group: Dict[str, Any] = None, - device_mesh: Dict[str, Any] = None, - deploy_options: Dict[str, Any] = None, - sampler_type: str = 'vllm', - engine_args: Optional[Dict[str, Any]] = None, - adapter_config: Optional[Dict[str, Any]] = None, - **kwargs): - """Build a sampler application for text generation inference. - - Args: - model_id: Model identifier (e.g., "Qwen/Qwen3.5-4B") - nproc_per_node: Number of GPU processes per node - device_group: Device group configuration dict - device_mesh: Device mesh configuration dict for parallelism - deploy_options: Ray Serve deployment options - sampler_type: Type of sampler to use ('vllm' or 'torch') - engine_args: Additional engine arguments for the sampler - adapter_config: Adapter lifecycle config (adapter_timeout, per_token_adapter_limit) - **kwargs: Additional arguments passed to the sampler - - Returns: - Ray Serve deployment bound with configuration - """ - app = FastAPI( - title='Twinkle Sampler', description='REST API for distributed text generation inference', version='1.0.0') - - @app.middleware('http') - async def verify_token(request: Request, call_next): - return await verify_request_token(request=request, call_next=call_next) - - @serve.deployment(name='SamplerManagement') - @serve.ingress(app) - class SamplerManagement(AdapterManagerMixin): - """Sampler management service for text generation inference. - - Manages: - - vLLM or Torch sampler initialization and lifecycle - - Adapter lifecycle via AdapterManagerMixin - - Inference requests with LoRA adapter support - - Template configuration for trajectory encoding - """ - - def __init__(self, - nproc_per_node: int, - device_group: Dict[str, Any], - device_mesh: Dict[str, Any], - sampler_type: str = 'vllm', - engine_args: Optional[Dict[str, Any]] = None, - adapter_config: Optional[Dict[str, Any]] = None, - **kwargs): - self.device_group = DeviceGroup(**device_group) - twinkle.initialize( - mode='ray', nproc_per_node=nproc_per_node, groups=[self.device_group], lazy_collect=False) - if 'mesh_dim_names' in device_mesh: - self.device_mesh = DeviceMesh(**device_mesh) - else: - self.device_mesh = DeviceMesh.from_sizes(**device_mesh) - self.sampler_type = sampler_type - replica_context = serve.get_replica_context() - replica_id = replica_context.replica_id.unique_id - # Initialize sampler based on type - if sampler_type == 'vllm': - from twinkle.sampler import vLLMSampler - sampler_kwargs = engine_args or {} - self.sampler = vLLMSampler( - model_id=model_id, - engine_args=sampler_kwargs, - device_mesh=self.device_mesh, - remote_group=self.device_group.name, - instance_id=replica_id, - **{ - k: v - for k, v in kwargs.items() if k not in ['engine_args'] - }) - else: - from twinkle.sampler import TorchSampler - self.sampler = TorchSampler( - model_id=model_id, - device_mesh=self.device_mesh, - instance_id=replica_id, - remote_group=self.device_group.name, - **kwargs) - - # Initialize state and adapter manager - self.state: ServerStateProxy = get_server_state() - _adapter_config = adapter_config or {} - self._init_adapter_manager(**_adapter_config) - self.start_adapter_countdown() - - def _on_adapter_expired(self, adapter_name: str, token: str) -> None: - """Handle expired adapters by removing them from the sampler.""" - try: - self.sampler.remove_adapter(adapter_name) - logger.info(f'Removed expired adapter {adapter_name}') - # Adapter count is now tracked dynamically, no manual update needed - except Exception as e: - logger.warning(f'Failed to remove expired adapter {adapter_name}: {e}') - - @staticmethod - def _get_adapter_name(request: Request, adapter_name: Optional[str]) -> Optional[str]: - if adapter_name is None or adapter_name == '': - return None - return request.state.request_id + '-' + adapter_name - - @app.post('/create', response_model=CreateResponse) - def create(self, request: Request) -> CreateResponse: - """Health check / session creation endpoint.""" - return CreateResponse() - - @app.post('/sample', response_model=SampleResponseModel) - def sample(self, request: Request, body: SampleRequest) -> SampleResponseModel: - """Sample completions from the model. - - Supports: - - Trajectory inputs (messages-based, requires template to be set) - - InputFeature inputs (pre-tokenized input_ids) - - LoRA adapter via adapter_name or adapter_uri (twinkle:// path) - - Multiple completions per prompt via num_samples - """ - try: - # Resolve adapter - adapter_path = None - adapter_name = body.adapter_name or '' - full_adapter_name = self._get_adapter_name(request, adapter_name) or '' - - if body.adapter_uri: - from .common.io_utils import create_checkpoint_manager - token = get_token_from_request(request) - checkpoint_manager = create_checkpoint_manager(token) - _, adapter_path = checkpoint_manager.parse_adapter_uri(body.adapter_uri) - - # Parse inputs - inputs = body.inputs - if isinstance(inputs, list) and inputs: - first = inputs[0] - if isinstance(first, dict) and 'input_ids' in first: - inputs = [InputFeature(**item) for item in inputs] - else: - inputs = [Trajectory(**item) for item in inputs] - elif isinstance(inputs, dict): - if 'input_ids' in inputs: - inputs = [InputFeature(**inputs)] - else: - inputs = [Trajectory(**inputs)] - - # Build sampling params - params = None - if body.sampling_params: - params = SamplingParams.from_dict(body.sampling_params) - - # Call sampler - response = self.sampler.sample( - inputs, - params, - adapter_name=full_adapter_name, - adapter_path=adapter_path, - num_samples=body.num_samples, - ) - if callable(response): - response = response() - - # Convert to response model - sequences = [] - for seq in response.sequences: - sequences.append({ - 'stop_reason': seq.stop_reason, - 'tokens': list(seq.tokens), - 'logprobs': list(seq.logprobs) if seq.logprobs is not None else None, - }) - - return SampleResponseModel( - sequences=sequences, - prompt_logprobs=response.prompt_logprobs, - topk_prompt_logprobs=response.topk_prompt_logprobs, - ) - except Exception: - logger.error(traceback.format_exc()) - raise - - @app.post('/set_template', response_model=SetTemplateResponse) - def set_template(self, request: Request, body: SetTemplateRequest) -> SetTemplateResponse: - """Set the chat template for encoding Trajectory inputs.""" - extra_kwargs = body.model_extra or {} - self.sampler.set_template(body.template_cls, **extra_kwargs) - return SetTemplateResponse() - - @app.post('/add_adapter_to_sampler', response_model=AddAdapterResponse) - def add_adapter_to_sampler(self, request: Request, body: AddAdapterRequest) -> AddAdapterResponse: - """Add a LoRA adapter to the sampler.""" - assert body.adapter_name, 'You need to specify a valid `adapter_name`' - full_adapter_name = self._get_adapter_name(request, body.adapter_name) - token = get_token_from_request(request) - - from peft import LoraConfig - config = LoraConfig(**body.config) if isinstance(body.config, dict) else body.config - - self.register_adapter(full_adapter_name, token) - - self.sampler.add_adapter_to_sampler(full_adapter_name, config) - - return AddAdapterResponse(adapter_name=full_adapter_name) - - @app.post('/heartbeat', response_model=HeartbeatResponse) - def heartbeat(self, request: Request, body: HeartbeatRequest) -> HeartbeatResponse: - """Keep an adapter alive by resetting its inactivity timer.""" - full_adapter_name = self._get_adapter_name(request, body.adapter_name) - self.assert_adapter_exists(adapter_name=full_adapter_name) - self.touch_adapter(full_adapter_name) - return HeartbeatResponse() - - return SamplerManagement.options(**deploy_options).bind(nproc_per_node, device_group, device_mesh, sampler_type, - engine_args, adapter_config, **kwargs) diff --git a/src/twinkle/server/twinkle/server.py b/src/twinkle/server/twinkle/server.py deleted file mode 100644 index 86857647..00000000 --- a/src/twinkle/server/twinkle/server.py +++ /dev/null @@ -1,270 +0,0 @@ -# Copyright (c) ModelScope Contributors. All rights reserved. -""" -Twinkle REST API Server - -This module provides a FastAPI server with REST API endpoints for: -- Training run management (list, get, update) -- Checkpoint management (list, delete) -- Weights info retrieval - -All endpoints include permission control to ensure users can only -access their own resources. -""" -from __future__ import annotations - -from fastapi import FastAPI, HTTPException, Request -from pydantic import BaseModel -from ray import serve -from typing import Any - -from twinkle.server.utils.state import ServerStateProxy, get_server_state -from twinkle.server.utils.validation import get_token_from_request, verify_request_token -from .common.io_utils import (CheckpointsListResponse, TrainingRun, TrainingRunsResponse, WeightsInfoResponse, - create_checkpoint_manager, create_training_run_manager, validate_user_path) - -# ----- Request/Response Models ----- - - -class HealthResponse(BaseModel): - status: str - - -class WeightsInfoRequest(BaseModel): - twinkle_path: str - - -class DeleteCheckpointResponse(BaseModel): - success: bool - message: str - - -class ErrorResponse(BaseModel): - detail: str - - -def build_server_app(deploy_options: dict[str, Any], **kwargs): - """ - Build the Twinkle REST API server application. - - This function creates a FastAPI application wrapped in a Ray Serve deployment - that provides REST API endpoints for managing training runs and checkpoints. - - Args: - deploy_options: Ray Serve deployment options (num_replicas, etc.) - **kwargs: Additional configuration options - - Returns: - A Ray Serve deployment handle - """ - app = FastAPI( - title='Twinkle Server', description='REST API for managing training runs and checkpoints', version='1.0.0') - - @app.middleware('http') - async def verify_token(request: Request, call_next): - """Verify authentication token for all requests.""" - return await verify_request_token(request=request, call_next=call_next) - - @serve.deployment(name='TwinkleServer') - @serve.ingress(app) - class TwinkleServer: - """ - Twinkle REST API Server. - - This server provides endpoints for: - - Health checks - - Training run management - - Checkpoint management - - Weights info retrieval - - All modifying operations (delete, etc.) are protected by permission checks - to ensure users can only modify their own resources. - """ - - def __init__(self, **kwargs) -> None: - self.state: ServerStateProxy = get_server_state() - self.route_prefix = kwargs.get('route_prefix', '/api/v1') - - def _get_user_token(self, request: Request) -> str: - """Extract user token from request state.""" - return get_token_from_request(request) - - # ----- Health Check ----- - - @app.get('/healthz', response_model=HealthResponse) - async def healthz(self, request: Request) -> HealthResponse: - """ - Health check endpoint. - - Returns: - HealthResponse with status "ok" if server is healthy - """ - return HealthResponse(status='ok') - - # ----- Training Runs Endpoints ----- - - @app.get('/training_runs', response_model=TrainingRunsResponse) - async def get_training_runs(self, request: Request, limit: int = 20, offset: int = 0) -> TrainingRunsResponse: - """ - List training runs. - - Returns training runs owned by the current user. - - Args: - limit: Maximum number of results (default: 20) - offset: Offset for pagination (default: 0) - - Returns: - TrainingRunsResponse with list of training runs and pagination info - """ - token = self._get_user_token(request) - training_run_manager = create_training_run_manager(token) - return training_run_manager.list_runs(limit=limit, offset=offset) - - @app.get('/training_runs/{run_id}', response_model=TrainingRun) - async def get_training_run(self, request: Request, run_id: str) -> TrainingRun: - """ - Get details of a specific training run. - - Users can only view their own training runs. - - Args: - run_id: The training run identifier - - Returns: - TrainingRun details - - Raises: - HTTPException 404 if run not found or not owned by user - """ - token = self._get_user_token(request) - training_run_manager = create_training_run_manager(token) - run = training_run_manager.get_with_permission(run_id) - if not run: - raise HTTPException(status_code=404, detail=f'Training run {run_id} not found or access denied') - return run - - @app.get('/training_runs/{run_id}/checkpoints', response_model=CheckpointsListResponse) - async def get_run_checkpoints(self, request: Request, run_id: str) -> CheckpointsListResponse: - """ - List checkpoints for a training run. - - Users can only view checkpoints for their own training runs. - - Args: - run_id: The training run identifier - - Returns: - CheckpointsListResponse with list of checkpoints - - Raises: - HTTPException 404 if run not found or not owned by user - """ - token = self._get_user_token(request) - checkpoint_manager = create_checkpoint_manager(token) - response = checkpoint_manager.list_checkpoints(run_id) - if response is None: - raise HTTPException(status_code=404, detail=f'Training run {run_id} not found or access denied') - return response - - @app.delete('/training_runs/{run_id}/checkpoints/{checkpoint_id:path}') - async def delete_run_checkpoint(self, request: Request, run_id: str, - checkpoint_id: str) -> DeleteCheckpointResponse: - """ - Delete a checkpoint from a training run. - - Users can only delete checkpoints from their own training runs. - Path traversal (using ..) is not allowed. - - Args: - run_id: The training run identifier - checkpoint_id: The checkpoint identifier (can include path like weights/checkpoint_name) - - Returns: - DeleteCheckpointResponse indicating success or failure - - Raises: - HTTPException 400 for invalid paths - HTTPException 403 if not owned by user - HTTPException 404 if checkpoint not found - """ - token = self._get_user_token(request) - - # Validate path safety - if not validate_user_path(token, checkpoint_id): - raise HTTPException(status_code=400, detail='Invalid checkpoint path: path traversal not allowed') - - checkpoint_manager = create_checkpoint_manager(token) - success = checkpoint_manager.delete(run_id, checkpoint_id) - if not success: - raise HTTPException(status_code=404, detail=f'Checkpoint {checkpoint_id} not found or access denied') - - return DeleteCheckpointResponse(success=True, message=f'Checkpoint {checkpoint_id} deleted successfully') - - @app.post('/weights_info', response_model=WeightsInfoResponse) - async def weights_info(self, request: Request, body: WeightsInfoRequest) -> WeightsInfoResponse: - """ - Get information about saved weights. - - Users can only view info for their own weights. - - Args: - body: Request containing the twinkle_path - - Returns: - WeightsInfoResponse with weight details - - Raises: - HTTPException 404 if weights not found or not owned by user - """ - token = self._get_user_token(request) - checkpoint_manager = create_checkpoint_manager(token) - response = checkpoint_manager.get_weights_info(body.twinkle_path) - if response is None: - raise HTTPException( - status_code=404, detail=f'Weights at {body.twinkle_path} not found or access denied') - return response - - # ----- Checkpoint Path Resolution ----- - - @app.get('/checkpoint_path/{run_id}/{checkpoint_id:path}') - async def get_checkpoint_path(self, request: Request, run_id: str, checkpoint_id: str) -> dict[str, str]: - """ - Get the filesystem path for a checkpoint. - - This endpoint resolves a checkpoint ID to its actual filesystem path, - which can be used for loading weights during resume training. - - Args: - run_id: The training run identifier - checkpoint_id: The checkpoint identifier - - Returns: - Dict with 'path' key containing the filesystem path - - Raises: - HTTPException 403/404 for permission/not found errors - """ - token = self._get_user_token(request) - - # Validate path safety - if not validate_user_path(token, checkpoint_id): - raise HTTPException(status_code=400, detail='Invalid checkpoint path: path traversal not allowed') - - training_run_manager = create_training_run_manager(token) - checkpoint_manager = create_checkpoint_manager(token) - - # Check ownership - run = training_run_manager.get(run_id) - if not run: - raise HTTPException(status_code=404, detail=f'Training run {run_id} not found or access denied') - - # Get checkpoint with token-based path - checkpoint = checkpoint_manager.get(run_id, checkpoint_id) - if not checkpoint: - raise HTTPException(status_code=404, detail=f'Checkpoint {checkpoint_id} not found') - - # Return the filesystem path - ckpt_dir = checkpoint_manager.get_ckpt_dir(run_id, checkpoint_id) - return {'path': str(ckpt_dir), 'twinkle_path': checkpoint.twinkle_path} - - return TwinkleServer.options(**deploy_options).bind(**kwargs) diff --git a/src/twinkle/server/utils/task_queue.py b/src/twinkle/server/utils/task_queue.py index 39511659..33025324 100644 --- a/src/twinkle/server/utils/task_queue.py +++ b/src/twinkle/server/utils/task_queue.py @@ -544,6 +544,58 @@ def get_rate_limiter_memory_stats(self) -> dict[str, Any]: """ return self._rate_limiter.get_memory_stats() + async def schedule_task_and_wait( + self, + coro_factory: Callable[[], Coroutine], + model_id: str | None = None, + token: str | None = None, + input_tokens: int = 0, + task_type: str | None = None, + ) -> Any: + """Schedule an async task and wait for its result synchronously. + + This is the twinkle-side counterpart to :meth:`schedule_task`. + It enqueues the task through the same serial worker, then blocks + (via async sleep) until the task completes, and returns the result + directly instead of a future reference dict. + + Args: + coro_factory: Factory that creates the coroutine to execute. + model_id: Optional model_id to associate with the result. + token: Optional user token for rate limiting. + input_tokens: Number of input tokens for tps rate limiting. + task_type: Optional task type for logging/observability. + + Returns: + The direct return value of the coroutine. + + Raises: + RuntimeError: If the task fails. + """ + future_ref = await self.schedule_task( + coro_factory, + model_id=model_id, + token=token, + input_tokens=input_tokens, + task_type=task_type, + ) + request_id = future_ref.get('request_id') + if request_id is None: + # Pre-flight check failed; surface the error from the stored future + raise RuntimeError(f'Task scheduling failed: {future_ref}') + + while True: + record = self.state.get_future(request_id) + if record and record.get('status') not in ('pending', 'queued', 'running'): + break + await asyncio.sleep(0.05) + + if record['status'] == 'failed': + error = record.get('result', {}).get('error', 'Unknown error') + raise RuntimeError(error) + + return record['result'] + async def shutdown_task_queue(self) -> None: """Gracefully shutdown the task queue and cleanup tasks. diff --git a/src/twinkle_client/http/http_utils.py b/src/twinkle_client/http/http_utils.py index 522b46af..6d86c1c7 100644 --- a/src/twinkle_client/http/http_utils.py +++ b/src/twinkle_client/http/http_utils.py @@ -42,7 +42,7 @@ def _serialize_params(params: Dict[str, Any]) -> Dict[str, Any]: if hasattr(value, 'processor_id'): serialized[key] = value.processor_id elif hasattr(value, '__dict__'): - from twinkle.server.twinkle.common.serialize import serialize_object + from twinkle.server.common.serialize import serialize_object serialized[key] = serialize_object(value) else: serialized[key] = value diff --git a/src/twinkle_client/manager.py b/src/twinkle_client/manager.py index f0f987a6..5419bf6f 100644 --- a/src/twinkle_client/manager.py +++ b/src/twinkle_client/manager.py @@ -3,8 +3,8 @@ from typing import Any, Dict, List, Optional -# Reuse Pydantic models from server -from twinkle.server.twinkle.common.io_utils import Checkpoint, Cursor, TrainingRun +# Shared Pydantic models +from twinkle_client.types.training import Checkpoint, Cursor, TrainingRun from .http.http_utils import http_get, http_post diff --git a/src/twinkle_client/types/__init__.py b/src/twinkle_client/types/__init__.py new file mode 100644 index 00000000..85b3e739 --- /dev/null +++ b/src/twinkle_client/types/__init__.py @@ -0,0 +1 @@ +# Copyright (c) ModelScope Contributors. All rights reserved. diff --git a/src/twinkle_client/types/model.py b/src/twinkle_client/types/model.py new file mode 100644 index 00000000..b3b9b6c4 --- /dev/null +++ b/src/twinkle_client/types/model.py @@ -0,0 +1,132 @@ +# Copyright (c) ModelScope Contributors. All rights reserved. +""" +Pydantic request/response models for twinkle model management endpoints. + +These models are used by both the server-side handler and the twinkle client. +""" +from pydantic import BaseModel +from typing import Any, Optional + + +class CreateRequest(BaseModel): + + class Config: + extra = 'allow' + + +class ForwardRequest(BaseModel): + inputs: Any + adapter_name: str + + class Config: + extra = 'allow' + + +class ForwardOnlyRequest(BaseModel): + inputs: Any + adapter_name: Optional[str] = None + + class Config: + extra = 'allow' + + +class AdapterRequest(BaseModel): + adapter_name: str + + class Config: + extra = 'allow' + + +class SetLossRequest(BaseModel): + loss_cls: str + adapter_name: str + + class Config: + extra = 'allow' + + +class SetOptimizerRequest(BaseModel): + optimizer_cls: str + adapter_name: str + + class Config: + extra = 'allow' + + +class SetLrSchedulerRequest(BaseModel): + scheduler_cls: str + adapter_name: str + + class Config: + extra = 'allow' + + +class SaveRequest(BaseModel): + adapter_name: str + save_optimizer: bool = False + name: Optional[str] = None + + class Config: + extra = 'allow' + + +class UploadToHubRequest(BaseModel): + checkpoint_dir: str + hub_model_id: str + hub_token: Optional[str] = None + async_upload: bool = True + + class Config: + extra = 'allow' + + +class LoadRequest(BaseModel): + adapter_name: str + load_optimizer: bool = False + name: str + + class Config: + extra = 'allow' + + +class AddAdapterRequest(BaseModel): + adapter_name: str + config: str + + class Config: + extra = 'allow' + + +class SetTemplateRequest(BaseModel): + template_cls: str + adapter_name: str + + class Config: + extra = 'allow' + + +class SetProcessorRequest(BaseModel): + processor_cls: str + adapter_name: str + + class Config: + extra = 'allow' + + +class HeartbeatRequest(BaseModel): + adapter_name: str + + +class CalculateMetricRequest(BaseModel): + adapter_name: str + is_training: bool = True + + class Config: + extra = 'allow' + + +class GetStateDictRequest(BaseModel): + adapter_name: str + + class Config: + extra = 'allow' diff --git a/src/twinkle_client/types/processor.py b/src/twinkle_client/types/processor.py new file mode 100644 index 00000000..feac393e --- /dev/null +++ b/src/twinkle_client/types/processor.py @@ -0,0 +1,30 @@ +# Copyright (c) ModelScope Contributors. All rights reserved. +""" +Pydantic request/response models for twinkle processor endpoints. + +These models are used by both the server-side handler and the twinkle client. + +Note: Class names are prefixed with 'Processor' to avoid name collisions when +importing from twinkle_client.types alongside model.py classes. +""" +from pydantic import BaseModel + + +class ProcessorCreateRequest(BaseModel): + processor_type: str + class_type: str + + class Config: + extra = 'allow' + + +class ProcessorHeartbeatRequest(BaseModel): + processor_id: str + + +class ProcessorCallRequest(BaseModel): + processor_id: str + function: str + + class Config: + extra = 'allow' diff --git a/src/twinkle_client/types/sampler.py b/src/twinkle_client/types/sampler.py new file mode 100644 index 00000000..303316a9 --- /dev/null +++ b/src/twinkle_client/types/sampler.py @@ -0,0 +1,68 @@ +# Copyright (c) ModelScope Contributors. All rights reserved. +""" +Pydantic request/response models for twinkle sampler endpoints. + +These models are used by both the server-side handler and the twinkle client. +""" +from pydantic import BaseModel, Field +from typing import Any, Dict, List, Optional + + +class SampleRequest(BaseModel): + """Request body for the /sample endpoint.""" + inputs: Any = Field(..., description='List of Trajectory or InputFeature dicts') + sampling_params: Optional[Dict[str, Any]] = Field( + None, description='Sampling parameters (max_tokens, temperature, etc.)') + adapter_name: str = Field('', description='Adapter name for LoRA inference') + adapter_uri: Optional[str] = Field( + None, description='Adapter URI (twinkle:// path or local path) for LoRA inference') + num_samples: int = Field(1, description='Number of completions to generate per prompt') + + +class SampleResponseModel(BaseModel): + """Response body for the /sample endpoint.""" + sequences: List[Dict[str, Any]] = Field( + ..., description='List of sampled sequences, each with tokens, logprobs, stop_reason') + prompt_logprobs: Optional[List[Optional[float]]] = None + topk_prompt_logprobs: Optional[List[Optional[List]]] = None + + +class SetTemplateRequest(BaseModel): + """Request body for the /set_template endpoint.""" + template_cls: str = Field(..., description="Template class name (e.g. 'Template')") + adapter_name: str = Field('', description='Adapter name to associate the template with') + + class Config: + extra = 'allow' + + +class SetTemplateResponse(BaseModel): + """Response body for the /set_template endpoint.""" + status: str = 'ok' + + +class AddAdapterRequest(BaseModel): + """Request body for the /add_adapter_to_sampler endpoint.""" + adapter_name: str = Field(..., description='Name of the adapter to add') + config: Any = Field(..., description='LoRA configuration dict') + + +class AddAdapterResponse(BaseModel): + """Response body for the /add_adapter_to_sampler endpoint.""" + status: str = 'ok' + adapter_name: str + + +class HeartbeatRequest(BaseModel): + """Request body for the /heartbeat endpoint.""" + adapter_name: str = Field(..., description='Adapter name to keep alive') + + +class HeartbeatResponse(BaseModel): + """Response body for the /heartbeat endpoint.""" + status: str = 'ok' + + +class CreateResponse(BaseModel): + """Response body for the /create endpoint.""" + status: str = 'ok' diff --git a/src/twinkle_client/types/server.py b/src/twinkle_client/types/server.py new file mode 100644 index 00000000..f9e79e7b --- /dev/null +++ b/src/twinkle_client/types/server.py @@ -0,0 +1,16 @@ +# Copyright (c) ModelScope Contributors. All rights reserved. +"""Shared Pydantic response models for the twinkle server health/error endpoints.""" +from pydantic import BaseModel + + +class HealthResponse(BaseModel): + status: str + + +class DeleteCheckpointResponse(BaseModel): + success: bool + message: str + + +class ErrorResponse(BaseModel): + detail: str diff --git a/src/twinkle_client/types/training.py b/src/twinkle_client/types/training.py new file mode 100644 index 00000000..4c8cba83 --- /dev/null +++ b/src/twinkle_client/types/training.py @@ -0,0 +1,91 @@ +# Copyright (c) ModelScope Contributors. All rights reserved. +""" +Shared Pydantic models for twinkle training runs and checkpoints. + +These types are used both by twinkle_client (as request/response shapes) +and by twinkle.server.common.io_utils (as persistence models). +""" +from datetime import datetime +from pydantic import BaseModel +from typing import Any, Dict, List, Optional + + +class Cursor(BaseModel): + limit: int + offset: int + total_count: int + + +class Checkpoint(BaseModel): + """Twinkle checkpoint model.""" + checkpoint_id: str + checkpoint_type: str + time: datetime + size_bytes: int + public: bool = False + twinkle_path: str + # Training run info (stored for hub downloads) + base_model: Optional[str] = None + is_lora: bool = False + lora_rank: Optional[int] = None + train_unembed: Optional[bool] = None + train_mlp: Optional[bool] = None + train_attn: Optional[bool] = None + user_metadata: Optional[Dict[str, Any]] = None + + +class TrainingRun(BaseModel): + """Twinkle training run model.""" + training_run_id: str + base_model: str + model_owner: str + is_lora: bool = False + corrupted: bool = False + lora_rank: Optional[int] = None + last_request_time: Optional[datetime] = None + last_checkpoint: Optional[Dict[str, Any]] = None + last_sampler_checkpoint: Optional[Dict[str, Any]] = None + user_metadata: Optional[Dict[str, Any]] = None + + +class TrainingRunsResponse(BaseModel): + training_runs: List[TrainingRun] + cursor: Cursor + + +class CheckpointsListResponse(BaseModel): + checkpoints: List[Checkpoint] + cursor: Optional[Cursor] = None + + +class ParsedCheckpointTwinklePath(BaseModel): + """Twinkle-specific parsed path model.""" + path: str + twinkle_path: str + training_run_id: str + checkpoint_type: str + checkpoint_id: str + + +class WeightsInfoResponse(BaseModel): + """Twinkle weights info response.""" + training_run_id: str + base_model: str + model_owner: str + is_lora: bool = False + lora_rank: Optional[int] = None + + +class LoraConfig(BaseModel): + """Twinkle LoRA configuration.""" + rank: int = 8 + train_unembed: bool = False + train_mlp: bool = True + train_attn: bool = True + + +class CreateModelRequest(BaseModel): + """Twinkle create model request.""" + base_model: str + lora_config: Optional[LoraConfig] = None + user_metadata: Optional[Dict[str, Any]] = None From d9c424ae3c7deba549b746b3ba17e3f5bbb2db0f Mon Sep 17 00:00:00 2001 From: Yunnglin Date: Tue, 10 Mar 2026 17:18:53 +0800 Subject: [PATCH 02/24] update refact --- src/twinkle/server/gateway/server.py | 36 +-- .../server/gateway/tinker_gateway_handlers.py | 256 ++++++++++++++++ src/twinkle/server/gateway/tinker_router.py | 289 ------------------ .../gateway/twinkle_gateway_handlers.py | 105 +++++++ src/twinkle/server/gateway/twinkle_router.py | 106 ------- src/twinkle/server/model/app.py | 6 +- src/twinkle/server/model/backends/common.py | 130 ++++++++ .../server/model/backends/megatron_model.py | 3 +- .../model/backends/transformers_model.py | 215 +++---------- src/twinkle/server/model/tinker_handlers.py | 5 +- src/twinkle/server/model/twinkle_handlers.py | 18 +- src/twinkle_client/types/server.py | 4 + 12 files changed, 555 insertions(+), 618 deletions(-) create mode 100644 src/twinkle/server/gateway/tinker_gateway_handlers.py delete mode 100644 src/twinkle/server/gateway/tinker_router.py create mode 100644 src/twinkle/server/gateway/twinkle_gateway_handlers.py delete mode 100644 src/twinkle/server/gateway/twinkle_router.py create mode 100644 src/twinkle/server/model/backends/common.py diff --git a/src/twinkle/server/gateway/server.py b/src/twinkle/server/gateway/server.py index 9071a814..767a0e67 100644 --- a/src/twinkle/server/gateway/server.py +++ b/src/twinkle/server/gateway/server.py @@ -17,8 +17,8 @@ from twinkle.server.utils.validation import verify_request_token from twinkle.utils.logger import get_logger from .proxy import ServiceProxy -from .tinker_router import tinker_router -from .twinkle_router import twinkle_router +from .tinker_gateway_handlers import TinkerGatewayHandlers +from .twinkle_gateway_handlers import TwinkleGatewayHandlers logger = get_logger() @@ -48,22 +48,9 @@ def build_server_app(deploy_options: dict[str, Any], async def verify_token(request: Request, call_next): return await verify_request_token(request=request, call_next=call_next) - @app.middleware('http') - async def inject_dependencies(request: Request, call_next): - """Middleware to inject GatewayServer dependencies into request.state. - - This must run after GatewayServer is instantiated. We use a marker - set by the first request to initialize the state reference. - """ - # The GatewayServer instance will set itself on the app state - server = getattr(app.state, 'gateway_server', None) - if server: - server._setup_request_state(request) - return await call_next(request) - @serve.deployment(name='GatewayServer') @serve.ingress(app) - class GatewayServer: + class GatewayServer(TinkerGatewayHandlers, TwinkleGatewayHandlers): """Unified gateway server handling both Tinker and Twinkle API clients.""" def __init__(self, @@ -79,8 +66,6 @@ def __init__(self, tinker_types.SupportedModel(model_name='Qwen/Qwen3-30B-A3B-Instruct-2507'), ] self._modelscope_config_lock = asyncio.Lock() - # Register self on app state so middleware can access dependencies - app.state.gateway_server = self def _normalize_models(self, supported_models): if not supported_models: @@ -109,18 +94,9 @@ def _get_base_model(self, model_id: str) -> str: return metadata['base_model'] raise HTTPException(status_code=404, detail=f'Model {model_id} not found') - def _setup_request_state(self, request: Request): - """Inject dependencies into request.state for router handlers.""" - request.state.server_state = self.state - request.state.proxy = self.proxy - request.state.supported_models = self.supported_models - request.state.modelscope_config_lock = self._modelscope_config_lock - request.state.validate_base_model = self._validate_base_model - request.state.get_base_model = self._get_base_model - - # Include routers for Tinker and Twinkle endpoints - app.include_router(tinker_router, prefix='/tinker') - app.include_router(twinkle_router, prefix='/twinkle') + # Register routes from both handler mixins + TinkerGatewayHandlers._register_tinker_routes(app) + TwinkleGatewayHandlers._register_twinkle_routes(app) return GatewayServer.options(**deploy_options).bind( supported_models=supported_models, server_config=server_config, http_options=http_options, **kwargs) diff --git a/src/twinkle/server/gateway/tinker_gateway_handlers.py b/src/twinkle/server/gateway/tinker_gateway_handlers.py new file mode 100644 index 00000000..08db7f91 --- /dev/null +++ b/src/twinkle/server/gateway/tinker_gateway_handlers.py @@ -0,0 +1,256 @@ +# Copyright (c) ModelScope Contributors. All rights reserved. +""" +Tinker-compatible gateway handler mixin. + +All endpoints are prefixed /tinker/* and registered via _register_tinker_routes(app). +Route closures use self.* directly (no request.state injection needed). +""" +from __future__ import annotations + +import asyncio +import os +from fastapi import FastAPI, HTTPException, Request, Response +from tinker import types +from typing import Any + +from twinkle.hub import HubOperation +from twinkle.server.common.io_utils import create_checkpoint_manager, create_training_run_manager +from twinkle.server.utils.task_queue import QueueState +from twinkle.server.utils.validation import get_token_from_request +from twinkle.utils.logger import get_logger + +logger = get_logger() + + +class TinkerGatewayHandlers: + """ + Mixin providing Tinker-compatible gateway endpoints. + + Expects the combined class to have: + self.state, self.proxy, self.supported_models, + self._modelscope_config_lock, self._validate_base_model(), self._get_base_model() + """ + + @staticmethod + def _register_tinker_routes(app: FastAPI): + """Register all /tinker/* routes on the given FastAPI app.""" + + @app.get('/tinker/healthz') + async def healthz(self, request: Request) -> types.HealthResponse: + return types.HealthResponse(status='ok') + + @app.get('/tinker/get_server_capabilities') + async def get_server_capabilities(self, request: Request) -> types.GetServerCapabilitiesResponse: + return types.GetServerCapabilitiesResponse(supported_models=self.supported_models) + + @app.post('/tinker/telemetry') + async def telemetry(self, request: Request, body: types.TelemetrySendRequest) -> types.TelemetryResponse: + return types.TelemetryResponse(status='accepted') + + @app.post('/tinker/create_session') + async def create_session(self, request: Request, + body: types.CreateSessionRequest) -> types.CreateSessionResponse: + session_id = self.state.create_session(body.model_dump()) + return types.CreateSessionResponse(session_id=session_id) + + @app.post('/tinker/session_heartbeat') + async def session_heartbeat(self, request: Request, + body: types.SessionHeartbeatRequest) -> types.SessionHeartbeatResponse: + alive = self.state.touch_session(body.session_id) + if not alive: + raise HTTPException(status_code=404, detail='Unknown session') + return types.SessionHeartbeatResponse() + + @app.post('/tinker/create_sampling_session') + async def create_sampling_session( + self, request: Request, + body: types.CreateSamplingSessionRequest) -> types.CreateSamplingSessionResponse: + sampling_session_id = self.state.create_sampling_session(body.model_dump()) + return types.CreateSamplingSessionResponse(sampling_session_id=sampling_session_id) + + @app.post('/tinker/retrieve_future') + async def retrieve_future(self, request: Request, body: types.FutureRetrieveRequest) -> Any: + """Retrieve the result of an async task with long polling.""" + request_id = body.request_id + max_wait = float(os.environ.get('TWINKLE_LONG_POLL_TIMEOUT', '30')) + poll_interval = float(os.environ.get('TWINKLE_POLL_INTERVAL', '0.5')) + start = asyncio.get_event_loop().time() + + while True: + record = self.state.get_future(request_id) + + if record is None: + return {'type': 'try_again'} + + status = record.get('status') + if status not in ('pending', 'queued', 'running', 'rate_limited'): + break + + if asyncio.get_event_loop().time() - start >= max_wait: + response_data = {'type': 'try_again'} + if queue_state := record.get('queue_state'): + response_data['queue_state'] = queue_state + if queue_state_reason := record.get('queue_state_reason'): + response_data['queue_state_reason'] = queue_state_reason + return response_data + + await asyncio.sleep(poll_interval) + + record = self.state.get_future(request_id) + if not record: + return {'type': 'try_again'} + + status = record.get('status') + + if status == 'rate_limited': + return { + 'type': 'try_again', + 'queue_state': QueueState.PAUSED_RATE_LIMIT.value, + 'queue_state_reason': record.get('reason', 'Rate limit exceeded') + } + + if status == 'failed': + result = record.get('result', {}) + return {'error': result.get('error', 'Unknown error'), 'category': result.get('category', 'Server')} + + result = record.get('result') + if result is None: + raise HTTPException(status_code=500, detail='Task completed but no result found') + + if hasattr(result, 'model_dump'): + return result.model_dump() + return result + + # --- Training Runs Endpoints --- + + @app.get('/tinker/training_runs') + async def get_training_runs(self, + request: Request, + limit: int = 20, + offset: int = 0) -> types.TrainingRunsResponse: + token = get_token_from_request(request) + training_run_manager = create_training_run_manager(token, client_type='tinker') + return training_run_manager.list_runs(limit=limit, offset=offset) + + @app.get('/tinker/training_runs/{run_id}') + async def get_training_run(self, request: Request, run_id: str) -> types.TrainingRun: + token = get_token_from_request(request) + training_run_manager = create_training_run_manager(token, client_type='tinker') + run = training_run_manager.get(run_id) + if not run: + raise HTTPException(status_code=404, detail=f'Training run {run_id} not found') + return run + + @app.get('/tinker/training_runs/{run_id}/checkpoints') + async def get_run_checkpoints(self, request: Request, run_id: str) -> types.CheckpointsListResponse: + token = get_token_from_request(request) + checkpoint_manager = create_checkpoint_manager(token, client_type='tinker') + response = checkpoint_manager.list_checkpoints(run_id) + if not response: + raise HTTPException(status_code=404, detail=f'Training run {run_id} not found') + return response + + @app.delete('/tinker/training_runs/{run_id}/checkpoints/{checkpoint_id:path}') + async def delete_run_checkpoint(self, request: Request, run_id: str, checkpoint_id: str) -> Any: + token = get_token_from_request(request) + checkpoint_manager = create_checkpoint_manager(token, client_type='tinker') + success = checkpoint_manager.delete(run_id, checkpoint_id) + if not success: + raise HTTPException(status_code=404, detail=f'Checkpoint {checkpoint_id} not found for run {run_id}') + return None + + @app.post('/tinker/weights_info') + async def weights_info(self, request: Request, body: dict[str, Any]) -> types.WeightsInfoResponse: + token = get_token_from_request(request) + checkpoint_manager = create_checkpoint_manager(token, client_type='tinker') + tinker_path = body.get('tinker_path') + response = checkpoint_manager.get_weights_info(tinker_path) + if not response: + raise HTTPException(status_code=404, detail=f'Weights at {tinker_path} not found') + return response + + @app.post('/tinker/training_runs/{run_id}/checkpoints/{checkpoint_id:path}/publish') + async def publish_checkpoint(self, request: Request, run_id: str, checkpoint_id: str) -> Response: + token = get_token_from_request(request) + + training_run_manager = create_training_run_manager(token, client_type='tinker') + checkpoint_manager = create_checkpoint_manager(token, client_type='tinker') + + run = training_run_manager.get(run_id) + if not run: + raise HTTPException(status_code=404, detail=f'Training run {run_id} not found or access denied') + + checkpoint = checkpoint_manager.get(run_id, checkpoint_id) + if not checkpoint: + raise HTTPException(status_code=404, detail=f'Checkpoint {checkpoint_id} not found') + + checkpoint_dir = str(checkpoint_manager.get_ckpt_dir(run_id, checkpoint_id)) + + async with self._modelscope_config_lock: + try: + from modelscope.hub.api import HubApi, ModelScopeConfig + hub_api = HubApi(token=token) + hub_api.login() + username = ModelScopeConfig.get_user_info()[0] + except Exception as e: + logger.error(f'Failed to get username from ModelScope: {e}') + raise HTTPException( + status_code=401, + detail='Failed to get username from ModelScope. Please ensure your token is valid.') + + checkpoint_name = checkpoint_id.split('/')[-1] + hub_model_id = f'{username}/{run_id}_{checkpoint_name}' + HubOperation.async_push_to_hub(repo_id=hub_model_id, folder_path=checkpoint_dir, token=token, private=True) + + return Response(status_code=204) + + # --- Model Proxy Endpoints --- + + @app.post('/tinker/create_model') + async def create_model(self, request: Request, body: types.CreateModelRequest) -> Any: + self._validate_base_model(body.base_model) + return await self.proxy.proxy_to_model(request, 'create_model', body.base_model) + + @app.post('/tinker/get_info') + async def get_info(self, request: Request, body: types.GetInfoRequest) -> Any: + return await self.proxy.proxy_to_model(request, 'get_info', self._get_base_model(body.model_id)) + + @app.post('/tinker/unload_model') + async def unload_model(self, request: Request, body: types.UnloadModelRequest) -> Any: + return await self.proxy.proxy_to_model(request, 'unload_model', self._get_base_model(body.model_id)) + + @app.post('/tinker/forward') + async def forward(self, request: Request, body: types.ForwardRequest) -> Any: + return await self.proxy.proxy_to_model(request, 'forward', self._get_base_model(body.model_id)) + + @app.post('/tinker/forward_backward') + async def forward_backward(self, request: Request, body: types.ForwardBackwardRequest) -> Any: + return await self.proxy.proxy_to_model(request, 'forward_backward', self._get_base_model(body.model_id)) + + @app.post('/tinker/optim_step') + async def optim_step(self, request: Request, body: types.OptimStepRequest) -> Any: + return await self.proxy.proxy_to_model(request, 'optim_step', self._get_base_model(body.model_id)) + + @app.post('/tinker/save_weights') + async def save_weights(self, request: Request, body: types.SaveWeightsRequest) -> Any: + return await self.proxy.proxy_to_model(request, 'save_weights', self._get_base_model(body.model_id)) + + @app.post('/tinker/load_weights') + async def load_weights(self, request: Request, body: types.LoadWeightsRequest) -> Any: + return await self.proxy.proxy_to_model(request, 'load_weights', self._get_base_model(body.model_id)) + + # --- Sampler Proxy Endpoints --- + + @app.post('/tinker/asample') + async def asample(self, request: Request, body: types.SampleRequest) -> Any: + base_model = body.base_model + if not base_model and body.sampling_session_id: + session = self.state.get_sampling_session(body.sampling_session_id) + if session: + base_model = session.get('base_model') + return await self.proxy.proxy_to_sampler(request, 'asample', base_model) + + @app.post('/tinker/save_weights_for_sampler') + async def save_weights_for_sampler(self, request: Request, body: types.SaveWeightsForSamplerRequest) -> Any: + return await self.proxy.proxy_to_model(request, 'save_weights_for_sampler', + self._get_base_model(body.model_id)) diff --git a/src/twinkle/server/gateway/tinker_router.py b/src/twinkle/server/gateway/tinker_router.py deleted file mode 100644 index f4290587..00000000 --- a/src/twinkle/server/gateway/tinker_router.py +++ /dev/null @@ -1,289 +0,0 @@ -# Copyright (c) ModelScope Contributors. All rights reserved. -""" -Tinker-compatible gateway router. - -Provides all tinker management and proxy endpoints under /tinker/* prefix. -Extracted from tinker/server.py — same endpoint logic, now on an APIRouter. -""" -from __future__ import annotations - -import asyncio -import os -from fastapi import APIRouter, HTTPException, Request, Response -from tinker import types -from typing import Any - -from twinkle.hub import HubOperation -from twinkle.server.common.io_utils import create_checkpoint_manager, create_training_run_manager -from twinkle.server.utils.task_queue import QueueState -from twinkle.server.utils.validation import get_token_from_request -from twinkle.utils.logger import get_logger - -logger = get_logger() - -tinker_router = APIRouter() - - -@tinker_router.get('/healthz') -async def healthz(request: Request) -> types.HealthResponse: - return types.HealthResponse(status='ok') - - -@tinker_router.get('/get_server_capabilities') -async def get_server_capabilities(request: Request) -> types.GetServerCapabilitiesResponse: - # GatewayServer injects self.supported_models via request.state in middleware - supported_models = getattr(request.state, 'supported_models', []) - return types.GetServerCapabilitiesResponse(supported_models=supported_models) - - -@tinker_router.post('/telemetry') -async def telemetry(request: Request, body: types.TelemetrySendRequest) -> types.TelemetryResponse: - return types.TelemetryResponse(status='accepted') - - -@tinker_router.post('/create_session') -async def create_session(request: Request, body: types.CreateSessionRequest) -> types.CreateSessionResponse: - state = request.state.server_state - session_id = state.create_session(body.model_dump()) - return types.CreateSessionResponse(session_id=session_id) - - -@tinker_router.post('/session_heartbeat') -async def session_heartbeat(request: Request, body: types.SessionHeartbeatRequest) -> types.SessionHeartbeatResponse: - state = request.state.server_state - alive = state.touch_session(body.session_id) - if not alive: - raise HTTPException(status_code=404, detail='Unknown session') - return types.SessionHeartbeatResponse() - - -@tinker_router.post('/create_sampling_session') -async def create_sampling_session(request: Request, - body: types.CreateSamplingSessionRequest) -> types.CreateSamplingSessionResponse: - state = request.state.server_state - sampling_session_id = state.create_sampling_session(body.model_dump()) - return types.CreateSamplingSessionResponse(sampling_session_id=sampling_session_id) - - -@tinker_router.post('/retrieve_future') -async def retrieve_future(request: Request, body: types.FutureRetrieveRequest) -> Any: - """Retrieve the result of an async task with long polling.""" - state = request.state.server_state - request_id = body.request_id - max_wait = float(os.environ.get('TWINKLE_LONG_POLL_TIMEOUT', '30')) - poll_interval = float(os.environ.get('TWINKLE_POLL_INTERVAL', '0.5')) - start = asyncio.get_event_loop().time() - - while True: - record = state.get_future(request_id) - - if record is None: - return {'type': 'try_again'} - - status = record.get('status') - if status not in ('pending', 'queued', 'running', 'rate_limited'): - break - - if asyncio.get_event_loop().time() - start >= max_wait: - response_data = {'type': 'try_again'} - if queue_state := record.get('queue_state'): - response_data['queue_state'] = queue_state - if queue_state_reason := record.get('queue_state_reason'): - response_data['queue_state_reason'] = queue_state_reason - return response_data - - await asyncio.sleep(poll_interval) - - record = state.get_future(request_id) - if not record: - return {'type': 'try_again'} - - status = record.get('status') - - if status == 'rate_limited': - return { - 'type': 'try_again', - 'queue_state': QueueState.PAUSED_RATE_LIMIT.value, - 'queue_state_reason': record.get('reason', 'Rate limit exceeded') - } - - if status == 'failed': - result = record.get('result', {}) - return {'error': result.get('error', 'Unknown error'), 'category': result.get('category', 'Server')} - - result = record.get('result') - if result is None: - raise HTTPException(status_code=500, detail='Task completed but no result found') - - if hasattr(result, 'model_dump'): - return result.model_dump() - return result - - -# --- Training Runs Endpoints --- - - -@tinker_router.get('/training_runs') -async def get_training_runs(request: Request, limit: int = 20, offset: int = 0) -> types.TrainingRunsResponse: - token = get_token_from_request(request) - training_run_manager = create_training_run_manager(token, client_type='tinker') - return training_run_manager.list_runs(limit=limit, offset=offset) - - -@tinker_router.get('/training_runs/{run_id}') -async def get_training_run(request: Request, run_id: str) -> types.TrainingRun: - token = get_token_from_request(request) - training_run_manager = create_training_run_manager(token, client_type='tinker') - run = training_run_manager.get(run_id) - if not run: - raise HTTPException(status_code=404, detail=f'Training run {run_id} not found') - return run - - -@tinker_router.get('/training_runs/{run_id}/checkpoints') -async def get_run_checkpoints(request: Request, run_id: str) -> types.CheckpointsListResponse: - token = get_token_from_request(request) - checkpoint_manager = create_checkpoint_manager(token, client_type='tinker') - response = checkpoint_manager.list_checkpoints(run_id) - if not response: - raise HTTPException(status_code=404, detail=f'Training run {run_id} not found') - return response - - -@tinker_router.delete('/training_runs/{run_id}/checkpoints/{checkpoint_id:path}') -async def delete_run_checkpoint(request: Request, run_id: str, checkpoint_id: str) -> Any: - token = get_token_from_request(request) - checkpoint_manager = create_checkpoint_manager(token, client_type='tinker') - success = checkpoint_manager.delete(run_id, checkpoint_id) - if not success: - raise HTTPException(status_code=404, detail=f'Checkpoint {checkpoint_id} not found for run {run_id}') - return None - - -@tinker_router.post('/weights_info') -async def weights_info(request: Request, body: dict[str, Any]) -> types.WeightsInfoResponse: - token = get_token_from_request(request) - checkpoint_manager = create_checkpoint_manager(token, client_type='tinker') - tinker_path = body.get('tinker_path') - response = checkpoint_manager.get_weights_info(tinker_path) - if not response: - raise HTTPException(status_code=404, detail=f'Weights at {tinker_path} not found') - return response - - -@tinker_router.post('/training_runs/{run_id}/checkpoints/{checkpoint_id:path}/publish') -async def publish_checkpoint(request: Request, run_id: str, checkpoint_id: str) -> Response: - token = get_token_from_request(request) - modelscope_config_lock = request.state.modelscope_config_lock - - training_run_manager = create_training_run_manager(token, client_type='tinker') - checkpoint_manager = create_checkpoint_manager(token, client_type='tinker') - - run = training_run_manager.get(run_id) - if not run: - raise HTTPException(status_code=404, detail=f'Training run {run_id} not found or access denied') - - checkpoint = checkpoint_manager.get(run_id, checkpoint_id) - if not checkpoint: - raise HTTPException(status_code=404, detail=f'Checkpoint {checkpoint_id} not found') - - checkpoint_dir = str(checkpoint_manager.get_ckpt_dir(run_id, checkpoint_id)) - - async with modelscope_config_lock: - try: - from modelscope.hub.api import HubApi, ModelScopeConfig - hub_api = HubApi(token=token) - hub_api.login() - username = ModelScopeConfig.get_user_info()[0] - except Exception as e: - logger.error(f'Failed to get username from ModelScope: {e}') - raise HTTPException( - status_code=401, detail='Failed to get username from ModelScope. Please ensure your token is valid.') - - checkpoint_name = checkpoint_id.split('/')[-1] - hub_model_id = f'{username}/{run_id}_{checkpoint_name}' - HubOperation.async_push_to_hub(repo_id=hub_model_id, folder_path=checkpoint_dir, token=token, private=True) - - return Response(status_code=204) - - -# --- Model Proxy Endpoints --- - - -@tinker_router.post('/create_model') -async def create_model(request: Request, body: types.CreateModelRequest) -> Any: - proxy = request.state.proxy - validate_base_model = request.state.validate_base_model - validate_base_model(body.base_model) - return await proxy.proxy_to_model(request, 'create_model', body.base_model) - - -@tinker_router.post('/get_info') -async def get_info(request: Request, body: types.GetInfoRequest) -> Any: - proxy = request.state.proxy - get_base_model = request.state.get_base_model - return await proxy.proxy_to_model(request, 'get_info', get_base_model(body.model_id)) - - -@tinker_router.post('/unload_model') -async def unload_model(request: Request, body: types.UnloadModelRequest) -> Any: - proxy = request.state.proxy - get_base_model = request.state.get_base_model - return await proxy.proxy_to_model(request, 'unload_model', get_base_model(body.model_id)) - - -@tinker_router.post('/forward') -async def forward(request: Request, body: types.ForwardRequest) -> Any: - proxy = request.state.proxy - get_base_model = request.state.get_base_model - return await proxy.proxy_to_model(request, 'forward', get_base_model(body.model_id)) - - -@tinker_router.post('/forward_backward') -async def forward_backward(request: Request, body: types.ForwardBackwardRequest) -> Any: - proxy = request.state.proxy - get_base_model = request.state.get_base_model - return await proxy.proxy_to_model(request, 'forward_backward', get_base_model(body.model_id)) - - -@tinker_router.post('/optim_step') -async def optim_step(request: Request, body: types.OptimStepRequest) -> Any: - proxy = request.state.proxy - get_base_model = request.state.get_base_model - return await proxy.proxy_to_model(request, 'optim_step', get_base_model(body.model_id)) - - -@tinker_router.post('/save_weights') -async def save_weights(request: Request, body: types.SaveWeightsRequest) -> Any: - proxy = request.state.proxy - get_base_model = request.state.get_base_model - return await proxy.proxy_to_model(request, 'save_weights', get_base_model(body.model_id)) - - -@tinker_router.post('/load_weights') -async def load_weights(request: Request, body: types.LoadWeightsRequest) -> Any: - proxy = request.state.proxy - get_base_model = request.state.get_base_model - return await proxy.proxy_to_model(request, 'load_weights', get_base_model(body.model_id)) - - -# --- Sampler Proxy Endpoints --- - - -@tinker_router.post('/asample') -async def asample(request: Request, body: types.SampleRequest) -> Any: - proxy = request.state.proxy - state = request.state.server_state - base_model = body.base_model - if not base_model and body.sampling_session_id: - session = state.get_sampling_session(body.sampling_session_id) - if session: - base_model = session.get('base_model') - return await proxy.proxy_to_sampler(request, 'asample', base_model) - - -@tinker_router.post('/save_weights_for_sampler') -async def save_weights_for_sampler(request: Request, body: types.SaveWeightsForSamplerRequest) -> Any: - proxy = request.state.proxy - get_base_model = request.state.get_base_model - return await proxy.proxy_to_model(request, 'save_weights_for_sampler', get_base_model(body.model_id)) diff --git a/src/twinkle/server/gateway/twinkle_gateway_handlers.py b/src/twinkle/server/gateway/twinkle_gateway_handlers.py new file mode 100644 index 00000000..4d19cbf1 --- /dev/null +++ b/src/twinkle/server/gateway/twinkle_gateway_handlers.py @@ -0,0 +1,105 @@ +# Copyright (c) ModelScope Contributors. All rights reserved. +""" +Twinkle-native gateway handler mixin. + +All endpoints are prefixed /twinkle/* and registered via _register_twinkle_routes(app). +Route closures use self.* directly (no request.state injection needed). +""" +from __future__ import annotations + +from fastapi import FastAPI, HTTPException, Request + +from twinkle.server.common.io_utils import create_checkpoint_manager, create_training_run_manager, validate_user_path +from twinkle.server.utils.validation import get_token_from_request +from twinkle.utils.logger import get_logger +from twinkle_client.types.server import DeleteCheckpointResponse, HealthResponse, WeightsInfoRequest +from twinkle_client.types.training import (CheckpointsListResponse, TrainingRun, TrainingRunsResponse, + WeightsInfoResponse) + +logger = get_logger() + + +class TwinkleGatewayHandlers: + """ + Mixin providing Twinkle-native gateway management endpoints. + + Expects the combined class to have: self.state + """ + + @staticmethod + def _register_twinkle_routes(app: FastAPI): + """Register all /twinkle/* routes on the given FastAPI app.""" + + @app.get('/twinkle/healthz', response_model=HealthResponse) + async def healthz(self, request: Request) -> HealthResponse: + return HealthResponse(status='ok') + + @app.get('/twinkle/training_runs', response_model=TrainingRunsResponse) + async def get_training_runs(self, request: Request, limit: int = 20, offset: int = 0) -> TrainingRunsResponse: + token = get_token_from_request(request) + training_run_manager = create_training_run_manager(token, client_type='twinkle') + return training_run_manager.list_runs(limit=limit, offset=offset) + + @app.get('/twinkle/training_runs/{run_id}', response_model=TrainingRun) + async def get_training_run(self, request: Request, run_id: str) -> TrainingRun: + token = get_token_from_request(request) + training_run_manager = create_training_run_manager(token, client_type='twinkle') + run = training_run_manager.get_with_permission(run_id) + if not run: + raise HTTPException(status_code=404, detail=f'Training run {run_id} not found or access denied') + return run + + @app.get('/twinkle/training_runs/{run_id}/checkpoints', response_model=CheckpointsListResponse) + async def get_run_checkpoints(self, request: Request, run_id: str) -> CheckpointsListResponse: + token = get_token_from_request(request) + checkpoint_manager = create_checkpoint_manager(token, client_type='twinkle') + response = checkpoint_manager.list_checkpoints(run_id) + if response is None: + raise HTTPException(status_code=404, detail=f'Training run {run_id} not found or access denied') + return response + + @app.delete('/twinkle/training_runs/{run_id}/checkpoints/{checkpoint_id:path}') + async def delete_run_checkpoint(self, request: Request, run_id: str, + checkpoint_id: str) -> DeleteCheckpointResponse: + token = get_token_from_request(request) + + if not validate_user_path(token, checkpoint_id): + raise HTTPException(status_code=400, detail='Invalid checkpoint path: path traversal not allowed') + + checkpoint_manager = create_checkpoint_manager(token, client_type='twinkle') + success = checkpoint_manager.delete(run_id, checkpoint_id) + if not success: + raise HTTPException(status_code=404, detail=f'Checkpoint {checkpoint_id} not found or access denied') + + return DeleteCheckpointResponse(success=True, message=f'Checkpoint {checkpoint_id} deleted successfully') + + @app.post('/twinkle/weights_info', response_model=WeightsInfoResponse) + async def weights_info(self, request: Request, body: WeightsInfoRequest) -> WeightsInfoResponse: + token = get_token_from_request(request) + checkpoint_manager = create_checkpoint_manager(token, client_type='twinkle') + response = checkpoint_manager.get_weights_info(body.twinkle_path) + if response is None: + raise HTTPException( + status_code=404, detail=f'Weights at {body.twinkle_path} not found or access denied') + return response + + @app.get('/twinkle/checkpoint_path/{run_id}/{checkpoint_id:path}') + async def get_checkpoint_path(self, request: Request, run_id: str, checkpoint_id: str) -> dict[str, str]: + token = get_token_from_request(request) + + if not validate_user_path(token, checkpoint_id): + raise HTTPException(status_code=400, detail='Invalid checkpoint path: path traversal not allowed') + + training_run_manager = create_training_run_manager(token, client_type='twinkle') + checkpoint_manager = create_checkpoint_manager(token, client_type='twinkle') + + run = training_run_manager.get(run_id) + if not run: + raise HTTPException(status_code=404, detail=f'Training run {run_id} not found or access denied') + + checkpoint = checkpoint_manager.get(run_id, checkpoint_id) + if not checkpoint: + raise HTTPException(status_code=404, detail=f'Checkpoint {checkpoint_id} not found') + + ckpt_dir = checkpoint_manager.get_ckpt_dir(run_id, checkpoint_id) + return {'path': str(ckpt_dir), 'twinkle_path': checkpoint.twinkle_path} diff --git a/src/twinkle/server/gateway/twinkle_router.py b/src/twinkle/server/gateway/twinkle_router.py deleted file mode 100644 index 9cd2af75..00000000 --- a/src/twinkle/server/gateway/twinkle_router.py +++ /dev/null @@ -1,106 +0,0 @@ -# Copyright (c) ModelScope Contributors. All rights reserved. -""" -Twinkle-native gateway router. - -Provides all twinkle management endpoints under /twinkle/* prefix. -Extracted from twinkle/server.py — same endpoint logic, now on an APIRouter. -""" -from __future__ import annotations - -from fastapi import APIRouter, HTTPException, Request -from pydantic import BaseModel -from typing import Any - -from twinkle.server.common.io_utils import create_checkpoint_manager, create_training_run_manager, validate_user_path -from twinkle.server.utils.validation import get_token_from_request -from twinkle.utils.logger import get_logger -from twinkle_client.types.server import DeleteCheckpointResponse, HealthResponse -from twinkle_client.types.training import (CheckpointsListResponse, TrainingRun, TrainingRunsResponse, - WeightsInfoResponse) - -logger = get_logger() - -twinkle_router = APIRouter() - - -class WeightsInfoRequest(BaseModel): - twinkle_path: str - - -@twinkle_router.get('/healthz', response_model=HealthResponse) -async def healthz(request: Request) -> HealthResponse: - return HealthResponse(status='ok') - - -@twinkle_router.get('/training_runs', response_model=TrainingRunsResponse) -async def get_training_runs(request: Request, limit: int = 20, offset: int = 0) -> TrainingRunsResponse: - token = get_token_from_request(request) - training_run_manager = create_training_run_manager(token, client_type='twinkle') - return training_run_manager.list_runs(limit=limit, offset=offset) - - -@twinkle_router.get('/training_runs/{run_id}', response_model=TrainingRun) -async def get_training_run(request: Request, run_id: str) -> TrainingRun: - token = get_token_from_request(request) - training_run_manager = create_training_run_manager(token, client_type='twinkle') - run = training_run_manager.get_with_permission(run_id) - if not run: - raise HTTPException(status_code=404, detail=f'Training run {run_id} not found or access denied') - return run - - -@twinkle_router.get('/training_runs/{run_id}/checkpoints', response_model=CheckpointsListResponse) -async def get_run_checkpoints(request: Request, run_id: str) -> CheckpointsListResponse: - token = get_token_from_request(request) - checkpoint_manager = create_checkpoint_manager(token, client_type='twinkle') - response = checkpoint_manager.list_checkpoints(run_id) - if response is None: - raise HTTPException(status_code=404, detail=f'Training run {run_id} not found or access denied') - return response - - -@twinkle_router.delete('/training_runs/{run_id}/checkpoints/{checkpoint_id:path}') -async def delete_run_checkpoint(request: Request, run_id: str, checkpoint_id: str) -> DeleteCheckpointResponse: - token = get_token_from_request(request) - - if not validate_user_path(token, checkpoint_id): - raise HTTPException(status_code=400, detail='Invalid checkpoint path: path traversal not allowed') - - checkpoint_manager = create_checkpoint_manager(token, client_type='twinkle') - success = checkpoint_manager.delete(run_id, checkpoint_id) - if not success: - raise HTTPException(status_code=404, detail=f'Checkpoint {checkpoint_id} not found or access denied') - - return DeleteCheckpointResponse(success=True, message=f'Checkpoint {checkpoint_id} deleted successfully') - - -@twinkle_router.post('/weights_info', response_model=WeightsInfoResponse) -async def weights_info(request: Request, body: WeightsInfoRequest) -> WeightsInfoResponse: - token = get_token_from_request(request) - checkpoint_manager = create_checkpoint_manager(token, client_type='twinkle') - response = checkpoint_manager.get_weights_info(body.twinkle_path) - if response is None: - raise HTTPException(status_code=404, detail=f'Weights at {body.twinkle_path} not found or access denied') - return response - - -@twinkle_router.get('/checkpoint_path/{run_id}/{checkpoint_id:path}') -async def get_checkpoint_path(request: Request, run_id: str, checkpoint_id: str) -> dict[str, str]: - token = get_token_from_request(request) - - if not validate_user_path(token, checkpoint_id): - raise HTTPException(status_code=400, detail='Invalid checkpoint path: path traversal not allowed') - - training_run_manager = create_training_run_manager(token, client_type='twinkle') - checkpoint_manager = create_checkpoint_manager(token, client_type='twinkle') - - run = training_run_manager.get(run_id) - if not run: - raise HTTPException(status_code=404, detail=f'Training run {run_id} not found or access denied') - - checkpoint = checkpoint_manager.get(run_id, checkpoint_id) - if not checkpoint: - raise HTTPException(status_code=404, detail=f'Checkpoint {checkpoint_id} not found') - - ckpt_dir = checkpoint_manager.get_ckpt_dir(run_id, checkpoint_id) - return {'path': str(ckpt_dir), 'twinkle_path': checkpoint.twinkle_path} diff --git a/src/twinkle/server/model/app.py b/src/twinkle/server/model/app.py index e5926f99..9887ace2 100644 --- a/src/twinkle/server/model/app.py +++ b/src/twinkle/server/model/app.py @@ -53,8 +53,6 @@ def build_model_app(model_id: str, Configured Ray Serve deployment bound with parameters """ app = FastAPI() - # Mutable list so inner route functions can capture the model_id - model_id_ref = [model_id] @app.middleware('http') async def verify_token(request: Request, call_next): @@ -149,8 +147,8 @@ def _on_adapter_expired(self, adapter_name: str) -> None: self._cleanup_adapter(adapter_name) # Register routes from both handler mixins - TinkerModelHandlers._register_tinker_routes(app, model_id_ref) - TwinkleModelHandlers._register_twinkle_routes(app, model_id_ref) + TinkerModelHandlers._register_tinker_routes(app, model_id) + TwinkleModelHandlers._register_twinkle_routes(app, model_id) return ModelManagement.options(**deploy_options).bind(nproc_per_node, device_group, device_mesh, use_megatron, queue_config, **kwargs) diff --git a/src/twinkle/server/model/backends/common.py b/src/twinkle/server/model/backends/common.py new file mode 100644 index 00000000..3553d657 --- /dev/null +++ b/src/twinkle/server/model/backends/common.py @@ -0,0 +1,130 @@ +# Copyright (c) ModelScope Contributors. All rights reserved. +""" +Shared helpers and base classes for backend model implementations. +""" +import numpy as np +import re +import torch +from numbers import Number +from tinker import types +from typing import List + +from twinkle import DeviceMesh +from twinkle.template import Template + + +def collect_forward_backward_results(results, device_mesh: DeviceMesh): + """Custom collect function for forward_backward that handles list [outputs, loss].""" + if not results: + return results + + pp_last_ranks = None + if device_mesh.pp_world_size > 1: + pp_last_ranks = set(device_mesh.get_pp_last_ranks()) + + tp_last_ranks = None + if device_mesh.tp_world_size > 1: + tp_last_ranks = set(device_mesh.get_tp_last_ranks()) + + mesh_flat = device_mesh.mesh.flatten() + + all_outputs = [] + all_losses = [] + for i, result in enumerate(results): + rank = mesh_flat[i] if i < len(mesh_flat) else -1 + + if pp_last_ranks is not None: + if rank not in pp_last_ranks: + continue + + if tp_last_ranks is not None: + if rank not in tp_last_ranks: + continue + + if result is None: + continue + + outputs, loss = result + if outputs is None or loss is None: + continue + all_outputs.extend(outputs) + all_losses.append(loss) + + if all_losses: + avg_loss = float(np.mean(all_losses)) + else: + avg_loss = 0.0 + + return [all_outputs, avg_loss] + + +def clean_metrics(metrics: dict) -> dict: + + def _to_float(v): + if isinstance(v, (float, int, Number, np.generic, str)): + try: + return float(v) + except Exception: + return None + if isinstance(v, torch.Tensor) and v.numel() == 1: + try: + return float(v.item()) + except Exception: + return None + return None + + cleaned = {} + for key, value in metrics.items(): + fv = _to_float(value) + if fv is not None: + cleaned[key] = fv + continue + + if isinstance(value, str): + s = value.strip() + if s: + try: + head, unit = s.split() + cleaned[f'{key}/{unit}'] = float(head) + except Exception: + m = re.match(r'^([+-]?(?:\d+(?:\.\d*)?|\.\d+)(?:[eE][+-]?\d+)?)', s) + if m: + cleaned[key] = float(m.group(1)) + + return cleaned + + +class TwinkleCompatModelBase: + """Base class containing common logic for Twinkle compatibility wrappers.""" + + def get_template(self, adapter_name: str) -> Template: + return self.optimizer_group[adapter_name].template + + @staticmethod + def _get_forward_output(inputs: List[types.Datum], logits: torch.Tensor, logps: torch.Tensor) -> List[dict]: + """Convert raw logits to the expected output format with logprobs and elementwise_loss.""" + from twinkle.utils.torch_utils import selective_log_softmax + device = logits.device if logits is not None else logps.device + results = [] + if logits is None: + logits = [None] * len(inputs) + for idx, (feature, logit) in enumerate(zip(inputs, logits)): + labels = feature.loss_fn_inputs['target_tokens'].to_torch().long().view(-1).to(device) + weights = feature.loss_fn_inputs['weights'].to_torch().view(-1).to(device) + + seq_len = labels.numel() + + if logps is None: + assert logits is not None + feature_logits = logit[:seq_len, :] + token_log_probs = selective_log_softmax(feature_logits, labels) + else: + token_log_probs = logps[idx, :seq_len] + + elementwise_loss = -token_log_probs * weights + + results.append({ + 'logprobs': types.TensorData.from_torch(token_log_probs.cpu()), + 'elementwise_loss': types.TensorData.from_torch(elementwise_loss.cpu()) + }) + return results diff --git a/src/twinkle/server/model/backends/megatron_model.py b/src/twinkle/server/model/backends/megatron_model.py index 61868247..29570c42 100644 --- a/src/twinkle/server/model/backends/megatron_model.py +++ b/src/twinkle/server/model/backends/megatron_model.py @@ -9,8 +9,7 @@ from twinkle import remote_class, remote_function from twinkle.server.common.datum import datum_to_input_feature, extract_rl_feature -from twinkle.server.model.backends.transformers_model import (TwinkleCompatModelBase, clean_metrics, - collect_forward_backward_results) +from twinkle.server.model.backends.common import TwinkleCompatModelBase, clean_metrics, collect_forward_backward_results from twinkle.utils import exists, requires if TYPE_CHECKING: diff --git a/src/twinkle/server/model/backends/transformers_model.py b/src/twinkle/server/model/backends/transformers_model.py index 2802741c..ef4194bf 100644 --- a/src/twinkle/server/model/backends/transformers_model.py +++ b/src/twinkle/server/model/backends/transformers_model.py @@ -2,11 +2,9 @@ """ Backend model implementations for the unified model deployment. -Contains two classes: -- TwinkleCompatTransformersModel: tinker-compat wrapper (Datum-based I/O), - moved from tinker/common/transformers_model.py. -- TwinkleCompatTransformersModelNative: twinkle-native wrapper - (InputFeature/Trajectory-based I/O), moved from twinkle/common/transformers_model.py. +Contains one unified class: +- TwinkleCompatTransformersModel: handles both tinker (Datum-based I/O) via /tinker/* + endpoints and twinkle-native (InputFeature/Trajectory-based I/O) via /twinkle/* endpoints. """ import numpy as np import torch @@ -14,147 +12,51 @@ from tinker import types from typing import Any, List, Union -# --------------------------------------------------------------------------- -# Shared helpers (moved from tinker/common/compat_base.py) -# --------------------------------------------------------------------------- -from twinkle import DeviceMesh, remote_class, remote_function +from twinkle import remote_class, remote_function from twinkle.data_format import InputFeature, Trajectory from twinkle.model import MultiLoraTransformersModel from twinkle.server.common.datum import datum_to_input_feature, extract_rl_feature -from twinkle.template import Template +from twinkle.server.model.backends.common import TwinkleCompatModelBase, clean_metrics, collect_forward_backward_results -def collect_forward_backward_results(results, device_mesh: DeviceMesh): - """Custom collect function for forward_backward that handles list [outputs, loss].""" - if not results: - return results - - pp_last_ranks = None - if device_mesh.pp_world_size > 1: - pp_last_ranks = set(device_mesh.get_pp_last_ranks()) - - tp_last_ranks = None - if device_mesh.tp_world_size > 1: - tp_last_ranks = set(device_mesh.get_tp_last_ranks()) - - mesh_flat = device_mesh.mesh.flatten() - - all_outputs = [] - all_losses = [] - for i, result in enumerate(results): - rank = mesh_flat[i] if i < len(mesh_flat) else -1 - - if pp_last_ranks is not None: - if rank not in pp_last_ranks: - continue - - if tp_last_ranks is not None: - if rank not in tp_last_ranks: - continue - - if result is None: - continue - - outputs, loss = result - if outputs is None or loss is None: - continue - all_outputs.extend(outputs) - all_losses.append(loss) - - if all_losses: - avg_loss = float(np.mean(all_losses)) - else: - avg_loss = 0.0 - - return [all_outputs, avg_loss] - - -def clean_metrics(metrics: dict) -> dict: - import re - from numbers import Number - - def _to_float(v): - if isinstance(v, (float, int, Number, np.generic, str)): - try: - return float(v) - except Exception: - return None - if isinstance(v, torch.Tensor) and v.numel() == 1: - try: - return float(v.item()) - except Exception: - return None - return None - - cleaned = {} - for key, value in metrics.items(): - fv = _to_float(value) - if fv is not None: - cleaned[key] = fv - continue - - if isinstance(value, str): - s = value.strip() - if s: - try: - head, unit = s.split() - cleaned[f'{key}/{unit}'] = float(head) - except Exception: - m = re.match(r'^([+-]?(?:\d+(?:\.\d*)?|\.\d+)(?:[eE][+-]?\d+)?)', s) - if m: - cleaned[key] = float(m.group(1)) - - return cleaned - +@remote_class() +class TwinkleCompatTransformersModel(MultiLoraTransformersModel, TwinkleCompatModelBase): + """Unified wrapper around MultiLoraTransformersModel. -class TwinkleCompatModelBase: - """Base class containing common logic for Twinkle compatibility wrappers.""" + Handles both: + - Tinker-compat I/O (Datum / TensorData) via /tinker/* endpoints. + - Twinkle-native I/O (InputFeature / Trajectory) via /twinkle/* endpoints. + """ - def get_template(self, adapter_name: str) -> Template: - return self.optimizer_group[adapter_name].template + # ------------------------------------------------------------------ + # Shared helper: CPU-safe serialisation for HTTP transport + # ------------------------------------------------------------------ @staticmethod - def _get_forward_output(inputs: List[types.Datum], logits: torch.Tensor, logps: torch.Tensor) -> List[dict]: - """Convert raw logits to the expected output format with logprobs and elementwise_loss.""" - from twinkle.utils.torch_utils import selective_log_softmax - device = logits.device if logits is not None else logps.device - results = [] - if logits is None: - logits = [None] * len(inputs) - for idx, (feature, logit) in enumerate(zip(inputs, logits)): - labels = feature.loss_fn_inputs['target_tokens'].to_torch().long().view(-1).to(device) - weights = feature.loss_fn_inputs['weights'].to_torch().view(-1).to(device) - - seq_len = labels.numel() - - if logps is None: - assert logits is not None - feature_logits = logit[:seq_len, :] - token_log_probs = selective_log_softmax(feature_logits, labels) - else: - token_log_probs = logps[idx, :seq_len] - - elementwise_loss = -token_log_probs * weights - - results.append({ - 'logprobs': types.TensorData.from_torch(token_log_probs.cpu()), - 'elementwise_loss': types.TensorData.from_torch(elementwise_loss.cpu()) - }) - return results - - -# --------------------------------------------------------------------------- -# Tinker-compat Transformers model (Datum-based I/O) -# --------------------------------------------------------------------------- - + def _to_cpu_safe_output(obj: Any) -> Any: + """Convert nested outputs into CPU-safe Python objects for HTTP transport.""" + from twinkle.utils import torch_util -@remote_class() -class TwinkleCompatTransformersModel(MultiLoraTransformersModel, TwinkleCompatModelBase): - """Tinker-compatible wrapper around MultiLoraTransformersModel. + if isinstance(obj, torch.Tensor): + tensor = torch_util.to_local_tensor(obj).detach().cpu() + if tensor.numel() == 1: + return tensor.item() + return tensor.tolist() + if isinstance(obj, np.ndarray): + if obj.size == 1: + return obj.item() + return obj.tolist() + if isinstance(obj, np.generic): + return obj.item() + if isinstance(obj, Mapping): + return {key: TwinkleCompatTransformersModel._to_cpu_safe_output(value) for key, value in obj.items()} + if isinstance(obj, (list, tuple)): + return [TwinkleCompatTransformersModel._to_cpu_safe_output(value) for value in obj] + return obj - Input/output is in tinker Datum / TensorData format. - Moved from tinker/common/transformers_model.py. - """ + # ------------------------------------------------------------------ + # Tinker-compat methods (Datum-based I/O) + # ------------------------------------------------------------------ @remote_function(dispatch='slice_dp', collect='flatten') def forward_only(self, *, inputs: List[types.Datum], **kwargs): @@ -224,44 +126,13 @@ def load(self, checkpoint_dir: str, **kwargs): else: return super().load(name=resolved.checkpoint_name, **kwargs) - -# --------------------------------------------------------------------------- -# Twinkle-native Transformers model (InputFeature/Trajectory-based I/O) -# --------------------------------------------------------------------------- - - -@remote_class() -class TwinkleCompatTransformersModelNative(MultiLoraTransformersModel): - """Twinkle-native wrapper around MultiLoraTransformersModel. - - Input/output is in native InputFeature / Trajectory format. - Moved from twinkle/common/transformers_model.py. - """ - - @staticmethod - def _to_cpu_safe_output(obj: Any) -> Any: - """Convert nested outputs into CPU-safe Python objects for HTTP transport.""" - from twinkle.utils import torch_util - - if isinstance(obj, torch.Tensor): - tensor = torch_util.to_local_tensor(obj).detach().cpu() - if tensor.numel() == 1: - return tensor.item() - return tensor.tolist() - if isinstance(obj, np.ndarray): - if obj.size == 1: - return obj.item() - return obj.tolist() - if isinstance(obj, np.generic): - return obj.item() - if isinstance(obj, Mapping): - return {key: TwinkleCompatTransformersModelNative._to_cpu_safe_output(value) for key, value in obj.items()} - if isinstance(obj, (list, tuple)): - return [TwinkleCompatTransformersModelNative._to_cpu_safe_output(value) for value in obj] - return obj + # ------------------------------------------------------------------ + # Twinkle-native methods (InputFeature/Trajectory-based I/O) + # ------------------------------------------------------------------ @remote_function(dispatch='slice_dp', collect='mean') - def forward_backward(self, *, inputs: Union[InputFeature, List[InputFeature], Trajectory, List[Trajectory]], - **kwargs): + def twinkle_forward_backward(self, *, inputs: Union[InputFeature, List[InputFeature], Trajectory, List[Trajectory]], + **kwargs): + """Forward+backward for twinkle-native clients (InputFeature/Trajectory I/O).""" output = super().forward_backward(inputs=inputs, **kwargs) return self._to_cpu_safe_output(output) diff --git a/src/twinkle/server/model/tinker_handlers.py b/src/twinkle/server/model/tinker_handlers.py index 66280676..3577cd93 100644 --- a/src/twinkle/server/model/tinker_handlers.py +++ b/src/twinkle/server/model/tinker_handlers.py @@ -27,17 +27,15 @@ class TinkerModelHandlers: """ @staticmethod - def _register_tinker_routes(app: FastAPI, model_id_ref: list): + def _register_tinker_routes(app: FastAPI, model_id: str): """Register all tinker routes on the given FastAPI app. This is called once during build_model_app to wire routes. - model_id_ref is a mutable list so we can capture the closure variable. """ @app.post('/tinker/create_model') async def create_model(self, request: Request, body: types.CreateModelRequest) -> types.UntypedAPIFuture: token = await self._on_request_start(request) - model_id = model_id_ref[0] async def _create_adapter(): _model_id = None @@ -70,7 +68,6 @@ async def _create_adapter(): @app.post('/tinker/get_info') async def get_info(self, request: Request, body: types.GetInfoRequest) -> types.GetInfoResponse: token = await self._on_request_start(request) - model_id = model_id_ref[0] training_run_manager = create_training_run_manager(token, client_type='tinker') metadata = training_run_manager.get(str(body.model_id)) model_name = metadata.base_model if metadata else model_id diff --git a/src/twinkle/server/model/twinkle_handlers.py b/src/twinkle/server/model/twinkle_handlers.py index e6b2243b..66035051 100644 --- a/src/twinkle/server/model/twinkle_handlers.py +++ b/src/twinkle/server/model/twinkle_handlers.py @@ -49,11 +49,8 @@ class TwinkleModelHandlers: """ @staticmethod - def _register_twinkle_routes(app: FastAPI, model_id_ref: list): - """Register all twinkle routes on the given FastAPI app. - - model_id_ref is a mutable list containing [model_id] for closure capture. - """ + def _register_twinkle_routes(app: FastAPI, model_id: str): + """Register all twinkle routes on the given FastAPI app.""" @app.post('/twinkle/create') async def create(self, request: Request, body: CreateRequest): @@ -123,7 +120,7 @@ async def _task(): self.assert_adapter_exists(adapter_name=adapter_name) extra_kwargs = body.model_extra or {} inputs = _parse_inputs(body.inputs) - ret = self.model.forward_backward(inputs=inputs, adapter_name=adapter_name, **extra_kwargs) + ret = self.model.twinkle_forward_backward(inputs=inputs, adapter_name=adapter_name, **extra_kwargs) return {'result': ret} return await self.schedule_task_and_wait(_task, task_type='forward_backward') @@ -226,12 +223,12 @@ async def _task(): @app.post('/twinkle/save') async def save(self, request: Request, body: SaveRequest): + token = await self._on_request_start(request) adapter_name = self._get_twinkle_adapter_name(request, body.adapter_name) async def _task(): self.assert_adapter_exists(adapter_name=adapter_name) extra_kwargs = body.model_extra or {} - token = request.state.token checkpoint_manager = create_checkpoint_manager(token, client_type='twinkle') checkpoint_name = checkpoint_manager.get_ckpt_name(body.name) save_dir = checkpoint_manager.get_save_dir(model_id=adapter_name, is_sampler=False) @@ -248,12 +245,12 @@ async def _task(): @app.post('/twinkle/load') async def load(self, request: Request, body: LoadRequest): + token = await self._on_request_start(request) adapter_name = self._get_twinkle_adapter_name(request, body.adapter_name) async def _task(): self.assert_adapter_exists(adapter_name=adapter_name) extra_kwargs = body.model_extra or {} - token = request.state.token checkpoint_manager = create_checkpoint_manager(token, client_type='twinkle') resolved = checkpoint_manager.resolve_load_path(body.name) ret = self.model.load( @@ -269,9 +266,9 @@ async def _task(): @app.post('/twinkle/upload_to_hub') async def upload_to_hub(self, request: Request, body: UploadToHubRequest): + token = await self._on_request_start(request) async def _task(): - token = request.state.token if body.checkpoint_dir.startswith('twinkle://'): checkpoint_manager = create_checkpoint_manager(token, client_type='twinkle') parsed = checkpoint_manager.parse_twinkle_path(body.checkpoint_dir) @@ -298,13 +295,12 @@ async def _task(): @app.post('/twinkle/add_adapter_to_model') async def add_adapter_to_model(self, request: Request, body: AddAdapterRequest): assert body.adapter_name, 'You need to specify a valid `adapter_name`' + token = await self._on_request_start(request) adapter_name = self._get_twinkle_adapter_name(request, body.adapter_name) - model_id = model_id_ref[0] async def _task(): config = deserialize_object(body.config) extra_kwargs = body.model_extra or {} - token = request.state.token training_run_manager = create_training_run_manager(token, client_type='twinkle') self.register_adapter(adapter_name, token) self.model.add_adapter_to_model(adapter_name, config, **extra_kwargs) diff --git a/src/twinkle_client/types/server.py b/src/twinkle_client/types/server.py index f9e79e7b..058da8d8 100644 --- a/src/twinkle_client/types/server.py +++ b/src/twinkle_client/types/server.py @@ -14,3 +14,7 @@ class DeleteCheckpointResponse(BaseModel): class ErrorResponse(BaseModel): detail: str + + +class WeightsInfoRequest(BaseModel): + twinkle_path: str From df5c3a0837196a38b7da31e783d7f60309845f42 Mon Sep 17 00:00:00 2001 From: Yunnglin Date: Tue, 10 Mar 2026 17:57:01 +0800 Subject: [PATCH 03/24] update refact cookbook --- .../megatron/server.py | 0 .../megatron}/server_config.yaml | 3 - .../megatron/server_config_4b.yaml} | 3 - .../transformer/server.py | 0 .../transformer/server_config.yaml | 3 - .../sample.py | 0 .../self_cognition.py | 0 .../short_math_grpo.py | 0 .../tinker/modelscope_service/server.py | 21 --- .../{custom_service => self_host}/lora.py | 0 .../{custom_service => self_host}/sample.py | 0 .../self_cognition.py | 0 .../short_math_grpo.py | 0 cookbook/client/twinkle/megatron/server.py | 20 --- .../twinkle/megatron/server_config.yaml | 85 ------------ .../client/twinkle/{ => self_host}/grpo.py | 0 .../client/twinkle/{ => self_host}/sample.py | 0 .../{ => self_host}/self_congnition.py | 0 cookbook/client/twinkle/transformer/server.py | 20 --- .../twinkle/transformer/server_config.yaml | 123 ------------------ .../server/gateway/tinker_gateway_handlers.py | 50 +++---- 21 files changed, 25 insertions(+), 303 deletions(-) rename cookbook/client/{tinker/custom_service => server}/megatron/server.py (100%) rename cookbook/client/{tinker/modelscope_service => server/megatron}/server_config.yaml (98%) rename cookbook/client/{tinker/custom_service/megatron/server_config.yaml => server/megatron/server_config_4b.yaml} (98%) rename cookbook/client/{tinker/custom_service => server}/transformer/server.py (100%) rename cookbook/client/{tinker/custom_service => server}/transformer/server_config.yaml (97%) rename cookbook/client/tinker/{modelscope_service => modelscope}/sample.py (100%) rename cookbook/client/tinker/{modelscope_service => modelscope}/self_cognition.py (100%) rename cookbook/client/tinker/{modelscope_service => modelscope}/short_math_grpo.py (100%) delete mode 100644 cookbook/client/tinker/modelscope_service/server.py rename cookbook/client/tinker/{custom_service => self_host}/lora.py (100%) rename cookbook/client/tinker/{custom_service => self_host}/sample.py (100%) rename cookbook/client/tinker/{custom_service => self_host}/self_cognition.py (100%) rename cookbook/client/tinker/{custom_service => self_host}/short_math_grpo.py (100%) delete mode 100644 cookbook/client/twinkle/megatron/server.py delete mode 100644 cookbook/client/twinkle/megatron/server_config.yaml rename cookbook/client/twinkle/{ => self_host}/grpo.py (100%) rename cookbook/client/twinkle/{ => self_host}/sample.py (100%) rename cookbook/client/twinkle/{ => self_host}/self_congnition.py (100%) delete mode 100644 cookbook/client/twinkle/transformer/server.py delete mode 100644 cookbook/client/twinkle/transformer/server_config.yaml diff --git a/cookbook/client/tinker/custom_service/megatron/server.py b/cookbook/client/server/megatron/server.py similarity index 100% rename from cookbook/client/tinker/custom_service/megatron/server.py rename to cookbook/client/server/megatron/server.py diff --git a/cookbook/client/tinker/modelscope_service/server_config.yaml b/cookbook/client/server/megatron/server_config.yaml similarity index 98% rename from cookbook/client/tinker/modelscope_service/server_config.yaml rename to cookbook/client/server/megatron/server_config.yaml index 18b0c1d2..7b2c9768 100644 --- a/cookbook/client/tinker/modelscope_service/server_config.yaml +++ b/cookbook/client/server/megatron/server_config.yaml @@ -1,8 +1,5 @@ # Twinkle Server Configuration - Tinker-Compatible Transformers Backend -# Server protocol type: "tinker" enables the Tinker-compatible API -server_type: tinker - # proxy_location: determines where the HTTP proxy runs. # "EveryNode" means each Ray node runs its own proxy (good for multi-node). proxy_location: EveryNode diff --git a/cookbook/client/tinker/custom_service/megatron/server_config.yaml b/cookbook/client/server/megatron/server_config_4b.yaml similarity index 98% rename from cookbook/client/tinker/custom_service/megatron/server_config.yaml rename to cookbook/client/server/megatron/server_config_4b.yaml index a8103b76..12dcc68f 100644 --- a/cookbook/client/tinker/custom_service/megatron/server_config.yaml +++ b/cookbook/client/server/megatron/server_config_4b.yaml @@ -1,8 +1,5 @@ # Twinkle Server Configuration - Tinker-Compatible Transformers Backend -# Server protocol type: "tinker" enables the Tinker-compatible API -server_type: tinker - # proxy_location: determines where the HTTP proxy runs. # "EveryNode" means each Ray node runs its own proxy (good for multi-node). proxy_location: EveryNode diff --git a/cookbook/client/tinker/custom_service/transformer/server.py b/cookbook/client/server/transformer/server.py similarity index 100% rename from cookbook/client/tinker/custom_service/transformer/server.py rename to cookbook/client/server/transformer/server.py diff --git a/cookbook/client/tinker/custom_service/transformer/server_config.yaml b/cookbook/client/server/transformer/server_config.yaml similarity index 97% rename from cookbook/client/tinker/custom_service/transformer/server_config.yaml rename to cookbook/client/server/transformer/server_config.yaml index e79ad6f2..c8db0b3d 100644 --- a/cookbook/client/tinker/custom_service/transformer/server_config.yaml +++ b/cookbook/client/server/transformer/server_config.yaml @@ -1,8 +1,5 @@ # Twinkle Server Configuration - Tinker-Compatible Transformers Backend -# Server protocol type: "tinker" enables the Tinker-compatible API -server_type: tinker - # proxy_location: determines where the HTTP proxy runs. # "EveryNode" means each Ray node runs its own proxy (good for multi-node). proxy_location: EveryNode diff --git a/cookbook/client/tinker/modelscope_service/sample.py b/cookbook/client/tinker/modelscope/sample.py similarity index 100% rename from cookbook/client/tinker/modelscope_service/sample.py rename to cookbook/client/tinker/modelscope/sample.py diff --git a/cookbook/client/tinker/modelscope_service/self_cognition.py b/cookbook/client/tinker/modelscope/self_cognition.py similarity index 100% rename from cookbook/client/tinker/modelscope_service/self_cognition.py rename to cookbook/client/tinker/modelscope/self_cognition.py diff --git a/cookbook/client/tinker/modelscope_service/short_math_grpo.py b/cookbook/client/tinker/modelscope/short_math_grpo.py similarity index 100% rename from cookbook/client/tinker/modelscope_service/short_math_grpo.py rename to cookbook/client/tinker/modelscope/short_math_grpo.py diff --git a/cookbook/client/tinker/modelscope_service/server.py b/cookbook/client/tinker/modelscope_service/server.py deleted file mode 100644 index e38f43a4..00000000 --- a/cookbook/client/tinker/modelscope_service/server.py +++ /dev/null @@ -1,21 +0,0 @@ -# Twinkle Server Launcher - Tinker-Compatible Megatron Backend -# -# This script starts the Twinkle server with Tinker-compatible API support -# using the Megatron model backend. -# It reads the server_config.yaml in the same directory for all -# configuration (model, deployment settings, etc.). -# Run this script BEFORE running the client training script (lora.py). - -import os - -# Enable Ray debug mode for verbose logging during development -os.environ['TWINKLE_TRUST_REMOTE_CODE'] = '1' - -from twinkle.server import launch_server - -# Resolve the path to server_config.yaml relative to this script's location -file_dir = os.path.abspath(os.path.dirname(__file__)) -config_path = os.path.join(file_dir, 'server_config.yaml') - -# Launch the Twinkle server — this call blocks until the server is shut down -launch_server(config_path=config_path) diff --git a/cookbook/client/tinker/custom_service/lora.py b/cookbook/client/tinker/self_host/lora.py similarity index 100% rename from cookbook/client/tinker/custom_service/lora.py rename to cookbook/client/tinker/self_host/lora.py diff --git a/cookbook/client/tinker/custom_service/sample.py b/cookbook/client/tinker/self_host/sample.py similarity index 100% rename from cookbook/client/tinker/custom_service/sample.py rename to cookbook/client/tinker/self_host/sample.py diff --git a/cookbook/client/tinker/custom_service/self_cognition.py b/cookbook/client/tinker/self_host/self_cognition.py similarity index 100% rename from cookbook/client/tinker/custom_service/self_cognition.py rename to cookbook/client/tinker/self_host/self_cognition.py diff --git a/cookbook/client/tinker/custom_service/short_math_grpo.py b/cookbook/client/tinker/self_host/short_math_grpo.py similarity index 100% rename from cookbook/client/tinker/custom_service/short_math_grpo.py rename to cookbook/client/tinker/self_host/short_math_grpo.py diff --git a/cookbook/client/twinkle/megatron/server.py b/cookbook/client/twinkle/megatron/server.py deleted file mode 100644 index 3e58a5a9..00000000 --- a/cookbook/client/twinkle/megatron/server.py +++ /dev/null @@ -1,20 +0,0 @@ -# Twinkle Server Launcher - Megatron Backend -# -# This script starts the Twinkle server using Ray Serve with Megatron support. -# It reads the server_config.yaml in the same directory for all -# configuration (model, processor, deployment settings, etc.). -# Run this script BEFORE running the client training script (lora.py). - -import os - -# Enable Ray debug mode for verbose logging during development -os.environ['RAY_DEBUG'] = '1' - -from twinkle.server import launch_server - -# Resolve the path to server_config.yaml relative to this script's location -file_dir = os.path.abspath(os.path.dirname(__file__)) -config_path = os.path.join(file_dir, 'server_config.yaml') - -# Launch the Twinkle server — this call blocks until the server is shut down -launch_server(config_path=config_path) diff --git a/cookbook/client/twinkle/megatron/server_config.yaml b/cookbook/client/twinkle/megatron/server_config.yaml deleted file mode 100644 index c8efe648..00000000 --- a/cookbook/client/twinkle/megatron/server_config.yaml +++ /dev/null @@ -1,85 +0,0 @@ -# Twinkle Server Configuration - Megatron Backend - -# Server protocol type: "twinkle" for the native Twinkle client protocol -server_type: twinkle - -# proxy_location: determines where the HTTP proxy runs. -# "EveryNode" means each Ray node runs its own proxy (good for multi-node). -proxy_location: EveryNode - -# HTTP listener settings -http_options: - host: 0.0.0.0 # Listen on all network interfaces - port: 8000 # Port number for the server - -# Applications: each entry defines a service component deployed on the server -applications: - - # 1. TwinkleServer - The central management server - # Handles client connections, training run tracking, checkpoint listing. - - name: server - route_prefix: /server # API endpoint prefix - import_path: server # Python module to import - args: - server_config: - per_token_model_limit: 3 # Maximum number of models (adapters) per token (server-globally enforced) - deployments: - - name: TwinkleServer - autoscaling_config: - min_replicas: 1 # Minimum number of replicas - max_replicas: 1 # Maximum number of replicas - target_ongoing_requests: 128 # Target concurrent requests per replica - ray_actor_options: - num_cpus: 0.1 # CPU resources allocated to this actor - - # 2. Model Service - Hosts the base model for training (Megatron backend) - # This is the actual model worker that performs forward/backward passes. - - name: models-Qwen3.5-4B - route_prefix: /models/Qwen/Qwen3.5-4B # REST path for this model - import_path: model - args: - use_megatron: true # Use Megatron-LM backend (not HuggingFace) - mixed_precision: bf16 - model_id: "ms://Qwen/Qwen3.5-4B" # ModelScope model identifier to load - nproc_per_node: 2 # Number of GPU processes per node - device_group: # Logical device group for this model - name: model - ranks: 2 # Number of GPUs to use - device_type: cuda - device_mesh: # Distributed training mesh configuration - device_type: cuda - dp_size: 2 # Data parallel size - adapter_config: - adapter_timeout: 1800 # Seconds before idle adapter unload - deployments: - - name: ModelManagement - autoscaling_config: - min_replicas: 1 - max_replicas: 1 - target_ongoing_requests: 16 - ray_actor_options: - num_cpus: 0.1 - - # 3. Processor Service - Handles data preprocessing on CPU - # Runs tokenization, template application, and other CPU-bound tasks. - - name: processor - route_prefix: /processors - import_path: processor - args: - nproc_per_node: 2 # Number of processor workers per node - ncpu_proc_per_node: 2 # Number of CPU processes per node - device_group: - name: model - ranks: 2 # Number of CPU workers to use - device_type: CPU - device_mesh: - device_type: CPU - dp_size: 2 # Data parallel size - deployments: - - name: ProcessorManagement - autoscaling_config: - min_replicas: 1 - max_replicas: 1 - target_ongoing_requests: 128 - ray_actor_options: - num_cpus: 0.1 diff --git a/cookbook/client/twinkle/grpo.py b/cookbook/client/twinkle/self_host/grpo.py similarity index 100% rename from cookbook/client/twinkle/grpo.py rename to cookbook/client/twinkle/self_host/grpo.py diff --git a/cookbook/client/twinkle/sample.py b/cookbook/client/twinkle/self_host/sample.py similarity index 100% rename from cookbook/client/twinkle/sample.py rename to cookbook/client/twinkle/self_host/sample.py diff --git a/cookbook/client/twinkle/self_congnition.py b/cookbook/client/twinkle/self_host/self_congnition.py similarity index 100% rename from cookbook/client/twinkle/self_congnition.py rename to cookbook/client/twinkle/self_host/self_congnition.py diff --git a/cookbook/client/twinkle/transformer/server.py b/cookbook/client/twinkle/transformer/server.py deleted file mode 100644 index ba84e2dd..00000000 --- a/cookbook/client/twinkle/transformer/server.py +++ /dev/null @@ -1,20 +0,0 @@ -# Twinkle Server Launcher - Transformers Backend -# -# This script starts the Twinkle server using Ray Serve. -# It reads the server_config.yaml in the same directory for all -# configuration (model, processor, deployment settings, etc.). -# Run this script BEFORE running the client training script (lora.py). - -import os - -# Enable Ray debug mode for verbose logging during development -os.environ['RAY_DEBUG'] = '1' - -from twinkle.server import launch_server - -# Resolve the path to server_config.yaml relative to this script's location -file_dir = os.path.abspath(os.path.dirname(__file__)) -config_path = os.path.join(file_dir, 'server_config.yaml') - -# Launch the Twinkle server — this call blocks until the server is shut down -launch_server(config_path=config_path) diff --git a/cookbook/client/twinkle/transformer/server_config.yaml b/cookbook/client/twinkle/transformer/server_config.yaml deleted file mode 100644 index e16ced6a..00000000 --- a/cookbook/client/twinkle/transformer/server_config.yaml +++ /dev/null @@ -1,123 +0,0 @@ -# Twinkle Server Configuration - Transformers Backend - -# Server protocol type: "twinkle" for the native Twinkle client protocol -server_type: twinkle - -# proxy_location: determines where the HTTP proxy runs. -# "EveryNode" means each Ray node runs its own proxy (good for multi-node). -proxy_location: EveryNode - -# HTTP listener settings -http_options: - host: 0.0.0.0 # Listen on all network interfaces - port: 8000 # Port number for the server - -# Applications: each entry defines a service component deployed on the server -applications: - - # 1. TwinkleServer - The central management server - # Handles client connections, training run tracking, checkpoint listing. - - name: server - route_prefix: /server # API endpoint prefix - import_path: server # Python module to import - args: - server_config: - per_token_model_limit: 3 # Maximum number of models (adapters) per token (server-globally enforced) - deployments: - - name: TwinkleServer - autoscaling_config: - min_replicas: 1 # Minimum number of replicas - max_replicas: 1 # Maximum number of replicas - target_ongoing_requests: 128 # Target concurrent requests per replica - ray_actor_options: - num_cpus: 0.1 # CPU resources allocated to this actor - - # 2. Model Service - Hosts the base model for training - # This is the actual model worker that performs forward/backward passes. - - name: models-Qwen3.5-4B - route_prefix: /models/Qwen/Qwen3.5-4B # REST path for this model - import_path: model - args: - use_megatron: false # Use HuggingFace Transformers (not Megatron) - model_id: "ms://Qwen/Qwen3.5-4B" # ModelScope model identifier to load - adapter_config: - adapter_timeout: 1800 # Seconds before an idle adapter is unloaded - nproc_per_node: 2 # Number of GPU processes per node - device_group: # Logical device group for this model - name: model - ranks: 2 # Number of GPUs to use - device_type: cuda - device_mesh: # Distributed training mesh configuration - device_type: cuda - dp_size: 2 # Mesh dimension names: 'dp' = data parallel - deployments: - - name: ModelManagement - autoscaling_config: - min_replicas: 1 - max_replicas: 1 - target_ongoing_requests: 16 - ray_actor_options: - num_cpus: 0.1 - runtime_env: - env_vars: - TWINKLE_TRUST_REMOTE_CODE: "0" - - # 3. Processor Service - Handles data preprocessing on CPU - # Runs tokenization, template application, and other CPU-bound tasks. - - name: processor - route_prefix: /processors - import_path: processor - args: - nproc_per_node: 2 # Number of processor workers per node - ncpu_proc_per_node: 2 # Number of CPU processes per node - device_group: - name: model - ranks: 2 # Number of CPU workers to use - device_type: CPU - device_mesh: - device_type: CPU - dp_size: 2 # Data parallel size - deployments: - - name: ProcessorManagement - autoscaling_config: - min_replicas: 1 - max_replicas: 1 - target_ongoing_requests: 128 - ray_actor_options: - num_cpus: 0.1 - runtime_env: - env_vars: - TWINKLE_TRUST_REMOTE_CODE: "0" - - # 4. Sampler Service - Handles text generation inference - # Uses vLLM for efficient batched generation with optional LoRA adapters. - - name: sampler-Qwen3.5-4B - route_prefix: /samplers/Qwen/Qwen3.5-4B # REST path for this sampler - import_path: sampler - args: - model_id: "ms://Qwen/Qwen3.5-4B" # ModelScope model identifier to load - sampler_type: vllm # Sampler backend (vllm or torch) - nproc_per_node: 2 # Number of GPU processes per node - engine_args: # vLLM engine configuration - gpu_memory_utilization: 0.4 - max_model_len: 1024 - adapter_config: # Adapter lifecycle management - adapter_timeout: 1800 # Seconds before idle adapter is unloaded - device_group: - name: sampler - ranks: 1 # Number of GPUs to use - device_type: cuda - device_mesh: - device_type: cuda - dp_size: 1 - deployments: - - name: SamplerManagement - autoscaling_config: - min_replicas: 1 - max_replicas: 1 - target_ongoing_requests: 16 - ray_actor_options: - num_cpus: 0.1 - runtime_env: - env_vars: - TWINKLE_TRUST_REMOTE_CODE: "0" diff --git a/src/twinkle/server/gateway/tinker_gateway_handlers.py b/src/twinkle/server/gateway/tinker_gateway_handlers.py index 08db7f91..27ce34e6 100644 --- a/src/twinkle/server/gateway/tinker_gateway_handlers.py +++ b/src/twinkle/server/gateway/tinker_gateway_handlers.py @@ -2,7 +2,7 @@ """ Tinker-compatible gateway handler mixin. -All endpoints are prefixed /tinker/* and registered via _register_tinker_routes(app). +All endpoints are prefixed /* and registered via _register_tinker_routes(app). Route closures use self.* directly (no request.state injection needed). """ from __future__ import annotations @@ -33,27 +33,27 @@ class TinkerGatewayHandlers: @staticmethod def _register_tinker_routes(app: FastAPI): - """Register all /tinker/* routes on the given FastAPI app.""" + """Register all /* routes on the given FastAPI app.""" - @app.get('/tinker/healthz') + @app.get('/healthz') async def healthz(self, request: Request) -> types.HealthResponse: return types.HealthResponse(status='ok') - @app.get('/tinker/get_server_capabilities') + @app.get('/get_server_capabilities') async def get_server_capabilities(self, request: Request) -> types.GetServerCapabilitiesResponse: return types.GetServerCapabilitiesResponse(supported_models=self.supported_models) - @app.post('/tinker/telemetry') + @app.post('/telemetry') async def telemetry(self, request: Request, body: types.TelemetrySendRequest) -> types.TelemetryResponse: return types.TelemetryResponse(status='accepted') - @app.post('/tinker/create_session') + @app.post('/create_session') async def create_session(self, request: Request, body: types.CreateSessionRequest) -> types.CreateSessionResponse: session_id = self.state.create_session(body.model_dump()) return types.CreateSessionResponse(session_id=session_id) - @app.post('/tinker/session_heartbeat') + @app.post('/session_heartbeat') async def session_heartbeat(self, request: Request, body: types.SessionHeartbeatRequest) -> types.SessionHeartbeatResponse: alive = self.state.touch_session(body.session_id) @@ -61,14 +61,14 @@ async def session_heartbeat(self, request: Request, raise HTTPException(status_code=404, detail='Unknown session') return types.SessionHeartbeatResponse() - @app.post('/tinker/create_sampling_session') + @app.post('/create_sampling_session') async def create_sampling_session( self, request: Request, body: types.CreateSamplingSessionRequest) -> types.CreateSamplingSessionResponse: sampling_session_id = self.state.create_sampling_session(body.model_dump()) return types.CreateSamplingSessionResponse(sampling_session_id=sampling_session_id) - @app.post('/tinker/retrieve_future') + @app.post('/retrieve_future') async def retrieve_future(self, request: Request, body: types.FutureRetrieveRequest) -> Any: """Retrieve the result of an async task with long polling.""" request_id = body.request_id @@ -123,7 +123,7 @@ async def retrieve_future(self, request: Request, body: types.FutureRetrieveRequ # --- Training Runs Endpoints --- - @app.get('/tinker/training_runs') + @app.get('/training_runs') async def get_training_runs(self, request: Request, limit: int = 20, @@ -132,7 +132,7 @@ async def get_training_runs(self, training_run_manager = create_training_run_manager(token, client_type='tinker') return training_run_manager.list_runs(limit=limit, offset=offset) - @app.get('/tinker/training_runs/{run_id}') + @app.get('/training_runs/{run_id}') async def get_training_run(self, request: Request, run_id: str) -> types.TrainingRun: token = get_token_from_request(request) training_run_manager = create_training_run_manager(token, client_type='tinker') @@ -141,7 +141,7 @@ async def get_training_run(self, request: Request, run_id: str) -> types.Trainin raise HTTPException(status_code=404, detail=f'Training run {run_id} not found') return run - @app.get('/tinker/training_runs/{run_id}/checkpoints') + @app.get('/training_runs/{run_id}/checkpoints') async def get_run_checkpoints(self, request: Request, run_id: str) -> types.CheckpointsListResponse: token = get_token_from_request(request) checkpoint_manager = create_checkpoint_manager(token, client_type='tinker') @@ -150,7 +150,7 @@ async def get_run_checkpoints(self, request: Request, run_id: str) -> types.Chec raise HTTPException(status_code=404, detail=f'Training run {run_id} not found') return response - @app.delete('/tinker/training_runs/{run_id}/checkpoints/{checkpoint_id:path}') + @app.delete('/training_runs/{run_id}/checkpoints/{checkpoint_id:path}') async def delete_run_checkpoint(self, request: Request, run_id: str, checkpoint_id: str) -> Any: token = get_token_from_request(request) checkpoint_manager = create_checkpoint_manager(token, client_type='tinker') @@ -159,7 +159,7 @@ async def delete_run_checkpoint(self, request: Request, run_id: str, checkpoint_ raise HTTPException(status_code=404, detail=f'Checkpoint {checkpoint_id} not found for run {run_id}') return None - @app.post('/tinker/weights_info') + @app.post('/weights_info') async def weights_info(self, request: Request, body: dict[str, Any]) -> types.WeightsInfoResponse: token = get_token_from_request(request) checkpoint_manager = create_checkpoint_manager(token, client_type='tinker') @@ -169,7 +169,7 @@ async def weights_info(self, request: Request, body: dict[str, Any]) -> types.We raise HTTPException(status_code=404, detail=f'Weights at {tinker_path} not found') return response - @app.post('/tinker/training_runs/{run_id}/checkpoints/{checkpoint_id:path}/publish') + @app.post('/training_runs/{run_id}/checkpoints/{checkpoint_id:path}/publish') async def publish_checkpoint(self, request: Request, run_id: str, checkpoint_id: str) -> Response: token = get_token_from_request(request) @@ -206,42 +206,42 @@ async def publish_checkpoint(self, request: Request, run_id: str, checkpoint_id: # --- Model Proxy Endpoints --- - @app.post('/tinker/create_model') + @app.post('/create_model') async def create_model(self, request: Request, body: types.CreateModelRequest) -> Any: self._validate_base_model(body.base_model) return await self.proxy.proxy_to_model(request, 'create_model', body.base_model) - @app.post('/tinker/get_info') + @app.post('/get_info') async def get_info(self, request: Request, body: types.GetInfoRequest) -> Any: return await self.proxy.proxy_to_model(request, 'get_info', self._get_base_model(body.model_id)) - @app.post('/tinker/unload_model') + @app.post('/unload_model') async def unload_model(self, request: Request, body: types.UnloadModelRequest) -> Any: return await self.proxy.proxy_to_model(request, 'unload_model', self._get_base_model(body.model_id)) - @app.post('/tinker/forward') + @app.post('/forward') async def forward(self, request: Request, body: types.ForwardRequest) -> Any: return await self.proxy.proxy_to_model(request, 'forward', self._get_base_model(body.model_id)) - @app.post('/tinker/forward_backward') + @app.post('/forward_backward') async def forward_backward(self, request: Request, body: types.ForwardBackwardRequest) -> Any: return await self.proxy.proxy_to_model(request, 'forward_backward', self._get_base_model(body.model_id)) - @app.post('/tinker/optim_step') + @app.post('/optim_step') async def optim_step(self, request: Request, body: types.OptimStepRequest) -> Any: return await self.proxy.proxy_to_model(request, 'optim_step', self._get_base_model(body.model_id)) - @app.post('/tinker/save_weights') + @app.post('/save_weights') async def save_weights(self, request: Request, body: types.SaveWeightsRequest) -> Any: return await self.proxy.proxy_to_model(request, 'save_weights', self._get_base_model(body.model_id)) - @app.post('/tinker/load_weights') + @app.post('/load_weights') async def load_weights(self, request: Request, body: types.LoadWeightsRequest) -> Any: return await self.proxy.proxy_to_model(request, 'load_weights', self._get_base_model(body.model_id)) # --- Sampler Proxy Endpoints --- - @app.post('/tinker/asample') + @app.post('/asample') async def asample(self, request: Request, body: types.SampleRequest) -> Any: base_model = body.base_model if not base_model and body.sampling_session_id: @@ -250,7 +250,7 @@ async def asample(self, request: Request, body: types.SampleRequest) -> Any: base_model = session.get('base_model') return await self.proxy.proxy_to_sampler(request, 'asample', base_model) - @app.post('/tinker/save_weights_for_sampler') + @app.post('/save_weights_for_sampler') async def save_weights_for_sampler(self, request: Request, body: types.SaveWeightsForSamplerRequest) -> Any: return await self.proxy.proxy_to_model(request, 'save_weights_for_sampler', self._get_base_model(body.model_id)) From 1b824cd0858717eeb77a46c8881b976a68d3531c Mon Sep 17 00:00:00 2001 From: Yunnglin Date: Wed, 11 Mar 2026 15:20:51 +0800 Subject: [PATCH 04/24] update refact server --- cookbook/client/server/megatron/server.py | 2 +- .../server/transformer/server_config.yaml | 66 +++++------ src/twinkle/server/gateway/server.py | 27 +++-- .../server/gateway/tinker_gateway_handlers.py | 112 ++++++++++-------- .../gateway/twinkle_gateway_handlers.py | 22 ++-- src/twinkle/server/model/app.py | 8 +- src/twinkle/server/model/tinker_handlers.py | 1 - src/twinkle/server/sampler/app.py | 8 +- 8 files changed, 133 insertions(+), 113 deletions(-) diff --git a/cookbook/client/server/megatron/server.py b/cookbook/client/server/megatron/server.py index e38f43a4..abce8cf6 100644 --- a/cookbook/client/server/megatron/server.py +++ b/cookbook/client/server/megatron/server.py @@ -15,7 +15,7 @@ # Resolve the path to server_config.yaml relative to this script's location file_dir = os.path.abspath(os.path.dirname(__file__)) -config_path = os.path.join(file_dir, 'server_config.yaml') +config_path = os.path.join(file_dir, 'server_config_4b.yaml') # Launch the Twinkle server — this call blocks until the server is shut down launch_server(config_path=config_path) diff --git a/cookbook/client/server/transformer/server_config.yaml b/cookbook/client/server/transformer/server_config.yaml index c8db0b3d..e3cbfac3 100644 --- a/cookbook/client/server/transformer/server_config.yaml +++ b/cookbook/client/server/transformer/server_config.yaml @@ -67,36 +67,36 @@ applications: # 3. Sampler Service - Runs inference / sampling using vLLM engine # Used for generating text from the model (e.g., evaluating LoRA results). - - name: sampler-Qwen3.5-4B - route_prefix: /api/v1/sampler/Qwen/Qwen3.5-4B - import_path: sampler - args: - model_id: "ms://Qwen/Qwen3.5-4B" # ModelScope model identifier - nproc_per_node: 2 # Number of GPU processes per node - sampler_type: vllm # Inference engine: 'vllm' (fast) or 'torch' (TorchSampler) - engine_args: # vLLM engine-specific settings - max_model_len: 4096 # Maximum sequence length the engine supports - gpu_memory_utilization: 0.5 # Fraction of GPU memory to use (0.0-1.0) - enable_lora: true # Allow loading LoRA adapters during inference - logprobs_mode: processed_logprobs # Logprobs mode for sampling results - device_group: # Logical device group for the sampler - name: sampler - ranks: 1 # Number of GPUs to use - device_type: cuda - device_mesh: - device_type: cuda - dp_size: 1 - queue_config: - rps_limit: 100 # Max requests per second - tps_limit: 100000 # Max tokens per second - deployments: - - name: SamplerManagement - autoscaling_config: - min_replicas: 1 - max_replicas: 1 - target_ongoing_requests: 16 - ray_actor_options: - num_cpus: 0.1 - runtime_env: - env_vars: - TWINKLE_TRUST_REMOTE_CODE: "0" + # - name: sampler-Qwen3.5-4B + # route_prefix: /api/v1/sampler/Qwen/Qwen3.5-4B + # import_path: sampler + # args: + # model_id: "ms://Qwen/Qwen3.5-4B" # ModelScope model identifier + # nproc_per_node: 2 # Number of GPU processes per node + # sampler_type: vllm # Inference engine: 'vllm' (fast) or 'torch' (TorchSampler) + # engine_args: # vLLM engine-specific settings + # max_model_len: 4096 # Maximum sequence length the engine supports + # gpu_memory_utilization: 0.5 # Fraction of GPU memory to use (0.0-1.0) + # enable_lora: true # Allow loading LoRA adapters during inference + # logprobs_mode: processed_logprobs # Logprobs mode for sampling results + # device_group: # Logical device group for the sampler + # name: sampler + # ranks: 1 # Number of GPUs to use + # device_type: cuda + # device_mesh: + # device_type: cuda + # dp_size: 1 + # queue_config: + # rps_limit: 100 # Max requests per second + # tps_limit: 100000 # Max tokens per second + # deployments: + # - name: SamplerManagement + # autoscaling_config: + # min_replicas: 1 + # max_replicas: 1 + # target_ongoing_requests: 16 + # ray_actor_options: + # num_cpus: 0.1 + # runtime_env: + # env_vars: + # TWINKLE_TRUST_REMOTE_CODE: "0" diff --git a/src/twinkle/server/gateway/server.py b/src/twinkle/server/gateway/server.py index 767a0e67..1ce1f1bd 100644 --- a/src/twinkle/server/gateway/server.py +++ b/src/twinkle/server/gateway/server.py @@ -42,14 +42,21 @@ def build_server_app(deploy_options: dict[str, Any], Returns: Configured Ray Serve deployment bound with options """ - app = FastAPI() - @app.middleware('http') - async def verify_token(request: Request, call_next): - return await verify_request_token(request=request, call_next=call_next) + def gateway_app(): + """Called once per replica at init time to build a fresh FastAPI app.""" + from ray import serve + app = FastAPI() + + @app.middleware('http') + async def verify_token(request: Request, call_next): + return await verify_request_token(request=request, call_next=call_next) + + self_fn = lambda: serve.get_replica_context().servable_object # noqa: E731 + TinkerGatewayHandlers._register_tinker_routes(app, self_fn) + TwinkleGatewayHandlers._register_twinkle_routes(app, self_fn) + return app - @serve.deployment(name='GatewayServer') - @serve.ingress(app) class GatewayServer(TinkerGatewayHandlers, TwinkleGatewayHandlers): """Unified gateway server handling both Tinker and Twinkle API clients.""" @@ -94,9 +101,7 @@ def _get_base_model(self, model_id: str) -> str: return metadata['base_model'] raise HTTPException(status_code=404, detail=f'Model {model_id} not found') - # Register routes from both handler mixins - TinkerGatewayHandlers._register_tinker_routes(app) - TwinkleGatewayHandlers._register_twinkle_routes(app) - - return GatewayServer.options(**deploy_options).bind( + GatewayServerWithIngress = serve.ingress(gateway_app)(GatewayServer) + DeploymentClass = serve.deployment(name='GatewayServer')(GatewayServerWithIngress) + return DeploymentClass.options(**deploy_options).bind( supported_models=supported_models, server_config=server_config, http_options=http_options, **kwargs) diff --git a/src/twinkle/server/gateway/tinker_gateway_handlers.py b/src/twinkle/server/gateway/tinker_gateway_handlers.py index 27ce34e6..dc893e09 100644 --- a/src/twinkle/server/gateway/tinker_gateway_handlers.py +++ b/src/twinkle/server/gateway/tinker_gateway_handlers.py @@ -11,7 +11,11 @@ import os from fastapi import FastAPI, HTTPException, Request, Response from tinker import types -from typing import Any +from typing import TYPE_CHECKING, Any + +if TYPE_CHECKING: + from twinkle.server.utils.state.server_state import ServerStateProxy + from .proxy import ServiceProxy from twinkle.hub import HubOperation from twinkle.server.common.io_utils import create_checkpoint_manager, create_training_run_manager @@ -30,54 +34,55 @@ class TinkerGatewayHandlers: self.state, self.proxy, self.supported_models, self._modelscope_config_lock, self._validate_base_model(), self._get_base_model() """ + state: ServerStateProxy + proxy: ServiceProxy @staticmethod - def _register_tinker_routes(app: FastAPI): + def _register_tinker_routes(app: FastAPI, self_fn): """Register all /* routes on the given FastAPI app.""" @app.get('/healthz') - async def healthz(self, request: Request) -> types.HealthResponse: + async def healthz(request: Request) -> types.HealthResponse: return types.HealthResponse(status='ok') @app.get('/get_server_capabilities') - async def get_server_capabilities(self, request: Request) -> types.GetServerCapabilitiesResponse: - return types.GetServerCapabilitiesResponse(supported_models=self.supported_models) + async def get_server_capabilities(request: Request) -> types.GetServerCapabilitiesResponse: + return types.GetServerCapabilitiesResponse(supported_models=self_fn().supported_models) @app.post('/telemetry') - async def telemetry(self, request: Request, body: types.TelemetrySendRequest) -> types.TelemetryResponse: + async def telemetry(request: Request, body: types.TelemetrySendRequest) -> types.TelemetryResponse: return types.TelemetryResponse(status='accepted') @app.post('/create_session') - async def create_session(self, request: Request, - body: types.CreateSessionRequest) -> types.CreateSessionResponse: - session_id = self.state.create_session(body.model_dump()) + async def create_session(request: Request, body: types.CreateSessionRequest) -> types.CreateSessionResponse: + session_id = self_fn().state.create_session(body.model_dump()) return types.CreateSessionResponse(session_id=session_id) @app.post('/session_heartbeat') - async def session_heartbeat(self, request: Request, + async def session_heartbeat(request: Request, body: types.SessionHeartbeatRequest) -> types.SessionHeartbeatResponse: - alive = self.state.touch_session(body.session_id) + alive = self_fn().state.touch_session(body.session_id) if not alive: raise HTTPException(status_code=404, detail='Unknown session') return types.SessionHeartbeatResponse() @app.post('/create_sampling_session') async def create_sampling_session( - self, request: Request, - body: types.CreateSamplingSessionRequest) -> types.CreateSamplingSessionResponse: - sampling_session_id = self.state.create_sampling_session(body.model_dump()) + request: Request, body: types.CreateSamplingSessionRequest) -> types.CreateSamplingSessionResponse: + sampling_session_id = self_fn().state.create_sampling_session(body.model_dump()) return types.CreateSamplingSessionResponse(sampling_session_id=sampling_session_id) @app.post('/retrieve_future') - async def retrieve_future(self, request: Request, body: types.FutureRetrieveRequest) -> Any: + async def retrieve_future(request: Request, body: types.FutureRetrieveRequest) -> Any: """Retrieve the result of an async task with long polling.""" request_id = body.request_id max_wait = float(os.environ.get('TWINKLE_LONG_POLL_TIMEOUT', '30')) poll_interval = float(os.environ.get('TWINKLE_POLL_INTERVAL', '0.5')) start = asyncio.get_event_loop().time() + gw = self_fn() while True: - record = self.state.get_future(request_id) + record = gw.state.get_future(request_id) if record is None: return {'type': 'try_again'} @@ -96,7 +101,7 @@ async def retrieve_future(self, request: Request, body: types.FutureRetrieveRequ await asyncio.sleep(poll_interval) - record = self.state.get_future(request_id) + record = gw.state.get_future(request_id) if not record: return {'type': 'try_again'} @@ -124,16 +129,13 @@ async def retrieve_future(self, request: Request, body: types.FutureRetrieveRequ # --- Training Runs Endpoints --- @app.get('/training_runs') - async def get_training_runs(self, - request: Request, - limit: int = 20, - offset: int = 0) -> types.TrainingRunsResponse: + async def get_training_runs(request: Request, limit: int = 20, offset: int = 0) -> types.TrainingRunsResponse: token = get_token_from_request(request) training_run_manager = create_training_run_manager(token, client_type='tinker') return training_run_manager.list_runs(limit=limit, offset=offset) @app.get('/training_runs/{run_id}') - async def get_training_run(self, request: Request, run_id: str) -> types.TrainingRun: + async def get_training_run(request: Request, run_id: str) -> types.TrainingRun: token = get_token_from_request(request) training_run_manager = create_training_run_manager(token, client_type='tinker') run = training_run_manager.get(run_id) @@ -142,7 +144,7 @@ async def get_training_run(self, request: Request, run_id: str) -> types.Trainin return run @app.get('/training_runs/{run_id}/checkpoints') - async def get_run_checkpoints(self, request: Request, run_id: str) -> types.CheckpointsListResponse: + async def get_run_checkpoints(request: Request, run_id: str) -> types.CheckpointsListResponse: token = get_token_from_request(request) checkpoint_manager = create_checkpoint_manager(token, client_type='tinker') response = checkpoint_manager.list_checkpoints(run_id) @@ -151,7 +153,7 @@ async def get_run_checkpoints(self, request: Request, run_id: str) -> types.Chec return response @app.delete('/training_runs/{run_id}/checkpoints/{checkpoint_id:path}') - async def delete_run_checkpoint(self, request: Request, run_id: str, checkpoint_id: str) -> Any: + async def delete_run_checkpoint(request: Request, run_id: str, checkpoint_id: str) -> Any: token = get_token_from_request(request) checkpoint_manager = create_checkpoint_manager(token, client_type='tinker') success = checkpoint_manager.delete(run_id, checkpoint_id) @@ -160,7 +162,7 @@ async def delete_run_checkpoint(self, request: Request, run_id: str, checkpoint_ return None @app.post('/weights_info') - async def weights_info(self, request: Request, body: dict[str, Any]) -> types.WeightsInfoResponse: + async def weights_info(request: Request, body: dict[str, Any]) -> types.WeightsInfoResponse: token = get_token_from_request(request) checkpoint_manager = create_checkpoint_manager(token, client_type='tinker') tinker_path = body.get('tinker_path') @@ -170,8 +172,9 @@ async def weights_info(self, request: Request, body: dict[str, Any]) -> types.We return response @app.post('/training_runs/{run_id}/checkpoints/{checkpoint_id:path}/publish') - async def publish_checkpoint(self, request: Request, run_id: str, checkpoint_id: str) -> Response: + async def publish_checkpoint(request: Request, run_id: str, checkpoint_id: str) -> Response: token = get_token_from_request(request) + gw = self_fn() training_run_manager = create_training_run_manager(token, client_type='tinker') checkpoint_manager = create_checkpoint_manager(token, client_type='tinker') @@ -186,7 +189,7 @@ async def publish_checkpoint(self, request: Request, run_id: str, checkpoint_id: checkpoint_dir = str(checkpoint_manager.get_ckpt_dir(run_id, checkpoint_id)) - async with self._modelscope_config_lock: + async with gw._modelscope_config_lock: try: from modelscope.hub.api import HubApi, ModelScopeConfig hub_api = HubApi(token=token) @@ -207,50 +210,59 @@ async def publish_checkpoint(self, request: Request, run_id: str, checkpoint_id: # --- Model Proxy Endpoints --- @app.post('/create_model') - async def create_model(self, request: Request, body: types.CreateModelRequest) -> Any: - self._validate_base_model(body.base_model) - return await self.proxy.proxy_to_model(request, 'create_model', body.base_model) + async def create_model(request: Request, body: types.CreateModelRequest) -> Any: + gw = self_fn() + gw._validate_base_model(body.base_model) + return await gw.proxy.proxy_to_model(request, 'create_model', body.base_model) @app.post('/get_info') - async def get_info(self, request: Request, body: types.GetInfoRequest) -> Any: - return await self.proxy.proxy_to_model(request, 'get_info', self._get_base_model(body.model_id)) + async def get_info(request: Request, body: types.GetInfoRequest) -> Any: + gw = self_fn() + return await gw.proxy.proxy_to_model(request, 'get_info', gw._get_base_model(body.model_id)) @app.post('/unload_model') - async def unload_model(self, request: Request, body: types.UnloadModelRequest) -> Any: - return await self.proxy.proxy_to_model(request, 'unload_model', self._get_base_model(body.model_id)) + async def unload_model(request: Request, body: types.UnloadModelRequest) -> Any: + gw = self_fn() + return await gw.proxy.proxy_to_model(request, 'unload_model', gw._get_base_model(body.model_id)) @app.post('/forward') - async def forward(self, request: Request, body: types.ForwardRequest) -> Any: - return await self.proxy.proxy_to_model(request, 'forward', self._get_base_model(body.model_id)) + async def forward(request: Request, body: types.ForwardRequest) -> Any: + gw = self_fn() + return await gw.proxy.proxy_to_model(request, 'forward', gw._get_base_model(body.model_id)) @app.post('/forward_backward') - async def forward_backward(self, request: Request, body: types.ForwardBackwardRequest) -> Any: - return await self.proxy.proxy_to_model(request, 'forward_backward', self._get_base_model(body.model_id)) + async def forward_backward(request: Request, body: types.ForwardBackwardRequest) -> Any: + gw = self_fn() + return await gw.proxy.proxy_to_model(request, 'forward_backward', gw._get_base_model(body.model_id)) @app.post('/optim_step') - async def optim_step(self, request: Request, body: types.OptimStepRequest) -> Any: - return await self.proxy.proxy_to_model(request, 'optim_step', self._get_base_model(body.model_id)) + async def optim_step(request: Request, body: types.OptimStepRequest) -> Any: + gw = self_fn() + return await gw.proxy.proxy_to_model(request, 'optim_step', gw._get_base_model(body.model_id)) @app.post('/save_weights') - async def save_weights(self, request: Request, body: types.SaveWeightsRequest) -> Any: - return await self.proxy.proxy_to_model(request, 'save_weights', self._get_base_model(body.model_id)) + async def save_weights(request: Request, body: types.SaveWeightsRequest) -> Any: + gw = self_fn() + return await gw.proxy.proxy_to_model(request, 'save_weights', gw._get_base_model(body.model_id)) @app.post('/load_weights') - async def load_weights(self, request: Request, body: types.LoadWeightsRequest) -> Any: - return await self.proxy.proxy_to_model(request, 'load_weights', self._get_base_model(body.model_id)) + async def load_weights(request: Request, body: types.LoadWeightsRequest) -> Any: + gw = self_fn() + return await gw.proxy.proxy_to_model(request, 'load_weights', gw._get_base_model(body.model_id)) # --- Sampler Proxy Endpoints --- @app.post('/asample') - async def asample(self, request: Request, body: types.SampleRequest) -> Any: + async def asample(request: Request, body: types.SampleRequest) -> Any: + gw = self_fn() base_model = body.base_model if not base_model and body.sampling_session_id: - session = self.state.get_sampling_session(body.sampling_session_id) + session = gw.state.get_sampling_session(body.sampling_session_id) if session: base_model = session.get('base_model') - return await self.proxy.proxy_to_sampler(request, 'asample', base_model) + return await gw.proxy.proxy_to_sampler(request, 'asample', base_model) @app.post('/save_weights_for_sampler') - async def save_weights_for_sampler(self, request: Request, body: types.SaveWeightsForSamplerRequest) -> Any: - return await self.proxy.proxy_to_model(request, 'save_weights_for_sampler', - self._get_base_model(body.model_id)) + async def save_weights_for_sampler(request: Request, body: types.SaveWeightsForSamplerRequest) -> Any: + gw = self_fn() + return await gw.proxy.proxy_to_model(request, 'save_weights_for_sampler', gw._get_base_model(body.model_id)) diff --git a/src/twinkle/server/gateway/twinkle_gateway_handlers.py b/src/twinkle/server/gateway/twinkle_gateway_handlers.py index 4d19cbf1..5939f1b3 100644 --- a/src/twinkle/server/gateway/twinkle_gateway_handlers.py +++ b/src/twinkle/server/gateway/twinkle_gateway_handlers.py @@ -8,6 +8,10 @@ from __future__ import annotations from fastapi import FastAPI, HTTPException, Request +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from twinkle.server.utils.state.server_state import ServerStateProxy from twinkle.server.common.io_utils import create_checkpoint_manager, create_training_run_manager, validate_user_path from twinkle.server.utils.validation import get_token_from_request @@ -25,23 +29,24 @@ class TwinkleGatewayHandlers: Expects the combined class to have: self.state """ + state: ServerStateProxy @staticmethod - def _register_twinkle_routes(app: FastAPI): + def _register_twinkle_routes(app: FastAPI, self_fn): """Register all /twinkle/* routes on the given FastAPI app.""" @app.get('/twinkle/healthz', response_model=HealthResponse) - async def healthz(self, request: Request) -> HealthResponse: + async def healthz(request: Request) -> HealthResponse: return HealthResponse(status='ok') @app.get('/twinkle/training_runs', response_model=TrainingRunsResponse) - async def get_training_runs(self, request: Request, limit: int = 20, offset: int = 0) -> TrainingRunsResponse: + async def get_training_runs(request: Request, limit: int = 20, offset: int = 0) -> TrainingRunsResponse: token = get_token_from_request(request) training_run_manager = create_training_run_manager(token, client_type='twinkle') return training_run_manager.list_runs(limit=limit, offset=offset) @app.get('/twinkle/training_runs/{run_id}', response_model=TrainingRun) - async def get_training_run(self, request: Request, run_id: str) -> TrainingRun: + async def get_training_run(request: Request, run_id: str) -> TrainingRun: token = get_token_from_request(request) training_run_manager = create_training_run_manager(token, client_type='twinkle') run = training_run_manager.get_with_permission(run_id) @@ -50,7 +55,7 @@ async def get_training_run(self, request: Request, run_id: str) -> TrainingRun: return run @app.get('/twinkle/training_runs/{run_id}/checkpoints', response_model=CheckpointsListResponse) - async def get_run_checkpoints(self, request: Request, run_id: str) -> CheckpointsListResponse: + async def get_run_checkpoints(request: Request, run_id: str) -> CheckpointsListResponse: token = get_token_from_request(request) checkpoint_manager = create_checkpoint_manager(token, client_type='twinkle') response = checkpoint_manager.list_checkpoints(run_id) @@ -59,8 +64,7 @@ async def get_run_checkpoints(self, request: Request, run_id: str) -> Checkpoint return response @app.delete('/twinkle/training_runs/{run_id}/checkpoints/{checkpoint_id:path}') - async def delete_run_checkpoint(self, request: Request, run_id: str, - checkpoint_id: str) -> DeleteCheckpointResponse: + async def delete_run_checkpoint(request: Request, run_id: str, checkpoint_id: str) -> DeleteCheckpointResponse: token = get_token_from_request(request) if not validate_user_path(token, checkpoint_id): @@ -74,7 +78,7 @@ async def delete_run_checkpoint(self, request: Request, run_id: str, return DeleteCheckpointResponse(success=True, message=f'Checkpoint {checkpoint_id} deleted successfully') @app.post('/twinkle/weights_info', response_model=WeightsInfoResponse) - async def weights_info(self, request: Request, body: WeightsInfoRequest) -> WeightsInfoResponse: + async def weights_info(request: Request, body: WeightsInfoRequest) -> WeightsInfoResponse: token = get_token_from_request(request) checkpoint_manager = create_checkpoint_manager(token, client_type='twinkle') response = checkpoint_manager.get_weights_info(body.twinkle_path) @@ -84,7 +88,7 @@ async def weights_info(self, request: Request, body: WeightsInfoRequest) -> Weig return response @app.get('/twinkle/checkpoint_path/{run_id}/{checkpoint_id:path}') - async def get_checkpoint_path(self, request: Request, run_id: str, checkpoint_id: str) -> dict[str, str]: + async def get_checkpoint_path(request: Request, run_id: str, checkpoint_id: str) -> dict[str, str]: token = get_token_from_request(request) if not validate_user_path(token, checkpoint_id): diff --git a/src/twinkle/server/model/app.py b/src/twinkle/server/model/app.py index 9887ace2..72618616 100644 --- a/src/twinkle/server/model/app.py +++ b/src/twinkle/server/model/app.py @@ -58,6 +58,10 @@ def build_model_app(model_id: str, async def verify_token(request: Request, call_next): return await verify_request_token(request=request, call_next=call_next) + # Register routes BEFORE @serve.ingress so Ray Serve captures them at decoration time + TinkerModelHandlers._register_tinker_routes(app, model_id) + TwinkleModelHandlers._register_twinkle_routes(app, model_id) + @serve.deployment( name='ModelManagement', request_router_config=RequestRouterConfig(request_router_class=StickyLoraRequestRouter), @@ -146,10 +150,6 @@ def _on_adapter_expired(self, adapter_name: str) -> None: self.fail_pending_tasks_for_model(adapter_name, reason='Adapter expired') self._cleanup_adapter(adapter_name) - # Register routes from both handler mixins - TinkerModelHandlers._register_tinker_routes(app, model_id) - TwinkleModelHandlers._register_twinkle_routes(app, model_id) - return ModelManagement.options(**deploy_options).bind(nproc_per_node, device_group, device_mesh, use_megatron, queue_config, **kwargs) diff --git a/src/twinkle/server/model/tinker_handlers.py b/src/twinkle/server/model/tinker_handlers.py index 3577cd93..3e65ce06 100644 --- a/src/twinkle/server/model/tinker_handlers.py +++ b/src/twinkle/server/model/tinker_handlers.py @@ -11,7 +11,6 @@ from typing import Any from twinkle.server.common.io_utils import create_checkpoint_manager, create_training_run_manager -from twinkle.server.utils.validation import get_token_from_request from twinkle.utils.logger import get_logger logger = get_logger() diff --git a/src/twinkle/server/sampler/app.py b/src/twinkle/server/sampler/app.py index 265ccd45..c3bff40b 100644 --- a/src/twinkle/server/sampler/app.py +++ b/src/twinkle/server/sampler/app.py @@ -62,6 +62,10 @@ def build_sampler_app(model_id: str, async def verify_token(request: Request, call_next): return await verify_request_token(request=request, call_next=call_next) + # Register routes BEFORE @serve.ingress so Ray Serve captures them at decoration time + TinkerSamplerHandlers._register_tinker_sampler_routes(app) + TwinkleSamplerHandlers._register_twinkle_sampler_routes(app) + @serve.deployment(name='SamplerManagement') @serve.ingress(app) class SamplerManagement(TaskQueueMixin, AdapterManagerMixin): @@ -148,10 +152,6 @@ def _on_adapter_expired(self, adapter_name: str, token: str = None) -> None: except Exception as e: logger.warning(f'Failed to remove expired adapter {adapter_name}: {e}') - # Register routes from both handler mixins - TinkerSamplerHandlers._register_tinker_sampler_routes(app) - TwinkleSamplerHandlers._register_twinkle_sampler_routes(app) - return SamplerManagement.options(**deploy_options).bind(nproc_per_node, device_group, device_mesh, sampler_type, engine_args, adapter_config, queue_config, **kwargs) From 699c9716d160a0dec8d6e12b8939cb07b9cf6afb Mon Sep 17 00:00:00 2001 From: Yunnglin Date: Wed, 11 Mar 2026 15:47:53 +0800 Subject: [PATCH 05/24] update refact server --- src/twinkle/server/gateway/server.py | 30 +- .../server/gateway/tinker_gateway_handlers.py | 459 +++++++++--------- .../gateway/twinkle_gateway_handlers.py | 163 +++---- 3 files changed, 313 insertions(+), 339 deletions(-) diff --git a/src/twinkle/server/gateway/server.py b/src/twinkle/server/gateway/server.py index 1ce1f1bd..06f9f37a 100644 --- a/src/twinkle/server/gateway/server.py +++ b/src/twinkle/server/gateway/server.py @@ -17,8 +17,8 @@ from twinkle.server.utils.validation import verify_request_token from twinkle.utils.logger import get_logger from .proxy import ServiceProxy -from .tinker_gateway_handlers import TinkerGatewayHandlers -from .twinkle_gateway_handlers import TwinkleGatewayHandlers +from .tinker_gateway_handlers import _register_tinker_routes +from .twinkle_gateway_handlers import _register_twinkle_routes logger = get_logger() @@ -43,21 +43,21 @@ def build_server_app(deploy_options: dict[str, Any], Configured Ray Serve deployment bound with options """ - def gateway_app(): - """Called once per replica at init time to build a fresh FastAPI app.""" - from ray import serve - app = FastAPI() + # Build the FastAPI app and register all routes BEFORE serve.ingress so that + # the frozen app contains the complete route table (visible to ProxyActor). + app = FastAPI() - @app.middleware('http') - async def verify_token(request: Request, call_next): - return await verify_request_token(request=request, call_next=call_next) + @app.middleware('http') + async def verify_token(request: Request, call_next): + return await verify_request_token(request=request, call_next=call_next) - self_fn = lambda: serve.get_replica_context().servable_object # noqa: E731 - TinkerGatewayHandlers._register_tinker_routes(app, self_fn) - TwinkleGatewayHandlers._register_twinkle_routes(app, self_fn) - return app + def get_gw(): + return serve.get_replica_context().servable_object - class GatewayServer(TinkerGatewayHandlers, TwinkleGatewayHandlers): + _register_tinker_routes(app, get_gw) + _register_twinkle_routes(app, get_gw) + + class GatewayServer: """Unified gateway server handling both Tinker and Twinkle API clients.""" def __init__(self, @@ -101,7 +101,7 @@ def _get_base_model(self, model_id: str) -> str: return metadata['base_model'] raise HTTPException(status_code=404, detail=f'Model {model_id} not found') - GatewayServerWithIngress = serve.ingress(gateway_app)(GatewayServer) + GatewayServerWithIngress = serve.ingress(app)(GatewayServer) DeploymentClass = serve.deployment(name='GatewayServer')(GatewayServerWithIngress) return DeploymentClass.options(**deploy_options).bind( supported_models=supported_models, server_config=server_config, http_options=http_options, **kwargs) diff --git a/src/twinkle/server/gateway/tinker_gateway_handlers.py b/src/twinkle/server/gateway/tinker_gateway_handlers.py index dc893e09..7949e8ff 100644 --- a/src/twinkle/server/gateway/tinker_gateway_handlers.py +++ b/src/twinkle/server/gateway/tinker_gateway_handlers.py @@ -1,15 +1,15 @@ # Copyright (c) ModelScope Contributors. All rights reserved. """ -Tinker-compatible gateway handler mixin. +Tinker-compatible gateway handlers. -All endpoints are prefixed /* and registered via _register_tinker_routes(app). -Route closures use self.* directly (no request.state injection needed). +All endpoints are prefixed /* and registered via _register_tinker_routes(app, self_fn). +self_fn is injected via FastAPI Depends to obtain the GatewayServer instance at request time. """ from __future__ import annotations import asyncio import os -from fastapi import FastAPI, HTTPException, Request, Response +from fastapi import Depends, FastAPI, HTTPException, Request, Response from tinker import types from typing import TYPE_CHECKING, Any @@ -26,243 +26,228 @@ logger = get_logger() -class TinkerGatewayHandlers: - """ - Mixin providing Tinker-compatible gateway endpoints. +def _register_tinker_routes(app: FastAPI, self_fn) -> None: + """Register all /* Tinker routes on the given FastAPI app. - Expects the combined class to have: - self.state, self.proxy, self.supported_models, - self._modelscope_config_lock, self._validate_base_model(), self._get_base_model() + self_fn is a zero-argument callable that returns the current GatewayServer + replica instance (e.g. ``lambda: serve.get_replica_context().servable_object``). + It is wired in via ``Depends`` so it is resolved lazily at request time. """ - state: ServerStateProxy - proxy: ServiceProxy - - @staticmethod - def _register_tinker_routes(app: FastAPI, self_fn): - """Register all /* routes on the given FastAPI app.""" - - @app.get('/healthz') - async def healthz(request: Request) -> types.HealthResponse: - return types.HealthResponse(status='ok') - - @app.get('/get_server_capabilities') - async def get_server_capabilities(request: Request) -> types.GetServerCapabilitiesResponse: - return types.GetServerCapabilitiesResponse(supported_models=self_fn().supported_models) - - @app.post('/telemetry') - async def telemetry(request: Request, body: types.TelemetrySendRequest) -> types.TelemetryResponse: - return types.TelemetryResponse(status='accepted') - - @app.post('/create_session') - async def create_session(request: Request, body: types.CreateSessionRequest) -> types.CreateSessionResponse: - session_id = self_fn().state.create_session(body.model_dump()) - return types.CreateSessionResponse(session_id=session_id) - - @app.post('/session_heartbeat') - async def session_heartbeat(request: Request, - body: types.SessionHeartbeatRequest) -> types.SessionHeartbeatResponse: - alive = self_fn().state.touch_session(body.session_id) - if not alive: - raise HTTPException(status_code=404, detail='Unknown session') - return types.SessionHeartbeatResponse() - - @app.post('/create_sampling_session') - async def create_sampling_session( - request: Request, body: types.CreateSamplingSessionRequest) -> types.CreateSamplingSessionResponse: - sampling_session_id = self_fn().state.create_sampling_session(body.model_dump()) - return types.CreateSamplingSessionResponse(sampling_session_id=sampling_session_id) - - @app.post('/retrieve_future') - async def retrieve_future(request: Request, body: types.FutureRetrieveRequest) -> Any: - """Retrieve the result of an async task with long polling.""" - request_id = body.request_id - max_wait = float(os.environ.get('TWINKLE_LONG_POLL_TIMEOUT', '30')) - poll_interval = float(os.environ.get('TWINKLE_POLL_INTERVAL', '0.5')) - start = asyncio.get_event_loop().time() - gw = self_fn() - - while True: - record = gw.state.get_future(request_id) - - if record is None: - return {'type': 'try_again'} - - status = record.get('status') - if status not in ('pending', 'queued', 'running', 'rate_limited'): - break - - if asyncio.get_event_loop().time() - start >= max_wait: - response_data = {'type': 'try_again'} - if queue_state := record.get('queue_state'): - response_data['queue_state'] = queue_state - if queue_state_reason := record.get('queue_state_reason'): - response_data['queue_state_reason'] = queue_state_reason - return response_data - - await asyncio.sleep(poll_interval) + @app.get('/healthz') + async def healthz(request: Request) -> types.HealthResponse: + return types.HealthResponse(status='ok') + + @app.get('/get_server_capabilities') + async def get_server_capabilities(request: Request, gw=Depends(self_fn)) -> types.GetServerCapabilitiesResponse: + return types.GetServerCapabilitiesResponse(supported_models=gw.supported_models) + + @app.post('/telemetry') + async def telemetry(request: Request, body: types.TelemetrySendRequest) -> types.TelemetryResponse: + return types.TelemetryResponse(status='accepted') + + @app.post('/create_session') + async def create_session(request: Request, body: types.CreateSessionRequest, + gw=Depends(self_fn)) -> types.CreateSessionResponse: + session_id = gw.state.create_session(body.model_dump()) + return types.CreateSessionResponse(session_id=session_id) + + @app.post('/session_heartbeat') + async def session_heartbeat( + request: Request, body: types.SessionHeartbeatRequest, + gw=Depends(self_fn)) -> types.SessionHeartbeatResponse: # noqa: E125 + alive = gw.state.touch_session(body.session_id) + if not alive: + raise HTTPException(status_code=404, detail='Unknown session') + return types.SessionHeartbeatResponse() + + @app.post('/create_sampling_session') + async def create_sampling_session( + request: Request, body: types.CreateSamplingSessionRequest, + gw=Depends(self_fn)) -> types.CreateSamplingSessionResponse: # noqa: E125 + sampling_session_id = gw.state.create_sampling_session(body.model_dump()) + return types.CreateSamplingSessionResponse(sampling_session_id=sampling_session_id) + + @app.post('/retrieve_future') + async def retrieve_future(request: Request, body: types.FutureRetrieveRequest, gw=Depends(self_fn)) -> Any: + """Retrieve the result of an async task with long polling.""" + request_id = body.request_id + max_wait = float(os.environ.get('TWINKLE_LONG_POLL_TIMEOUT', '30')) + poll_interval = float(os.environ.get('TWINKLE_POLL_INTERVAL', '0.5')) + start = asyncio.get_event_loop().time() + + while True: record = gw.state.get_future(request_id) - if not record: + + if record is None: return {'type': 'try_again'} status = record.get('status') - - if status == 'rate_limited': - return { - 'type': 'try_again', - 'queue_state': QueueState.PAUSED_RATE_LIMIT.value, - 'queue_state_reason': record.get('reason', 'Rate limit exceeded') - } - - if status == 'failed': - result = record.get('result', {}) - return {'error': result.get('error', 'Unknown error'), 'category': result.get('category', 'Server')} - - result = record.get('result') - if result is None: - raise HTTPException(status_code=500, detail='Task completed but no result found') - - if hasattr(result, 'model_dump'): - return result.model_dump() - return result - - # --- Training Runs Endpoints --- - - @app.get('/training_runs') - async def get_training_runs(request: Request, limit: int = 20, offset: int = 0) -> types.TrainingRunsResponse: - token = get_token_from_request(request) - training_run_manager = create_training_run_manager(token, client_type='tinker') - return training_run_manager.list_runs(limit=limit, offset=offset) - - @app.get('/training_runs/{run_id}') - async def get_training_run(request: Request, run_id: str) -> types.TrainingRun: - token = get_token_from_request(request) - training_run_manager = create_training_run_manager(token, client_type='tinker') - run = training_run_manager.get(run_id) - if not run: - raise HTTPException(status_code=404, detail=f'Training run {run_id} not found') - return run - - @app.get('/training_runs/{run_id}/checkpoints') - async def get_run_checkpoints(request: Request, run_id: str) -> types.CheckpointsListResponse: - token = get_token_from_request(request) - checkpoint_manager = create_checkpoint_manager(token, client_type='tinker') - response = checkpoint_manager.list_checkpoints(run_id) - if not response: - raise HTTPException(status_code=404, detail=f'Training run {run_id} not found') - return response - - @app.delete('/training_runs/{run_id}/checkpoints/{checkpoint_id:path}') - async def delete_run_checkpoint(request: Request, run_id: str, checkpoint_id: str) -> Any: - token = get_token_from_request(request) - checkpoint_manager = create_checkpoint_manager(token, client_type='tinker') - success = checkpoint_manager.delete(run_id, checkpoint_id) - if not success: - raise HTTPException(status_code=404, detail=f'Checkpoint {checkpoint_id} not found for run {run_id}') - return None - - @app.post('/weights_info') - async def weights_info(request: Request, body: dict[str, Any]) -> types.WeightsInfoResponse: - token = get_token_from_request(request) - checkpoint_manager = create_checkpoint_manager(token, client_type='tinker') - tinker_path = body.get('tinker_path') - response = checkpoint_manager.get_weights_info(tinker_path) - if not response: - raise HTTPException(status_code=404, detail=f'Weights at {tinker_path} not found') - return response - - @app.post('/training_runs/{run_id}/checkpoints/{checkpoint_id:path}/publish') - async def publish_checkpoint(request: Request, run_id: str, checkpoint_id: str) -> Response: - token = get_token_from_request(request) - gw = self_fn() - - training_run_manager = create_training_run_manager(token, client_type='tinker') - checkpoint_manager = create_checkpoint_manager(token, client_type='tinker') - - run = training_run_manager.get(run_id) - if not run: - raise HTTPException(status_code=404, detail=f'Training run {run_id} not found or access denied') - - checkpoint = checkpoint_manager.get(run_id, checkpoint_id) - if not checkpoint: - raise HTTPException(status_code=404, detail=f'Checkpoint {checkpoint_id} not found') - - checkpoint_dir = str(checkpoint_manager.get_ckpt_dir(run_id, checkpoint_id)) - - async with gw._modelscope_config_lock: - try: - from modelscope.hub.api import HubApi, ModelScopeConfig - hub_api = HubApi(token=token) - hub_api.login() - username = ModelScopeConfig.get_user_info()[0] - except Exception as e: - logger.error(f'Failed to get username from ModelScope: {e}') - raise HTTPException( - status_code=401, - detail='Failed to get username from ModelScope. Please ensure your token is valid.') - - checkpoint_name = checkpoint_id.split('/')[-1] - hub_model_id = f'{username}/{run_id}_{checkpoint_name}' - HubOperation.async_push_to_hub(repo_id=hub_model_id, folder_path=checkpoint_dir, token=token, private=True) - - return Response(status_code=204) - - # --- Model Proxy Endpoints --- - - @app.post('/create_model') - async def create_model(request: Request, body: types.CreateModelRequest) -> Any: - gw = self_fn() - gw._validate_base_model(body.base_model) - return await gw.proxy.proxy_to_model(request, 'create_model', body.base_model) - - @app.post('/get_info') - async def get_info(request: Request, body: types.GetInfoRequest) -> Any: - gw = self_fn() - return await gw.proxy.proxy_to_model(request, 'get_info', gw._get_base_model(body.model_id)) - - @app.post('/unload_model') - async def unload_model(request: Request, body: types.UnloadModelRequest) -> Any: - gw = self_fn() - return await gw.proxy.proxy_to_model(request, 'unload_model', gw._get_base_model(body.model_id)) - - @app.post('/forward') - async def forward(request: Request, body: types.ForwardRequest) -> Any: - gw = self_fn() - return await gw.proxy.proxy_to_model(request, 'forward', gw._get_base_model(body.model_id)) - - @app.post('/forward_backward') - async def forward_backward(request: Request, body: types.ForwardBackwardRequest) -> Any: - gw = self_fn() - return await gw.proxy.proxy_to_model(request, 'forward_backward', gw._get_base_model(body.model_id)) - - @app.post('/optim_step') - async def optim_step(request: Request, body: types.OptimStepRequest) -> Any: - gw = self_fn() - return await gw.proxy.proxy_to_model(request, 'optim_step', gw._get_base_model(body.model_id)) - - @app.post('/save_weights') - async def save_weights(request: Request, body: types.SaveWeightsRequest) -> Any: - gw = self_fn() - return await gw.proxy.proxy_to_model(request, 'save_weights', gw._get_base_model(body.model_id)) - - @app.post('/load_weights') - async def load_weights(request: Request, body: types.LoadWeightsRequest) -> Any: - gw = self_fn() - return await gw.proxy.proxy_to_model(request, 'load_weights', gw._get_base_model(body.model_id)) - - # --- Sampler Proxy Endpoints --- - - @app.post('/asample') - async def asample(request: Request, body: types.SampleRequest) -> Any: - gw = self_fn() - base_model = body.base_model - if not base_model and body.sampling_session_id: - session = gw.state.get_sampling_session(body.sampling_session_id) - if session: - base_model = session.get('base_model') - return await gw.proxy.proxy_to_sampler(request, 'asample', base_model) - - @app.post('/save_weights_for_sampler') - async def save_weights_for_sampler(request: Request, body: types.SaveWeightsForSamplerRequest) -> Any: - gw = self_fn() - return await gw.proxy.proxy_to_model(request, 'save_weights_for_sampler', gw._get_base_model(body.model_id)) + if status not in ('pending', 'queued', 'running', 'rate_limited'): + break + + if asyncio.get_event_loop().time() - start >= max_wait: + response_data = {'type': 'try_again'} + if queue_state := record.get('queue_state'): + response_data['queue_state'] = queue_state + if queue_state_reason := record.get('queue_state_reason'): + response_data['queue_state_reason'] = queue_state_reason + return response_data + + await asyncio.sleep(poll_interval) + + record = gw.state.get_future(request_id) + if not record: + return {'type': 'try_again'} + + status = record.get('status') + + if status == 'rate_limited': + return { + 'type': 'try_again', + 'queue_state': QueueState.PAUSED_RATE_LIMIT.value, + 'queue_state_reason': record.get('reason', 'Rate limit exceeded') + } + + if status == 'failed': + result = record.get('result', {}) + return {'error': result.get('error', 'Unknown error'), 'category': result.get('category', 'Server')} + + result = record.get('result') + if result is None: + raise HTTPException(status_code=500, detail='Task completed but no result found') + + if hasattr(result, 'model_dump'): + return result.model_dump() + return result + + # --- Training Runs Endpoints --- + + @app.get('/training_runs') + async def get_training_runs(request: Request, limit: int = 20, offset: int = 0) -> types.TrainingRunsResponse: + token = get_token_from_request(request) + training_run_manager = create_training_run_manager(token, client_type='tinker') + return training_run_manager.list_runs(limit=limit, offset=offset) + + @app.get('/training_runs/{run_id}') + async def get_training_run(request: Request, run_id: str) -> types.TrainingRun: + token = get_token_from_request(request) + training_run_manager = create_training_run_manager(token, client_type='tinker') + run = training_run_manager.get(run_id) + if not run: + raise HTTPException(status_code=404, detail=f'Training run {run_id} not found') + return run + + @app.get('/training_runs/{run_id}/checkpoints') + async def get_run_checkpoints(request: Request, run_id: str) -> types.CheckpointsListResponse: + token = get_token_from_request(request) + checkpoint_manager = create_checkpoint_manager(token, client_type='tinker') + response = checkpoint_manager.list_checkpoints(run_id) + if not response: + raise HTTPException(status_code=404, detail=f'Training run {run_id} not found') + return response + + @app.delete('/training_runs/{run_id}/checkpoints/{checkpoint_id:path}') + async def delete_run_checkpoint(request: Request, run_id: str, checkpoint_id: str) -> Any: + token = get_token_from_request(request) + checkpoint_manager = create_checkpoint_manager(token, client_type='tinker') + success = checkpoint_manager.delete(run_id, checkpoint_id) + if not success: + raise HTTPException(status_code=404, detail=f'Checkpoint {checkpoint_id} not found for run {run_id}') + return None + + @app.post('/weights_info') + async def weights_info(request: Request, body: dict[str, Any]) -> types.WeightsInfoResponse: + token = get_token_from_request(request) + checkpoint_manager = create_checkpoint_manager(token, client_type='tinker') + tinker_path = body.get('tinker_path') + response = checkpoint_manager.get_weights_info(tinker_path) + if not response: + raise HTTPException(status_code=404, detail=f'Weights at {tinker_path} not found') + return response + + @app.post('/training_runs/{run_id}/checkpoints/{checkpoint_id:path}/publish') + async def publish_checkpoint(request: Request, run_id: str, checkpoint_id: str, gw=Depends(self_fn)) -> Response: + token = get_token_from_request(request) + + training_run_manager = create_training_run_manager(token, client_type='tinker') + checkpoint_manager = create_checkpoint_manager(token, client_type='tinker') + + run = training_run_manager.get(run_id) + if not run: + raise HTTPException(status_code=404, detail=f'Training run {run_id} not found or access denied') + + checkpoint = checkpoint_manager.get(run_id, checkpoint_id) + if not checkpoint: + raise HTTPException(status_code=404, detail=f'Checkpoint {checkpoint_id} not found') + + checkpoint_dir = str(checkpoint_manager.get_ckpt_dir(run_id, checkpoint_id)) + + async with gw._modelscope_config_lock: + try: + from modelscope.hub.api import HubApi, ModelScopeConfig + hub_api = HubApi(token=token) + hub_api.login() + username = ModelScopeConfig.get_user_info()[0] + except Exception as e: + logger.error(f'Failed to get username from ModelScope: {e}') + raise HTTPException( + status_code=401, + detail='Failed to get username from ModelScope. Please ensure your token is valid.') + + checkpoint_name = checkpoint_id.split('/')[-1] + hub_model_id = f'{username}/{run_id}_{checkpoint_name}' + HubOperation.async_push_to_hub(repo_id=hub_model_id, folder_path=checkpoint_dir, token=token, private=True) + + return Response(status_code=204) + + # --- Model Proxy Endpoints --- + + @app.post('/create_model') + async def create_model(request: Request, body: types.CreateModelRequest, gw=Depends(self_fn)) -> Any: + gw._validate_base_model(body.base_model) + return await gw.proxy.proxy_to_model(request, 'create_model', body.base_model) + + @app.post('/get_info') + async def get_info(request: Request, body: types.GetInfoRequest, gw=Depends(self_fn)) -> Any: + return await gw.proxy.proxy_to_model(request, 'get_info', gw._get_base_model(body.model_id)) + + @app.post('/unload_model') + async def unload_model(request: Request, body: types.UnloadModelRequest, gw=Depends(self_fn)) -> Any: + return await gw.proxy.proxy_to_model(request, 'unload_model', gw._get_base_model(body.model_id)) + + @app.post('/forward') + async def forward(request: Request, body: types.ForwardRequest, gw=Depends(self_fn)) -> Any: + return await gw.proxy.proxy_to_model(request, 'forward', gw._get_base_model(body.model_id)) + + @app.post('/forward_backward') + async def forward_backward(request: Request, body: types.ForwardBackwardRequest, gw=Depends(self_fn)) -> Any: + return await gw.proxy.proxy_to_model(request, 'forward_backward', gw._get_base_model(body.model_id)) + + @app.post('/optim_step') + async def optim_step(request: Request, body: types.OptimStepRequest, gw=Depends(self_fn)) -> Any: + return await gw.proxy.proxy_to_model(request, 'optim_step', gw._get_base_model(body.model_id)) + + @app.post('/save_weights') + async def save_weights(request: Request, body: types.SaveWeightsRequest, gw=Depends(self_fn)) -> Any: + return await gw.proxy.proxy_to_model(request, 'save_weights', gw._get_base_model(body.model_id)) + + @app.post('/load_weights') + async def load_weights(request: Request, body: types.LoadWeightsRequest, gw=Depends(self_fn)) -> Any: + return await gw.proxy.proxy_to_model(request, 'load_weights', gw._get_base_model(body.model_id)) + + # --- Sampler Proxy Endpoints --- + + @app.post('/asample') + async def asample(request: Request, body: types.SampleRequest, gw=Depends(self_fn)) -> Any: + base_model = body.base_model + if not base_model and body.sampling_session_id: + session = gw.state.get_sampling_session(body.sampling_session_id) + if session: + base_model = session.get('base_model') + return await gw.proxy.proxy_to_sampler(request, 'asample', base_model) + + @app.post('/save_weights_for_sampler') + async def save_weights_for_sampler(request: Request, body: types.SaveWeightsForSamplerRequest, + gw=Depends(self_fn)) -> Any: + return await gw.proxy.proxy_to_model(request, 'save_weights_for_sampler', gw._get_base_model(body.model_id)) diff --git a/src/twinkle/server/gateway/twinkle_gateway_handlers.py b/src/twinkle/server/gateway/twinkle_gateway_handlers.py index 5939f1b3..28bae7ae 100644 --- a/src/twinkle/server/gateway/twinkle_gateway_handlers.py +++ b/src/twinkle/server/gateway/twinkle_gateway_handlers.py @@ -1,9 +1,8 @@ # Copyright (c) ModelScope Contributors. All rights reserved. """ -Twinkle-native gateway handler mixin. +Twinkle-native gateway handlers. -All endpoints are prefixed /twinkle/* and registered via _register_twinkle_routes(app). -Route closures use self.* directly (no request.state injection needed). +All endpoints are prefixed /twinkle/* and registered via _register_twinkle_routes(app, self_fn). """ from __future__ import annotations @@ -23,87 +22,77 @@ logger = get_logger() -class TwinkleGatewayHandlers: - """ - Mixin providing Twinkle-native gateway management endpoints. - - Expects the combined class to have: self.state - """ - state: ServerStateProxy - - @staticmethod - def _register_twinkle_routes(app: FastAPI, self_fn): - """Register all /twinkle/* routes on the given FastAPI app.""" - - @app.get('/twinkle/healthz', response_model=HealthResponse) - async def healthz(request: Request) -> HealthResponse: - return HealthResponse(status='ok') - - @app.get('/twinkle/training_runs', response_model=TrainingRunsResponse) - async def get_training_runs(request: Request, limit: int = 20, offset: int = 0) -> TrainingRunsResponse: - token = get_token_from_request(request) - training_run_manager = create_training_run_manager(token, client_type='twinkle') - return training_run_manager.list_runs(limit=limit, offset=offset) - - @app.get('/twinkle/training_runs/{run_id}', response_model=TrainingRun) - async def get_training_run(request: Request, run_id: str) -> TrainingRun: - token = get_token_from_request(request) - training_run_manager = create_training_run_manager(token, client_type='twinkle') - run = training_run_manager.get_with_permission(run_id) - if not run: - raise HTTPException(status_code=404, detail=f'Training run {run_id} not found or access denied') - return run - - @app.get('/twinkle/training_runs/{run_id}/checkpoints', response_model=CheckpointsListResponse) - async def get_run_checkpoints(request: Request, run_id: str) -> CheckpointsListResponse: - token = get_token_from_request(request) - checkpoint_manager = create_checkpoint_manager(token, client_type='twinkle') - response = checkpoint_manager.list_checkpoints(run_id) - if response is None: - raise HTTPException(status_code=404, detail=f'Training run {run_id} not found or access denied') - return response - - @app.delete('/twinkle/training_runs/{run_id}/checkpoints/{checkpoint_id:path}') - async def delete_run_checkpoint(request: Request, run_id: str, checkpoint_id: str) -> DeleteCheckpointResponse: - token = get_token_from_request(request) - - if not validate_user_path(token, checkpoint_id): - raise HTTPException(status_code=400, detail='Invalid checkpoint path: path traversal not allowed') - - checkpoint_manager = create_checkpoint_manager(token, client_type='twinkle') - success = checkpoint_manager.delete(run_id, checkpoint_id) - if not success: - raise HTTPException(status_code=404, detail=f'Checkpoint {checkpoint_id} not found or access denied') - - return DeleteCheckpointResponse(success=True, message=f'Checkpoint {checkpoint_id} deleted successfully') - - @app.post('/twinkle/weights_info', response_model=WeightsInfoResponse) - async def weights_info(request: Request, body: WeightsInfoRequest) -> WeightsInfoResponse: - token = get_token_from_request(request) - checkpoint_manager = create_checkpoint_manager(token, client_type='twinkle') - response = checkpoint_manager.get_weights_info(body.twinkle_path) - if response is None: - raise HTTPException( - status_code=404, detail=f'Weights at {body.twinkle_path} not found or access denied') - return response - - @app.get('/twinkle/checkpoint_path/{run_id}/{checkpoint_id:path}') - async def get_checkpoint_path(request: Request, run_id: str, checkpoint_id: str) -> dict[str, str]: - token = get_token_from_request(request) - - if not validate_user_path(token, checkpoint_id): - raise HTTPException(status_code=400, detail='Invalid checkpoint path: path traversal not allowed') - - training_run_manager = create_training_run_manager(token, client_type='twinkle') - checkpoint_manager = create_checkpoint_manager(token, client_type='twinkle') - - run = training_run_manager.get(run_id) - if not run: - raise HTTPException(status_code=404, detail=f'Training run {run_id} not found or access denied') - - checkpoint = checkpoint_manager.get(run_id, checkpoint_id) - if not checkpoint: - raise HTTPException(status_code=404, detail=f'Checkpoint {checkpoint_id} not found') - - ckpt_dir = checkpoint_manager.get_ckpt_dir(run_id, checkpoint_id) - return {'path': str(ckpt_dir), 'twinkle_path': checkpoint.twinkle_path} +def _register_twinkle_routes(app: FastAPI, self_fn) -> None: + """Register all /twinkle/* routes on the given FastAPI app.""" + + @app.get('/twinkle/healthz', response_model=HealthResponse) + async def healthz(request: Request) -> HealthResponse: + return HealthResponse(status='ok') + + @app.get('/twinkle/training_runs', response_model=TrainingRunsResponse) + async def get_training_runs(request: Request, limit: int = 20, offset: int = 0) -> TrainingRunsResponse: + token = get_token_from_request(request) + training_run_manager = create_training_run_manager(token, client_type='twinkle') + return training_run_manager.list_runs(limit=limit, offset=offset) + + @app.get('/twinkle/training_runs/{run_id}', response_model=TrainingRun) + async def get_training_run(request: Request, run_id: str) -> TrainingRun: + token = get_token_from_request(request) + training_run_manager = create_training_run_manager(token, client_type='twinkle') + run = training_run_manager.get_with_permission(run_id) + if not run: + raise HTTPException(status_code=404, detail=f'Training run {run_id} not found or access denied') + return run + + @app.get('/twinkle/training_runs/{run_id}/checkpoints', response_model=CheckpointsListResponse) + async def get_run_checkpoints(request: Request, run_id: str) -> CheckpointsListResponse: + token = get_token_from_request(request) + checkpoint_manager = create_checkpoint_manager(token, client_type='twinkle') + response = checkpoint_manager.list_checkpoints(run_id) + if response is None: + raise HTTPException(status_code=404, detail=f'Training run {run_id} not found or access denied') + return response + + @app.delete('/twinkle/training_runs/{run_id}/checkpoints/{checkpoint_id:path}') + async def delete_run_checkpoint(request: Request, run_id: str, checkpoint_id: str) -> DeleteCheckpointResponse: + token = get_token_from_request(request) + + if not validate_user_path(token, checkpoint_id): + raise HTTPException(status_code=400, detail='Invalid checkpoint path: path traversal not allowed') + + checkpoint_manager = create_checkpoint_manager(token, client_type='twinkle') + success = checkpoint_manager.delete(run_id, checkpoint_id) + if not success: + raise HTTPException(status_code=404, detail=f'Checkpoint {checkpoint_id} not found or access denied') + + return DeleteCheckpointResponse(success=True, message=f'Checkpoint {checkpoint_id} deleted successfully') + + @app.post('/twinkle/weights_info', response_model=WeightsInfoResponse) + async def weights_info(request: Request, body: WeightsInfoRequest) -> WeightsInfoResponse: + token = get_token_from_request(request) + checkpoint_manager = create_checkpoint_manager(token, client_type='twinkle') + response = checkpoint_manager.get_weights_info(body.twinkle_path) + if response is None: + raise HTTPException(status_code=404, detail=f'Weights at {body.twinkle_path} not found or access denied') + return response + + @app.get('/twinkle/checkpoint_path/{run_id}/{checkpoint_id:path}') + async def get_checkpoint_path(request: Request, run_id: str, checkpoint_id: str) -> dict[str, str]: + token = get_token_from_request(request) + + if not validate_user_path(token, checkpoint_id): + raise HTTPException(status_code=400, detail='Invalid checkpoint path: path traversal not allowed') + + training_run_manager = create_training_run_manager(token, client_type='twinkle') + checkpoint_manager = create_checkpoint_manager(token, client_type='twinkle') + + run = training_run_manager.get(run_id) + if not run: + raise HTTPException(status_code=404, detail=f'Training run {run_id} not found or access denied') + + checkpoint = checkpoint_manager.get(run_id, checkpoint_id) + if not checkpoint: + raise HTTPException(status_code=404, detail=f'Checkpoint {checkpoint_id} not found') + + ckpt_dir = checkpoint_manager.get_ckpt_dir(run_id, checkpoint_id) + return {'path': str(ckpt_dir), 'twinkle_path': checkpoint.twinkle_path} From 7c37423a612c0a0aab62e8291b5692dd15ad765b Mon Sep 17 00:00:00 2001 From: Yunnglin Date: Wed, 11 Mar 2026 17:01:28 +0800 Subject: [PATCH 06/24] update server app --- README.md | 2 +- README_ZH.md | 2 +- .../tinker/modelscope/self_cognition.py | 2 +- .../client/tinker/self_host/self_cognition.py | 2 +- .../Usage Guide/Introduction-with-Qwen3.5.md | 2 +- docs/source_en/Usage Guide/Quick-Start.md | 2 +- .../Tinker-Compatible-Client.md | 2 +- .../Usage Guide/Train-as-a-Service.md | 2 +- ...00\344\275\263\345\256\236\350\267\265.md" | 2 +- ...53\351\200\237\345\274\200\345\247\213.md" | 2 +- ...71\345\256\242\346\210\267\347\253\257.md" | 2 +- ...55\347\273\203\346\234\215\345\212\241.md" | 2 +- src/twinkle/server/gateway/server.py | 100 ++- .../server/gateway/tinker_gateway_handlers.py | 105 +-- .../gateway/twinkle_gateway_handlers.py | 5 +- src/twinkle/server/model/app.py | 205 +++--- src/twinkle/server/model/backends/common.py | 11 +- src/twinkle/server/model/tinker_handlers.py | 562 ++++++++-------- src/twinkle/server/model/twinkle_handlers.py | 615 +++++++++--------- src/twinkle/server/sampler/app.py | 207 +++--- src/twinkle/server/sampler/tinker_handlers.py | 193 +++--- .../server/sampler/twinkle_handlers.py | 245 +++---- 22 files changed, 1168 insertions(+), 1104 deletions(-) diff --git a/README.md b/README.md index d0693836..877a3890 100644 --- a/README.md +++ b/README.md @@ -213,7 +213,7 @@ from twinkle import init_tinker_client from twinkle.dataloader import DataLoader from twinkle.dataset import Dataset, DatasetMeta from twinkle.preprocessor import SelfCognitionProcessor -from twinkle.server.tinker.common import input_feature_to_datum +from twinkle.server.common import input_feature_to_datum base_model = 'ms://Qwen/Qwen3-30B-A3B-Instruct-2507' base_url='your-base-url' diff --git a/README_ZH.md b/README_ZH.md index 11a6cccc..2ded262f 100644 --- a/README_ZH.md +++ b/README_ZH.md @@ -193,7 +193,7 @@ from twinkle import init_tinker_client from twinkle.dataloader import DataLoader from twinkle.dataset import Dataset, DatasetMeta from twinkle.preprocessor import SelfCognitionProcessor -from twinkle.server.tinker.common import input_feature_to_datum +from twinkle.server.common import input_feature_to_datum base_model = 'ms://Qwen/Qwen3-30B-A3B-Instruct-2507' base_url='your-base-url' diff --git a/cookbook/client/tinker/modelscope/self_cognition.py b/cookbook/client/tinker/modelscope/self_cognition.py index f8d2a607..cb3b1700 100644 --- a/cookbook/client/tinker/modelscope/self_cognition.py +++ b/cookbook/client/tinker/modelscope/self_cognition.py @@ -15,7 +15,7 @@ from twinkle.dataloader import DataLoader from twinkle.dataset import Dataset, DatasetMeta from twinkle.preprocessor import SelfCognitionProcessor -from twinkle.server.tinker.common import input_feature_to_datum +from twinkle.server.common import input_feature_to_datum # Initialize the Tinker client before importing ServiceClient init_tinker_client() diff --git a/cookbook/client/tinker/self_host/self_cognition.py b/cookbook/client/tinker/self_host/self_cognition.py index e285cc7f..81125e53 100644 --- a/cookbook/client/tinker/self_host/self_cognition.py +++ b/cookbook/client/tinker/self_host/self_cognition.py @@ -16,7 +16,7 @@ from twinkle.dataloader import DataLoader from twinkle.dataset import Dataset, DatasetMeta from twinkle.preprocessor import SelfCognitionProcessor -from twinkle.server.tinker.common import input_feature_to_datum +from twinkle.server.common import input_feature_to_datum # Initialize the Tinker client before importing ServiceClient init_tinker_client() diff --git a/docs/source_en/Usage Guide/Introduction-with-Qwen3.5.md b/docs/source_en/Usage Guide/Introduction-with-Qwen3.5.md index 2f67e37b..3ef72e3e 100644 --- a/docs/source_en/Usage Guide/Introduction-with-Qwen3.5.md +++ b/docs/source_en/Usage Guide/Introduction-with-Qwen3.5.md @@ -458,7 +458,7 @@ from twinkle import init_tinker_client from twinkle.dataloader import DataLoader from twinkle.dataset import Dataset, DatasetMeta from twinkle.preprocessor import SelfCognitionProcessor -from twinkle.server.tinker.common import input_feature_to_datum +from twinkle.server.common import input_feature_to_datum # Initialize Tinker client (must be called before importing ServiceClient) init_tinker_client() diff --git a/docs/source_en/Usage Guide/Quick-Start.md b/docs/source_en/Usage Guide/Quick-Start.md index 24820fea..cde5bf19 100644 --- a/docs/source_en/Usage Guide/Quick-Start.md +++ b/docs/source_en/Usage Guide/Quick-Start.md @@ -679,7 +679,7 @@ from tinker import ServiceClient from twinkle.dataloader import DataLoader from twinkle.dataset import Dataset, DatasetMeta from twinkle.preprocessor import SelfCognitionProcessor -from twinkle.server.tinker.common import input_feature_to_datum +from twinkle.server.common import input_feature_to_datum # The base model to fine-tune / evaluate base_model = 'ms://Qwen/Qwen3.5-4B' diff --git a/docs/source_en/Usage Guide/Server and Client/Tinker-Compatible-Client.md b/docs/source_en/Usage Guide/Server and Client/Tinker-Compatible-Client.md index a01fd141..e44f3cea 100644 --- a/docs/source_en/Usage Guide/Server and Client/Tinker-Compatible-Client.md +++ b/docs/source_en/Usage Guide/Server and Client/Tinker-Compatible-Client.md @@ -143,7 +143,7 @@ from twinkle import init_tinker_client from twinkle.dataloader import DataLoader from twinkle.dataset import Dataset, DatasetMeta from twinkle.preprocessor import SelfCognitionProcessor -from twinkle.server.tinker.common import input_feature_to_datum +from twinkle.server.common import input_feature_to_datum # Initialize Tinker client before importing ServiceClient init_tinker_client() diff --git a/docs/source_en/Usage Guide/Train-as-a-Service.md b/docs/source_en/Usage Guide/Train-as-a-Service.md index fd6c30f3..29828091 100644 --- a/docs/source_en/Usage Guide/Train-as-a-Service.md +++ b/docs/source_en/Usage Guide/Train-as-a-Service.md @@ -28,7 +28,7 @@ from twinkle_client import init_tinker_client from twinkle.dataloader import DataLoader from twinkle.dataset import Dataset, DatasetMeta from twinkle.preprocessor import SelfCognitionProcessor -from twinkle.server.tinker.common import input_feature_to_datum +from twinkle.server.common import input_feature_to_datum base_model = 'ms://Qwen/Qwen3-30B-A3B-Instruct-2507' base_url='http://www.modelscope.cn/twinkle' diff --git "a/docs/source_zh/\344\275\277\347\224\250\346\214\207\345\274\225/Qwen3.5\346\234\200\344\275\263\345\256\236\350\267\265.md" "b/docs/source_zh/\344\275\277\347\224\250\346\214\207\345\274\225/Qwen3.5\346\234\200\344\275\263\345\256\236\350\267\265.md" index 8b86b9b0..cfb57655 100644 --- "a/docs/source_zh/\344\275\277\347\224\250\346\214\207\345\274\225/Qwen3.5\346\234\200\344\275\263\345\256\236\350\267\265.md" +++ "b/docs/source_zh/\344\275\277\347\224\250\346\214\207\345\274\225/Qwen3.5\346\234\200\344\275\263\345\256\236\350\267\265.md" @@ -458,7 +458,7 @@ from twinkle import init_tinker_client from twinkle.dataloader import DataLoader from twinkle.dataset import Dataset, DatasetMeta from twinkle.preprocessor import SelfCognitionProcessor -from twinkle.server.tinker.common import input_feature_to_datum +from twinkle.server.common import input_feature_to_datum # 初始化 Tinker 客户端(必须在导入 ServiceClient 之前) init_tinker_client() diff --git "a/docs/source_zh/\344\275\277\347\224\250\346\214\207\345\274\225/\345\277\253\351\200\237\345\274\200\345\247\213.md" "b/docs/source_zh/\344\275\277\347\224\250\346\214\207\345\274\225/\345\277\253\351\200\237\345\274\200\345\247\213.md" index 5e4cbf0d..b79126af 100644 --- "a/docs/source_zh/\344\275\277\347\224\250\346\214\207\345\274\225/\345\277\253\351\200\237\345\274\200\345\247\213.md" +++ "b/docs/source_zh/\344\275\277\347\224\250\346\214\207\345\274\225/\345\277\253\351\200\237\345\274\200\345\247\213.md" @@ -681,7 +681,7 @@ from tinker import ServiceClient from twinkle.dataloader import DataLoader from twinkle.dataset import Dataset, DatasetMeta from twinkle.preprocessor import SelfCognitionProcessor -from twinkle.server.tinker.common import input_feature_to_datum +from twinkle.server.common import input_feature_to_datum # The base model to fine-tune / evaluate base_model = 'Qwen/Qwen3.5-4B' diff --git "a/docs/source_zh/\344\275\277\347\224\250\346\214\207\345\274\225/\346\234\215\345\212\241\347\253\257\345\222\214\345\256\242\346\210\267\347\253\257/Tinker\345\205\274\345\256\271\345\256\242\346\210\267\347\253\257.md" "b/docs/source_zh/\344\275\277\347\224\250\346\214\207\345\274\225/\346\234\215\345\212\241\347\253\257\345\222\214\345\256\242\346\210\267\347\253\257/Tinker\345\205\274\345\256\271\345\256\242\346\210\267\347\253\257.md" index 11b51303..27db69b2 100644 --- "a/docs/source_zh/\344\275\277\347\224\250\346\214\207\345\274\225/\346\234\215\345\212\241\347\253\257\345\222\214\345\256\242\346\210\267\347\253\257/Tinker\345\205\274\345\256\271\345\256\242\346\210\267\347\253\257.md" +++ "b/docs/source_zh/\344\275\277\347\224\250\346\214\207\345\274\225/\346\234\215\345\212\241\347\253\257\345\222\214\345\256\242\346\210\267\347\253\257/Tinker\345\205\274\345\256\271\345\256\242\346\210\267\347\253\257.md" @@ -143,7 +143,7 @@ from twinkle import init_tinker_client from twinkle.dataloader import DataLoader from twinkle.dataset import Dataset, DatasetMeta from twinkle.preprocessor import SelfCognitionProcessor -from twinkle.server.tinker.common import input_feature_to_datum +from twinkle.server.common import input_feature_to_datum # 在导入 ServiceClient 之前,先初始化 Tinker 客户端 init_tinker_client() diff --git "a/docs/source_zh/\344\275\277\347\224\250\346\214\207\345\274\225/\350\256\255\347\273\203\346\234\215\345\212\241.md" "b/docs/source_zh/\344\275\277\347\224\250\346\214\207\345\274\225/\350\256\255\347\273\203\346\234\215\345\212\241.md" index c0d5b68f..5d0272c3 100644 --- "a/docs/source_zh/\344\275\277\347\224\250\346\214\207\345\274\225/\350\256\255\347\273\203\346\234\215\345\212\241.md" +++ "b/docs/source_zh/\344\275\277\347\224\250\346\214\207\345\274\225/\350\256\255\347\273\203\346\234\215\345\212\241.md" @@ -31,7 +31,7 @@ from twinkle_client import init_tinker_client from twinkle.dataloader import DataLoader from twinkle.dataset import Dataset, DatasetMeta from twinkle.preprocessor import SelfCognitionProcessor -from twinkle.server.tinker.common import input_feature_to_datum +from twinkle.server.common import input_feature_to_datum base_model = 'ms://Qwen/Qwen3-30B-A3B-Instruct-2507' base_url='http://www.modelscope.cn/twinkle' diff --git a/src/twinkle/server/gateway/server.py b/src/twinkle/server/gateway/server.py index 06f9f37a..dd591ccf 100644 --- a/src/twinkle/server/gateway/server.py +++ b/src/twinkle/server/gateway/server.py @@ -23,6 +23,51 @@ logger = get_logger() +class GatewayServer: + """Unified gateway server handling both Tinker and Twinkle API clients.""" + + def __init__(self, + supported_models: list | None = None, + server_config: dict[str, Any] = {}, + http_options: dict[str, Any] | None = None, + **kwargs) -> None: + self.state = get_server_state(**server_config) + self.route_prefix = kwargs.get('route_prefix', '/api/v1') + self.http_options = http_options or {} + self.proxy = ServiceProxy(http_options=http_options, route_prefix=self.route_prefix) + self.supported_models = self._normalize_models(supported_models) or [ + tinker_types.SupportedModel(model_name='Qwen/Qwen3-30B-A3B-Instruct-2507'), + ] + self._modelscope_config_lock = asyncio.Lock() + + def _normalize_models(self, supported_models): + if not supported_models: + return [] + normalized = [] + for item in supported_models: + if isinstance(item, tinker_types.SupportedModel): + normalized.append(item) + elif isinstance(item, dict): + normalized.append(tinker_types.SupportedModel(**item)) + elif isinstance(item, str): + normalized.append(tinker_types.SupportedModel(model_name=item)) + return normalized + + def _validate_base_model(self, base_model: str) -> None: + supported_model_names = [m.model_name for m in self.supported_models] + if base_model not in supported_model_names: + raise HTTPException( + status_code=400, + detail=f"Base model '{base_model}' is not supported. " + f"Supported models: {', '.join(supported_model_names)}") + + def _get_base_model(self, model_id: str) -> str: + metadata = self.state.get_model_metadata(model_id) + if metadata and metadata.get('base_model'): + return metadata['base_model'] + raise HTTPException(status_code=404, detail=f'Model {model_id} not found') + + def build_server_app(deploy_options: dict[str, Any], supported_models: list | None = None, server_config: dict[str, Any] = {}, @@ -30,7 +75,7 @@ def build_server_app(deploy_options: dict[str, Any], **kwargs): """Build and configure the unified gateway server application. - Serves Tinker endpoints at /tinker/* and Twinkle endpoints at /twinkle/*. + Serves Tinker endpoints at /* and Twinkle endpoints at /twinkle/*. Args: deploy_options: Ray Serve deployment configuration @@ -42,64 +87,17 @@ def build_server_app(deploy_options: dict[str, Any], Returns: Configured Ray Serve deployment bound with options """ - - # Build the FastAPI app and register all routes BEFORE serve.ingress so that - # the frozen app contains the complete route table (visible to ProxyActor). app = FastAPI() @app.middleware('http') async def verify_token(request: Request, call_next): return await verify_request_token(request=request, call_next=call_next) - def get_gw(): + def get_self() -> GatewayServer: return serve.get_replica_context().servable_object - _register_tinker_routes(app, get_gw) - _register_twinkle_routes(app, get_gw) - - class GatewayServer: - """Unified gateway server handling both Tinker and Twinkle API clients.""" - - def __init__(self, - supported_models: list | None = None, - server_config: dict[str, Any] = {}, - http_options: dict[str, Any] | None = None, - **kwargs) -> None: - self.state = get_server_state(**server_config) - self.route_prefix = kwargs.get('route_prefix', '/api/v1') - self.http_options = http_options or {} - self.proxy = ServiceProxy(http_options=http_options, route_prefix=self.route_prefix) - self.supported_models = self._normalize_models(supported_models) or [ - tinker_types.SupportedModel(model_name='Qwen/Qwen3-30B-A3B-Instruct-2507'), - ] - self._modelscope_config_lock = asyncio.Lock() - - def _normalize_models(self, supported_models): - if not supported_models: - return [] - normalized = [] - for item in supported_models: - if isinstance(item, tinker_types.SupportedModel): - normalized.append(item) - elif isinstance(item, dict): - normalized.append(tinker_types.SupportedModel(**item)) - elif isinstance(item, str): - normalized.append(tinker_types.SupportedModel(model_name=item)) - return normalized - - def _validate_base_model(self, base_model: str) -> None: - supported_model_names = [m.model_name for m in self.supported_models] - if base_model not in supported_model_names: - raise HTTPException( - status_code=400, - detail=f"Base model '{base_model}' is not supported. " - f"Supported models: {', '.join(supported_model_names)}") - - def _get_base_model(self, model_id: str) -> str: - metadata = self.state.get_model_metadata(model_id) - if metadata and metadata.get('base_model'): - return metadata['base_model'] - raise HTTPException(status_code=404, detail=f'Model {model_id} not found') + _register_tinker_routes(app, get_self) + _register_twinkle_routes(app, get_self) GatewayServerWithIngress = serve.ingress(app)(GatewayServer) DeploymentClass = serve.deployment(name='GatewayServer')(GatewayServerWithIngress) diff --git a/src/twinkle/server/gateway/tinker_gateway_handlers.py b/src/twinkle/server/gateway/tinker_gateway_handlers.py index 7949e8ff..e3a3b503 100644 --- a/src/twinkle/server/gateway/tinker_gateway_handlers.py +++ b/src/twinkle/server/gateway/tinker_gateway_handlers.py @@ -11,11 +11,10 @@ import os from fastapi import Depends, FastAPI, HTTPException, Request, Response from tinker import types -from typing import TYPE_CHECKING, Any +from typing import TYPE_CHECKING, Any, Callable if TYPE_CHECKING: - from twinkle.server.utils.state.server_state import ServerStateProxy - from .proxy import ServiceProxy + from .server import GatewayServer from twinkle.hub import HubOperation from twinkle.server.common.io_utils import create_checkpoint_manager, create_training_run_manager @@ -26,7 +25,7 @@ logger = get_logger() -def _register_tinker_routes(app: FastAPI, self_fn) -> None: +def _register_tinker_routes(app: FastAPI, self_fn: Callable[[], GatewayServer]) -> None: """Register all /* Tinker routes on the given FastAPI app. self_fn is a zero-argument callable that returns the current GatewayServer @@ -39,37 +38,45 @@ async def healthz(request: Request) -> types.HealthResponse: return types.HealthResponse(status='ok') @app.get('/get_server_capabilities') - async def get_server_capabilities(request: Request, gw=Depends(self_fn)) -> types.GetServerCapabilitiesResponse: - return types.GetServerCapabilitiesResponse(supported_models=gw.supported_models) + async def get_server_capabilities( + request: Request, + self: GatewayServer = Depends(self_fn), + ) -> types.GetServerCapabilitiesResponse: + return types.GetServerCapabilitiesResponse(supported_models=self.supported_models) @app.post('/telemetry') async def telemetry(request: Request, body: types.TelemetrySendRequest) -> types.TelemetryResponse: return types.TelemetryResponse(status='accepted') @app.post('/create_session') - async def create_session(request: Request, body: types.CreateSessionRequest, - gw=Depends(self_fn)) -> types.CreateSessionResponse: - session_id = gw.state.create_session(body.model_dump()) + async def create_session( + request: Request, + body: types.CreateSessionRequest, + self: GatewayServer = Depends(self_fn), + ) -> types.CreateSessionResponse: + session_id = self.state.create_session(body.model_dump()) return types.CreateSessionResponse(session_id=session_id) @app.post('/session_heartbeat') async def session_heartbeat( - request: Request, body: types.SessionHeartbeatRequest, - gw=Depends(self_fn)) -> types.SessionHeartbeatResponse: # noqa: E125 - alive = gw.state.touch_session(body.session_id) + request: Request, body: types.SessionHeartbeatRequest, self: GatewayServer = Depends(self_fn) + ) -> types.SessionHeartbeatResponse: # noqa: E125 + alive = self.state.touch_session(body.session_id) if not alive: raise HTTPException(status_code=404, detail='Unknown session') return types.SessionHeartbeatResponse() @app.post('/create_sampling_session') async def create_sampling_session( - request: Request, body: types.CreateSamplingSessionRequest, - gw=Depends(self_fn)) -> types.CreateSamplingSessionResponse: # noqa: E125 - sampling_session_id = gw.state.create_sampling_session(body.model_dump()) + request: Request, body: types.CreateSamplingSessionRequest, self: GatewayServer = Depends(self_fn) + ) -> types.CreateSamplingSessionResponse: # noqa: E125 + sampling_session_id = self.state.create_sampling_session(body.model_dump()) return types.CreateSamplingSessionResponse(sampling_session_id=sampling_session_id) @app.post('/retrieve_future') - async def retrieve_future(request: Request, body: types.FutureRetrieveRequest, gw=Depends(self_fn)) -> Any: + async def retrieve_future(request: Request, + body: types.FutureRetrieveRequest, + self: GatewayServer = Depends(self_fn)) -> Any: """Retrieve the result of an async task with long polling.""" request_id = body.request_id max_wait = float(os.environ.get('TWINKLE_LONG_POLL_TIMEOUT', '30')) @@ -77,7 +84,7 @@ async def retrieve_future(request: Request, body: types.FutureRetrieveRequest, g start = asyncio.get_event_loop().time() while True: - record = gw.state.get_future(request_id) + record = self.state.get_future(request_id) if record is None: return {'type': 'try_again'} @@ -96,7 +103,7 @@ async def retrieve_future(request: Request, body: types.FutureRetrieveRequest, g await asyncio.sleep(poll_interval) - record = gw.state.get_future(request_id) + record = self.state.get_future(request_id) if not record: return {'type': 'try_again'} @@ -167,7 +174,10 @@ async def weights_info(request: Request, body: dict[str, Any]) -> types.WeightsI return response @app.post('/training_runs/{run_id}/checkpoints/{checkpoint_id:path}/publish') - async def publish_checkpoint(request: Request, run_id: str, checkpoint_id: str, gw=Depends(self_fn)) -> Response: + async def publish_checkpoint(request: Request, + run_id: str, + checkpoint_id: str, + self: GatewayServer = Depends(self_fn)) -> Response: token = get_token_from_request(request) training_run_manager = create_training_run_manager(token, client_type='tinker') @@ -183,7 +193,7 @@ async def publish_checkpoint(request: Request, run_id: str, checkpoint_id: str, checkpoint_dir = str(checkpoint_manager.get_ckpt_dir(run_id, checkpoint_id)) - async with gw._modelscope_config_lock: + async with self._modelscope_config_lock: try: from modelscope.hub.api import HubApi, ModelScopeConfig hub_api = HubApi(token=token) @@ -204,50 +214,59 @@ async def publish_checkpoint(request: Request, run_id: str, checkpoint_id: str, # --- Model Proxy Endpoints --- @app.post('/create_model') - async def create_model(request: Request, body: types.CreateModelRequest, gw=Depends(self_fn)) -> Any: - gw._validate_base_model(body.base_model) - return await gw.proxy.proxy_to_model(request, 'create_model', body.base_model) + async def create_model(request: Request, body: types.CreateModelRequest, + self: GatewayServer = Depends(self_fn)) -> Any: + self._validate_base_model(body.base_model) + return await self.proxy.proxy_to_model(request, 'create_model', body.base_model) @app.post('/get_info') - async def get_info(request: Request, body: types.GetInfoRequest, gw=Depends(self_fn)) -> Any: - return await gw.proxy.proxy_to_model(request, 'get_info', gw._get_base_model(body.model_id)) + async def get_info(request: Request, body: types.GetInfoRequest, self: GatewayServer = Depends(self_fn)) -> Any: + return await self.proxy.proxy_to_model(request, 'get_info', self._get_base_model(body.model_id)) @app.post('/unload_model') - async def unload_model(request: Request, body: types.UnloadModelRequest, gw=Depends(self_fn)) -> Any: - return await gw.proxy.proxy_to_model(request, 'unload_model', gw._get_base_model(body.model_id)) + async def unload_model(request: Request, body: types.UnloadModelRequest, + self: GatewayServer = Depends(self_fn)) -> Any: + return await self.proxy.proxy_to_model(request, 'unload_model', self._get_base_model(body.model_id)) @app.post('/forward') - async def forward(request: Request, body: types.ForwardRequest, gw=Depends(self_fn)) -> Any: - return await gw.proxy.proxy_to_model(request, 'forward', gw._get_base_model(body.model_id)) + async def forward(request: Request, body: types.ForwardRequest, self: GatewayServer = Depends(self_fn)) -> Any: + return await self.proxy.proxy_to_model(request, 'forward', self._get_base_model(body.model_id)) @app.post('/forward_backward') - async def forward_backward(request: Request, body: types.ForwardBackwardRequest, gw=Depends(self_fn)) -> Any: - return await gw.proxy.proxy_to_model(request, 'forward_backward', gw._get_base_model(body.model_id)) + async def forward_backward(request: Request, + body: types.ForwardBackwardRequest, + self: GatewayServer = Depends(self_fn)) -> Any: + return await self.proxy.proxy_to_model(request, 'forward_backward', self._get_base_model(body.model_id)) @app.post('/optim_step') - async def optim_step(request: Request, body: types.OptimStepRequest, gw=Depends(self_fn)) -> Any: - return await gw.proxy.proxy_to_model(request, 'optim_step', gw._get_base_model(body.model_id)) + async def optim_step(request: Request, body: types.OptimStepRequest, self: GatewayServer = Depends(self_fn)) -> Any: + return await self.proxy.proxy_to_model(request, 'optim_step', self._get_base_model(body.model_id)) @app.post('/save_weights') - async def save_weights(request: Request, body: types.SaveWeightsRequest, gw=Depends(self_fn)) -> Any: - return await gw.proxy.proxy_to_model(request, 'save_weights', gw._get_base_model(body.model_id)) + async def save_weights(request: Request, body: types.SaveWeightsRequest, + self: GatewayServer = Depends(self_fn)) -> Any: + return await self.proxy.proxy_to_model(request, 'save_weights', self._get_base_model(body.model_id)) @app.post('/load_weights') - async def load_weights(request: Request, body: types.LoadWeightsRequest, gw=Depends(self_fn)) -> Any: - return await gw.proxy.proxy_to_model(request, 'load_weights', gw._get_base_model(body.model_id)) + async def load_weights(request: Request, body: types.LoadWeightsRequest, + self: GatewayServer = Depends(self_fn)) -> Any: + return await self.proxy.proxy_to_model(request, 'load_weights', self._get_base_model(body.model_id)) # --- Sampler Proxy Endpoints --- @app.post('/asample') - async def asample(request: Request, body: types.SampleRequest, gw=Depends(self_fn)) -> Any: + async def asample(request: Request, body: types.SampleRequest, self: GatewayServer = Depends(self_fn)) -> Any: base_model = body.base_model if not base_model and body.sampling_session_id: - session = gw.state.get_sampling_session(body.sampling_session_id) + session = self.state.get_sampling_session(body.sampling_session_id) if session: base_model = session.get('base_model') - return await gw.proxy.proxy_to_sampler(request, 'asample', base_model) + return await self.proxy.proxy_to_sampler(request, 'asample', base_model) @app.post('/save_weights_for_sampler') - async def save_weights_for_sampler(request: Request, body: types.SaveWeightsForSamplerRequest, - gw=Depends(self_fn)) -> Any: - return await gw.proxy.proxy_to_model(request, 'save_weights_for_sampler', gw._get_base_model(body.model_id)) + async def save_weights_for_sampler( + request: Request, + body: types.SaveWeightsForSamplerRequest, + self: GatewayServer = Depends(self_fn), + ) -> Any: + return await self.proxy.proxy_to_model(request, 'save_weights_for_sampler', self._get_base_model(body.model_id)) diff --git a/src/twinkle/server/gateway/twinkle_gateway_handlers.py b/src/twinkle/server/gateway/twinkle_gateway_handlers.py index 28bae7ae..471878d9 100644 --- a/src/twinkle/server/gateway/twinkle_gateway_handlers.py +++ b/src/twinkle/server/gateway/twinkle_gateway_handlers.py @@ -7,10 +7,11 @@ from __future__ import annotations from fastapi import FastAPI, HTTPException, Request -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Callable if TYPE_CHECKING: from twinkle.server.utils.state.server_state import ServerStateProxy + from .server import GatewayServer from twinkle.server.common.io_utils import create_checkpoint_manager, create_training_run_manager, validate_user_path from twinkle.server.utils.validation import get_token_from_request @@ -22,7 +23,7 @@ logger = get_logger() -def _register_twinkle_routes(app: FastAPI, self_fn) -> None: +def _register_twinkle_routes(app: FastAPI, self_fn: Callable[[], GatewayServer]) -> None: """Register all /twinkle/* routes on the given FastAPI app.""" @app.get('/twinkle/healthz', response_model=HealthResponse) diff --git a/src/twinkle/server/model/app.py b/src/twinkle/server/model/app.py index 72618616..b9294fd6 100644 --- a/src/twinkle/server/model/app.py +++ b/src/twinkle/server/model/app.py @@ -5,6 +5,8 @@ Builds a single Ray Serve deployment (ModelManagement) that simultaneously handles both Tinker (/tinker/*) and Twinkle (/twinkle/*) model endpoints. """ +from __future__ import annotations + from fastapi import FastAPI, Request from ray import serve from ray.serve.config import RequestRouterConfig @@ -19,20 +21,105 @@ from twinkle.utils.logger import get_logger from ..common.router import StickyLoraRequestRouter from ..utils import wrap_builder_with_device_group_env -from .tinker_handlers import TinkerModelHandlers -from .twinkle_handlers import TwinkleModelHandlers +from .tinker_handlers import _register_tinker_routes +from .twinkle_handlers import _register_twinkle_routes logger = get_logger() +class ModelManagement(TaskQueueMixin, AdapterManagerMixin): + """Unified model management service. + + Handles: + - Base model and multiple LoRA adapters (multi-user) + - Tinker training operations via /tinker/* endpoints (async/polling) + - Twinkle training operations via /twinkle/* endpoints (synchronous) + - Adapter lifecycle via AdapterManagerMixin + - Per-user rate limiting via TaskQueueMixin + """ + + def __init__(self, + model_id: str, + nproc_per_node: int, + device_group: dict[str, Any], + device_mesh: dict[str, Any], + use_megatron: bool = False, + adapter_config: dict[str, Any] = {}, + queue_config: dict[str, Any] | None = None, + **kwargs): + self.device_group = DeviceGroup(**device_group) + twinkle.initialize(mode='ray', nproc_per_node=nproc_per_node, groups=[self.device_group], lazy_collect=False) + if 'mesh_dim_names' in device_mesh: + self.device_mesh = DeviceMesh(**device_mesh) + else: + self.device_mesh = DeviceMesh.from_sizes(**device_mesh) + self.use_megatron = use_megatron + self.replica_id = serve.get_replica_context().replica_id.unique_id + self.max_loras = kwargs.get('max_loras', 5) + self.base_model = model_id + + # Choose model backend + if use_megatron: + from ..model.backends.megatron_model import TwinkleCompatMegatronModel + self.model = TwinkleCompatMegatronModel( + model_id=model_id, + device_mesh=self.device_mesh, + remote_group=self.device_group.name, + instance_id=self.replica_id, + **kwargs) + else: + from ..model.backends.transformers_model import TwinkleCompatTransformersModel + self.model = TwinkleCompatTransformersModel( + model_id=model_id, + device_mesh=self.device_mesh, + remote_group=self.device_group.name, + instance_id=self.replica_id, + **kwargs) + + self.state: ServerStateProxy = get_server_state() + self.state.register_replica(self.replica_id, self.max_loras) + + # Initialize mixins + self._init_task_queue(TaskQueueConfig.from_dict(queue_config)) + self._init_adapter_manager(**adapter_config) + self.start_adapter_countdown() + + @serve.multiplexed(max_num_models_per_replica=5) + async def _sticky_entry(self, sticky_key: str): + return sticky_key + + async def _ensure_sticky(self): + sticky_key = serve.get_multiplexed_model_id() + await self._sticky_entry(sticky_key) + + async def _on_request_start(self, request: Request) -> str: + await self._ensure_sticky() + token = get_token_from_request(request) + return token + + def __del__(self): + self.state.unregister_replica(self.replica_id) + + def _cleanup_adapter(self, adapter_name: str) -> None: + if self.get_adapter_info(adapter_name): + self.clear_adapter_state(adapter_name) + self.model.remove_adapter(adapter_name) + self.unregister_adapter(adapter_name) + self.state.unload_model(adapter_name) + + def _on_adapter_expired(self, adapter_name: str) -> None: + self.fail_pending_tasks_for_model(adapter_name, reason='Adapter expired') + self._cleanup_adapter(adapter_name) + + def build_model_app(model_id: str, nproc_per_node: int, - device_group: Dict[str, Any], - device_mesh: Dict[str, Any], - deploy_options: Dict[str, Any], + device_group: dict[str, Any], + device_mesh: dict[str, Any], + deploy_options: dict[str, Any], use_megatron: bool = False, - adapter_config: Dict[str, Any] = {}, - queue_config: Optional[Dict[str, Any]] = None, + adapter_config: dict[str, Any] = {}, + queue_config: dict[str, Any] | None = None, **kwargs): """Build a unified model management application for distributed training. @@ -52,106 +139,28 @@ def build_model_app(model_id: str, Returns: Configured Ray Serve deployment bound with parameters """ + # Build the FastAPI app and register all routes BEFORE serve.ingress so that + # the frozen app contains the complete route table (visible to ProxyActor). app = FastAPI() @app.middleware('http') async def verify_token(request: Request, call_next): return await verify_request_token(request=request, call_next=call_next) - # Register routes BEFORE @serve.ingress so Ray Serve captures them at decoration time - TinkerModelHandlers._register_tinker_routes(app, model_id) - TwinkleModelHandlers._register_twinkle_routes(app, model_id) + def get_self() -> ModelManagement: + return serve.get_replica_context().servable_object + + _register_tinker_routes(app, get_self) + _register_twinkle_routes(app, get_self) - @serve.deployment( + ModelManagementWithIngress = serve.ingress(app)(ModelManagement) + DeploymentClass = serve.deployment( name='ModelManagement', request_router_config=RequestRouterConfig(request_router_class=StickyLoraRequestRouter), - ) - @serve.ingress(app) - class ModelManagement(TaskQueueMixin, AdapterManagerMixin): - """Unified model management service. - - Handles: - - Base model and multiple LoRA adapters (multi-user) - - Tinker training operations via /tinker/* endpoints (async/polling) - - Twinkle training operations via /twinkle/* endpoints (synchronous) - - Adapter lifecycle via AdapterManagerMixin - - Per-user rate limiting via TaskQueueMixin - """ - - def __init__(self, - nproc_per_node: int, - device_group: Dict[str, Any], - device_mesh: Dict[str, Any], - use_megatron: bool = False, - queue_config: Optional[Dict[str, Any]] = None, - **kwargs): - self.device_group = DeviceGroup(**device_group) - twinkle.initialize( - mode='ray', nproc_per_node=nproc_per_node, groups=[self.device_group], lazy_collect=False) - if 'mesh_dim_names' in device_mesh: - self.device_mesh = DeviceMesh(**device_mesh) - else: - self.device_mesh = DeviceMesh.from_sizes(**device_mesh) - self.use_megatron = use_megatron - self.replica_id = serve.get_replica_context().replica_id.unique_id - self.max_loras = kwargs.get('max_loras', 5) - - # Choose model backend - if use_megatron: - from ..model.backends.megatron_model import TwinkleCompatMegatronModel - self.model = TwinkleCompatMegatronModel( - model_id=model_id, - device_mesh=self.device_mesh, - remote_group=self.device_group.name, - instance_id=self.replica_id, - **kwargs) - else: - from ..model.backends.transformers_model import TwinkleCompatTransformersModel - self.model = TwinkleCompatTransformersModel( - model_id=model_id, - device_mesh=self.device_mesh, - remote_group=self.device_group.name, - instance_id=self.replica_id, - **kwargs) - - self.base_model = model_id - self.state: ServerStateProxy = get_server_state() - self.state.register_replica(self.replica_id, self.max_loras) - - # Initialize mixins - self._init_task_queue(TaskQueueConfig.from_dict(queue_config)) - self._init_adapter_manager(**adapter_config) - self.start_adapter_countdown() - - @serve.multiplexed(max_num_models_per_replica=kwargs.get('max_loras', 5)) - async def _sticky_entry(self, sticky_key: str): - return sticky_key - - async def _ensure_sticky(self): - sticky_key = serve.get_multiplexed_model_id() - await self._sticky_entry(sticky_key) - - async def _on_request_start(self, request: Request) -> str: - await self._ensure_sticky() - token = get_token_from_request(request) - return token - - def __del__(self): - self.state.unregister_replica(self.replica_id) - - def _cleanup_adapter(self, adapter_name: str) -> None: - if self.get_adapter_info(adapter_name): - self.clear_adapter_state(adapter_name) - self.model.remove_adapter(adapter_name) - self.unregister_adapter(adapter_name) - self.state.unload_model(adapter_name) - - def _on_adapter_expired(self, adapter_name: str) -> None: - self.fail_pending_tasks_for_model(adapter_name, reason='Adapter expired') - self._cleanup_adapter(adapter_name) - - return ModelManagement.options(**deploy_options).bind(nproc_per_node, device_group, device_mesh, use_megatron, - queue_config, **kwargs) + )( + ModelManagementWithIngress) + return DeploymentClass.options(**deploy_options).bind(model_id, nproc_per_node, device_group, device_mesh, + use_megatron, adapter_config, queue_config, **kwargs) build_model_app = wrap_builder_with_device_group_env(build_model_app) diff --git a/src/twinkle/server/model/backends/common.py b/src/twinkle/server/model/backends/common.py index 3553d657..e1f62e23 100644 --- a/src/twinkle/server/model/backends/common.py +++ b/src/twinkle/server/model/backends/common.py @@ -84,7 +84,7 @@ def _to_float(v): s = value.strip() if s: try: - head, unit = s.split() + head, unit = s.split(maxsplit=1) cleaned[f'{key}/{unit}'] = float(head) except Exception: m = re.match(r'^([+-]?(?:\d+(?:\.\d*)?|\.\d+)(?:[eE][+-]?\d+)?)', s) @@ -104,7 +104,12 @@ def get_template(self, adapter_name: str) -> Template: def _get_forward_output(inputs: List[types.Datum], logits: torch.Tensor, logps: torch.Tensor) -> List[dict]: """Convert raw logits to the expected output format with logprobs and elementwise_loss.""" from twinkle.utils.torch_utils import selective_log_softmax - device = logits.device if logits is not None else logps.device + if logps is not None: + device = logps.device + elif logits is not None: + device = logits.device + else: + raise ValueError('At least one of logits or logps must be provided.') results = [] if logits is None: logits = [None] * len(inputs) @@ -115,7 +120,7 @@ def _get_forward_output(inputs: List[types.Datum], logits: torch.Tensor, logps: seq_len = labels.numel() if logps is None: - assert logits is not None + assert logit is not None, 'logit must not be None when logps is None' feature_logits = logit[:seq_len, :] token_log_probs = selective_log_softmax(feature_logits, labels) else: diff --git a/src/twinkle/server/model/tinker_handlers.py b/src/twinkle/server/model/tinker_handlers.py index 3e65ce06..ca59b808 100644 --- a/src/twinkle/server/model/tinker_handlers.py +++ b/src/twinkle/server/model/tinker_handlers.py @@ -3,12 +3,18 @@ Tinker-compatible model handler mixin. All endpoints are prefixed /tinker/... and use schedule_task() returning UntypedAPIFuture. +self_fn is injected via FastAPI Depends to obtain the ModelManagement instance at request time. """ +from __future__ import annotations + import traceback -from fastapi import FastAPI, Request +from fastapi import Depends, FastAPI, Request from peft import LoraConfig from tinker import types -from typing import Any +from typing import TYPE_CHECKING, Any, Callable + +if TYPE_CHECKING: + from .app import ModelManagement from twinkle.server.common.io_utils import create_checkpoint_manager, create_training_run_manager from twinkle.utils.logger import get_logger @@ -16,279 +22,287 @@ logger = get_logger() -class TinkerModelHandlers: - """ - Mixin providing Tinker-compatible model management endpoints. +def _register_tinker_routes(app: FastAPI, self_fn: Callable[[], ModelManagement]) -> None: + """Register all /tinker/* routes on the given FastAPI app. - Expects the combined class to also inherit TaskQueueMixin and AdapterManagerMixin, - and to have: - self.model, self.state, self.device_mesh, self.base_model, self.replica_id + self_fn is a zero-argument callable that returns the current ModelManagement + replica instance. It is wired in via Depends so it is resolved lazily at request time. """ - @staticmethod - def _register_tinker_routes(app: FastAPI, model_id: str): - """Register all tinker routes on the given FastAPI app. - - This is called once during build_model_app to wire routes. - """ - - @app.post('/tinker/create_model') - async def create_model(self, request: Request, body: types.CreateModelRequest) -> types.UntypedAPIFuture: - token = await self._on_request_start(request) - - async def _create_adapter(): - _model_id = None - try: - _model_id = self.state.register_model(body.model_dump(), token=token, replica_id=self.replica_id) - if body.lora_config: - lora_cfg = LoraConfig(r=body.lora_config.rank, target_modules='all-linear') - adapter_name = self.get_adapter_name(adapter_name=_model_id) - self.register_adapter(adapter_name, token, session_id=body.session_id) - self.model.add_adapter_to_model(adapter_name=adapter_name, config_or_dir=lora_cfg) - self.model.set_template('Template', adapter_name=adapter_name, model_id=model_id) - self.model.set_processor('InputProcessor', adapter_name=adapter_name) - self.model.set_optimizer('Adam', adapter_name=adapter_name) - self.set_adapter_state(adapter_name, 'grad_ready', False) - training_run_manager = create_training_run_manager(token, client_type='tinker') - training_run_manager.save(_model_id, body) - return types.CreateModelResponse(model_id=_model_id) - except Exception: - if _model_id: - adapter_name = self.get_adapter_name(adapter_name=_model_id) - self._cleanup_adapter(adapter_name) - logger.error(traceback.format_exc()) - return types.RequestFailedResponse( - error=traceback.format_exc(), - category=types.RequestErrorCategory.Server, - ) - - return await self.schedule_task(_create_adapter, token=token, task_type='create_model') - - @app.post('/tinker/get_info') - async def get_info(self, request: Request, body: types.GetInfoRequest) -> types.GetInfoResponse: - token = await self._on_request_start(request) - training_run_manager = create_training_run_manager(token, client_type='tinker') - metadata = training_run_manager.get(str(body.model_id)) - model_name = metadata.base_model if metadata else model_id - lora_rank = None - is_lora = False - if metadata and hasattr(metadata, 'lora_rank') and metadata.lora_rank: - lora_rank = metadata.lora_rank - is_lora = metadata.is_lora - return types.GetInfoResponse( - model_data=types.ModelData(model_name=model_name), - model_id=body.model_id, - is_lora=is_lora, - lora_rank=lora_rank, - model_name=model_name, - ) - - @app.post('/tinker/unload_model') - async def unload_model(self, request: Request, body: types.UnloadModelRequest) -> types.UntypedAPIFuture: - token = await self._on_request_start(request) - - async def _do_unload(): - adapter_name = self.get_adapter_name(adapter_name=body.model_id) - self._cleanup_adapter(adapter_name) - return types.UnloadModelResponse(model_id=body.model_id) - - return await self.schedule_task(_do_unload, model_id=body.model_id, token=token, task_type='unload_model') - - @app.post('/tinker/forward') - async def forward(self, request: Request, body: types.ForwardRequest) -> types.UntypedAPIFuture: - token = await self._on_request_start(request) - - async def _do_forward(): - try: - adapter_name = self.get_adapter_name(adapter_name=body.model_id) - self.assert_adapter_exists(adapter_name=adapter_name) - self.touch_adapter(adapter_name) - datum_list = body.forward_input.data - loss_fn_config = body.forward_input.loss_fn_config or {} - output = self.model.forward_only(inputs=datum_list, adapter_name=adapter_name) - loss = self.model.calculate_loss(adapter_name=adapter_name, **loss_fn_config) - return types.ForwardBackwardOutput( - loss_fn_output_type='CrossEntropyLossReturn', - loss_fn_outputs=output, - metrics={'loss:sum': loss}, - ) - except Exception: - logger.error(traceback.format_exc()) - return types.RequestFailedResponse( - error=traceback.format_exc(), - category=types.RequestErrorCategory.Server, - ) - - datum_list = body.forward_input.data - input_tokens = sum(len(d.model_input.to_ints()) for d in datum_list) - batch_size = len(datum_list) - return await self.schedule_task( - _do_forward, - model_id=body.model_id, - token=token, - input_tokens=input_tokens, - batch_size=batch_size, - data_world_size=self.device_mesh.data_world_size, - task_type='forward', - ) - - @app.post('/tinker/forward_backward') - async def forward_backward(self, request: Request, - body: types.ForwardBackwardRequest) -> types.UntypedAPIFuture: - token = await self._on_request_start(request) - - async def _do_forward_backward(): - try: - adapter_name = self.get_adapter_name(adapter_name=body.model_id) - self.assert_adapter_exists(adapter_name=adapter_name) - self.touch_adapter(adapter_name) - datum_list = body.forward_backward_input.data - loss_fn = body.forward_backward_input.loss_fn - loss_fn_config = body.forward_backward_input.loss_fn_config or {} - output, loss = self.model.forward_backward( - inputs=datum_list, adapter_name=adapter_name, loss_fn=loss_fn, **loss_fn_config) - if loss_fn == 'importance_sampling': - output_type = 'ImportanceSamplingLossReturn' - else: - output_type = 'CrossEntropyLossReturn' - self.set_adapter_state(adapter_name, 'grad_ready', True) - return types.ForwardBackwardOutput( - loss_fn_output_type=output_type, - loss_fn_outputs=output, - metrics={'loss:avg': loss}, - ) - except Exception: - logger.error(traceback.format_exc()) - return types.RequestFailedResponse( - error=traceback.format_exc(), - category=types.RequestErrorCategory.Server, - ) - - datum_list = body.forward_backward_input.data - input_tokens = sum(len(d.model_input.to_ints()) for d in datum_list) - batch_size = len(datum_list) - return await self.schedule_task( - _do_forward_backward, - model_id=body.model_id, - token=token, - input_tokens=input_tokens, - batch_size=batch_size, - data_world_size=self.device_mesh.data_world_size, - task_type='forward_backward', - ) - - @app.post('/tinker/optim_step') - async def optim_step(self, request: Request, body: types.OptimStepRequest) -> types.UntypedAPIFuture: - token = await self._on_request_start(request) - - async def _do_optim(): - try: - adapter_name = self.get_adapter_name(adapter_name=body.model_id) - self.assert_adapter_exists(adapter_name=adapter_name) - if not self.get_adapter_state(adapter_name, 'grad_ready', False): - raise RuntimeError( - f'No accumulated gradients for adapter={adapter_name}; call forward_backward before optim_step' # noqa: E501 - ) - self.touch_adapter(adapter_name) - self.model.step(adam_params=body.adam_params, adapter_name=adapter_name) - self.set_adapter_state(adapter_name, 'grad_ready', False) - metrics = self.model.calculate_metric(is_training=True, adapter_name=adapter_name) - return types.OptimStepResponse(metrics=metrics) - except Exception: - logger.error(traceback.format_exc()) - return types.RequestFailedResponse( - error=traceback.format_exc(), - category=types.RequestErrorCategory.Server, - ) - - return await self.schedule_task(_do_optim, model_id=body.model_id, token=token, task_type='optim_step') - - @app.post('/tinker/save_weights') - async def save_weights(self, request: Request, body: types.SaveWeightsRequest) -> types.UntypedAPIFuture: - token = await self._on_request_start(request) - - async def _do_save(): - try: - adapter_name = self.get_adapter_name(adapter_name=body.model_id) - self.assert_adapter_exists(adapter_name=adapter_name) - self.touch_adapter(adapter_name) - checkpoint_manager = create_checkpoint_manager(token, client_type='tinker') - checkpoint_name = checkpoint_manager.get_ckpt_name(body.path) - save_dir = checkpoint_manager.get_save_dir(model_id=body.model_id, is_sampler=False) - self.model.save( - name=checkpoint_name, output_dir=save_dir, adapter_name=adapter_name, save_optimizer=True) - tinker_path = checkpoint_manager.save(body.model_id, name=checkpoint_name, is_sampler=False) - return types.SaveWeightsResponse(path=tinker_path, type='save_weights') - except Exception: - logger.error(traceback.format_exc()) - return types.RequestFailedResponse( - error=traceback.format_exc(), - category=types.RequestErrorCategory.Server, - ) - - return await self.schedule_task(_do_save, model_id=body.model_id, token=token, task_type='save_weights') - - @app.post('/tinker/save_weights_for_sampler') - async def save_weights_for_sampler(self, request: Request, - body: types.SaveWeightsForSamplerRequest) -> types.UntypedAPIFuture: - token = await self._on_request_start(request) - - async def _do_save_for_sampler(): - try: - adapter_name = self.get_adapter_name(adapter_name=body.model_id) - self.assert_adapter_exists(adapter_name=adapter_name) - self.touch_adapter(adapter_name) - checkpoint_manager = create_checkpoint_manager(token, client_type='tinker') - checkpoint_name = checkpoint_manager.get_ckpt_name(body.path) - save_dir = checkpoint_manager.get_save_dir(model_id=body.model_id, is_sampler=True) - tinker_path = checkpoint_manager.save(body.model_id, name=checkpoint_name, is_sampler=True) - logger.info(f'Saving weights to {save_dir}') - self.model.save( - name=checkpoint_name, output_dir=save_dir, adapter_name=adapter_name, save_optimizer=False) - payload = body.model_dump() - payload['model_path'] = tinker_path - metadata = self.state.get_model_metadata(body.model_id) or {} - if metadata.get('base_model'): - payload['base_model'] = metadata['base_model'] - sampling_session_id = self.state.create_sampling_session(payload) - return types.SaveWeightsForSamplerResponseInternal( - path=None, sampling_session_id=sampling_session_id) - except Exception: - logger.error(traceback.format_exc()) - return types.RequestFailedResponse( - error=traceback.format_exc(), - category=types.RequestErrorCategory.Server, - ) - - return await self.schedule_task( - _do_save_for_sampler, model_id=body.model_id, token=token, task_type='save_weights_for_sampler') - - @app.post('/tinker/load_weights') - async def load_weights(self, request: Request, body: types.LoadWeightsRequest) -> types.UntypedAPIFuture: - token = await self._on_request_start(request) - - async def _do_load(): - try: - assert self.model is not None, 'Model not loaded, please load model first' - adapter_name = self.get_adapter_name(adapter_name=body.model_id) - self.assert_adapter_exists(adapter_name=adapter_name) - self.touch_adapter(adapter_name) - self.model.load( - checkpoint_dir=body.path, load_optimizer=body.optimizer, adapter_name=adapter_name, token=token) + @app.post('/tinker/create_model') + async def create_model( + request: Request, + body: types.CreateModelRequest, + self: ModelManagement = Depends(self_fn), + ) -> types.UntypedAPIFuture: + token = await self._on_request_start(request) + + async def _create_adapter(): + _model_id = None + try: + _model_id = self.state.register_model(body.model_dump(), token=token, replica_id=self.replica_id) + if body.lora_config: + lora_cfg = LoraConfig(r=body.lora_config.rank, target_modules='all-linear') + adapter_name = self.get_adapter_name(adapter_name=_model_id) + self.register_adapter(adapter_name, token, session_id=body.session_id) + self.model.add_adapter_to_model(adapter_name=adapter_name, config_or_dir=lora_cfg) + self.model.set_template('Template', adapter_name=adapter_name, model_id=self.base_model) + self.model.set_processor('InputProcessor', adapter_name=adapter_name) + self.model.set_optimizer('Adam', adapter_name=adapter_name) self.set_adapter_state(adapter_name, 'grad_ready', False) - return types.LoadWeightsResponse(path=body.path, type='load_weights') - except Exception: - logger.error(traceback.format_exc()) - return types.RequestFailedResponse( - error=traceback.format_exc(), - category=types.RequestErrorCategory.Server, - ) - - return await self.schedule_task(_do_load, model_id=body.model_id, token=token, task_type='load_weights') - - # Tinker uses {request_id}-{adapter_name} prefix via self.get_adapter_name() - # which is inherited from AdapterManagerMixin (no-op here; method kept for clarity). - @staticmethod - def get_adapter_name(adapter_name: Any) -> Any: - """Returns adapter_name as-is; overridden by AdapterManagerMixin in the combined class.""" - return adapter_name + training_run_manager = create_training_run_manager(token, client_type='tinker') + training_run_manager.save(_model_id, body) + return types.CreateModelResponse(model_id=_model_id) + except Exception: + if _model_id: + adapter_name = self.get_adapter_name(adapter_name=_model_id) + self._cleanup_adapter(adapter_name) + logger.error(traceback.format_exc()) + return types.RequestFailedResponse( + error=traceback.format_exc(), + category=types.RequestErrorCategory.Server, + ) + + return await self.schedule_task(_create_adapter, token=token, task_type='create_model') + + @app.post('/tinker/get_info') + async def get_info(request: Request, body: types.GetInfoRequest, + self: ModelManagement = Depends(self_fn)) -> types.GetInfoResponse: + token = await self._on_request_start(request) + training_run_manager = create_training_run_manager(token, client_type='tinker') + metadata = training_run_manager.get(str(body.model_id)) + model_name = metadata.base_model if metadata else self.base_model + lora_rank = None + is_lora = False + if metadata and hasattr(metadata, 'lora_rank') and metadata.lora_rank: + lora_rank = metadata.lora_rank + is_lora = metadata.is_lora + return types.GetInfoResponse( + model_data=types.ModelData(model_name=model_name), + model_id=body.model_id, + is_lora=is_lora, + lora_rank=lora_rank, + model_name=model_name, + ) + + @app.post('/tinker/unload_model') + async def unload_model( + request: Request, + body: types.UnloadModelRequest, + self: ModelManagement = Depends(self_fn), + ) -> types.UntypedAPIFuture: + token = await self._on_request_start(request) + + async def _do_unload(): + adapter_name = self.get_adapter_name(adapter_name=body.model_id) + self._cleanup_adapter(adapter_name) + return types.UnloadModelResponse(model_id=body.model_id) + + return await self.schedule_task(_do_unload, model_id=body.model_id, token=token, task_type='unload_model') + + @app.post('/tinker/forward') + async def forward(request: Request, body: types.ForwardRequest, + self: ModelManagement = Depends(self_fn)) -> types.UntypedAPIFuture: + token = await self._on_request_start(request) + + async def _do_forward(): + try: + adapter_name = self.get_adapter_name(adapter_name=body.model_id) + self.assert_adapter_exists(adapter_name=adapter_name) + self.touch_adapter(adapter_name) + datum_list = body.forward_input.data + loss_fn_config = body.forward_input.loss_fn_config or {} + output = self.model.forward_only(inputs=datum_list, adapter_name=adapter_name) + loss = self.model.calculate_loss(adapter_name=adapter_name, **loss_fn_config) + return types.ForwardBackwardOutput( + loss_fn_output_type='CrossEntropyLossReturn', + loss_fn_outputs=output, + metrics={'loss:sum': loss}, + ) + except Exception: + logger.error(traceback.format_exc()) + return types.RequestFailedResponse( + error=traceback.format_exc(), + category=types.RequestErrorCategory.Server, + ) + + datum_list = body.forward_input.data + input_tokens = sum(len(d.model_input.to_ints()) for d in datum_list) + batch_size = len(datum_list) + return await self.schedule_task( + _do_forward, + model_id=body.model_id, + token=token, + input_tokens=input_tokens, + batch_size=batch_size, + data_world_size=self.device_mesh.data_world_size, + task_type='forward', + ) + + @app.post('/tinker/forward_backward') + async def forward_backward( + request: Request, + body: types.ForwardBackwardRequest, + self: ModelManagement = Depends(self_fn), + ) -> types.UntypedAPIFuture: + token = await self._on_request_start(request) + + async def _do_forward_backward(): + try: + adapter_name = self.get_adapter_name(adapter_name=body.model_id) + self.assert_adapter_exists(adapter_name=adapter_name) + self.touch_adapter(adapter_name) + datum_list = body.forward_backward_input.data + loss_fn = body.forward_backward_input.loss_fn + loss_fn_config = body.forward_backward_input.loss_fn_config or {} + output, loss = self.model.forward_backward( + inputs=datum_list, adapter_name=adapter_name, loss_fn=loss_fn, **loss_fn_config) + output_type = ('ImportanceSamplingLossReturn' + if loss_fn == 'importance_sampling' else 'CrossEntropyLossReturn') + self.set_adapter_state(adapter_name, 'grad_ready', True) + return types.ForwardBackwardOutput( + loss_fn_output_type=output_type, + loss_fn_outputs=output, + metrics={'loss:avg': loss}, + ) + except Exception: + logger.error(traceback.format_exc()) + return types.RequestFailedResponse( + error=traceback.format_exc(), + category=types.RequestErrorCategory.Server, + ) + + datum_list = body.forward_backward_input.data + input_tokens = sum(len(d.model_input.to_ints()) for d in datum_list) + batch_size = len(datum_list) + return await self.schedule_task( + _do_forward_backward, + model_id=body.model_id, + token=token, + input_tokens=input_tokens, + batch_size=batch_size, + data_world_size=self.device_mesh.data_world_size, + task_type='forward_backward', + ) + + @app.post('/tinker/optim_step') + async def optim_step( + request: Request, + body: types.OptimStepRequest, + self: ModelManagement = Depends(self_fn), + ) -> types.UntypedAPIFuture: + token = await self._on_request_start(request) + + async def _do_optim(): + try: + adapter_name = self.get_adapter_name(adapter_name=body.model_id) + self.assert_adapter_exists(adapter_name=adapter_name) + if not self.get_adapter_state(adapter_name, 'grad_ready', False): + raise RuntimeError(f'No accumulated gradients for adapter={adapter_name}; ' + 'call forward_backward before optim_step') + self.touch_adapter(adapter_name) + self.model.step(adam_params=body.adam_params, adapter_name=adapter_name) + self.set_adapter_state(adapter_name, 'grad_ready', False) + metrics = self.model.calculate_metric(is_training=True, adapter_name=adapter_name) + return types.OptimStepResponse(metrics=metrics) + except Exception: + logger.error(traceback.format_exc()) + return types.RequestFailedResponse( + error=traceback.format_exc(), + category=types.RequestErrorCategory.Server, + ) + + return await self.schedule_task(_do_optim, model_id=body.model_id, token=token, task_type='optim_step') + + @app.post('/tinker/save_weights') + async def save_weights( + request: Request, + body: types.SaveWeightsRequest, + self: ModelManagement = Depends(self_fn), + ) -> types.UntypedAPIFuture: + token = await self._on_request_start(request) + + async def _do_save(): + try: + adapter_name = self.get_adapter_name(adapter_name=body.model_id) + self.assert_adapter_exists(adapter_name=adapter_name) + self.touch_adapter(adapter_name) + checkpoint_manager = create_checkpoint_manager(token, client_type='tinker') + checkpoint_name = checkpoint_manager.get_ckpt_name(body.path) + save_dir = checkpoint_manager.get_save_dir(model_id=body.model_id, is_sampler=False) + self.model.save( + name=checkpoint_name, output_dir=save_dir, adapter_name=adapter_name, save_optimizer=True) + tinker_path = checkpoint_manager.save(body.model_id, name=checkpoint_name, is_sampler=False) + return types.SaveWeightsResponse(path=tinker_path, type='save_weights') + except Exception: + logger.error(traceback.format_exc()) + return types.RequestFailedResponse( + error=traceback.format_exc(), + category=types.RequestErrorCategory.Server, + ) + + return await self.schedule_task(_do_save, model_id=body.model_id, token=token, task_type='save_weights') + + @app.post('/tinker/save_weights_for_sampler') + async def save_weights_for_sampler( + request: Request, + body: types.SaveWeightsForSamplerRequest, + self: ModelManagement = Depends(self_fn), + ) -> types.UntypedAPIFuture: + token = await self._on_request_start(request) + + async def _do_save_for_sampler(): + try: + adapter_name = self.get_adapter_name(adapter_name=body.model_id) + self.assert_adapter_exists(adapter_name=adapter_name) + self.touch_adapter(adapter_name) + checkpoint_manager = create_checkpoint_manager(token, client_type='tinker') + checkpoint_name = checkpoint_manager.get_ckpt_name(body.path) + save_dir = checkpoint_manager.get_save_dir(model_id=body.model_id, is_sampler=True) + tinker_path = checkpoint_manager.save(body.model_id, name=checkpoint_name, is_sampler=True) + logger.info(f'Saving weights to {save_dir}') + self.model.save( + name=checkpoint_name, output_dir=save_dir, adapter_name=adapter_name, save_optimizer=False) + payload = body.model_dump() + payload['model_path'] = tinker_path + metadata = self.state.get_model_metadata(body.model_id) or {} + if metadata.get('base_model'): + payload['base_model'] = metadata['base_model'] + sampling_session_id = self.state.create_sampling_session(payload) + return types.SaveWeightsForSamplerResponseInternal(path=None, sampling_session_id=sampling_session_id) + except Exception: + logger.error(traceback.format_exc()) + return types.RequestFailedResponse( + error=traceback.format_exc(), + category=types.RequestErrorCategory.Server, + ) + + return await self.schedule_task( + _do_save_for_sampler, model_id=body.model_id, token=token, task_type='save_weights_for_sampler') + + @app.post('/tinker/load_weights') + async def load_weights( + request: Request, + body: types.LoadWeightsRequest, + self: ModelManagement = Depends(self_fn), + ) -> types.UntypedAPIFuture: + token = await self._on_request_start(request) + + async def _do_load(): + try: + assert self.model is not None, 'Model not loaded, please load model first' + adapter_name = self.get_adapter_name(adapter_name=body.model_id) + self.assert_adapter_exists(adapter_name=adapter_name) + self.touch_adapter(adapter_name) + self.model.load( + checkpoint_dir=body.path, load_optimizer=body.optimizer, adapter_name=adapter_name, token=token) + self.set_adapter_state(adapter_name, 'grad_ready', False) + return types.LoadWeightsResponse(path=body.path, type='load_weights') + except Exception: + logger.error(traceback.format_exc()) + return types.RequestFailedResponse( + error=traceback.format_exc(), + category=types.RequestErrorCategory.Server, + ) + + return await self.schedule_task(_do_load, model_id=body.model_id, token=token, task_type='load_weights') diff --git a/src/twinkle/server/model/twinkle_handlers.py b/src/twinkle/server/model/twinkle_handlers.py index 66035051..c59928f4 100644 --- a/src/twinkle/server/model/twinkle_handlers.py +++ b/src/twinkle/server/model/twinkle_handlers.py @@ -4,11 +4,17 @@ All endpoints are prefixed /twinkle/... and use schedule_task_and_wait() returning results directly (synchronous from the client's perspective). +self_fn is injected via FastAPI Depends to obtain the ModelManagement instance at request time. """ +from __future__ import annotations + import traceback -from fastapi import FastAPI, Request +from fastapi import Depends, FastAPI, Request from peft import LoraConfig -from typing import Any, Optional +from typing import TYPE_CHECKING, Any, Callable, Optional + +if TYPE_CHECKING: + from .app import ModelManagement from twinkle.data_format import InputFeature, Trajectory from twinkle.server.common.io_utils import create_checkpoint_manager, create_training_run_manager @@ -39,335 +45,330 @@ def _parse_inputs(inputs: Any): return inputs -class TwinkleModelHandlers: - """ - Mixin providing Twinkle-native model management endpoints. +def _get_twinkle_adapter_name(request: Request, adapter_name: str | None) -> str | None: + """Build the per-request adapter name from the request_id prefix.""" + if adapter_name is None or adapter_name == '': + return None + return request.state.request_id + '-' + adapter_name + - Expects the combined class to also inherit TaskQueueMixin and AdapterManagerMixin, - and to have: self.model, self.state, self.base_model - The get_adapter_name static method uses request.state.request_id prefix. +def _register_twinkle_routes(app: FastAPI, self_fn: Callable[[], ModelManagement]) -> None: + """Register all /twinkle/* routes on the given FastAPI app. + + self_fn is a zero-argument callable that returns the current ModelManagement + replica instance. It is wired in via Depends so it is resolved lazily at request time. """ - @staticmethod - def _register_twinkle_routes(app: FastAPI, model_id: str): - """Register all twinkle routes on the given FastAPI app.""" - - @app.post('/twinkle/create') - async def create(self, request: Request, body: CreateRequest): - return {'status': 'ok'} - - @staticmethod - def _get_twinkle_adapter_name(request: Request, adapter_name: Optional[str]) -> Optional[str]: - if adapter_name is None or adapter_name == '': - return None - return request.state.request_id + '-' + adapter_name + @app.post('/twinkle/create') + async def create(request: Request, body: CreateRequest, self: ModelManagement = Depends(self_fn)): + return {'status': 'ok'} + + @app.post('/twinkle/forward') + async def forward(request: Request, body: ForwardRequest, self: ModelManagement = Depends(self_fn)): + adapter_name = _get_twinkle_adapter_name(request, body.adapter_name) + + async def _task(): + self.assert_adapter_exists(adapter_name=adapter_name) + extra_kwargs = body.model_extra or {} + inputs = _parse_inputs(body.inputs) + ret = self.model.forward(inputs=inputs, adapter_name=adapter_name, **extra_kwargs) + return {'result': ret} - @app.post('/twinkle/forward') - async def forward(self, request: Request, body: ForwardRequest): - adapter_name = self._get_twinkle_adapter_name(request, body.adapter_name) + return await self.schedule_task_and_wait(_task, task_type='forward') - async def _task(): - self.assert_adapter_exists(adapter_name=adapter_name) - extra_kwargs = body.model_extra or {} - inputs = _parse_inputs(body.inputs) - ret = self.model.forward(inputs=inputs, adapter_name=adapter_name, **extra_kwargs) - return {'result': ret} + @app.post('/twinkle/forward_only') + async def forward_only(request: Request, body: ForwardOnlyRequest, self: ModelManagement = Depends(self_fn)): + adapter_name = _get_twinkle_adapter_name(request, body.adapter_name) - return await self.schedule_task_and_wait(_task, task_type='forward') - - @app.post('/twinkle/forward_only') - async def forward_only(self, request: Request, body: ForwardOnlyRequest): - adapter_name = self._get_twinkle_adapter_name(request, body.adapter_name) - - async def _task(): - self.assert_adapter_exists(adapter_name=adapter_name) - extra_kwargs = body.model_extra or {} - inputs = _parse_inputs(body.inputs) - ret = self.model.forward_only(inputs=inputs, adapter_name=adapter_name, **extra_kwargs) - return {'result': ret} + async def _task(): + self.assert_adapter_exists(adapter_name=adapter_name) + extra_kwargs = body.model_extra or {} + inputs = _parse_inputs(body.inputs) + ret = self.model.forward_only(inputs=inputs, adapter_name=adapter_name, **extra_kwargs) + return {'result': ret} - return await self.schedule_task_and_wait(_task, task_type='forward_only') - - @app.post('/twinkle/calculate_loss') - async def calculate_loss(self, request: Request, body: AdapterRequest): - adapter_name = self._get_twinkle_adapter_name(request, body.adapter_name) - - async def _task(): - self.assert_adapter_exists(adapter_name=adapter_name) - extra_kwargs = body.model_extra or {} - ret = self.model.calculate_loss(adapter_name=adapter_name, **extra_kwargs) - return {'result': ret} + return await self.schedule_task_and_wait(_task, task_type='forward_only') - return await self.schedule_task_and_wait(_task, task_type='calculate_loss') - - @app.post('/twinkle/backward') - async def backward(self, request: Request, body: AdapterRequest): - adapter_name = self._get_twinkle_adapter_name(request, body.adapter_name) - - async def _task(): - self.assert_adapter_exists(adapter_name=adapter_name) - extra_kwargs = body.model_extra or {} - ret = self.model.backward(adapter_name=adapter_name, **extra_kwargs) - return {'result': ret} + @app.post('/twinkle/calculate_loss') + async def calculate_loss(request: Request, body: AdapterRequest, self: ModelManagement = Depends(self_fn)): + adapter_name = _get_twinkle_adapter_name(request, body.adapter_name) - return await self.schedule_task_and_wait(_task, task_type='backward') - - @app.post('/twinkle/forward_backward') - async def forward_backward(self, request: Request, body: ForwardRequest): - adapter_name = self._get_twinkle_adapter_name(request, body.adapter_name) - - async def _task(): - self.assert_adapter_exists(adapter_name=adapter_name) - extra_kwargs = body.model_extra or {} - inputs = _parse_inputs(body.inputs) - ret = self.model.twinkle_forward_backward(inputs=inputs, adapter_name=adapter_name, **extra_kwargs) - return {'result': ret} - - return await self.schedule_task_and_wait(_task, task_type='forward_backward') - - @app.post('/twinkle/clip_grad_norm') - async def clip_grad_norm(self, request: Request, body: AdapterRequest): - adapter_name = self._get_twinkle_adapter_name(request, body.adapter_name) - - async def _task(): - self.assert_adapter_exists(adapter_name=adapter_name) - extra_kwargs = body.model_extra or {} - ret = self.model.clip_grad_norm(adapter_name=adapter_name, **extra_kwargs) - return {'result': str(ret)} - - return await self.schedule_task_and_wait(_task, task_type='clip_grad_norm') - - @app.post('/twinkle/step') - async def step(self, request: Request, body: AdapterRequest): - adapter_name = self._get_twinkle_adapter_name(request, body.adapter_name) - - async def _task(): - self.assert_adapter_exists(adapter_name=adapter_name) - extra_kwargs = body.model_extra or {} - ret = self.model.step(adapter_name=adapter_name, **extra_kwargs) - return {'result': ret} - - return await self.schedule_task_and_wait(_task, task_type='step') - - @app.post('/twinkle/zero_grad') - async def zero_grad(self, request: Request, body: AdapterRequest): - adapter_name = self._get_twinkle_adapter_name(request, body.adapter_name) - - async def _task(): - self.assert_adapter_exists(adapter_name=adapter_name) - extra_kwargs = body.model_extra or {} - ret = self.model.zero_grad(adapter_name=adapter_name, **extra_kwargs) - return {'result': ret} - - return await self.schedule_task_and_wait(_task, task_type='zero_grad') - - @app.post('/twinkle/lr_step') - async def lr_step(self, request: Request, body: AdapterRequest): - adapter_name = self._get_twinkle_adapter_name(request, body.adapter_name) - - async def _task(): - self.assert_adapter_exists(adapter_name=adapter_name) - extra_kwargs = body.model_extra or {} - ret = self.model.lr_step(adapter_name=adapter_name, **extra_kwargs) - return {'result': ret} + async def _task(): + self.assert_adapter_exists(adapter_name=adapter_name) + extra_kwargs = body.model_extra or {} + ret = self.model.calculate_loss(adapter_name=adapter_name, **extra_kwargs) + return {'result': ret} - return await self.schedule_task_and_wait(_task, task_type='lr_step') + return await self.schedule_task_and_wait(_task, task_type='calculate_loss') - @app.post('/twinkle/get_train_configs') - async def get_train_configs(self, request: Request, body: AdapterRequest): - adapter_name = self._get_twinkle_adapter_name(request, body.adapter_name) - - async def _task(): - self.assert_adapter_exists(adapter_name=adapter_name) - extra_kwargs = body.model_extra or {} - ret = self.model.get_train_configs(adapter_name=adapter_name, **extra_kwargs) - return {'result': ret} + @app.post('/twinkle/backward') + async def backward(request: Request, body: AdapterRequest, self: ModelManagement = Depends(self_fn)): + adapter_name = _get_twinkle_adapter_name(request, body.adapter_name) - return await self.schedule_task_and_wait(_task, task_type='get_train_configs') + async def _task(): + self.assert_adapter_exists(adapter_name=adapter_name) + extra_kwargs = body.model_extra or {} + ret = self.model.backward(adapter_name=adapter_name, **extra_kwargs) + return {'result': ret} - @app.post('/twinkle/set_loss') - async def set_loss(self, request: Request, body: SetLossRequest): - adapter_name = self._get_twinkle_adapter_name(request, body.adapter_name) - - async def _task(): - self.assert_adapter_exists(adapter_name=adapter_name) - extra_kwargs = body.model_extra or {} - ret = self.model.set_loss(body.loss_cls, adapter_name=adapter_name, **extra_kwargs) - return {'result': ret} + return await self.schedule_task_and_wait(_task, task_type='backward') - return await self.schedule_task_and_wait(_task, task_type='set_loss') + @app.post('/twinkle/forward_backward') + async def forward_backward(request: Request, body: ForwardRequest, self: ModelManagement = Depends(self_fn)): + adapter_name = _get_twinkle_adapter_name(request, body.adapter_name) - @app.post('/twinkle/set_optimizer') - async def set_optimizer(self, request: Request, body: SetOptimizerRequest): - adapter_name = self._get_twinkle_adapter_name(request, body.adapter_name) - - async def _task(): - self.assert_adapter_exists(adapter_name=adapter_name) - extra_kwargs = body.model_extra or {} - ret = self.model.set_optimizer(body.optimizer_cls, adapter_name=adapter_name, **extra_kwargs) - return {'result': ret} + async def _task(): + self.assert_adapter_exists(adapter_name=adapter_name) + extra_kwargs = body.model_extra or {} + inputs = _parse_inputs(body.inputs) + ret = self.model.twinkle_forward_backward(inputs=inputs, adapter_name=adapter_name, **extra_kwargs) + return {'result': ret} - return await self.schedule_task_and_wait(_task, task_type='set_optimizer') + return await self.schedule_task_and_wait(_task, task_type='forward_backward') - @app.post('/twinkle/set_lr_scheduler') - async def set_lr_scheduler(self, request: Request, body: SetLrSchedulerRequest): - adapter_name = self._get_twinkle_adapter_name(request, body.adapter_name) - - async def _task(): - self.assert_adapter_exists(adapter_name=adapter_name) - extra_kwargs = body.model_extra or {} - ret = self.model.set_lr_scheduler(body.scheduler_cls, adapter_name=adapter_name, **extra_kwargs) - return {'result': ret} + @app.post('/twinkle/clip_grad_norm') + async def clip_grad_norm(request: Request, body: AdapterRequest, self: ModelManagement = Depends(self_fn)): + adapter_name = _get_twinkle_adapter_name(request, body.adapter_name) - return await self.schedule_task_and_wait(_task, task_type='set_lr_scheduler') + async def _task(): + self.assert_adapter_exists(adapter_name=adapter_name) + extra_kwargs = body.model_extra or {} + ret = self.model.clip_grad_norm(adapter_name=adapter_name, **extra_kwargs) + return {'result': str(ret)} - @app.post('/twinkle/save') - async def save(self, request: Request, body: SaveRequest): - token = await self._on_request_start(request) - adapter_name = self._get_twinkle_adapter_name(request, body.adapter_name) - - async def _task(): - self.assert_adapter_exists(adapter_name=adapter_name) - extra_kwargs = body.model_extra or {} - checkpoint_manager = create_checkpoint_manager(token, client_type='twinkle') - checkpoint_name = checkpoint_manager.get_ckpt_name(body.name) - save_dir = checkpoint_manager.get_save_dir(model_id=adapter_name, is_sampler=False) - checkpoint_dir = self.model.save( - name=checkpoint_name, - output_dir=save_dir, - adapter_name=adapter_name, - save_optimizer=body.save_optimizer, - **extra_kwargs) - twinkle_path = checkpoint_manager.save(model_id=adapter_name, name=checkpoint_name, is_sampler=False) - return {'result': twinkle_path, 'checkpoint_dir': checkpoint_dir} - - return await self.schedule_task_and_wait(_task, task_type='save') - - @app.post('/twinkle/load') - async def load(self, request: Request, body: LoadRequest): - token = await self._on_request_start(request) - adapter_name = self._get_twinkle_adapter_name(request, body.adapter_name) - - async def _task(): - self.assert_adapter_exists(adapter_name=adapter_name) - extra_kwargs = body.model_extra or {} + return await self.schedule_task_and_wait(_task, task_type='clip_grad_norm') + + @app.post('/twinkle/step') + async def step(request: Request, body: AdapterRequest, self: ModelManagement = Depends(self_fn)): + adapter_name = _get_twinkle_adapter_name(request, body.adapter_name) + + async def _task(): + self.assert_adapter_exists(adapter_name=adapter_name) + extra_kwargs = body.model_extra or {} + ret = self.model.step(adapter_name=adapter_name, **extra_kwargs) + return {'result': ret} + + return await self.schedule_task_and_wait(_task, task_type='step') + + @app.post('/twinkle/zero_grad') + async def zero_grad(request: Request, body: AdapterRequest, self: ModelManagement = Depends(self_fn)): + adapter_name = _get_twinkle_adapter_name(request, body.adapter_name) + + async def _task(): + self.assert_adapter_exists(adapter_name=adapter_name) + extra_kwargs = body.model_extra or {} + ret = self.model.zero_grad(adapter_name=adapter_name, **extra_kwargs) + return {'result': ret} + + return await self.schedule_task_and_wait(_task, task_type='zero_grad') + + @app.post('/twinkle/lr_step') + async def lr_step(request: Request, body: AdapterRequest, self: ModelManagement = Depends(self_fn)): + adapter_name = _get_twinkle_adapter_name(request, body.adapter_name) + + async def _task(): + self.assert_adapter_exists(adapter_name=adapter_name) + extra_kwargs = body.model_extra or {} + ret = self.model.lr_step(adapter_name=adapter_name, **extra_kwargs) + return {'result': ret} + + return await self.schedule_task_and_wait(_task, task_type='lr_step') + + @app.post('/twinkle/get_train_configs') + async def get_train_configs(request: Request, body: AdapterRequest, self: ModelManagement = Depends(self_fn)): + adapter_name = _get_twinkle_adapter_name(request, body.adapter_name) + + async def _task(): + self.assert_adapter_exists(adapter_name=adapter_name) + extra_kwargs = body.model_extra or {} + ret = self.model.get_train_configs(adapter_name=adapter_name, **extra_kwargs) + return {'result': ret} + + return await self.schedule_task_and_wait(_task, task_type='get_train_configs') + + @app.post('/twinkle/set_loss') + async def set_loss(request: Request, body: SetLossRequest, self: ModelManagement = Depends(self_fn)): + adapter_name = _get_twinkle_adapter_name(request, body.adapter_name) + + async def _task(): + self.assert_adapter_exists(adapter_name=adapter_name) + extra_kwargs = body.model_extra or {} + ret = self.model.set_loss(body.loss_cls, adapter_name=adapter_name, **extra_kwargs) + return {'result': ret} + + return await self.schedule_task_and_wait(_task, task_type='set_loss') + + @app.post('/twinkle/set_optimizer') + async def set_optimizer(request: Request, body: SetOptimizerRequest, self: ModelManagement = Depends(self_fn)): + adapter_name = _get_twinkle_adapter_name(request, body.adapter_name) + + async def _task(): + self.assert_adapter_exists(adapter_name=adapter_name) + extra_kwargs = body.model_extra or {} + ret = self.model.set_optimizer(body.optimizer_cls, adapter_name=adapter_name, **extra_kwargs) + return {'result': ret} + + return await self.schedule_task_and_wait(_task, task_type='set_optimizer') + + @app.post('/twinkle/set_lr_scheduler') + async def set_lr_scheduler(request: Request, body: SetLrSchedulerRequest, self: ModelManagement = Depends(self_fn)): + adapter_name = _get_twinkle_adapter_name(request, body.adapter_name) + + async def _task(): + self.assert_adapter_exists(adapter_name=adapter_name) + extra_kwargs = body.model_extra or {} + ret = self.model.set_lr_scheduler(body.scheduler_cls, adapter_name=adapter_name, **extra_kwargs) + return {'result': ret} + + return await self.schedule_task_and_wait(_task, task_type='set_lr_scheduler') + + @app.post('/twinkle/save') + async def save(request: Request, body: SaveRequest, self: ModelManagement = Depends(self_fn)): + token = await self._on_request_start(request) + adapter_name = _get_twinkle_adapter_name(request, body.adapter_name) + + async def _task(): + self.assert_adapter_exists(adapter_name=adapter_name) + extra_kwargs = body.model_extra or {} + checkpoint_manager = create_checkpoint_manager(token, client_type='twinkle') + checkpoint_name = checkpoint_manager.get_ckpt_name(body.name) + save_dir = checkpoint_manager.get_save_dir(model_id=adapter_name, is_sampler=False) + checkpoint_dir = self.model.save( + name=checkpoint_name, + output_dir=save_dir, + adapter_name=adapter_name, + save_optimizer=body.save_optimizer, + **extra_kwargs) + twinkle_path = checkpoint_manager.save(model_id=adapter_name, name=checkpoint_name, is_sampler=False) + return {'result': twinkle_path, 'checkpoint_dir': checkpoint_dir} + + return await self.schedule_task_and_wait(_task, task_type='save') + + @app.post('/twinkle/load') + async def load(request: Request, body: LoadRequest, self: ModelManagement = Depends(self_fn)): + token = await self._on_request_start(request) + adapter_name = _get_twinkle_adapter_name(request, body.adapter_name) + + async def _task(): + self.assert_adapter_exists(adapter_name=adapter_name) + extra_kwargs = body.model_extra or {} + checkpoint_manager = create_checkpoint_manager(token, client_type='twinkle') + resolved = checkpoint_manager.resolve_load_path(body.name) + ret = self.model.load( + name=resolved.checkpoint_name, + output_dir=resolved.checkpoint_dir, + adapter_name=adapter_name, + load_optimizer=body.load_optimizer, + token=token, + **extra_kwargs) + return {'result': ret} + + return await self.schedule_task_and_wait(_task, task_type='load') + + @app.post('/twinkle/upload_to_hub') + async def upload_to_hub(request: Request, body: UploadToHubRequest, self: ModelManagement = Depends(self_fn)): + token = await self._on_request_start(request) + + async def _task(): + if body.checkpoint_dir.startswith('twinkle://'): checkpoint_manager = create_checkpoint_manager(token, client_type='twinkle') - resolved = checkpoint_manager.resolve_load_path(body.name) - ret = self.model.load( - name=resolved.checkpoint_name, - output_dir=resolved.checkpoint_dir, - adapter_name=adapter_name, - load_optimizer=body.load_optimizer, - token=token, - **extra_kwargs) - return {'result': ret} - - return await self.schedule_task_and_wait(_task, task_type='load') - - @app.post('/twinkle/upload_to_hub') - async def upload_to_hub(self, request: Request, body: UploadToHubRequest): - token = await self._on_request_start(request) - - async def _task(): - if body.checkpoint_dir.startswith('twinkle://'): - checkpoint_manager = create_checkpoint_manager(token, client_type='twinkle') - parsed = checkpoint_manager.parse_twinkle_path(body.checkpoint_dir) - if not parsed: - raise ValueError(f'Invalid twinkle path format: {body.checkpoint_dir}') - checkpoint_id = parsed.checkpoint_id - model_id_to_load = parsed.training_run_id - checkpoint = checkpoint_manager.get(model_id_to_load, checkpoint_id) - if not checkpoint: - raise ValueError(f'Checkpoint not found or access denied: {body.checkpoint_dir}') - checkpoint_dir = str( - checkpoint_manager.get_ckpt_dir(model_id=model_id_to_load, checkpoint_id=checkpoint_id)) - else: - checkpoint_dir = body.checkpoint_dir - self.model.upload_to_hub( - checkpoint_dir=checkpoint_dir, - hub_model_id=body.hub_model_id, - hub_token=body.hub_token or token, - async_upload=body.async_upload) - return {'result': body.hub_model_id} - - return await self.schedule_task_and_wait(_task, task_type='upload_to_hub') - - @app.post('/twinkle/add_adapter_to_model') - async def add_adapter_to_model(self, request: Request, body: AddAdapterRequest): - assert body.adapter_name, 'You need to specify a valid `adapter_name`' - token = await self._on_request_start(request) - adapter_name = self._get_twinkle_adapter_name(request, body.adapter_name) - - async def _task(): - config = deserialize_object(body.config) - extra_kwargs = body.model_extra or {} - training_run_manager = create_training_run_manager(token, client_type='twinkle') - self.register_adapter(adapter_name, token) - self.model.add_adapter_to_model(adapter_name, config, **extra_kwargs) - from twinkle.server.common.io_utils import CreateModelRequest - from twinkle.server.common.io_utils import LoraConfig as IoLoraConfig - lora_config = None - if isinstance(config, LoraConfig): - lora_config = IoLoraConfig(rank=config.r, train_unembed=False, train_mlp=True, train_attn=True) - run_config = CreateModelRequest( - base_model=model_id, lora_config=lora_config, user_metadata={'adapter_name': body.adapter_name}) - training_run_manager.save(adapter_name, run_config) - return {'status': 'ok', 'adapter_name': adapter_name} - - return await self.schedule_task_and_wait(_task, task_type='add_adapter_to_model') - - @app.post('/twinkle/set_template') - async def set_template(self, request: Request, body: SetTemplateRequest): - adapter_name = self._get_twinkle_adapter_name(request, body.adapter_name) - - async def _task(): - self.assert_adapter_exists(adapter_name=adapter_name) - extra_kwargs = body.model_extra or {} - ret = self.model.set_template(body.template_cls, adapter_name=adapter_name, **extra_kwargs) - return {'result': ret} - - return await self.schedule_task_and_wait(_task, task_type='set_template') - - @app.post('/twinkle/set_processor') - async def set_processor(self, request: Request, body: SetProcessorRequest): - adapter_name = self._get_twinkle_adapter_name(request, body.adapter_name) - - async def _task(): - self.assert_adapter_exists(adapter_name=adapter_name) - extra_kwargs = body.model_extra or {} - ret = self.model.set_processor(body.processor_cls, adapter_name=adapter_name, **extra_kwargs) - return {'result': ret} - - return await self.schedule_task_and_wait(_task, task_type='set_processor') - - @app.post('/twinkle/heartbeat') - async def heartbeat(self, request: Request, body: HeartbeatRequest): - adapter_name = self._get_twinkle_adapter_name(request, body.adapter_name) + parsed = checkpoint_manager.parse_twinkle_path(body.checkpoint_dir) + if not parsed: + raise ValueError(f'Invalid twinkle path format: {body.checkpoint_dir}') + checkpoint_id = parsed.checkpoint_id + model_id_to_load = parsed.training_run_id + checkpoint = checkpoint_manager.get(model_id_to_load, checkpoint_id) + if not checkpoint: + raise ValueError(f'Checkpoint not found or access denied: {body.checkpoint_dir}') + checkpoint_dir = str( + checkpoint_manager.get_ckpt_dir(model_id=model_id_to_load, checkpoint_id=checkpoint_id)) + else: + checkpoint_dir = body.checkpoint_dir + self.model.upload_to_hub( + checkpoint_dir=checkpoint_dir, + hub_model_id=body.hub_model_id, + hub_token=body.hub_token or token, + async_upload=body.async_upload) + return {'result': body.hub_model_id} + + return await self.schedule_task_and_wait(_task, task_type='upload_to_hub') + + @app.post('/twinkle/add_adapter_to_model') + async def add_adapter_to_model(request: Request, body: AddAdapterRequest, self: ModelManagement = Depends(self_fn)): + assert body.adapter_name, 'You need to specify a valid `adapter_name`' + token = await self._on_request_start(request) + adapter_name = _get_twinkle_adapter_name(request, body.adapter_name) + + async def _task(): + config = deserialize_object(body.config) + extra_kwargs = body.model_extra or {} + training_run_manager = create_training_run_manager(token, client_type='twinkle') + self.register_adapter(adapter_name, token) + self.model.add_adapter_to_model(adapter_name, config, **extra_kwargs) + from twinkle.server.common.io_utils import CreateModelRequest + from twinkle.server.common.io_utils import LoraConfig as IoLoraConfig + lora_config = None + if isinstance(config, LoraConfig): + lora_config = IoLoraConfig(rank=config.r, train_unembed=False, train_mlp=True, train_attn=True) + run_config = CreateModelRequest( + base_model=self.base_model, lora_config=lora_config, user_metadata={'adapter_name': body.adapter_name}) + training_run_manager.save(adapter_name, run_config) + return {'status': 'ok', 'adapter_name': adapter_name} + + return await self.schedule_task_and_wait(_task, task_type='add_adapter_to_model') + + @app.post('/twinkle/set_template') + async def set_template(request: Request, body: SetTemplateRequest, self: ModelManagement = Depends(self_fn)): + adapter_name = _get_twinkle_adapter_name(request, body.adapter_name) + + async def _task(): self.assert_adapter_exists(adapter_name=adapter_name) - self.touch_adapter(adapter_name) - return {'status': 'ok'} + extra_kwargs = body.model_extra or {} + ret = self.model.set_template(body.template_cls, adapter_name=adapter_name, **extra_kwargs) + return {'result': ret} + + return await self.schedule_task_and_wait(_task, task_type='set_template') - @app.post('/twinkle/calculate_metric') - async def calculate_metric(self, request: Request, body: CalculateMetricRequest): - adapter_name = self._get_twinkle_adapter_name(request, body.adapter_name) + @app.post('/twinkle/set_processor') + async def set_processor(request: Request, body: SetProcessorRequest, self: ModelManagement = Depends(self_fn)): + adapter_name = _get_twinkle_adapter_name(request, body.adapter_name) - async def _task(): - self.assert_adapter_exists(adapter_name=adapter_name) - extra_kwargs = body.model_extra or {} - ret = self.model.calculate_metric( - is_training=body.is_training, adapter_name=adapter_name, **extra_kwargs) - return {'result': ret} + async def _task(): + self.assert_adapter_exists(adapter_name=adapter_name) + extra_kwargs = body.model_extra or {} + ret = self.model.set_processor(body.processor_cls, adapter_name=adapter_name, **extra_kwargs) + return {'result': ret} - return await self.schedule_task_and_wait(_task, task_type='calculate_metric') + return await self.schedule_task_and_wait(_task, task_type='set_processor') - @app.post('/twinkle/get_state_dict') - async def get_state_dict(self, request: Request, body: GetStateDictRequest): - adapter_name = self._get_twinkle_adapter_name(request, body.adapter_name) + @app.post('/twinkle/heartbeat') + async def heartbeat(request: Request, body: HeartbeatRequest, self: ModelManagement = Depends(self_fn)): + adapter_name = _get_twinkle_adapter_name(request, body.adapter_name) + self.assert_adapter_exists(adapter_name=adapter_name) + self.touch_adapter(adapter_name) + return {'status': 'ok'} - async def _task(): - self.assert_adapter_exists(adapter_name=adapter_name) - extra_kwargs = body.model_extra or {} - ret = self.model.get_state_dict(adapter_name=adapter_name, **extra_kwargs) - return {'result': ret} + @app.post('/twinkle/calculate_metric') + async def calculate_metric(request: Request, body: CalculateMetricRequest, + self: ModelManagement = Depends(self_fn)): + adapter_name = _get_twinkle_adapter_name(request, body.adapter_name) + + async def _task(): + self.assert_adapter_exists(adapter_name=adapter_name) + extra_kwargs = body.model_extra or {} + ret = self.model.calculate_metric(is_training=body.is_training, adapter_name=adapter_name, **extra_kwargs) + return {'result': ret} + + return await self.schedule_task_and_wait(_task, task_type='calculate_metric') + + @app.post('/twinkle/get_state_dict') + async def get_state_dict(request: Request, body: GetStateDictRequest, self: ModelManagement = Depends(self_fn)): + adapter_name = _get_twinkle_adapter_name(request, body.adapter_name) + + async def _task(): + self.assert_adapter_exists(adapter_name=adapter_name) + extra_kwargs = body.model_extra or {} + ret = self.model.get_state_dict(adapter_name=adapter_name, **extra_kwargs) + return {'result': ret} - return await self.schedule_task_and_wait(_task, task_type='get_state_dict') + return await self.schedule_task_and_wait(_task, task_type='get_state_dict') diff --git a/src/twinkle/server/sampler/app.py b/src/twinkle/server/sampler/app.py index c3bff40b..c69a6956 100644 --- a/src/twinkle/server/sampler/app.py +++ b/src/twinkle/server/sampler/app.py @@ -5,6 +5,8 @@ Builds a single Ray Serve deployment (SamplerManagement) that simultaneously handles both Tinker (/tinker/asample) and Twinkle (/twinkle/*) sampler endpoints. """ +from __future__ import annotations + from fastapi import FastAPI, Request from ray import serve from typing import Any, Dict, Optional @@ -17,21 +19,106 @@ from twinkle.server.utils.validation import get_token_from_request, verify_request_token from twinkle.utils.logger import get_logger from ..utils import wrap_builder_with_device_group_env -from .tinker_handlers import TinkerSamplerHandlers -from .twinkle_handlers import TwinkleSamplerHandlers +from .tinker_handlers import _register_tinker_sampler_routes +from .twinkle_handlers import _register_twinkle_sampler_routes logger = get_logger() +class SamplerManagement(TaskQueueMixin, AdapterManagerMixin): + """Unified sampler management service. + + Manages: + - vLLM or Torch sampler initialization and lifecycle + - Tinker inference requests (/tinker/asample) with rate limiting via TaskQueueMixin + - Twinkle inference requests (/twinkle/*) calling sampler directly + - Adapter lifecycle via AdapterManagerMixin + - Template configuration for trajectory encoding + """ + + def __init__(self, + model_id: str, + nproc_per_node: int, + device_group: dict[str, Any], + device_mesh: dict[str, Any], + sampler_type: str = 'vllm', + engine_args: dict[str, Any] | None = None, + adapter_config: dict[str, Any] | None = None, + queue_config: dict[str, Any] | None = None, + **kwargs): + self.device_group = DeviceGroup(**device_group) + twinkle.initialize(mode='ray', nproc_per_node=nproc_per_node, groups=[self.device_group], lazy_collect=False) + if 'mesh_dim_names' in device_mesh: + self.device_mesh = DeviceMesh(**device_mesh) + else: + self.device_mesh = DeviceMesh.from_sizes(**device_mesh) + self.sampler_type = sampler_type + replica_context = serve.get_replica_context() + replica_id = replica_context.replica_id.unique_id + + # Initialize sampler based on type + if sampler_type == 'vllm': + from twinkle.sampler import vLLMSampler + sampler_kwargs = engine_args or {} + self.sampler = vLLMSampler( + model_id=model_id, + engine_args=sampler_kwargs, + device_mesh=self.device_mesh, + remote_group=self.device_group.name, + instance_id=replica_id, + **{ + k: v + for k, v in kwargs.items() if k not in ['engine_args'] + }) + else: + from twinkle.sampler import TorchSampler + self.sampler = TorchSampler( + model_id=model_id, + device_mesh=self.device_mesh, + instance_id=replica_id, + remote_group=self.device_group.name, + **kwargs) + + self.sampler.set_template('Template', model_id=model_id) + self.state: ServerStateProxy = get_server_state() + + # Initialize both mixins + self._init_task_queue(TaskQueueConfig.from_dict(queue_config)) + _adapter_config = adapter_config or {} + self._init_adapter_manager(**_adapter_config) + self.start_adapter_countdown() + + @serve.multiplexed(max_num_models_per_replica=5) + async def _sticky_entry(self, sticky_key: str): + return sticky_key + + async def _ensure_sticky(self): + sticky_key = serve.get_multiplexed_model_id() + await self._sticky_entry(sticky_key) + + async def _on_request_start(self, request: Request) -> str: + await self._ensure_sticky() + token = get_token_from_request(request) + return token + + def _on_adapter_expired(self, adapter_name: str, token: str = None) -> None: + """Handle expired adapters by removing them from the sampler.""" + try: + self.sampler.remove_adapter(adapter_name) + logger.info(f'Removed expired adapter {adapter_name}') + except Exception as e: + logger.warning(f'Failed to remove expired adapter {adapter_name}: {e}') + + def build_sampler_app(model_id: str, nproc_per_node: int, - device_group: Dict[str, Any], - device_mesh: Dict[str, Any], - deploy_options: Dict[str, Any], + device_group: dict[str, Any], + device_mesh: dict[str, Any], + deploy_options: dict[str, Any], sampler_type: str = 'vllm', - engine_args: Optional[Dict[str, Any]] = None, - adapter_config: Optional[Dict[str, Any]] = None, - queue_config: Optional[Dict[str, Any]] = None, + engine_args: dict[str, Any] | None = None, + adapter_config: dict[str, Any] | None = None, + queue_config: dict[str, Any] | None = None, **kwargs): """Build a unified sampler application for text generation inference. @@ -53,6 +140,8 @@ def build_sampler_app(model_id: str, Returns: Ray Serve deployment bound with configuration """ + # Build the FastAPI app and register all routes BEFORE serve.ingress so that + # the frozen app contains the complete route table (visible to ProxyActor). app = FastAPI( title='Unified Sampler', description='REST API for distributed text generation inference (Tinker + Twinkle)', @@ -62,98 +151,18 @@ def build_sampler_app(model_id: str, async def verify_token(request: Request, call_next): return await verify_request_token(request=request, call_next=call_next) + def get_self() -> SamplerManagement: + return serve.get_replica_context().servable_object + # Register routes BEFORE @serve.ingress so Ray Serve captures them at decoration time - TinkerSamplerHandlers._register_tinker_sampler_routes(app) - TwinkleSamplerHandlers._register_twinkle_sampler_routes(app) - - @serve.deployment(name='SamplerManagement') - @serve.ingress(app) - class SamplerManagement(TaskQueueMixin, AdapterManagerMixin): - """Unified sampler management service. - - Manages: - - vLLM or Torch sampler initialization and lifecycle - - Tinker inference requests (/tinker/asample) with rate limiting via TaskQueueMixin - - Twinkle inference requests (/twinkle/*) calling sampler directly - - Adapter lifecycle via AdapterManagerMixin - - Template configuration for trajectory encoding - """ - - def __init__(self, - nproc_per_node: int, - device_group: Dict[str, Any], - device_mesh: Dict[str, Any], - sampler_type: str = 'vllm', - engine_args: Optional[Dict[str, Any]] = None, - adapter_config: Optional[Dict[str, Any]] = None, - queue_config: Optional[Dict[str, Any]] = None, - **kwargs): - self.device_group = DeviceGroup(**device_group) - twinkle.initialize( - mode='ray', nproc_per_node=nproc_per_node, groups=[self.device_group], lazy_collect=False) - if 'mesh_dim_names' in device_mesh: - self.device_mesh = DeviceMesh(**device_mesh) - else: - self.device_mesh = DeviceMesh.from_sizes(**device_mesh) - self.sampler_type = sampler_type - replica_context = serve.get_replica_context() - replica_id = replica_context.replica_id.unique_id - - # Initialize sampler based on type - if sampler_type == 'vllm': - from twinkle.sampler import vLLMSampler - sampler_kwargs = engine_args or {} - self.sampler = vLLMSampler( - model_id=model_id, - engine_args=sampler_kwargs, - device_mesh=self.device_mesh, - remote_group=self.device_group.name, - instance_id=replica_id, - **{ - k: v - for k, v in kwargs.items() if k not in ['engine_args'] - }) - else: - from twinkle.sampler import TorchSampler - self.sampler = TorchSampler( - model_id=model_id, - device_mesh=self.device_mesh, - instance_id=replica_id, - remote_group=self.device_group.name, - **kwargs) - - self.sampler.set_template('Template', model_id=model_id) - self.state: ServerStateProxy = get_server_state() - - # Initialize both mixins - self._init_task_queue(TaskQueueConfig.from_dict(queue_config)) - _adapter_config = adapter_config or {} - self._init_adapter_manager(**_adapter_config) - self.start_adapter_countdown() - - @serve.multiplexed(max_num_models_per_replica=5) - async def _sticky_entry(self, sticky_key: str): - return sticky_key - - async def _ensure_sticky(self): - sticky_key = serve.get_multiplexed_model_id() - await self._sticky_entry(sticky_key) - - async def _on_request_start(self, request: Request) -> str: - await self._ensure_sticky() - token = get_token_from_request(request) - return token - - def _on_adapter_expired(self, adapter_name: str, token: str = None) -> None: - """Handle expired adapters by removing them from the sampler.""" - try: - self.sampler.remove_adapter(adapter_name) - logger.info(f'Removed expired adapter {adapter_name}') - except Exception as e: - logger.warning(f'Failed to remove expired adapter {adapter_name}: {e}') - - return SamplerManagement.options(**deploy_options).bind(nproc_per_node, device_group, device_mesh, sampler_type, - engine_args, adapter_config, queue_config, **kwargs) + _register_tinker_sampler_routes(app, get_self) + _register_twinkle_sampler_routes(app, get_self) + + SamplerManagementWithIngress = serve.ingress(app)(SamplerManagement) + DeploymentClass = serve.deployment(name='SamplerManagement')(SamplerManagementWithIngress) + return DeploymentClass.options(**deploy_options).bind(model_id, nproc_per_node, device_group, device_mesh, + sampler_type, engine_args, adapter_config, queue_config, + **kwargs) build_sampler_app = wrap_builder_with_device_group_env(build_sampler_app) diff --git a/src/twinkle/server/sampler/tinker_handlers.py b/src/twinkle/server/sampler/tinker_handlers.py index acb2bb0e..21bee277 100644 --- a/src/twinkle/server/sampler/tinker_handlers.py +++ b/src/twinkle/server/sampler/tinker_handlers.py @@ -4,10 +4,16 @@ Provides POST /tinker/asample using schedule_task() returning UntypedAPIFuture. """ +from __future__ import annotations + import os import traceback -from fastapi import FastAPI, Request +from fastapi import Depends, FastAPI, Request from tinker import types +from typing import TYPE_CHECKING, Callable + +if TYPE_CHECKING: + from .app import SamplerManagement from twinkle.data_format import SamplingParams from twinkle.server.common.io_utils import create_checkpoint_manager @@ -16,105 +22,100 @@ logger = get_logger() -class TinkerSamplerHandlers: - """ - Mixin providing Tinker-compatible sampler endpoint. +def _register_tinker_sampler_routes(app: FastAPI, self_fn: Callable[[], SamplerManagement]) -> None: + """Register the tinker sampler route on the given FastAPI app. - Expects the combined class to also inherit TaskQueueMixin and to have: - self.sampler, self.state + self_fn is a zero-argument callable returning the current SamplerManagement replica instance. + It is wired in via Depends so it is resolved lazily at request time. """ - @staticmethod - def _register_tinker_sampler_routes(app: FastAPI): - """Register the tinker sampler route on the given FastAPI app.""" - - @app.post('/tinker/asample') - async def asample(self, request: Request, body: types.SampleRequest) -> types.UntypedAPIFuture: - """Execute text generation (inference) for Tinker clients. - - Args: - request: FastAPI request with auth token - body: SampleRequest with prompt, sampling params, and adapter info - - Returns: - UntypedAPIFuture wrapping SampleResponse with generated sequences - """ - from twinkle.server.utils.validation import get_token_from_request - token = await self._on_request_start(request) - - async def _do_sample(): - try: - # Extract prompt token IDs from ModelInput - prompt_inputs = {'input_ids': body.prompt.to_ints()} - - # Get model_path from body or sampling session - model_path = body.model_path - if not model_path and body.sampling_session_id: - session = self.state.get_sampling_session(body.sampling_session_id) - if session: - model_path = session.get('model_path') - - # Parse and resolve adapter URI from model_path - adapter_uri = None - if model_path: - checkpoint_manager = create_checkpoint_manager(token, client_type='tinker') - adapter_name, adapter_uri = checkpoint_manager.parse_adapter_uri(model_path) - - # Validate adapter URI - if not adapter_uri or not os.path.exists(adapter_uri): - return types.RequestFailedResponse( - error=f'Adapter URI {model_path} does not exist. Please check the model_path.', - category=types.RequestErrorCategory.User, - ) - - # Convert tinker SamplingParams to twinkle SamplingParams if needed - sampling_params = None - if body.sampling_params: - sampling_params = SamplingParams( - max_tokens=body.sampling_params.max_tokens or 256, - temperature=body.sampling_params.temperature or 1.0, - top_p=body.sampling_params.top_p, - top_k=body.sampling_params.top_k, - stop=body.sampling_params.stop, - ) - - response = self.sampler.sample( - inputs=[prompt_inputs] * body.num_samples, - sampling_params=sampling_params, - adapter_path=adapter_uri, + @app.post('/tinker/asample') + async def asample(request: Request, body: types.SampleRequest, + self: SamplerManagement = Depends(self_fn)) -> types.UntypedAPIFuture: + """Execute text generation (inference) for Tinker clients. + + Args: + request: FastAPI request with auth token + body: SampleRequest with prompt, sampling params, and adapter info + + Returns: + UntypedAPIFuture wrapping SampleResponse with generated sequences + """ + token = await self._on_request_start(request) + + async def _do_sample(): + try: + # Extract prompt token IDs from ModelInput + prompt_inputs = {'input_ids': body.prompt.to_ints()} + + # Get model_path from body or sampling session + model_path = body.model_path + if not model_path and body.sampling_session_id: + session = self.state.get_sampling_session(body.sampling_session_id) + if session: + model_path = session.get('model_path') + + # Parse and resolve adapter URI from model_path + adapter_uri = None + if model_path: + checkpoint_manager = create_checkpoint_manager(token, client_type='tinker') + adapter_name, adapter_uri = checkpoint_manager.parse_adapter_uri(model_path) + + # Validate adapter URI + if not adapter_uri or not os.path.exists(adapter_uri): + return types.RequestFailedResponse( + error=f'Adapter URI {model_path} does not exist. Please check the model_path.', + category=types.RequestErrorCategory.User, ) - # Convert twinkle SampleResponse to tinker types - tinker_sequences = [] - for seq in response.sequences: - logprobs = None - if seq.logprobs is not None: - if any(lp is None for lp in seq.logprobs): - logprobs = None - else: - logprobs = list(seq.logprobs) - tinker_sequences.append( - types.SampledSequence( - stop_reason=seq.stop_reason, - tokens=list(seq.tokens), - logprobs=logprobs, - )) - return types.SampleResponse( - sequences=tinker_sequences, - prompt_logprobs=response.prompt_logprobs, - topk_prompt_logprobs=response.topk_prompt_logprobs, - ) - except Exception: - logger.error(traceback.format_exc()) - return types.RequestFailedResponse( - error=traceback.format_exc(), - category=types.RequestErrorCategory.Server, + # Convert tinker SamplingParams to twinkle SamplingParams if needed + sampling_params = None + if body.sampling_params: + sampling_params = SamplingParams( + max_tokens=body.sampling_params.max_tokens or 256, + temperature=body.sampling_params.temperature or 1.0, + top_p=body.sampling_params.top_p, + top_k=body.sampling_params.top_k, + stop=body.sampling_params.stop, ) - input_tokens = len(body.prompt.to_ints()) - return await self.schedule_task( - _do_sample, - token=token, - input_tokens=input_tokens, - task_type='sample', - ) + response = self.sampler.sample( + inputs=[prompt_inputs] * body.num_samples, + sampling_params=sampling_params, + adapter_path=adapter_uri, + ) + + # Convert twinkle SampleResponse to tinker types + tinker_sequences = [] + for seq in response.sequences: + logprobs = None + if seq.logprobs is not None: + if any(lp is None for lp in seq.logprobs): + logprobs = None + else: + logprobs = list(seq.logprobs) + tinker_sequences.append( + types.SampledSequence( + stop_reason=seq.stop_reason, + tokens=list(seq.tokens), + logprobs=logprobs, + )) + return types.SampleResponse( + sequences=tinker_sequences, + prompt_logprobs=response.prompt_logprobs, + topk_prompt_logprobs=response.topk_prompt_logprobs, + ) + except Exception: + logger.error(traceback.format_exc()) + return types.RequestFailedResponse( + error=traceback.format_exc(), + category=types.RequestErrorCategory.Server, + ) + + input_tokens = len(body.prompt.to_ints()) + return await self.schedule_task( + _do_sample, + token=token, + input_tokens=input_tokens, + task_type='sample', + ) diff --git a/src/twinkle/server/sampler/twinkle_handlers.py b/src/twinkle/server/sampler/twinkle_handlers.py index b35ac404..860561e4 100644 --- a/src/twinkle/server/sampler/twinkle_handlers.py +++ b/src/twinkle/server/sampler/twinkle_handlers.py @@ -4,9 +4,14 @@ Provides /twinkle/* sampler endpoints that call the sampler directly (no queue needed). """ +from __future__ import annotations + import traceback -from fastapi import FastAPI, Request -from typing import Optional +from fastapi import Depends, FastAPI, Request +from typing import TYPE_CHECKING, Callable, Optional + +if TYPE_CHECKING: + from .app import SamplerManagement from twinkle.data_format import InputFeature, SamplingParams, Trajectory from twinkle.utils.logger import get_logger @@ -17,123 +22,125 @@ logger = get_logger() -class TwinkleSamplerHandlers: - """ - Mixin providing Twinkle-native sampler endpoints. +def _get_twinkle_sampler_adapter_name(request: Request, adapter_name: str | None) -> str | None: + """Prefix the adapter name with the request ID for per-request isolation.""" + if adapter_name is None or adapter_name == '': + return None + return request.state.request_id + '-' + adapter_name + + +def _register_twinkle_sampler_routes(app: FastAPI, self_fn: Callable[[], SamplerManagement]) -> None: + """Register all /twinkle/* sampler routes on the given FastAPI app. - Expects the combined class to also have: - self.sampler, self.state - The class should also inherit AdapterManagerMixin for adapter lifecycle. + self_fn is a zero-argument callable returning the current SamplerManagement replica instance. + It is wired in via Depends so it is resolved lazily at request time. """ - @staticmethod - def _register_twinkle_sampler_routes(app: FastAPI): - """Register all twinkle sampler routes on the given FastAPI app.""" - - @staticmethod - def _get_twinkle_sampler_adapter_name(request: Request, adapter_name: Optional[str]) -> Optional[str]: - if adapter_name is None or adapter_name == '': - return None - return request.state.request_id + '-' + adapter_name - - @app.post('/twinkle/create', response_model=CreateResponse) - def create(self, request: Request) -> CreateResponse: - """Health check / session creation endpoint.""" - return CreateResponse() - - @app.post('/twinkle/sample', response_model=SampleResponseModel) - def sample(self, request: Request, body: SampleRequest) -> SampleResponseModel: - """Sample completions from the model. - - Supports Trajectory or InputFeature inputs, with optional LoRA adapter. - """ - try: - # Resolve adapter - adapter_path = None - adapter_name = body.adapter_name or '' - full_adapter_name = _get_twinkle_sampler_adapter_name(request, adapter_name) or '' - - if body.adapter_uri: - from twinkle.server.common.io_utils import create_checkpoint_manager - from twinkle.server.utils.validation import get_token_from_request - token = get_token_from_request(request) - checkpoint_manager = create_checkpoint_manager(token, client_type='twinkle') - _, adapter_path = checkpoint_manager.parse_adapter_uri(body.adapter_uri) - - # Parse inputs - inputs = body.inputs - if isinstance(inputs, list) and inputs: - first = inputs[0] - if isinstance(first, dict) and 'input_ids' in first: - inputs = [InputFeature(**item) for item in inputs] - else: - inputs = [Trajectory(**item) for item in inputs] - elif isinstance(inputs, dict): - if 'input_ids' in inputs: - inputs = [InputFeature(**inputs)] - else: - inputs = [Trajectory(**inputs)] - - # Build sampling params - params = None - if body.sampling_params: - params = SamplingParams.from_dict(body.sampling_params) - - # Call sampler - response = self.sampler.sample( - inputs, - params, - adapter_name=full_adapter_name, - adapter_path=adapter_path, - num_samples=body.num_samples, - ) - if callable(response): - response = response() - - sequences = [] - for seq in response.sequences: - sequences.append({ - 'stop_reason': seq.stop_reason, - 'tokens': list(seq.tokens), - 'logprobs': list(seq.logprobs) if seq.logprobs is not None else None, - }) - - return SampleResponseModel( - sequences=sequences, - prompt_logprobs=response.prompt_logprobs, - topk_prompt_logprobs=response.topk_prompt_logprobs, - ) - except Exception: - logger.error(traceback.format_exc()) - raise - - @app.post('/twinkle/set_template', response_model=SetTemplateResponse) - def set_template(self, request: Request, body: SetTemplateRequest) -> SetTemplateResponse: - """Set the chat template for encoding Trajectory inputs.""" - extra_kwargs = body.model_extra or {} - self.sampler.set_template(body.template_cls, **extra_kwargs) - return SetTemplateResponse() - - @app.post('/twinkle/add_adapter_to_sampler', response_model=AddAdapterResponse) - def add_adapter_to_sampler(self, request: Request, body: AddAdapterRequest) -> AddAdapterResponse: - """Add a LoRA adapter to the sampler.""" - assert body.adapter_name, 'You need to specify a valid `adapter_name`' - full_adapter_name = _get_twinkle_sampler_adapter_name(request, body.adapter_name) - from twinkle.server.utils.validation import get_token_from_request - token = get_token_from_request(request) - - from peft import LoraConfig - config = LoraConfig(**body.config) if isinstance(body.config, dict) else body.config - - self.register_adapter(full_adapter_name, token) - self.sampler.add_adapter_to_sampler(full_adapter_name, config) - - return AddAdapterResponse(adapter_name=full_adapter_name) - - @app.post('/twinkle/heartbeat', response_model=HeartbeatResponse) - def heartbeat(self, request: Request, body: HeartbeatRequest) -> HeartbeatResponse: - """Keep an adapter alive by resetting its inactivity timer.""" - full_adapter_name = _get_twinkle_sampler_adapter_name(request, body.adapter_name) - self.assert_adapter_exists(adapter_name=full_adapter_name) - self.touch_adapter(full_adapter_name) - return HeartbeatResponse() + @app.post('/twinkle/create', response_model=CreateResponse) + def create(request: Request, self: SamplerManagement = Depends(self_fn)) -> CreateResponse: + """Health check / session creation endpoint.""" + return CreateResponse() + + @app.post('/twinkle/sample', response_model=SampleResponseModel) + def sample(request: Request, body: SampleRequest, + self: SamplerManagement = Depends(self_fn)) -> SampleResponseModel: + """Sample completions from the model. + + Supports Trajectory or InputFeature inputs, with optional LoRA adapter. + """ + try: + # Resolve adapter + adapter_path = None + adapter_name = body.adapter_name or '' + full_adapter_name = _get_twinkle_sampler_adapter_name(request, adapter_name) or '' + + if body.adapter_uri: + from twinkle.server.common.io_utils import create_checkpoint_manager + from twinkle.server.utils.validation import get_token_from_request + token = get_token_from_request(request) + checkpoint_manager = create_checkpoint_manager(token, client_type='twinkle') + _, adapter_path = checkpoint_manager.parse_adapter_uri(body.adapter_uri) + + # Parse inputs + inputs = body.inputs + if isinstance(inputs, list) and inputs: + first = inputs[0] + if isinstance(first, dict) and 'input_ids' in first: + inputs = [InputFeature(**item) for item in inputs] + else: + inputs = [Trajectory(**item) for item in inputs] + elif isinstance(inputs, dict): + if 'input_ids' in inputs: + inputs = [InputFeature(**inputs)] + else: + inputs = [Trajectory(**inputs)] + + # Build sampling params + params = None + if body.sampling_params: + params = SamplingParams.from_dict(body.sampling_params) + + # Call sampler + response = self.sampler.sample( + inputs, + params, + adapter_name=full_adapter_name, + adapter_path=adapter_path, + num_samples=body.num_samples, + ) + if callable(response): + response = response() + + sequences = [] + for seq in response.sequences: + sequences.append({ + 'stop_reason': seq.stop_reason, + 'tokens': list(seq.tokens), + 'logprobs': list(seq.logprobs) if seq.logprobs is not None else None, + }) + + return SampleResponseModel( + sequences=sequences, + prompt_logprobs=response.prompt_logprobs, + topk_prompt_logprobs=response.topk_prompt_logprobs, + ) + except Exception: + logger.error(traceback.format_exc()) + raise + + @app.post('/twinkle/set_template', response_model=SetTemplateResponse) + def set_template(request: Request, body: SetTemplateRequest, + self: SamplerManagement = Depends(self_fn)) -> SetTemplateResponse: + """Set the chat template for encoding Trajectory inputs.""" + extra_kwargs = body.model_extra or {} + self.sampler.set_template(body.template_cls, **extra_kwargs) + return SetTemplateResponse() + + @app.post('/twinkle/add_adapter_to_sampler', response_model=AddAdapterResponse) + def add_adapter_to_sampler( + request: Request, + body: AddAdapterRequest, + self: SamplerManagement = Depends(self_fn), + ) -> AddAdapterResponse: + """Add a LoRA adapter to the sampler.""" + assert body.adapter_name, 'You need to specify a valid `adapter_name`' + full_adapter_name = _get_twinkle_sampler_adapter_name(request, body.adapter_name) + from twinkle.server.utils.validation import get_token_from_request + token = get_token_from_request(request) + + from peft import LoraConfig + config = LoraConfig(**body.config) if isinstance(body.config, dict) else body.config + + self.register_adapter(full_adapter_name, token) + self.sampler.add_adapter_to_sampler(full_adapter_name, config) + + return AddAdapterResponse(adapter_name=full_adapter_name) + + @app.post('/twinkle/heartbeat', response_model=HeartbeatResponse) + def heartbeat(request: Request, body: HeartbeatRequest, + self: SamplerManagement = Depends(self_fn)) -> HeartbeatResponse: + """Keep an adapter alive by resetting its inactivity timer.""" + full_adapter_name = _get_twinkle_sampler_adapter_name(request, body.adapter_name) + self.assert_adapter_exists(adapter_name=full_adapter_name) + self.touch_adapter(full_adapter_name) + return HeartbeatResponse() From 3bc466f0c5bdd87299a34249ee0e9852d5aff854 Mon Sep 17 00:00:00 2001 From: Yunnglin Date: Wed, 11 Mar 2026 18:10:41 +0800 Subject: [PATCH 07/24] update io --- client_tools/client_generator.py | 12 +- src/twinkle/server/common/io_utils.py | 290 +----------------- src/twinkle/server/common/tinker_io_utils.py | 134 ++++++++ src/twinkle/server/common/twinkle_io_utils.py | 141 +++++++++ .../gateway/twinkle_gateway_handlers.py | 1 - src/twinkle/server/processor/app.py | 6 +- src/twinkle_client/dataloader/dataloader.py | 12 +- src/twinkle_client/dataset/base.py | 22 +- .../dataset/iterable_dataset.py | 14 +- .../dataset/iterable_packing_dataset.py | 12 +- src/twinkle_client/dataset/lazy_dataset.py | 12 +- src/twinkle_client/dataset/packing_dataset.py | 10 +- src/twinkle_client/http/heartbeat.py | 5 +- src/twinkle_client/http/http_utils.py | 1 - src/twinkle_client/manager.py | 4 +- .../model/multi_lora_transformers.py | 2 +- src/twinkle_client/processor/base.py | 6 +- src/twinkle_client/sampler/vllm_sampler.py | 2 +- 18 files changed, 347 insertions(+), 339 deletions(-) create mode 100644 src/twinkle/server/common/tinker_io_utils.py create mode 100644 src/twinkle/server/common/twinkle_io_utils.py diff --git a/client_tools/client_generator.py b/client_tools/client_generator.py index c337c464..858f4872 100644 --- a/client_tools/client_generator.py +++ b/client_tools/client_generator.py @@ -274,7 +274,7 @@ def build_method(name: str, signature: str) -> str: code = f''' def {name}(self{sig_part}): response = http_post( - url=f'{{self.server_url}}/processors/call', + url=f'{{self.server_url}}/call', json_data={{ 'processor_id': self.processor_id, 'function': '{name}', @@ -288,7 +288,7 @@ def {name}(self{sig_part}): code += ''' def __next__(self): response = http_post( - url=f'{self.server_url}/processors/call', + url=f'{self.server_url}/call', json_data={ 'processor_id': self.processor_id, 'function': '__next__', @@ -346,10 +346,10 @@ class {class_name}({inheritance}): def __init__({init_params}): from twinkle_client.http import get_base_url - self.server_url = get_base_url() + self.server_url = f'{{get_base_url()}}/processors/twinkle' response = http_post( - url=f'{{self.server_url}}/processors/create', + url=f'{{self.server_url}}/create', json_data={{ 'processor_type': '{processor_type}', 'class_type': '{class_name}', @@ -466,7 +466,7 @@ def __init__(self, model_id: str, **kwargs): self.model_id = model_id if '://' in model_id: model_id = model_id.split('://')[1] - self.server_url = f'{self.server_url}/models/{model_id}' + self.server_url = f'{self.server_url}/models/{model_id}/twinkle' self.adapter_name = None response = http_post( url=f'{self.server_url}/create', @@ -743,7 +743,7 @@ def __init__(self, model_id: str, **kwargs): self.adapter_name = None if '://' in model_id: model_id = model_id.split('://')[1] - self.server_url = f'{self.server_url}/samplers/{model_id}' + self.server_url = f'{self.server_url}/samplers/{model_id}/twinkle' response = http_post( url=f'{self.server_url}/create', json_data=kwargs diff --git a/src/twinkle/server/common/io_utils.py b/src/twinkle/server/common/io_utils.py index 089a4955..9c389de1 100644 --- a/src/twinkle/server/common/io_utils.py +++ b/src/twinkle/server/common/io_utils.py @@ -2,25 +2,19 @@ """ Unified IO utilities for managing training runs and checkpoints. -Merges tinker/common/io_utils.py and twinkle/common/io_utils.py. -Both client-type implementations share the same underlying base classes; -factory functions accept a ``client_type`` parameter ('tinker' or 'twinkle'). +Manager implementations live in dedicated modules: + - ``tinker_io_utils`` : Tinker-specific managers (use tinker.types) + - ``twinkle_io_utils`` : Twinkle-specific managers (use twinkle_client.types.training) -Pydantic models that need to be shared with the client live in -``twinkle_client.types.training``. +This module re-exports everything and provides the factory functions +``create_training_run_manager`` and ``create_checkpoint_manager``. """ -from datetime import datetime -from tinker import types as tinker_types -from typing import Any, Dict, List, Optional - -from twinkle.server.utils.io_utils import (CHECKPOINT_INFO_FILENAME, TRAIN_RUN_INFO_FILENAME, TWINKLE_DEFAULT_SAVE_DIR, - BaseCheckpoint, BaseCheckpointManager, BaseCreateModelRequest, - BaseLoraConfig, BaseParsedCheckpointPath, BaseTrainingRun, - BaseTrainingRunManager, BaseWeightsInfoResponse, Cursor, ResolvedLoadPath, - validate_ownership, validate_user_path) +from twinkle.server.common.tinker_io_utils import TinkerCheckpointManager, TinkerTrainingRunManager +from twinkle.server.common.twinkle_io_utils import TwinkleCheckpointManager, TwinkleTrainingRunManager +from twinkle.server.utils.io_utils import ResolvedLoadPath, validate_ownership, validate_user_path # Re-export twinkle-native pydantic models from twinkle_client.types from twinkle_client.types.training import Checkpoint as TwinkleCheckpoint -from twinkle_client.types.training import (CheckpointsListResponse, CreateModelRequest, LoraConfig, +from twinkle_client.types.training import (CheckpointsListResponse, CreateModelRequest, Cursor, LoraConfig, ParsedCheckpointTwinklePath) from twinkle_client.types.training import TrainingRun as TwinkleTrainingRun from twinkle_client.types.training import TrainingRunsResponse, WeightsInfoResponse @@ -32,6 +26,10 @@ 'validate_ownership', 'ResolvedLoadPath', 'Cursor', + 'TinkerTrainingRunManager', + 'TinkerCheckpointManager', + 'TwinkleTrainingRunManager', + 'TwinkleCheckpointManager', # Twinkle-native models (re-exported for convenience) 'TwinkleCheckpoint', 'TwinkleTrainingRun', @@ -43,268 +41,6 @@ 'ParsedCheckpointTwinklePath', ] -# --------------------------------------------------------------------------- -# Tinker-specific managers (use tinker.types for model instances) -# --------------------------------------------------------------------------- - - -class TinkerTrainingRunManager(BaseTrainingRunManager): - """Tinker-specific training run manager using tinker.types models.""" - - @property - def train_run_info_filename(self) -> str: - return TRAIN_RUN_INFO_FILENAME - - def _create_training_run(self, model_id: str, run_config: tinker_types.CreateModelRequest) -> Dict[str, Any]: - lora_config = run_config.lora_config - train_run_data = tinker_types.TrainingRun( - training_run_id=model_id, - base_model=run_config.base_model, - model_owner=self.token, - is_lora=True if lora_config else False, - corrupted=False, - lora_rank=lora_config.rank if lora_config else None, - last_request_time=datetime.now(), - last_checkpoint=None, - last_sampler_checkpoint=None, - user_metadata=run_config.user_metadata) - - new_data = train_run_data.model_dump(mode='json') - if lora_config: - new_data['train_unembed'] = lora_config.train_unembed - new_data['train_mlp'] = lora_config.train_mlp - new_data['train_attn'] = lora_config.train_attn - return new_data - - def _parse_training_run(self, data: Dict[str, Any]) -> tinker_types.TrainingRun: - data = self._transform_checkpoint_fields(data) - return tinker_types.TrainingRun(**data) - - def _transform_checkpoint_fields(self, data: Dict[str, Any]) -> Dict[str, Any]: - data = data.copy() - for field in ['last_checkpoint', 'last_sampler_checkpoint']: - if field in data and data[field] is not None: - ckpt = data[field].copy() - if 'twinkle_path' in ckpt and 'tinker_path' not in ckpt: - ckpt['tinker_path'] = ckpt.pop('twinkle_path') - elif 'tinker_path' not in ckpt: - path = ckpt.get('path') or ckpt.get('twinkle_path') - if path: - ckpt['tinker_path'] = path - elif 'checkpoint_id' in ckpt and 'training_run_id' in data: - ckpt['tinker_path'] = f"twinkle://{data['training_run_id']}/{ckpt['checkpoint_id']}" - data[field] = ckpt - return data - - def _create_training_runs_response(self, runs: List[tinker_types.TrainingRun], limit: int, offset: int, - total: int) -> tinker_types.TrainingRunsResponse: - return tinker_types.TrainingRunsResponse( - training_runs=runs, cursor=tinker_types.Cursor(limit=limit, offset=offset, total_count=total)) - - -class TinkerCheckpointManager(BaseCheckpointManager): - """Tinker-specific checkpoint manager using tinker.types models.""" - - @property - def path_prefix(self) -> str: - return 'twinkle://' - - @property - def path_field_name(self) -> str: - return 'tinker_path' - - def _create_checkpoint(self, - checkpoint_id, - checkpoint_type, - path, - size_bytes, - public, - base_model=None, - is_lora=False, - lora_rank=None, - train_unembed=None, - train_mlp=None, - train_attn=None, - user_metadata=None) -> Dict[str, Any]: - checkpoint = tinker_types.Checkpoint( - checkpoint_id=checkpoint_id, - checkpoint_type=checkpoint_type, - time=datetime.now(), - tinker_path=path, - size_bytes=size_bytes, - public=public) - result = checkpoint.model_dump(mode='json') - result['base_model'] = base_model - result['is_lora'] = is_lora - result['lora_rank'] = lora_rank - result['train_unembed'] = train_unembed - result['train_mlp'] = train_mlp - result['train_attn'] = train_attn - result['user_metadata'] = user_metadata - return result - - def _parse_checkpoint(self, data: Dict[str, Any]) -> tinker_types.Checkpoint: - data = data.copy() - if 'twinkle_path' in data and 'tinker_path' not in data: - data['tinker_path'] = data.pop('twinkle_path') - elif 'tinker_path' not in data and 'path' in data: - data['tinker_path'] = data.pop('path') - return tinker_types.Checkpoint(**data) - - def _create_checkpoints_response( - self, checkpoints: List[tinker_types.Checkpoint]) -> tinker_types.CheckpointsListResponse: - return tinker_types.CheckpointsListResponse(checkpoints=checkpoints, cursor=None) - - def _create_parsed_path(self, path, training_run_id, checkpoint_type, - checkpoint_id) -> tinker_types.ParsedCheckpointTinkerPath: - return tinker_types.ParsedCheckpointTinkerPath( - tinker_path=path, - training_run_id=training_run_id, - checkpoint_type=checkpoint_type, - checkpoint_id=checkpoint_id, - ) - - def _create_weights_info(self, run_info: Dict[str, Any]) -> tinker_types.WeightsInfoResponse: - return tinker_types.WeightsInfoResponse(**run_info) - - def parse_tinker_path(self, tinker_path: str) -> Optional[tinker_types.ParsedCheckpointTinkerPath]: - return self.parse_path(tinker_path) - - -# --------------------------------------------------------------------------- -# Twinkle-specific managers (use twinkle_client.types.training models) -# --------------------------------------------------------------------------- - - -class TwinkleTrainingRunManager(BaseTrainingRunManager): - """Twinkle-specific training run manager.""" - - @property - def train_run_info_filename(self) -> str: - return TRAIN_RUN_INFO_FILENAME - - def _create_training_run(self, model_id: str, run_config: CreateModelRequest) -> Dict[str, Any]: - lora_config = run_config.lora_config - train_run_data = TwinkleTrainingRun( - training_run_id=model_id, - base_model=run_config.base_model, - model_owner=self.token, - is_lora=True if lora_config else False, - corrupted=False, - lora_rank=lora_config.rank if lora_config else None, - last_request_time=datetime.now(), - last_checkpoint=None, - last_sampler_checkpoint=None, - user_metadata=run_config.user_metadata) - - new_data = train_run_data.model_dump(mode='json') - if lora_config: - new_data['train_unembed'] = lora_config.train_unembed - new_data['train_mlp'] = lora_config.train_mlp - new_data['train_attn'] = lora_config.train_attn - return new_data - - def _parse_training_run(self, data: Dict[str, Any]) -> TwinkleTrainingRun: - return TwinkleTrainingRun(**data) - - def _create_training_runs_response(self, runs: List[TwinkleTrainingRun], limit: int, offset: int, - total: int) -> TrainingRunsResponse: - return TrainingRunsResponse(training_runs=runs, cursor=Cursor(limit=limit, offset=offset, total_count=total)) - - def get_with_permission(self, model_id: str) -> Optional[TwinkleTrainingRun]: - run = self.get(model_id) - if run and validate_ownership(self.token, run.model_owner): - return run - return None - - -class TwinkleCheckpointManager(BaseCheckpointManager): - """Twinkle-specific checkpoint manager.""" - - @property - def path_prefix(self) -> str: - return 'twinkle://' - - @property - def path_field_name(self) -> str: - return 'twinkle_path' - - def _create_checkpoint(self, - checkpoint_id, - checkpoint_type, - path, - size_bytes, - public, - base_model=None, - is_lora=False, - lora_rank=None, - train_unembed=None, - train_mlp=None, - train_attn=None, - user_metadata=None) -> Dict[str, Any]: - checkpoint = TwinkleCheckpoint( - checkpoint_id=checkpoint_id, - checkpoint_type=checkpoint_type, - time=datetime.now(), - twinkle_path=path, - size_bytes=size_bytes, - public=public, - base_model=base_model, - is_lora=is_lora, - lora_rank=lora_rank, - train_unembed=train_unembed, - train_mlp=train_mlp, - train_attn=train_attn, - user_metadata=user_metadata) - return checkpoint.model_dump(mode='json') - - def _parse_checkpoint(self, data: Dict[str, Any]) -> TwinkleCheckpoint: - data = data.copy() - if 'tinker_path' in data and 'twinkle_path' not in data: - data['twinkle_path'] = data.pop('tinker_path') - elif 'twinkle_path' not in data and 'path' in data: - data['twinkle_path'] = data.pop('path') - return TwinkleCheckpoint(**data) - - def get(self, model_id: str, checkpoint_id: str) -> Optional[TwinkleCheckpoint]: - data = self._read_ckpt_info(model_id, checkpoint_id) - if not data: - return None - if 'twinkle_path' not in data and 'tinker_path' not in data and 'path' not in data: - if 'checkpoint_id' in data: - data = data.copy() - data['twinkle_path'] = f"{self.path_prefix}{model_id}/{data['checkpoint_id']}" - return self._parse_checkpoint(data) - - def _create_checkpoints_response(self, checkpoints: List[TwinkleCheckpoint]) -> CheckpointsListResponse: - return CheckpointsListResponse(checkpoints=checkpoints, cursor=None) - - def _create_parsed_path(self, path, training_run_id, checkpoint_type, checkpoint_id) -> ParsedCheckpointTwinklePath: - return ParsedCheckpointTwinklePath( - path=path, - twinkle_path=path, - training_run_id=training_run_id, - checkpoint_type=checkpoint_type, - checkpoint_id=checkpoint_id, - ) - - def _create_weights_info(self, run_info: Dict[str, Any]) -> WeightsInfoResponse: - return WeightsInfoResponse( - training_run_id=run_info.get('training_run_id', ''), - base_model=run_info.get('base_model', ''), - model_owner=run_info.get('model_owner', ''), - is_lora=run_info.get('is_lora', False), - lora_rank=run_info.get('lora_rank'), - ) - - def parse_twinkle_path(self, twinkle_path: str) -> Optional[ParsedCheckpointTwinklePath]: - return self.parse_path(twinkle_path) - - -# --------------------------------------------------------------------------- -# Unified factory functions -# --------------------------------------------------------------------------- - def create_training_run_manager(token: str, client_type: str = 'twinkle'): """Create a TrainingRunManager for the given token. diff --git a/src/twinkle/server/common/tinker_io_utils.py b/src/twinkle/server/common/tinker_io_utils.py new file mode 100644 index 00000000..90d49482 --- /dev/null +++ b/src/twinkle/server/common/tinker_io_utils.py @@ -0,0 +1,134 @@ +# Copyright (c) ModelScope Contributors. All rights reserved. +""" +Tinker-specific IO managers for training runs and checkpoints. + +Uses ``tinker.types`` models for all serialization and response construction. +""" +from datetime import datetime +from tinker import types as tinker_types +from typing import Any, Dict, List, Optional + +from twinkle.server.utils.io_utils import TRAIN_RUN_INFO_FILENAME, BaseCheckpointManager, BaseTrainingRunManager + + +class TinkerTrainingRunManager(BaseTrainingRunManager): + """Tinker-specific training run manager using tinker.types models.""" + + @property + def train_run_info_filename(self) -> str: + return TRAIN_RUN_INFO_FILENAME + + def _create_training_run(self, model_id: str, run_config: tinker_types.CreateModelRequest) -> Dict[str, Any]: + lora_config = run_config.lora_config + train_run_data = tinker_types.TrainingRun( + training_run_id=model_id, + base_model=run_config.base_model, + model_owner=self.token, + is_lora=True if lora_config else False, + corrupted=False, + lora_rank=lora_config.rank if lora_config else None, + last_request_time=datetime.now(), + last_checkpoint=None, + last_sampler_checkpoint=None, + user_metadata=run_config.user_metadata) + + new_data = train_run_data.model_dump(mode='json') + if lora_config: + new_data['train_unembed'] = lora_config.train_unembed + new_data['train_mlp'] = lora_config.train_mlp + new_data['train_attn'] = lora_config.train_attn + return new_data + + def _parse_training_run(self, data: Dict[str, Any]) -> tinker_types.TrainingRun: + data = self._transform_checkpoint_fields(data) + return tinker_types.TrainingRun(**data) + + def _transform_checkpoint_fields(self, data: Dict[str, Any]) -> Dict[str, Any]: + data = data.copy() + for field in ['last_checkpoint', 'last_sampler_checkpoint']: + if field in data and data[field] is not None: + ckpt = data[field].copy() + if 'twinkle_path' in ckpt and 'tinker_path' not in ckpt: + ckpt['tinker_path'] = ckpt.pop('twinkle_path') + elif 'tinker_path' not in ckpt: + path = ckpt.get('path') or ckpt.get('twinkle_path') + if path: + ckpt['tinker_path'] = path + elif 'checkpoint_id' in ckpt and 'training_run_id' in data: + ckpt['tinker_path'] = f"twinkle://{data['training_run_id']}/{ckpt['checkpoint_id']}" + data[field] = ckpt + return data + + def _create_training_runs_response(self, runs: List[tinker_types.TrainingRun], limit: int, offset: int, + total: int) -> tinker_types.TrainingRunsResponse: + return tinker_types.TrainingRunsResponse( + training_runs=runs, cursor=tinker_types.Cursor(limit=limit, offset=offset, total_count=total)) + + +class TinkerCheckpointManager(BaseCheckpointManager): + """Tinker-specific checkpoint manager using tinker.types models.""" + + @property + def path_prefix(self) -> str: + return 'twinkle://' + + @property + def path_field_name(self) -> str: + return 'tinker_path' + + def _create_checkpoint(self, + checkpoint_id, + checkpoint_type, + path, + size_bytes, + public, + base_model=None, + is_lora=False, + lora_rank=None, + train_unembed=None, + train_mlp=None, + train_attn=None, + user_metadata=None) -> Dict[str, Any]: + checkpoint = tinker_types.Checkpoint( + checkpoint_id=checkpoint_id, + checkpoint_type=checkpoint_type, + time=datetime.now(), + tinker_path=path, + size_bytes=size_bytes, + public=public) + result = checkpoint.model_dump(mode='json') + result['base_model'] = base_model + result['is_lora'] = is_lora + result['lora_rank'] = lora_rank + result['train_unembed'] = train_unembed + result['train_mlp'] = train_mlp + result['train_attn'] = train_attn + result['user_metadata'] = user_metadata + return result + + def _parse_checkpoint(self, data: Dict[str, Any]) -> tinker_types.Checkpoint: + data = data.copy() + if 'twinkle_path' in data and 'tinker_path' not in data: + data['tinker_path'] = data.pop('twinkle_path') + elif 'tinker_path' not in data and 'path' in data: + data['tinker_path'] = data.pop('path') + return tinker_types.Checkpoint(**data) + + def _create_checkpoints_response( + self, checkpoints: List[tinker_types.Checkpoint]) -> tinker_types.CheckpointsListResponse: + return tinker_types.CheckpointsListResponse(checkpoints=checkpoints, cursor=None) + + def _create_parsed_path(self, path, training_run_id, checkpoint_type, + checkpoint_id) -> tinker_types.ParsedCheckpointTinkerPath: + return tinker_types.ParsedCheckpointTinkerPath( + tinker_path=path, + training_run_id=training_run_id, + checkpoint_type=checkpoint_type, + checkpoint_id=checkpoint_id, + ) + + def _create_weights_info(self, run_info: Dict[str, Any]) -> tinker_types.WeightsInfoResponse: + return tinker_types.WeightsInfoResponse(**run_info) + + def parse_tinker_path(self, tinker_path: str) -> Optional[tinker_types.ParsedCheckpointTinkerPath]: + return self.parse_path(tinker_path) diff --git a/src/twinkle/server/common/twinkle_io_utils.py b/src/twinkle/server/common/twinkle_io_utils.py new file mode 100644 index 00000000..7661f1dc --- /dev/null +++ b/src/twinkle/server/common/twinkle_io_utils.py @@ -0,0 +1,141 @@ +# Copyright (c) ModelScope Contributors. All rights reserved. +""" +Twinkle-specific IO managers for training runs and checkpoints. + +Uses ``twinkle_client.types.training`` models for all serialization and response construction. +""" +from datetime import datetime +from typing import Any, Dict, List, Optional + +from twinkle.server.utils.io_utils import (TRAIN_RUN_INFO_FILENAME, BaseCheckpointManager, BaseTrainingRunManager, + validate_ownership) +from twinkle_client.types.training import Checkpoint as TwinkleCheckpoint +from twinkle_client.types.training import (CheckpointsListResponse, CreateModelRequest, Cursor, + ParsedCheckpointTwinklePath) +from twinkle_client.types.training import TrainingRun as TwinkleTrainingRun +from twinkle_client.types.training import TrainingRunsResponse, WeightsInfoResponse + + +class TwinkleTrainingRunManager(BaseTrainingRunManager): + """Twinkle-specific training run manager.""" + + @property + def train_run_info_filename(self) -> str: + return TRAIN_RUN_INFO_FILENAME + + def _create_training_run(self, model_id: str, run_config: CreateModelRequest) -> Dict[str, Any]: + lora_config = run_config.lora_config + train_run_data = TwinkleTrainingRun( + training_run_id=model_id, + base_model=run_config.base_model, + model_owner=self.token, + is_lora=True if lora_config else False, + corrupted=False, + lora_rank=lora_config.rank if lora_config else None, + last_request_time=datetime.now(), + last_checkpoint=None, + last_sampler_checkpoint=None, + user_metadata=run_config.user_metadata) + + new_data = train_run_data.model_dump(mode='json') + if lora_config: + new_data['train_unembed'] = lora_config.train_unembed + new_data['train_mlp'] = lora_config.train_mlp + new_data['train_attn'] = lora_config.train_attn + return new_data + + def _parse_training_run(self, data: Dict[str, Any]) -> TwinkleTrainingRun: + return TwinkleTrainingRun(**data) + + def _create_training_runs_response(self, runs: List[TwinkleTrainingRun], limit: int, offset: int, + total: int) -> TrainingRunsResponse: + return TrainingRunsResponse(training_runs=runs, cursor=Cursor(limit=limit, offset=offset, total_count=total)) + + def get_with_permission(self, model_id: str) -> Optional[TwinkleTrainingRun]: + run = self.get(model_id) + if run and validate_ownership(self.token, run.model_owner): + return run + return None + + +class TwinkleCheckpointManager(BaseCheckpointManager): + """Twinkle-specific checkpoint manager.""" + + @property + def path_prefix(self) -> str: + return 'twinkle://' + + @property + def path_field_name(self) -> str: + return 'twinkle_path' + + def _create_checkpoint(self, + checkpoint_id, + checkpoint_type, + path, + size_bytes, + public, + base_model=None, + is_lora=False, + lora_rank=None, + train_unembed=None, + train_mlp=None, + train_attn=None, + user_metadata=None) -> Dict[str, Any]: + checkpoint = TwinkleCheckpoint( + checkpoint_id=checkpoint_id, + checkpoint_type=checkpoint_type, + time=datetime.now(), + twinkle_path=path, + size_bytes=size_bytes, + public=public, + base_model=base_model, + is_lora=is_lora, + lora_rank=lora_rank, + train_unembed=train_unembed, + train_mlp=train_mlp, + train_attn=train_attn, + user_metadata=user_metadata) + return checkpoint.model_dump(mode='json') + + def _parse_checkpoint(self, data: Dict[str, Any]) -> TwinkleCheckpoint: + data = data.copy() + if 'tinker_path' in data and 'twinkle_path' not in data: + data['twinkle_path'] = data.pop('tinker_path') + elif 'twinkle_path' not in data and 'path' in data: + data['twinkle_path'] = data.pop('path') + return TwinkleCheckpoint(**data) + + def get(self, model_id: str, checkpoint_id: str) -> Optional[TwinkleCheckpoint]: + data = self._read_ckpt_info(model_id, checkpoint_id) + if not data: + return None + if 'twinkle_path' not in data and 'tinker_path' not in data and 'path' not in data: + if 'checkpoint_id' in data: + data = data.copy() + data['twinkle_path'] = f"{self.path_prefix}{model_id}/{data['checkpoint_id']}" + return self._parse_checkpoint(data) + + def _create_checkpoints_response(self, checkpoints: List[TwinkleCheckpoint]) -> CheckpointsListResponse: + return CheckpointsListResponse(checkpoints=checkpoints, cursor=None) + + def _create_parsed_path(self, path, training_run_id, checkpoint_type, checkpoint_id) -> ParsedCheckpointTwinklePath: + return ParsedCheckpointTwinklePath( + path=path, + twinkle_path=path, + training_run_id=training_run_id, + checkpoint_type=checkpoint_type, + checkpoint_id=checkpoint_id, + ) + + def _create_weights_info(self, run_info: Dict[str, Any]) -> WeightsInfoResponse: + return WeightsInfoResponse( + training_run_id=run_info.get('training_run_id', ''), + base_model=run_info.get('base_model', ''), + model_owner=run_info.get('model_owner', ''), + is_lora=run_info.get('is_lora', False), + lora_rank=run_info.get('lora_rank'), + ) + + def parse_twinkle_path(self, twinkle_path: str) -> Optional[ParsedCheckpointTwinklePath]: + return self.parse_path(twinkle_path) diff --git a/src/twinkle/server/gateway/twinkle_gateway_handlers.py b/src/twinkle/server/gateway/twinkle_gateway_handlers.py index 471878d9..a13484f6 100644 --- a/src/twinkle/server/gateway/twinkle_gateway_handlers.py +++ b/src/twinkle/server/gateway/twinkle_gateway_handlers.py @@ -10,7 +10,6 @@ from typing import TYPE_CHECKING, Callable if TYPE_CHECKING: - from twinkle.server.utils.state.server_state import ServerStateProxy from .server import GatewayServer from twinkle.server.common.io_utils import create_checkpoint_manager, create_training_run_manager, validate_user_path diff --git a/src/twinkle/server/processor/app.py b/src/twinkle/server/processor/app.py index 68cddc20..fb55b453 100644 --- a/src/twinkle/server/processor/app.py +++ b/src/twinkle/server/processor/app.py @@ -108,7 +108,7 @@ def handle_processor_count(self, token: str, add: bool): if cur_count <= 0: self.state.pop_config(user_key) - @app.post('/create') + @app.post('/twinkle/create') def create(self, request: Request, body: ProcessorCreateRequest): processor_type_name = body.processor_type class_type = body.class_type @@ -142,7 +142,7 @@ def create(self, request: Request, body: ProcessorCreateRequest): self.resource_records[processor_id] = 0 return {'processor_id': 'pid:' + processor_id} - @app.post('/heartbeat') + @app.post('/twinkle/heartbeat') def heartbeat(self, body: ProcessorHeartbeatRequest): processor_ids = body.processor_id.split(',') for _id in processor_ids: @@ -150,7 +150,7 @@ def heartbeat(self, body: ProcessorHeartbeatRequest): self.resource_records[_id] = 0 return {'status': 'ok'} - @app.post('/call') + @app.post('/twinkle/call') def call(self, body: ProcessorCallRequest): processor_id = body.processor_id function_name = body.function diff --git a/src/twinkle_client/dataloader/dataloader.py b/src/twinkle_client/dataloader/dataloader.py index 3cd2b564..8178163b 100644 --- a/src/twinkle_client/dataloader/dataloader.py +++ b/src/twinkle_client/dataloader/dataloader.py @@ -19,10 +19,10 @@ class DataLoader(object): def __init__(self, dataset: Union[Dataset, Callable], **kwargs): from twinkle_client.http import get_base_url - self.server_url = get_base_url() + self.server_url = f'{get_base_url()}/processors/twinkle' response = http_post( - url=f'{self.server_url}/processors/create', + url=f'{self.server_url}/create', json_data={ 'processor_type': 'dataloader', 'class_type': 'DataLoader', @@ -42,7 +42,7 @@ def __del__(self): def __len__(self): response = http_post( - url=f'{self.server_url}/processors/call', + url=f'{self.server_url}/call', json_data={ 'processor_id': self.processor_id, 'function': '__len__', @@ -55,7 +55,7 @@ def __len__(self): def set_processor(self, processor_cls: Union[Type[InputProcessor], str, InputProcessor, Callable], **kwargs): response = http_post( - url=f'{self.server_url}/processors/call', + url=f'{self.server_url}/call', json_data={ 'processor_id': self.processor_id, 'function': 'set_processor', @@ -69,7 +69,7 @@ def set_processor(self, processor_cls: Union[Type[InputProcessor], str, InputPro def __iter__(self): response = http_post( - url=f'{self.server_url}/processors/call', + url=f'{self.server_url}/call', json_data={ 'processor_id': self.processor_id, 'function': '__iter__', @@ -81,7 +81,7 @@ def __iter__(self): def __next__(self): response = http_post( - url=f'{self.server_url}/processors/call', + url=f'{self.server_url}/call', json_data={ 'processor_id': self.processor_id, 'function': '__next__', diff --git a/src/twinkle_client/dataset/base.py b/src/twinkle_client/dataset/base.py index 3d5b5062..5effa069 100644 --- a/src/twinkle_client/dataset/base.py +++ b/src/twinkle_client/dataset/base.py @@ -22,10 +22,10 @@ class Dataset(object): def __init__(self, dataset_meta: DatasetMeta, **kwargs): from twinkle_client.http import get_base_url - self.server_url = get_base_url() + self.server_url = f'{get_base_url()}/processors/twinkle' response = http_post( - url=f'{self.server_url}/processors/create', + url=f'{self.server_url}/create', json_data={ 'processor_type': 'dataset', 'class_type': 'Dataset', @@ -45,7 +45,7 @@ def __del__(self): def set_template(self, template_func: Union[Template, Type[Template], str], **kwargs): response = http_post( - url=f'{self.server_url}/processors/call', + url=f'{self.server_url}/call', json_data={ 'processor_id': self.processor_id, 'function': 'set_template', @@ -59,7 +59,7 @@ def set_template(self, template_func: Union[Template, Type[Template], str], **kw def encode(self, add_generation_prompt: bool = False, **kwargs): response = http_post( - url=f'{self.server_url}/processors/call', + url=f'{self.server_url}/call', json_data={ 'processor_id': self.processor_id, 'function': 'encode', @@ -73,7 +73,7 @@ def encode(self, add_generation_prompt: bool = False, **kwargs): def check(self, **kwargs): response = http_post( - url=f'{self.server_url}/processors/call', + url=f'{self.server_url}/call', json_data={ 'processor_id': self.processor_id, 'function': 'check', @@ -87,7 +87,7 @@ def check(self, **kwargs): def map(self, preprocess_func: Union[Preprocessor, Callable, str, Type[Preprocessor]], dataset_meta: DatasetMeta = None, init_args: Dict[str, Any] = None, **kwargs): response = http_post( - url=f'{self.server_url}/processors/call', + url=f'{self.server_url}/call', json_data={ 'processor_id': self.processor_id, 'function': 'map', @@ -101,7 +101,7 @@ def map(self, preprocess_func: Union[Preprocessor, Callable, str, Type[Preproces def filter(self, filter_func: Union[Callable, str, Type[DataFilter], DataFilter], dataset_meta: DatasetMeta = None, init_args: Dict[str, Any] = None, **kwargs): response = http_post( - url=f'{self.server_url}/processors/call', + url=f'{self.server_url}/call', json_data={ 'processor_id': self.processor_id, 'function': 'filter', @@ -115,7 +115,7 @@ def filter(self, filter_func: Union[Callable, str, Type[DataFilter], DataFilter] def add_dataset(self, dataset_meta: DatasetMeta, **kwargs): response = http_post( - url=f'{self.server_url}/processors/call', + url=f'{self.server_url}/call', json_data={ 'processor_id': self.processor_id, 'function': 'add_dataset', @@ -129,7 +129,7 @@ def add_dataset(self, dataset_meta: DatasetMeta, **kwargs): def mix_dataset(self, interleave = True): response = http_post( - url=f'{self.server_url}/processors/call', + url=f'{self.server_url}/call', json_data={ 'processor_id': self.processor_id, 'function': 'mix_dataset', @@ -142,7 +142,7 @@ def mix_dataset(self, interleave = True): def __getitem__(self, idx): response = http_post( - url=f'{self.server_url}/processors/call', + url=f'{self.server_url}/call', json_data={ 'processor_id': self.processor_id, 'function': '__getitem__', @@ -155,7 +155,7 @@ def __getitem__(self, idx): def __len__(self): response = http_post( - url=f'{self.server_url}/processors/call', + url=f'{self.server_url}/call', json_data={ 'processor_id': self.processor_id, 'function': '__len__', diff --git a/src/twinkle_client/dataset/iterable_dataset.py b/src/twinkle_client/dataset/iterable_dataset.py index 347d1012..3bc3fe6c 100644 --- a/src/twinkle_client/dataset/iterable_dataset.py +++ b/src/twinkle_client/dataset/iterable_dataset.py @@ -19,10 +19,10 @@ class IterableDataset(IterableDataset): def __init__(self, dataset_meta: DatasetMeta, **kwargs): from twinkle_client.http import get_base_url - self.server_url = get_base_url() + self.server_url = f'{get_base_url()}/processors/twinkle' response = http_post( - url=f'{self.server_url}/processors/create', + url=f'{self.server_url}/create', json_data={ 'processor_type': 'dataset', 'class_type': 'IterableDataset', @@ -42,7 +42,7 @@ def __del__(self): def add_dataset(self, dataset_meta: DatasetMeta, **kwargs): response = http_post( - url=f'{self.server_url}/processors/call', + url=f'{self.server_url}/call', json_data={ 'processor_id': self.processor_id, 'function': 'add_dataset', @@ -56,7 +56,7 @@ def add_dataset(self, dataset_meta: DatasetMeta, **kwargs): def __len__(self): response = http_post( - url=f'{self.server_url}/processors/call', + url=f'{self.server_url}/call', json_data={ 'processor_id': self.processor_id, 'function': '__len__', @@ -69,7 +69,7 @@ def __len__(self): def __getitem__(self, idx): response = http_post( - url=f'{self.server_url}/processors/call', + url=f'{self.server_url}/call', json_data={ 'processor_id': self.processor_id, 'function': '__getitem__', @@ -82,7 +82,7 @@ def __getitem__(self, idx): def __iter__(self): response = http_post( - url=f'{self.server_url}/processors/call', + url=f'{self.server_url}/call', json_data={ 'processor_id': self.processor_id, 'function': '__iter__', @@ -94,7 +94,7 @@ def __iter__(self): def __next__(self): response = http_post( - url=f'{self.server_url}/processors/call', + url=f'{self.server_url}/call', json_data={ 'processor_id': self.processor_id, 'function': '__next__', diff --git a/src/twinkle_client/dataset/iterable_packing_dataset.py b/src/twinkle_client/dataset/iterable_packing_dataset.py index ce2d918d..9e15b52a 100644 --- a/src/twinkle_client/dataset/iterable_packing_dataset.py +++ b/src/twinkle_client/dataset/iterable_packing_dataset.py @@ -21,10 +21,10 @@ class IterablePackingDataset(IterableDataset): def __init__(self, dataset_meta: DatasetMeta, packing_interval: int = 128, packing_num_proc: int = 1, cyclic: bool = False, **kwargs): from twinkle_client.http import get_base_url - self.server_url = get_base_url() + self.server_url = f'{get_base_url()}/processors/twinkle' response = http_post( - url=f'{self.server_url}/processors/create', + url=f'{self.server_url}/create', json_data={ 'processor_type': 'dataset', 'class_type': 'IterablePackingDataset', @@ -44,7 +44,7 @@ def __del__(self): def set_template(self, template_cls: Union[Type[Template], str, Template], **kwargs): response = http_post( - url=f'{self.server_url}/processors/call', + url=f'{self.server_url}/call', json_data={ 'processor_id': self.processor_id, 'function': 'set_template', @@ -58,7 +58,7 @@ def set_template(self, template_cls: Union[Type[Template], str, Template], **kwa def pack_dataset(self): response = http_post( - url=f'{self.server_url}/processors/call', + url=f'{self.server_url}/call', json_data={ 'processor_id': self.processor_id, 'function': 'pack_dataset', @@ -71,7 +71,7 @@ def pack_dataset(self): def __iter__(self): response = http_post( - url=f'{self.server_url}/processors/call', + url=f'{self.server_url}/call', json_data={ 'processor_id': self.processor_id, 'function': '__iter__', @@ -83,7 +83,7 @@ def __iter__(self): def __next__(self): response = http_post( - url=f'{self.server_url}/processors/call', + url=f'{self.server_url}/call', json_data={ 'processor_id': self.processor_id, 'function': '__next__', diff --git a/src/twinkle_client/dataset/lazy_dataset.py b/src/twinkle_client/dataset/lazy_dataset.py index ce8178b1..106fd3a9 100644 --- a/src/twinkle_client/dataset/lazy_dataset.py +++ b/src/twinkle_client/dataset/lazy_dataset.py @@ -19,10 +19,10 @@ class LazyDataset(Dataset): def __init__(self, dataset_meta: DatasetMeta, **kwargs): from twinkle_client.http import get_base_url - self.server_url = get_base_url() + self.server_url = f'{get_base_url()}/processors/twinkle' response = http_post( - url=f'{self.server_url}/processors/create', + url=f'{self.server_url}/create', json_data={ 'processor_type': 'dataset', 'class_type': 'LazyDataset', @@ -42,7 +42,7 @@ def __del__(self): def encode(self, **kwargs): response = http_post( - url=f'{self.server_url}/processors/call', + url=f'{self.server_url}/call', json_data={ 'processor_id': self.processor_id, 'function': 'encode', @@ -56,7 +56,7 @@ def encode(self, **kwargs): def check(self, **kwargs): response = http_post( - url=f'{self.server_url}/processors/call', + url=f'{self.server_url}/call', json_data={ 'processor_id': self.processor_id, 'function': 'check', @@ -70,7 +70,7 @@ def check(self, **kwargs): def __getitem__(self, idx): response = http_post( - url=f'{self.server_url}/processors/call', + url=f'{self.server_url}/call', json_data={ 'processor_id': self.processor_id, 'function': '__getitem__', @@ -83,7 +83,7 @@ def __getitem__(self, idx): def __len__(self): response = http_post( - url=f'{self.server_url}/processors/call', + url=f'{self.server_url}/call', json_data={ 'processor_id': self.processor_id, 'function': '__len__', diff --git a/src/twinkle_client/dataset/packing_dataset.py b/src/twinkle_client/dataset/packing_dataset.py index 0d91546f..37cb36ca 100644 --- a/src/twinkle_client/dataset/packing_dataset.py +++ b/src/twinkle_client/dataset/packing_dataset.py @@ -19,10 +19,10 @@ class PackingDataset(Dataset): def __init__(self, dataset_meta: DatasetMeta, packing_num_proc: int = 1, **kwargs): from twinkle_client.http import get_base_url - self.server_url = get_base_url() + self.server_url = f'{get_base_url()}/processors/twinkle' response = http_post( - url=f'{self.server_url}/processors/create', + url=f'{self.server_url}/create', json_data={ 'processor_type': 'dataset', 'class_type': 'PackingDataset', @@ -42,7 +42,7 @@ def __del__(self): def pack_dataset(self): response = http_post( - url=f'{self.server_url}/processors/call', + url=f'{self.server_url}/call', json_data={ 'processor_id': self.processor_id, 'function': 'pack_dataset', @@ -55,7 +55,7 @@ def pack_dataset(self): def __getitem__(self, index): response = http_post( - url=f'{self.server_url}/processors/call', + url=f'{self.server_url}/call', json_data={ 'processor_id': self.processor_id, 'function': '__getitem__', @@ -68,7 +68,7 @@ def __getitem__(self, index): def __len__(self): response = http_post( - url=f'{self.server_url}/processors/call', + url=f'{self.server_url}/call', json_data={ 'processor_id': self.processor_id, 'function': '__len__', diff --git a/src/twinkle_client/http/heartbeat.py b/src/twinkle_client/http/heartbeat.py index 4a42f75a..c348a4f6 100644 --- a/src/twinkle_client/http/heartbeat.py +++ b/src/twinkle_client/http/heartbeat.py @@ -4,7 +4,7 @@ from typing import Callable, Dict, Optional, Set from .http_utils import http_post -from .utils import TWINKLE_SERVER_URL +from .utils import get_base_url class HeartbeatManager: @@ -33,7 +33,6 @@ def __init__(self): return self._initialized = True - self.server_url = TWINKLE_SERVER_URL # Processor heartbeat management self.processor_ids: Set[str] = set() @@ -52,7 +51,7 @@ def __init__(self): def processor_heartbeat_func(self, processor_id_list: str): response = http_post( - url=f'{self.server_url}/processors/heartbeat', json_data={'processor_id': processor_id_list}) + url=f'{get_base_url()}/processors/twinkle/heartbeat', json_data={'processor_id': processor_id_list}) response.raise_for_status() def register_processor(self, processor_id: str): diff --git a/src/twinkle_client/http/http_utils.py b/src/twinkle_client/http/http_utils.py index 6d86c1c7..74f3c3bd 100644 --- a/src/twinkle_client/http/http_utils.py +++ b/src/twinkle_client/http/http_utils.py @@ -1,5 +1,4 @@ import requests -from numbers import Number from typing import Any, Callable, Dict, Mapping, Optional from .utils import get_api_key, get_base_url, get_request_id diff --git a/src/twinkle_client/manager.py b/src/twinkle_client/manager.py index 5419bf6f..874126f5 100644 --- a/src/twinkle_client/manager.py +++ b/src/twinkle_client/manager.py @@ -31,10 +31,10 @@ class TwinkleClient: base_url: Base URL of the Twinkle server (e.g., "http://localhost:8000"). api_key: API key for authentication. If not provided, uses TWINKLE_SERVER_TOKEN environment variable - route_prefix: API route prefix (default: "/server") + route_prefix: API route prefix (default: "/api/v1/twinkle") """ - def __init__(self, base_url: str = None, api_key: str = None, route_prefix: str | None = '/server'): + def __init__(self, base_url: str = None, api_key: str = None, route_prefix: str | None = '/api/v1/twinkle'): self.base_url = base_url self.api_key = api_key self.route_prefix = route_prefix.rstrip('/') if route_prefix else '' diff --git a/src/twinkle_client/model/multi_lora_transformers.py b/src/twinkle_client/model/multi_lora_transformers.py index f681c96b..04be1ade 100644 --- a/src/twinkle_client/model/multi_lora_transformers.py +++ b/src/twinkle_client/model/multi_lora_transformers.py @@ -30,7 +30,7 @@ def __init__(self, model_id: str, **kwargs): self.model_id = model_id if '://' in model_id: model_id = model_id.split('://')[1] - self.server_url = f'{self.server_url}/models/{model_id}' + self.server_url = f'{self.server_url}/models/{model_id}/twinkle' self.adapter_name = None response = http_post( url=f'{self.server_url}/create', diff --git a/src/twinkle_client/processor/base.py b/src/twinkle_client/processor/base.py index d59572a7..0dfa3aa6 100644 --- a/src/twinkle_client/processor/base.py +++ b/src/twinkle_client/processor/base.py @@ -19,10 +19,10 @@ class InputProcessor(object): def __init__(self, device_mesh: Optional[DeviceMesh] = None, padding_free: bool = False, framework: Literal['transformers', 'megatron'] = 'transformers', **kwargs): from twinkle_client.http import get_base_url - self.server_url = get_base_url() + self.server_url = f'{get_base_url()}/processors/twinkle' response = http_post( - url=f'{self.server_url}/processors/create', + url=f'{self.server_url}/create', json_data={ 'processor_type': 'processor', 'class_type': 'InputProcessor', @@ -42,7 +42,7 @@ def __del__(self): def __call__(self, inputs: Union[InputFeature, List[InputFeature]], **kwargs): response = http_post( - url=f'{self.server_url}/processors/call', + url=f'{self.server_url}/call', json_data={ 'processor_id': self.processor_id, 'function': '__call__', diff --git a/src/twinkle_client/sampler/vllm_sampler.py b/src/twinkle_client/sampler/vllm_sampler.py index 907881a4..93004779 100644 --- a/src/twinkle_client/sampler/vllm_sampler.py +++ b/src/twinkle_client/sampler/vllm_sampler.py @@ -30,7 +30,7 @@ def __init__(self, model_id: str, **kwargs): self.adapter_name = None if '://' in model_id: model_id = model_id.split('://')[1] - self.server_url = f'{self.server_url}/samplers/{model_id}' + self.server_url = f'{self.server_url}/samplers/{model_id}/twinkle' response = http_post( url=f'{self.server_url}/create', json_data=kwargs From eec089a9c01e4983f36dceaa244d109501402820 Mon Sep 17 00:00:00 2001 From: Yunnglin Date: Wed, 11 Mar 2026 19:31:45 +0800 Subject: [PATCH 08/24] update io --- client_tools/client_generator.py | 6 +- .../server/transformer/server_config.yaml | 23 +++++++ src/twinkle/server/common/__init__.py | 4 +- .../server/common/checkpoint_factory.py | 39 +++++++++++ src/twinkle/server/common/io_utils.py | 68 ------------------- ...inker_io_utils.py => tinker_checkpoint.py} | 4 +- ...nkle_io_utils.py => twinkle_checkpoint.py} | 34 +++++----- .../server/gateway/tinker_gateway_handlers.py | 2 +- .../gateway/twinkle_gateway_handlers.py | 3 +- .../server/model/backends/megatron_model.py | 2 +- .../model/backends/transformers_model.py | 2 +- src/twinkle/server/model/tinker_handlers.py | 2 +- src/twinkle/server/model/twinkle_handlers.py | 6 +- src/twinkle/server/sampler/tinker_handlers.py | 2 +- .../server/sampler/twinkle_handlers.py | 2 +- src/twinkle/server/types/__init__.py | 1 + src/twinkle/server/types/checkpoint.py | 23 +++++++ src/twinkle/server/utils/__init__.py | 4 +- .../utils/{io_utils.py => checkpoint_base.py} | 40 ++++------- src/twinkle_client/dataloader/dataloader.py | 2 +- src/twinkle_client/dataset/base.py | 2 +- .../dataset/iterable_dataset.py | 2 +- .../dataset/iterable_packing_dataset.py | 2 +- src/twinkle_client/dataset/lazy_dataset.py | 2 +- src/twinkle_client/dataset/packing_dataset.py | 2 +- src/twinkle_client/http/utils.py | 6 +- src/twinkle_client/manager.py | 8 +-- .../model/multi_lora_transformers.py | 2 +- src/twinkle_client/processor/base.py | 2 +- src/twinkle_client/sampler/vllm_sampler.py | 2 +- src/twinkle_client/utils/patch_tinker.py | 4 +- 31 files changed, 153 insertions(+), 150 deletions(-) create mode 100644 src/twinkle/server/common/checkpoint_factory.py delete mode 100644 src/twinkle/server/common/io_utils.py rename src/twinkle/server/common/{tinker_io_utils.py => tinker_checkpoint.py} (96%) rename src/twinkle/server/common/{twinkle_io_utils.py => twinkle_checkpoint.py} (79%) create mode 100644 src/twinkle/server/types/__init__.py create mode 100644 src/twinkle/server/types/checkpoint.py rename src/twinkle/server/utils/{io_utils.py => checkpoint_base.py} (96%) diff --git a/client_tools/client_generator.py b/client_tools/client_generator.py index 858f4872..f4224894 100644 --- a/client_tools/client_generator.py +++ b/client_tools/client_generator.py @@ -347,7 +347,7 @@ class {class_name}({inheritance}): def __init__({init_params}): from twinkle_client.http import get_base_url - self.server_url = f'{{get_base_url()}}/processors/twinkle' + self.server_url = f'{{get_base_url()}}/processor/twinkle' response = http_post( url=f'{{self.server_url}}/create', json_data={{ @@ -466,7 +466,7 @@ def __init__(self, model_id: str, **kwargs): self.model_id = model_id if '://' in model_id: model_id = model_id.split('://')[1] - self.server_url = f'{self.server_url}/models/{model_id}/twinkle' + self.server_url = f'{self.server_url}/model/{model_id}/twinkle' self.adapter_name = None response = http_post( url=f'{self.server_url}/create', @@ -743,7 +743,7 @@ def __init__(self, model_id: str, **kwargs): self.adapter_name = None if '://' in model_id: model_id = model_id.split('://')[1] - self.server_url = f'{self.server_url}/samplers/{model_id}/twinkle' + self.server_url = f'{self.server_url}/sampler/{model_id}/twinkle' response = http_post( url=f'{self.server_url}/create', json_data=kwargs diff --git a/cookbook/client/server/transformer/server_config.yaml b/cookbook/client/server/transformer/server_config.yaml index e3cbfac3..c5c584ef 100644 --- a/cookbook/client/server/transformer/server_config.yaml +++ b/cookbook/client/server/transformer/server_config.yaml @@ -100,3 +100,26 @@ applications: # runtime_env: # env_vars: # TWINKLE_TRUST_REMOTE_CODE: "0" + + # 4. Processor Service - Runs inference / sampling using vLLM engine + - name: processor + route_prefix: /api/v1/processor + import_path: processor + args: + nproc_per_node: 2 # 每节点处理器 worker 数 + ncpu_proc_per_node: 2 # 每节点 CPU 进程数 + device_group: + name: model + ranks: 2 + device_type: CPU + device_mesh: + device_type: CPU + dp_size: 2 # 数据并行大小 + deployments: + - name: ProcessorManagement + autoscaling_config: + min_replicas: 1 + max_replicas: 1 + target_ongoing_requests: 128 + ray_actor_options: + num_cpus: 0.1 diff --git a/src/twinkle/server/common/__init__.py b/src/twinkle/server/common/__init__.py index 495ae39b..bb00e2bd 100644 --- a/src/twinkle/server/common/__init__.py +++ b/src/twinkle/server/common/__init__.py @@ -1,6 +1,6 @@ # Copyright (c) ModelScope Contributors. All rights reserved. +from .checkpoint_factory import create_checkpoint_manager, create_training_run_manager from .datum import datum_to_input_feature, extract_rl_feature, input_feature_to_datum -from .io_utils import create_checkpoint_manager, create_training_run_manager, validate_ownership, validate_user_path from .router import StickyLoraRequestRouter from .serialize import deserialize_object, serialize_object @@ -10,8 +10,6 @@ 'input_feature_to_datum', 'create_checkpoint_manager', 'create_training_run_manager', - 'validate_user_path', - 'validate_ownership', 'StickyLoraRequestRouter', 'deserialize_object', 'serialize_object', diff --git a/src/twinkle/server/common/checkpoint_factory.py b/src/twinkle/server/common/checkpoint_factory.py new file mode 100644 index 00000000..cbb2f2c6 --- /dev/null +++ b/src/twinkle/server/common/checkpoint_factory.py @@ -0,0 +1,39 @@ +# Copyright (c) ModelScope Contributors. All rights reserved. +""" +Factory functions for creating checkpoint and training-run manager instances. + +Use these functions as the entry point rather than instantiating managers directly: + + from twinkle.server.common.checkpoint_factory import ( + create_checkpoint_manager, + create_training_run_manager, + ) +""" +from twinkle.server.common.tinker_checkpoint import TinkerCheckpointManager, TinkerTrainingRunManager +from twinkle.server.common.twinkle_checkpoint import TwinkleCheckpointManager, TwinkleTrainingRunManager + + +def create_training_run_manager(token: str, client_type: str = 'twinkle'): + """Create a TrainingRunManager for the given token. + + Args: + token: User authentication token. + client_type: 'tinker' or 'twinkle' (default 'twinkle'). + """ + if client_type == 'tinker': + return TinkerTrainingRunManager(token) + return TwinkleTrainingRunManager(token) + + +def create_checkpoint_manager(token: str, client_type: str = 'twinkle'): + """Create a CheckpointManager for the given token. + + Args: + token: User authentication token. + client_type: 'tinker' or 'twinkle' (default 'twinkle'). + """ + if client_type == 'tinker': + run_mgr = TinkerTrainingRunManager(token) + return TinkerCheckpointManager(token, run_mgr) + run_mgr = TwinkleTrainingRunManager(token) + return TwinkleCheckpointManager(token, run_mgr) diff --git a/src/twinkle/server/common/io_utils.py b/src/twinkle/server/common/io_utils.py deleted file mode 100644 index 9c389de1..00000000 --- a/src/twinkle/server/common/io_utils.py +++ /dev/null @@ -1,68 +0,0 @@ -# Copyright (c) ModelScope Contributors. All rights reserved. -""" -Unified IO utilities for managing training runs and checkpoints. - -Manager implementations live in dedicated modules: - - ``tinker_io_utils`` : Tinker-specific managers (use tinker.types) - - ``twinkle_io_utils`` : Twinkle-specific managers (use twinkle_client.types.training) - -This module re-exports everything and provides the factory functions -``create_training_run_manager`` and ``create_checkpoint_manager``. -""" -from twinkle.server.common.tinker_io_utils import TinkerCheckpointManager, TinkerTrainingRunManager -from twinkle.server.common.twinkle_io_utils import TwinkleCheckpointManager, TwinkleTrainingRunManager -from twinkle.server.utils.io_utils import ResolvedLoadPath, validate_ownership, validate_user_path -# Re-export twinkle-native pydantic models from twinkle_client.types -from twinkle_client.types.training import Checkpoint as TwinkleCheckpoint -from twinkle_client.types.training import (CheckpointsListResponse, CreateModelRequest, Cursor, LoraConfig, - ParsedCheckpointTwinklePath) -from twinkle_client.types.training import TrainingRun as TwinkleTrainingRun -from twinkle_client.types.training import TrainingRunsResponse, WeightsInfoResponse - -__all__ = [ - 'create_checkpoint_manager', - 'create_training_run_manager', - 'validate_user_path', - 'validate_ownership', - 'ResolvedLoadPath', - 'Cursor', - 'TinkerTrainingRunManager', - 'TinkerCheckpointManager', - 'TwinkleTrainingRunManager', - 'TwinkleCheckpointManager', - # Twinkle-native models (re-exported for convenience) - 'TwinkleCheckpoint', - 'TwinkleTrainingRun', - 'TrainingRunsResponse', - 'CheckpointsListResponse', - 'WeightsInfoResponse', - 'LoraConfig', - 'CreateModelRequest', - 'ParsedCheckpointTwinklePath', -] - - -def create_training_run_manager(token: str, client_type: str = 'twinkle'): - """Create a TrainingRunManager for the given token. - - Args: - token: User authentication token. - client_type: 'tinker' or 'twinkle' (default 'twinkle'). - """ - if client_type == 'tinker': - return TinkerTrainingRunManager(token) - return TwinkleTrainingRunManager(token) - - -def create_checkpoint_manager(token: str, client_type: str = 'twinkle'): - """Create a CheckpointManager for the given token. - - Args: - token: User authentication token. - client_type: 'tinker' or 'twinkle' (default 'twinkle'). - """ - if client_type == 'tinker': - run_mgr = TinkerTrainingRunManager(token) - return TinkerCheckpointManager(token, run_mgr) - run_mgr = TwinkleTrainingRunManager(token) - return TwinkleCheckpointManager(token, run_mgr) diff --git a/src/twinkle/server/common/tinker_io_utils.py b/src/twinkle/server/common/tinker_checkpoint.py similarity index 96% rename from src/twinkle/server/common/tinker_io_utils.py rename to src/twinkle/server/common/tinker_checkpoint.py index 90d49482..fa7e5a11 100644 --- a/src/twinkle/server/common/tinker_io_utils.py +++ b/src/twinkle/server/common/tinker_checkpoint.py @@ -1,6 +1,6 @@ # Copyright (c) ModelScope Contributors. All rights reserved. """ -Tinker-specific IO managers for training runs and checkpoints. +Tinker-specific checkpoint and training-run managers. Uses ``tinker.types`` models for all serialization and response construction. """ @@ -8,7 +8,7 @@ from tinker import types as tinker_types from typing import Any, Dict, List, Optional -from twinkle.server.utils.io_utils import TRAIN_RUN_INFO_FILENAME, BaseCheckpointManager, BaseTrainingRunManager +from twinkle.server.utils.checkpoint_base import TRAIN_RUN_INFO_FILENAME, BaseCheckpointManager, BaseTrainingRunManager class TinkerTrainingRunManager(BaseTrainingRunManager): diff --git a/src/twinkle/server/common/twinkle_io_utils.py b/src/twinkle/server/common/twinkle_checkpoint.py similarity index 79% rename from src/twinkle/server/common/twinkle_io_utils.py rename to src/twinkle/server/common/twinkle_checkpoint.py index 7661f1dc..4b77d581 100644 --- a/src/twinkle/server/common/twinkle_io_utils.py +++ b/src/twinkle/server/common/twinkle_checkpoint.py @@ -1,19 +1,17 @@ # Copyright (c) ModelScope Contributors. All rights reserved. """ -Twinkle-specific IO managers for training runs and checkpoints. +Twinkle-specific checkpoint and training-run managers. Uses ``twinkle_client.types.training`` models for all serialization and response construction. """ from datetime import datetime from typing import Any, Dict, List, Optional -from twinkle.server.utils.io_utils import (TRAIN_RUN_INFO_FILENAME, BaseCheckpointManager, BaseTrainingRunManager, - validate_ownership) -from twinkle_client.types.training import Checkpoint as TwinkleCheckpoint -from twinkle_client.types.training import (CheckpointsListResponse, CreateModelRequest, Cursor, - ParsedCheckpointTwinklePath) -from twinkle_client.types.training import TrainingRun as TwinkleTrainingRun -from twinkle_client.types.training import TrainingRunsResponse, WeightsInfoResponse +from twinkle.server.utils.checkpoint_base import (TRAIN_RUN_INFO_FILENAME, BaseCheckpointManager, + BaseTrainingRunManager, validate_ownership) +from twinkle_client.types.training import (Checkpoint, CheckpointsListResponse, CreateModelRequest, Cursor, + ParsedCheckpointTwinklePath, TrainingRun, TrainingRunsResponse, + WeightsInfoResponse) class TwinkleTrainingRunManager(BaseTrainingRunManager): @@ -25,7 +23,7 @@ def train_run_info_filename(self) -> str: def _create_training_run(self, model_id: str, run_config: CreateModelRequest) -> Dict[str, Any]: lora_config = run_config.lora_config - train_run_data = TwinkleTrainingRun( + train_run_data = TrainingRun( training_run_id=model_id, base_model=run_config.base_model, model_owner=self.token, @@ -44,14 +42,14 @@ def _create_training_run(self, model_id: str, run_config: CreateModelRequest) -> new_data['train_attn'] = lora_config.train_attn return new_data - def _parse_training_run(self, data: Dict[str, Any]) -> TwinkleTrainingRun: - return TwinkleTrainingRun(**data) + def _parse_training_run(self, data: Dict[str, Any]) -> TrainingRun: + return TrainingRun(**data) - def _create_training_runs_response(self, runs: List[TwinkleTrainingRun], limit: int, offset: int, + def _create_training_runs_response(self, runs: List[TrainingRun], limit: int, offset: int, total: int) -> TrainingRunsResponse: return TrainingRunsResponse(training_runs=runs, cursor=Cursor(limit=limit, offset=offset, total_count=total)) - def get_with_permission(self, model_id: str) -> Optional[TwinkleTrainingRun]: + def get_with_permission(self, model_id: str) -> Optional[TrainingRun]: run = self.get(model_id) if run and validate_ownership(self.token, run.model_owner): return run @@ -82,7 +80,7 @@ def _create_checkpoint(self, train_mlp=None, train_attn=None, user_metadata=None) -> Dict[str, Any]: - checkpoint = TwinkleCheckpoint( + checkpoint = Checkpoint( checkpoint_id=checkpoint_id, checkpoint_type=checkpoint_type, time=datetime.now(), @@ -98,15 +96,15 @@ def _create_checkpoint(self, user_metadata=user_metadata) return checkpoint.model_dump(mode='json') - def _parse_checkpoint(self, data: Dict[str, Any]) -> TwinkleCheckpoint: + def _parse_checkpoint(self, data: Dict[str, Any]) -> Checkpoint: data = data.copy() if 'tinker_path' in data and 'twinkle_path' not in data: data['twinkle_path'] = data.pop('tinker_path') elif 'twinkle_path' not in data and 'path' in data: data['twinkle_path'] = data.pop('path') - return TwinkleCheckpoint(**data) + return Checkpoint(**data) - def get(self, model_id: str, checkpoint_id: str) -> Optional[TwinkleCheckpoint]: + def get(self, model_id: str, checkpoint_id: str) -> Optional[Checkpoint]: data = self._read_ckpt_info(model_id, checkpoint_id) if not data: return None @@ -116,7 +114,7 @@ def get(self, model_id: str, checkpoint_id: str) -> Optional[TwinkleCheckpoint]: data['twinkle_path'] = f"{self.path_prefix}{model_id}/{data['checkpoint_id']}" return self._parse_checkpoint(data) - def _create_checkpoints_response(self, checkpoints: List[TwinkleCheckpoint]) -> CheckpointsListResponse: + def _create_checkpoints_response(self, checkpoints: List[Checkpoint]) -> CheckpointsListResponse: return CheckpointsListResponse(checkpoints=checkpoints, cursor=None) def _create_parsed_path(self, path, training_run_id, checkpoint_type, checkpoint_id) -> ParsedCheckpointTwinklePath: diff --git a/src/twinkle/server/gateway/tinker_gateway_handlers.py b/src/twinkle/server/gateway/tinker_gateway_handlers.py index e3a3b503..b545ad63 100644 --- a/src/twinkle/server/gateway/tinker_gateway_handlers.py +++ b/src/twinkle/server/gateway/tinker_gateway_handlers.py @@ -17,7 +17,7 @@ from .server import GatewayServer from twinkle.hub import HubOperation -from twinkle.server.common.io_utils import create_checkpoint_manager, create_training_run_manager +from twinkle.server.common.checkpoint_factory import create_checkpoint_manager, create_training_run_manager from twinkle.server.utils.task_queue import QueueState from twinkle.server.utils.validation import get_token_from_request from twinkle.utils.logger import get_logger diff --git a/src/twinkle/server/gateway/twinkle_gateway_handlers.py b/src/twinkle/server/gateway/twinkle_gateway_handlers.py index a13484f6..8a22cc6f 100644 --- a/src/twinkle/server/gateway/twinkle_gateway_handlers.py +++ b/src/twinkle/server/gateway/twinkle_gateway_handlers.py @@ -12,7 +12,8 @@ if TYPE_CHECKING: from .server import GatewayServer -from twinkle.server.common.io_utils import create_checkpoint_manager, create_training_run_manager, validate_user_path +from twinkle.server.common.checkpoint_factory import create_checkpoint_manager, create_training_run_manager +from twinkle.server.utils.checkpoint_base import validate_user_path from twinkle.server.utils.validation import get_token_from_request from twinkle.utils.logger import get_logger from twinkle_client.types.server import DeleteCheckpointResponse, HealthResponse, WeightsInfoRequest diff --git a/src/twinkle/server/model/backends/megatron_model.py b/src/twinkle/server/model/backends/megatron_model.py index 29570c42..ae47cfc1 100644 --- a/src/twinkle/server/model/backends/megatron_model.py +++ b/src/twinkle/server/model/backends/megatron_model.py @@ -118,7 +118,7 @@ def load(self, checkpoint_dir: str, **kwargs): token = kwargs.pop('token', None) if not token: raise ValueError('Token is required for loading checkpoints') - from twinkle.server.common.io_utils import create_checkpoint_manager + from twinkle.server.common.checkpoint_factory import create_checkpoint_manager checkpoint_manager = create_checkpoint_manager(token, client_type='tinker') resolved = checkpoint_manager.resolve_load_path(checkpoint_dir) if resolved.is_twinkle_path: diff --git a/src/twinkle/server/model/backends/transformers_model.py b/src/twinkle/server/model/backends/transformers_model.py index ef4194bf..6cbc401b 100644 --- a/src/twinkle/server/model/backends/transformers_model.py +++ b/src/twinkle/server/model/backends/transformers_model.py @@ -118,7 +118,7 @@ def load(self, checkpoint_dir: str, **kwargs): token = kwargs.pop('token', None) if not token: raise ValueError('Token is required for loading checkpoints') - from twinkle.server.common.io_utils import create_checkpoint_manager + from twinkle.server.common.checkpoint_factory import create_checkpoint_manager checkpoint_manager = create_checkpoint_manager(token, client_type='tinker') resolved = checkpoint_manager.resolve_load_path(checkpoint_dir) if resolved.is_twinkle_path: diff --git a/src/twinkle/server/model/tinker_handlers.py b/src/twinkle/server/model/tinker_handlers.py index ca59b808..b89a5936 100644 --- a/src/twinkle/server/model/tinker_handlers.py +++ b/src/twinkle/server/model/tinker_handlers.py @@ -16,7 +16,7 @@ if TYPE_CHECKING: from .app import ModelManagement -from twinkle.server.common.io_utils import create_checkpoint_manager, create_training_run_manager +from twinkle.server.common.checkpoint_factory import create_checkpoint_manager, create_training_run_manager from twinkle.utils.logger import get_logger logger = get_logger() diff --git a/src/twinkle/server/model/twinkle_handlers.py b/src/twinkle/server/model/twinkle_handlers.py index c59928f4..43dc8c73 100644 --- a/src/twinkle/server/model/twinkle_handlers.py +++ b/src/twinkle/server/model/twinkle_handlers.py @@ -17,7 +17,7 @@ from .app import ModelManagement from twinkle.data_format import InputFeature, Trajectory -from twinkle.server.common.io_utils import create_checkpoint_manager, create_training_run_manager +from twinkle.server.common.checkpoint_factory import create_checkpoint_manager, create_training_run_manager from twinkle.server.common.serialize import deserialize_object from twinkle.utils.logger import get_logger from twinkle_client.types.model import (AdapterRequest, AddAdapterRequest, CalculateMetricRequest, CreateRequest, @@ -305,8 +305,8 @@ async def _task(): training_run_manager = create_training_run_manager(token, client_type='twinkle') self.register_adapter(adapter_name, token) self.model.add_adapter_to_model(adapter_name, config, **extra_kwargs) - from twinkle.server.common.io_utils import CreateModelRequest - from twinkle.server.common.io_utils import LoraConfig as IoLoraConfig + from twinkle_client.types.training import CreateModelRequest + from twinkle_client.types.training import LoraConfig as IoLoraConfig lora_config = None if isinstance(config, LoraConfig): lora_config = IoLoraConfig(rank=config.r, train_unembed=False, train_mlp=True, train_attn=True) diff --git a/src/twinkle/server/sampler/tinker_handlers.py b/src/twinkle/server/sampler/tinker_handlers.py index 21bee277..4cd574be 100644 --- a/src/twinkle/server/sampler/tinker_handlers.py +++ b/src/twinkle/server/sampler/tinker_handlers.py @@ -16,7 +16,7 @@ from .app import SamplerManagement from twinkle.data_format import SamplingParams -from twinkle.server.common.io_utils import create_checkpoint_manager +from twinkle.server.common.checkpoint_factory import create_checkpoint_manager from twinkle.utils.logger import get_logger logger = get_logger() diff --git a/src/twinkle/server/sampler/twinkle_handlers.py b/src/twinkle/server/sampler/twinkle_handlers.py index 860561e4..f2aea1cb 100644 --- a/src/twinkle/server/sampler/twinkle_handlers.py +++ b/src/twinkle/server/sampler/twinkle_handlers.py @@ -55,7 +55,7 @@ def sample(request: Request, body: SampleRequest, full_adapter_name = _get_twinkle_sampler_adapter_name(request, adapter_name) or '' if body.adapter_uri: - from twinkle.server.common.io_utils import create_checkpoint_manager + from twinkle.server.common.checkpoint_factory import create_checkpoint_manager from twinkle.server.utils.validation import get_token_from_request token = get_token_from_request(request) checkpoint_manager = create_checkpoint_manager(token, client_type='twinkle') diff --git a/src/twinkle/server/types/__init__.py b/src/twinkle/server/types/__init__.py new file mode 100644 index 00000000..85b3e739 --- /dev/null +++ b/src/twinkle/server/types/__init__.py @@ -0,0 +1 @@ +# Copyright (c) ModelScope Contributors. All rights reserved. diff --git a/src/twinkle/server/types/checkpoint.py b/src/twinkle/server/types/checkpoint.py new file mode 100644 index 00000000..fe89cb41 --- /dev/null +++ b/src/twinkle/server/types/checkpoint.py @@ -0,0 +1,23 @@ +# Copyright (c) ModelScope Contributors. All rights reserved. +""" +Server-specific Pydantic models for checkpoint path resolution. +""" +from pydantic import BaseModel +from typing import Optional + + +class ResolvedLoadPath(BaseModel): + """Result of resolving a load path. + + Attributes: + checkpoint_name: The name of the checkpoint (e.g., 'step-8' or hub model id) + checkpoint_dir: The directory containing the checkpoint, or None if loading from hub + is_twinkle_path: Whether the path was a twinkle:// path + training_run_id: The training run ID (only set for twinkle:// paths) + checkpoint_id: The checkpoint ID (only set for twinkle:// paths) + """ + checkpoint_name: str + checkpoint_dir: Optional[str] = None + is_twinkle_path: bool = False + training_run_id: Optional[str] = None + checkpoint_id: Optional[str] = None diff --git a/src/twinkle/server/utils/__init__.py b/src/twinkle/server/utils/__init__.py index dca07caf..f9855709 100644 --- a/src/twinkle/server/utils/__init__.py +++ b/src/twinkle/server/utils/__init__.py @@ -1,7 +1,7 @@ # Copyright (c) ModelScope Contributors. All rights reserved. from .adapter_manager import AdapterManagerMixin +from .checkpoint_base import (TRAIN_RUN_INFO_FILENAME, TWINKLE_DEFAULT_SAVE_DIR, BaseCheckpointManager, BaseFileManager, + BaseTrainingRunManager) from .device_utils import auto_fill_device_group_visible_devices, wrap_builder_with_device_group_env -from .io_utils import (TRAIN_RUN_INFO_FILENAME, TWINKLE_DEFAULT_SAVE_DIR, BaseCheckpointManager, BaseFileManager, - BaseTrainingRunManager) from .rate_limiter import RateLimiter from .task_queue import QueueState, TaskQueueConfig, TaskQueueMixin, TaskStatus diff --git a/src/twinkle/server/utils/io_utils.py b/src/twinkle/server/utils/checkpoint_base.py similarity index 96% rename from src/twinkle/server/utils/io_utils.py rename to src/twinkle/server/utils/checkpoint_base.py index 1a95b6c2..1cf49b97 100644 --- a/src/twinkle/server/utils/io_utils.py +++ b/src/twinkle/server/utils/checkpoint_base.py @@ -1,10 +1,16 @@ # Copyright (c) ModelScope Contributors. All rights reserved. """ -Base IO utilities for managing training runs and checkpoints. +Base infrastructure for checkpoint and training-run persistence. -This module provides abstract base classes that encapsulate common logic for -file-based storage of training run metadata and checkpoint information. -Both tinker and twinkle servers inherit from these classes. +Provides: +- Constants and path-hashing utilities +- Permission-check helpers (``validate_user_path``, ``validate_ownership``) +- Internal Pydantic base specs used as type constraints for the generic managers +- Abstract base managers: ``BaseTrainingRunManager``, ``BaseCheckpointManager`` + +Concrete implementations live in: + - ``twinkle.server.common.tinker_checkpoint`` + - ``twinkle.server.common.twinkle_checkpoint`` """ import hashlib import hmac @@ -20,6 +26,7 @@ from twinkle import get_logger from twinkle.hub import HubOperation +from twinkle.server.types.checkpoint import ResolvedLoadPath logger = get_logger() @@ -41,13 +48,7 @@ def _hash_token(token: str) -> str: return hmac.new(_TOKEN_SALT, token.encode('utf-8'), hashlib.sha256).hexdigest()[:16] -# ----- Common Pydantic Models ----- - - -class Cursor(BaseModel): - limit: int - offset: int - total_count: int +# ----- Internal Pydantic Base Specs ----- class BaseCheckpoint(BaseModel): @@ -104,23 +105,6 @@ class BaseParsedCheckpointPath(BaseModel): checkpoint_id: str -class ResolvedLoadPath(BaseModel): - """Result of resolving a load path. - - Attributes: - checkpoint_name: The name of the checkpoint (e.g., 'step-8' or hub model id) - checkpoint_dir: The directory containing the checkpoint, or None if loading from hub - is_twinkle_path: Whether the path was a twinkle:// path - training_run_id: The training run ID (only set for twinkle:// paths) - checkpoint_id: The checkpoint ID (only set for twinkle:// paths) - """ - checkpoint_name: str - checkpoint_dir: Optional[str] = None - is_twinkle_path: bool = False - training_run_id: Optional[str] = None - checkpoint_id: Optional[str] = None - - class BaseWeightsInfoResponse(BaseModel): """Base model for weights info response.""" training_run_id: str diff --git a/src/twinkle_client/dataloader/dataloader.py b/src/twinkle_client/dataloader/dataloader.py index 8178163b..477149c5 100644 --- a/src/twinkle_client/dataloader/dataloader.py +++ b/src/twinkle_client/dataloader/dataloader.py @@ -20,7 +20,7 @@ class DataLoader(object): def __init__(self, dataset: Union[Dataset, Callable], **kwargs): from twinkle_client.http import get_base_url - self.server_url = f'{get_base_url()}/processors/twinkle' + self.server_url = f'{get_base_url()}/processor/twinkle' response = http_post( url=f'{self.server_url}/create', json_data={ diff --git a/src/twinkle_client/dataset/base.py b/src/twinkle_client/dataset/base.py index 5effa069..d2b57ab4 100644 --- a/src/twinkle_client/dataset/base.py +++ b/src/twinkle_client/dataset/base.py @@ -23,7 +23,7 @@ class Dataset(object): def __init__(self, dataset_meta: DatasetMeta, **kwargs): from twinkle_client.http import get_base_url - self.server_url = f'{get_base_url()}/processors/twinkle' + self.server_url = f'{get_base_url()}/processor/twinkle' response = http_post( url=f'{self.server_url}/create', json_data={ diff --git a/src/twinkle_client/dataset/iterable_dataset.py b/src/twinkle_client/dataset/iterable_dataset.py index 3bc3fe6c..1aa0b006 100644 --- a/src/twinkle_client/dataset/iterable_dataset.py +++ b/src/twinkle_client/dataset/iterable_dataset.py @@ -20,7 +20,7 @@ class IterableDataset(IterableDataset): def __init__(self, dataset_meta: DatasetMeta, **kwargs): from twinkle_client.http import get_base_url - self.server_url = f'{get_base_url()}/processors/twinkle' + self.server_url = f'{get_base_url()}/processor/twinkle' response = http_post( url=f'{self.server_url}/create', json_data={ diff --git a/src/twinkle_client/dataset/iterable_packing_dataset.py b/src/twinkle_client/dataset/iterable_packing_dataset.py index 9e15b52a..d0dd1582 100644 --- a/src/twinkle_client/dataset/iterable_packing_dataset.py +++ b/src/twinkle_client/dataset/iterable_packing_dataset.py @@ -22,7 +22,7 @@ class IterablePackingDataset(IterableDataset): def __init__(self, dataset_meta: DatasetMeta, packing_interval: int = 128, packing_num_proc: int = 1, cyclic: bool = False, **kwargs): from twinkle_client.http import get_base_url - self.server_url = f'{get_base_url()}/processors/twinkle' + self.server_url = f'{get_base_url()}/processor/twinkle' response = http_post( url=f'{self.server_url}/create', json_data={ diff --git a/src/twinkle_client/dataset/lazy_dataset.py b/src/twinkle_client/dataset/lazy_dataset.py index 106fd3a9..52109a8e 100644 --- a/src/twinkle_client/dataset/lazy_dataset.py +++ b/src/twinkle_client/dataset/lazy_dataset.py @@ -20,7 +20,7 @@ class LazyDataset(Dataset): def __init__(self, dataset_meta: DatasetMeta, **kwargs): from twinkle_client.http import get_base_url - self.server_url = f'{get_base_url()}/processors/twinkle' + self.server_url = f'{get_base_url()}/processor/twinkle' response = http_post( url=f'{self.server_url}/create', json_data={ diff --git a/src/twinkle_client/dataset/packing_dataset.py b/src/twinkle_client/dataset/packing_dataset.py index 37cb36ca..175b8cd1 100644 --- a/src/twinkle_client/dataset/packing_dataset.py +++ b/src/twinkle_client/dataset/packing_dataset.py @@ -20,7 +20,7 @@ class PackingDataset(Dataset): def __init__(self, dataset_meta: DatasetMeta, packing_num_proc: int = 1, **kwargs): from twinkle_client.http import get_base_url - self.server_url = f'{get_base_url()}/processors/twinkle' + self.server_url = f'{get_base_url()}/processor/twinkle' response = http_post( url=f'{self.server_url}/create', json_data={ diff --git a/src/twinkle_client/http/utils.py b/src/twinkle_client/http/utils.py index ad49ffe1..4ec1230b 100644 --- a/src/twinkle_client/http/utils.py +++ b/src/twinkle_client/http/utils.py @@ -23,7 +23,11 @@ def set_base_url(url: str): def get_base_url() -> Optional[str]: """Get the current base URL from context or environment variable.""" - return _base_url_context.get() or TWINKLE_SERVER_URL + base_url = _base_url_context.get() or TWINKLE_SERVER_URL + # if not ends with '/api/v1' then append it + if not base_url.endswith('/api/v1'): + base_url += '/api/v1' + return base_url def clear_base_url(): diff --git a/src/twinkle_client/manager.py b/src/twinkle_client/manager.py index 874126f5..2c3c0e96 100644 --- a/src/twinkle_client/manager.py +++ b/src/twinkle_client/manager.py @@ -5,7 +5,7 @@ # Shared Pydantic models from twinkle_client.types.training import Checkpoint, Cursor, TrainingRun -from .http.http_utils import http_get, http_post +from .http import http_get, http_post, get_base_url class TwinkleClientError(Exception): @@ -31,11 +31,11 @@ class TwinkleClient: base_url: Base URL of the Twinkle server (e.g., "http://localhost:8000"). api_key: API key for authentication. If not provided, uses TWINKLE_SERVER_TOKEN environment variable - route_prefix: API route prefix (default: "/api/v1/twinkle") + route_prefix: API route prefix (default: "/twinkle") """ - def __init__(self, base_url: str = None, api_key: str = None, route_prefix: str | None = '/api/v1/twinkle'): - self.base_url = base_url + def __init__(self, base_url: str = None, api_key: str = None, route_prefix: str | None = '/twinkle'): + self.base_url = base_url or get_base_url() self.api_key = api_key self.route_prefix = route_prefix.rstrip('/') if route_prefix else '' diff --git a/src/twinkle_client/model/multi_lora_transformers.py b/src/twinkle_client/model/multi_lora_transformers.py index 04be1ade..bb91ad3a 100644 --- a/src/twinkle_client/model/multi_lora_transformers.py +++ b/src/twinkle_client/model/multi_lora_transformers.py @@ -30,7 +30,7 @@ def __init__(self, model_id: str, **kwargs): self.model_id = model_id if '://' in model_id: model_id = model_id.split('://')[1] - self.server_url = f'{self.server_url}/models/{model_id}/twinkle' + self.server_url = f'{self.server_url}/model/{model_id}/twinkle' self.adapter_name = None response = http_post( url=f'{self.server_url}/create', diff --git a/src/twinkle_client/processor/base.py b/src/twinkle_client/processor/base.py index 0dfa3aa6..63aaa9c4 100644 --- a/src/twinkle_client/processor/base.py +++ b/src/twinkle_client/processor/base.py @@ -20,7 +20,7 @@ class InputProcessor(object): def __init__(self, device_mesh: Optional[DeviceMesh] = None, padding_free: bool = False, framework: Literal['transformers', 'megatron'] = 'transformers', **kwargs): from twinkle_client.http import get_base_url - self.server_url = f'{get_base_url()}/processors/twinkle' + self.server_url = f'{get_base_url()}/processor/twinkle' response = http_post( url=f'{self.server_url}/create', json_data={ diff --git a/src/twinkle_client/sampler/vllm_sampler.py b/src/twinkle_client/sampler/vllm_sampler.py index 93004779..3c703569 100644 --- a/src/twinkle_client/sampler/vllm_sampler.py +++ b/src/twinkle_client/sampler/vllm_sampler.py @@ -30,7 +30,7 @@ def __init__(self, model_id: str, **kwargs): self.adapter_name = None if '://' in model_id: model_id = model_id.split('://')[1] - self.server_url = f'{self.server_url}/samplers/{model_id}/twinkle' + self.server_url = f'{self.server_url}/sampler/{model_id}/twinkle' response = http_post( url=f'{self.server_url}/create', json_data=kwargs diff --git a/src/twinkle_client/utils/patch_tinker.py b/src/twinkle_client/utils/patch_tinker.py index 826274ae..5f6d955e 100644 --- a/src/twinkle_client/utils/patch_tinker.py +++ b/src/twinkle_client/utils/patch_tinker.py @@ -53,10 +53,10 @@ def _patched_async_tinker_init( # Get api_key from environment if not provided if api_key is None: - api_key = os.environ.get('TINKER_API_KEY') + api_key = os.environ.get('TWINKLE_SERVER_TOKEN') if api_key is None: raise TinkerError( - 'The api_key client option must be set either by passing api_key to the client or by setting the TINKER_API_KEY environment variable' + 'The api_key client option must be set either by passing api_key to the client or by setting the TWINKLE_SERVER_TOKEN environment variable' ) # REMOVED: api_key 'tml-' prefix validation # Original code: From 3bd78aebbf4a091026c586edc8701140fbf4eba3 Mon Sep 17 00:00:00 2001 From: Yunnglin Date: Wed, 11 Mar 2026 20:24:54 +0800 Subject: [PATCH 09/24] update twinkle --- client_tools/client_generator.py | 204 ++++++------- src/twinkle/server/gateway/proxy.py | 2 +- .../gateway/twinkle_gateway_handlers.py | 24 +- src/twinkle_client/__init__.py | 63 ++-- src/twinkle_client/dataloader/dataloader.py | 9 +- src/twinkle_client/dataset/base.py | 9 +- .../dataset/iterable_dataset.py | 9 +- .../dataset/iterable_packing_dataset.py | 9 +- src/twinkle_client/dataset/lazy_dataset.py | 9 +- src/twinkle_client/dataset/packing_dataset.py | 9 +- src/twinkle_client/http/__init__.py | 6 +- src/twinkle_client/http/heartbeat.py | 2 +- src/twinkle_client/http/http_utils.py | 6 +- src/twinkle_client/http/utils.py | 16 + src/twinkle_client/manager.py | 283 +++++++++++------- .../model/multi_lora_transformers.py | 147 ++++----- src/twinkle_client/processor/base.py | 9 +- src/twinkle_client/sampler/vllm_sampler.py | 48 +-- src/twinkle_client/types/__init__.py | 36 +++ src/twinkle_client/types/model.py | 111 ++++++- src/twinkle_client/types/session.py | 24 ++ 21 files changed, 618 insertions(+), 417 deletions(-) create mode 100644 src/twinkle_client/types/session.py diff --git a/client_tools/client_generator.py b/client_tools/client_generator.py index f4224894..60276a4e 100644 --- a/client_tools/client_generator.py +++ b/client_tools/client_generator.py @@ -243,7 +243,7 @@ def build_imports() -> Tuple[List[str], str]: if typing_imports: lines.append(f"from typing import {', '.join(sorted(typing_imports))}") lines.extend([ - 'from twinkle_client.http import http_post, heartbeat_manager', + 'from twinkle_client.http import http_post', ]) lines.extend(sorted(twinkle_imports)) @@ -358,13 +358,6 @@ def __init__({init_params}): ) response.raise_for_status() self.processor_id = response.json()['processor_id'] - heartbeat_manager.register_processor(self.processor_id) - - def __del__(self): - try: - heartbeat_manager.unregister_processor(self.processor_id) - except: - pass ''' @@ -444,18 +437,36 @@ def generate_models(): client_module_path = src_client_path / 'model' client_module_path.mkdir(parents=True, exist_ok=True) - model_code = AUTO_GEN_WARNING + '''from typing import Any, Optional, Union, Type, Dict, Literal, List -import uuid -from twinkle_client.http import http_post, heartbeat_manager -from twinkle import DeviceMesh -from twinkle.data_format import InputFeature, Trajectory + model_code = AUTO_GEN_WARNING + '''from typing import Any, Dict, Optional +from twinkle_client.http import http_post +from twinkle_client.types.model import ( + BackwardResponse, + CalculateLossResponse, + CalculateMetricResponse, + ClipGradNormResponse, + ForwardBackwardResponse, + ForwardResponse, + GetStateDictResponse, + GetTrainConfigsResponse, + LoadResponse, + LrStepResponse, + SaveResponse, + SetLossResponse, + SetLrSchedulerResponse, + SetOptimizerResponse, + SetProcessorResponse, + SetTemplateResponse, + StepResponse, + UploadToHubResponse, + ZeroGradResponse, +) class MultiLoraTransformersModel: """Client wrapper for TwinkleModel that calls server HTTP endpoints. This client manages adapters and sends training/inference requests to the model server. - Each adapter has its own lifecycle managed through automatic heartbeats. + The server-side session (managed by TwinkleClient) keeps the model alive. """ def __init__(self, model_id: str, **kwargs): @@ -473,208 +484,193 @@ def __init__(self, model_id: str, **kwargs): ) response.raise_for_status() - def _send_adapter_heartbeat(self): - """Internal method to send adapter heartbeat.""" - response = http_post( - url=f'{self.server_url}/heartbeat', - json_data={'adapter_name': self.adapter_name} - ) - response.raise_for_status() - - def add_adapter_to_model(self, adapter_name: str, config: Dict[str, Any], **kwargs): - """Add a new adapter to the model and start automatic heartbeat.""" + def add_adapter_to_model(self, adapter_name: str, config: Dict[str, Any], **kwargs) -> None: + """Add a new adapter to the model.""" response = http_post( url=f'{self.server_url}/add_adapter_to_model', json_data={'adapter_name': adapter_name, 'config': config, **kwargs} ) response.raise_for_status() - - # Register adapter for automatic heartbeat after successful creation self.adapter_name = adapter_name - heartbeat_manager.register_adapter( - self.adapter_name, - self._send_adapter_heartbeat - ) - def __del__(self): - """Cleanup: unregister adapter from heartbeat manager.""" - try: - heartbeat_manager.unregister_adapter(self.adapter_name) - except: - pass - - def forward(self, inputs: Any, **kwargs): + def forward(self, inputs: Any, **kwargs) -> ForwardResponse: """Execute forward pass on the model.""" response = http_post( url=f'{self.server_url}/forward', json_data={'inputs': inputs, 'adapter_name': self.adapter_name, **kwargs} ) response.raise_for_status() - return response.json()['result'] + return ForwardResponse(**response.json()) - def forward_only(self, inputs: Any, **kwargs): + def forward_only(self, inputs: Any, **kwargs) -> ForwardResponse: """Execute forward pass without gradient computation.""" response = http_post( url=f'{self.server_url}/forward_only', json_data={'inputs': inputs, 'adapter_name': self.adapter_name, **kwargs} ) response.raise_for_status() - return response.json()['result'] + return ForwardResponse(**response.json()) - def calculate_loss(self, **kwargs): + def calculate_loss(self, **kwargs) -> CalculateLossResponse: """Calculate loss from model outputs.""" response = http_post( url=f'{self.server_url}/calculate_loss', json_data={'adapter_name': self.adapter_name, **kwargs} ) response.raise_for_status() - return response.json()['result'] + return CalculateLossResponse(**response.json()) - def get_train_configs(self, **kwargs): - """Get training configs""" + def get_train_configs(self, **kwargs) -> GetTrainConfigsResponse: + """Get training configs.""" response = http_post( url=f'{self.server_url}/get_train_configs', json_data={'adapter_name': self.adapter_name, **kwargs} ) response.raise_for_status() - return response.json()['result'] + return GetTrainConfigsResponse(**response.json()) - def backward(self, **kwargs): + def backward(self, **kwargs) -> BackwardResponse: """Execute backward pass.""" response = http_post( url=f'{self.server_url}/backward', json_data={'adapter_name': self.adapter_name, **kwargs} ) response.raise_for_status() - return response.json()['result'] + return BackwardResponse(**response.json()) - def forward_backward(self, inputs: Any, **kwargs): + def forward_backward(self, inputs: Any, **kwargs) -> ForwardBackwardResponse: """Execute combined forward and backward pass.""" response = http_post( url=f'{self.server_url}/forward_backward', json_data={'inputs': inputs, 'adapter_name': self.adapter_name, **kwargs} ) response.raise_for_status() - return response.json()['result'] + return ForwardBackwardResponse(**response.json()) - def step(self, **kwargs): + def step(self, **kwargs) -> StepResponse: """Execute optimizer step.""" response = http_post( url=f'{self.server_url}/step', json_data={'adapter_name': self.adapter_name, **kwargs} ) response.raise_for_status() - return response.json()['result'] + return StepResponse(**response.json()) - def zero_grad(self, **kwargs): + def zero_grad(self, **kwargs) -> ZeroGradResponse: """Zero out gradients.""" response = http_post( url=f'{self.server_url}/zero_grad', json_data={'adapter_name': self.adapter_name, **kwargs} ) response.raise_for_status() - return response.json()['result'] + return ZeroGradResponse(**response.json()) - def lr_step(self, **kwargs): + def lr_step(self, **kwargs) -> LrStepResponse: """Execute learning rate scheduler step.""" response = http_post( url=f'{self.server_url}/lr_step', json_data={'adapter_name': self.adapter_name, **kwargs} ) response.raise_for_status() - return response.json()['result'] + return LrStepResponse(**response.json()) - def set_loss(self, loss_cls: str, **kwargs): + def set_loss(self, loss_cls: str, **kwargs) -> SetLossResponse: """Set the loss function.""" response = http_post( url=f'{self.server_url}/set_loss', json_data={'loss_cls': loss_cls, 'adapter_name': self.adapter_name, **kwargs} ) response.raise_for_status() - return response.json()['result'] + return SetLossResponse(**response.json()) - def clip_grad_norm(self, max_grad_norm: float=1.0, norm_type=2, **kwargs): - """Set the loss function.""" + def clip_grad_norm(self, max_grad_norm: float = 1.0, norm_type: int = 2, **kwargs) -> ClipGradNormResponse: + """Clip gradient norm.""" response = http_post( url=f'{self.server_url}/clip_grad_norm', json_data={'max_grad_norm': max_grad_norm, 'norm_type': norm_type, 'adapter_name': self.adapter_name, **kwargs} ) response.raise_for_status() - return response.json()['result'] + return ClipGradNormResponse(**response.json()) - def set_optimizer(self, optimizer_cls: str, **kwargs): + def set_optimizer(self, optimizer_cls: str, **kwargs) -> SetOptimizerResponse: """Set the optimizer.""" response = http_post( url=f'{self.server_url}/set_optimizer', json_data={'optimizer_cls': optimizer_cls, 'adapter_name': self.adapter_name, **kwargs} ) response.raise_for_status() - return response.json()['result'] + return SetOptimizerResponse(**response.json()) - def set_lr_scheduler(self, scheduler_cls: str, **kwargs): + def set_lr_scheduler(self, scheduler_cls: str, **kwargs) -> SetLrSchedulerResponse: """Set the learning rate scheduler.""" response = http_post( url=f'{self.server_url}/set_lr_scheduler', json_data={'scheduler_cls': scheduler_cls, 'adapter_name': self.adapter_name, **kwargs} ) response.raise_for_status() - return response.json()['result'] + return SetLrSchedulerResponse(**response.json()) - def save(self, name: str, **kwargs): + def save(self, name: str, **kwargs) -> SaveResponse: """Save model checkpoint.""" response = http_post( url=f'{self.server_url}/save', json_data={'name': name, 'adapter_name': self.adapter_name, **kwargs} ) response.raise_for_status() - return response.json()['result'] + return SaveResponse(**response.json()) - def load(self, name: str, **kwargs): + def load(self, name: str, **kwargs) -> LoadResponse: """Load model checkpoint.""" response = http_post( url=f'{self.server_url}/load', json_data={'name': name, 'adapter_name': self.adapter_name, **kwargs} ) response.raise_for_status() - return response.json()['result'] + return LoadResponse(**response.json()) - def set_template(self, template_cls: str, **kwargs): + def set_template(self, template_cls: str, **kwargs) -> SetTemplateResponse: """Set the template for data processing.""" response = http_post( url=f'{self.server_url}/set_template', json_data={'template_cls': template_cls, 'adapter_name': self.adapter_name, 'model_id': self.model_id, **kwargs} ) response.raise_for_status() - return response.json()['result'] + return SetTemplateResponse(**response.json()) - def set_processor(self, processor_cls: str, **kwargs): + def set_processor(self, processor_cls: str, **kwargs) -> SetProcessorResponse: """Set the input processor.""" response = http_post( url=f'{self.server_url}/set_processor', json_data={'processor_cls': processor_cls, 'adapter_name': self.adapter_name, **kwargs} ) response.raise_for_status() - return response.json()['result'] + return SetProcessorResponse(**response.json()) - def calculate_metric(self, is_training: bool = True, **kwargs): + def calculate_metric(self, is_training: bool = True, **kwargs) -> CalculateMetricResponse: """Calculate metrics from model outputs.""" response = http_post( url=f'{self.server_url}/calculate_metric', json_data={'is_training': is_training, 'adapter_name': self.adapter_name, **kwargs} ) response.raise_for_status() - return response.json()['result'] + return CalculateMetricResponse(**response.json()) - def get_state_dict(self, **kwargs): + def get_state_dict(self, **kwargs) -> GetStateDictResponse: """Get model state dictionary.""" response = http_post( url=f'{self.server_url}/get_state_dict', json_data={'adapter_name': self.adapter_name, **kwargs} ) response.raise_for_status() - return response.json()['result'] + return GetStateDictResponse(**response.json()) - def upload_to_hub(self, checkpoint_dir: str, hub_model_id: str, hub_token: Optional[str] = None, async_upload: bool = True): + def upload_to_hub( + self, + checkpoint_dir: str, + hub_model_id: str, + hub_token: Optional[str] = None, + async_upload: bool = True, + ) -> UploadToHubResponse: """Upload model checkpoint to hub. Args: @@ -689,11 +685,11 @@ def upload_to_hub(self, checkpoint_dir: str, hub_model_id: str, hub_token: Optio 'checkpoint_dir': checkpoint_dir, 'hub_model_id': hub_model_id, 'hub_token': hub_token, - 'async_upload': async_upload + 'async_upload': async_upload, } ) response.raise_for_status() - return response.json() + return UploadToHubResponse(**response.json()) ''' # Write the model client file @@ -721,9 +717,10 @@ def generate_samplers(): client_module_path = src_client_path / 'sampler' client_module_path.mkdir(parents=True, exist_ok=True) - sampler_code = AUTO_GEN_WARNING + '''from typing import Any, Optional, List, Dict, Union -from twinkle_client.http import http_post, heartbeat_manager + sampler_code = AUTO_GEN_WARNING + '''from typing import Any, Dict, List, Optional, Union +from twinkle_client.http import http_post from twinkle.sampler.base import Sampler +from twinkle_client.types.sampler import AddAdapterResponse, SampleResponseModel, SetTemplateResponse from peft import PeftConfig from twinkle.data_format import Trajectory, InputFeature @@ -732,7 +729,7 @@ class vLLMSampler(Sampler): """Client wrapper for Sampler that calls server HTTP endpoints. This client manages sampling operations and adapter synchronization with the sampler server. - Each adapter has its own lifecycle managed through automatic heartbeats. + The server-side session (managed by TwinkleClient) keeps the sampler alive. """ def __init__(self, model_id: str, **kwargs): @@ -750,18 +747,8 @@ def __init__(self, model_id: str, **kwargs): ) response.raise_for_status() - def _send_adapter_heartbeat(self): - """Internal method to send adapter heartbeat.""" - if not self.adapter_name: - return - response = http_post( - url=f'{self.server_url}/heartbeat', - json_data={'adapter_name': self.adapter_name} - ) - response.raise_for_status() - - def add_adapter_to_sampler(self, adapter_name: str, config: PeftConfig, **kwargs): - """Add a new adapter to the sampler and start automatic heartbeat.""" + def add_adapter_to_sampler(self, adapter_name: str, config: PeftConfig, **kwargs) -> AddAdapterResponse: + """Add a new adapter to the sampler.""" if isinstance(config, PeftConfig): config = config.__dict__ response = http_post( @@ -769,23 +756,8 @@ def add_adapter_to_sampler(self, adapter_name: str, config: PeftConfig, **kwargs json_data={'adapter_name': adapter_name, 'config': config, **kwargs} ) response.raise_for_status() - - # Register adapter for automatic heartbeat after successful creation self.adapter_name = adapter_name - heartbeat_manager.register_adapter( - self.adapter_name, - self._send_adapter_heartbeat - ) - - return response.json() - - def __del__(self): - """Cleanup: unregister adapter from heartbeat manager.""" - try: - if self.adapter_name: - heartbeat_manager.unregister_adapter(self.adapter_name) - except: - pass + return AddAdapterResponse(**response.json()) def sample( self, @@ -794,7 +766,7 @@ def sample( adapter_name: str = '', adapter_uri: Optional[str] = None, num_samples: int = 1, - ) -> Dict[str, Any]: + ) -> SampleResponseModel: """Sample from the model. Args: @@ -805,7 +777,7 @@ def sample( num_samples: Number of completions to generate per prompt. Returns: - Dict with 'sequences' list, each containing tokens, logprobs, stop_reason. + SampleResponseModel with 'sequences' list, each containing tokens, logprobs, stop_reason. """ json_data = { 'inputs': inputs, @@ -821,16 +793,16 @@ def sample( json_data=json_data ) response.raise_for_status() - return response.json() + return SampleResponseModel(**response.json()) - def set_template(self, template_cls: str, adapter_name: str = '', **kwargs): + def set_template(self, template_cls: str, adapter_name: str = '', **kwargs) -> SetTemplateResponse: """Set the template for encoding trajectories.""" response = http_post( url=f'{self.server_url}/set_template', json_data={'template_cls': template_cls, 'adapter_name': adapter_name, **kwargs} ) response.raise_for_status() - return response.json() + return SetTemplateResponse(**response.json()) ''' # Write the sampler client file diff --git a/src/twinkle/server/gateway/proxy.py b/src/twinkle/server/gateway/proxy.py index 5517014e..e8346d6c 100644 --- a/src/twinkle/server/gateway/proxy.py +++ b/src/twinkle/server/gateway/proxy.py @@ -67,7 +67,7 @@ def _prepare_headers(self, request_headers) -> dict[str, str]: headers.pop('host', None) headers.pop('content-length', None) request_id = request_headers.get('X-Ray-Serve-Request-Id') - if request_id is not None: + if request_id is not None and not request_headers.get('serve_multiplexed_model_id'): headers['serve_multiplexed_model_id'] = request_id return headers diff --git a/src/twinkle/server/gateway/twinkle_gateway_handlers.py b/src/twinkle/server/gateway/twinkle_gateway_handlers.py index 8a22cc6f..3d6ca314 100644 --- a/src/twinkle/server/gateway/twinkle_gateway_handlers.py +++ b/src/twinkle/server/gateway/twinkle_gateway_handlers.py @@ -6,7 +6,7 @@ """ from __future__ import annotations -from fastapi import FastAPI, HTTPException, Request +from fastapi import Depends, FastAPI, HTTPException, Request from typing import TYPE_CHECKING, Callable if TYPE_CHECKING: @@ -17,6 +17,8 @@ from twinkle.server.utils.validation import get_token_from_request from twinkle.utils.logger import get_logger from twinkle_client.types.server import DeleteCheckpointResponse, HealthResponse, WeightsInfoRequest +from twinkle_client.types.session import (CreateSessionRequest, CreateSessionResponse, SessionHeartbeatRequest, + SessionHeartbeatResponse) from twinkle_client.types.training import (CheckpointsListResponse, TrainingRun, TrainingRunsResponse, WeightsInfoResponse) @@ -30,6 +32,26 @@ def _register_twinkle_routes(app: FastAPI, self_fn: Callable[[], GatewayServer]) async def healthz(request: Request) -> HealthResponse: return HealthResponse(status='ok') + @app.post('/twinkle/create_session', response_model=CreateSessionResponse) + async def create_session( + request: Request, + body: CreateSessionRequest, + self: GatewayServer = Depends(self_fn), + ) -> CreateSessionResponse: + session_id = self.state.create_session(body.model_dump()) + return CreateSessionResponse(session_id=session_id) + + @app.post('/twinkle/session_heartbeat', response_model=SessionHeartbeatResponse) + async def session_heartbeat( + request: Request, + body: SessionHeartbeatRequest, + self: GatewayServer = Depends(self_fn), + ) -> SessionHeartbeatResponse: + alive = self.state.touch_session(body.session_id) + if not alive: + raise HTTPException(status_code=404, detail='Unknown session') + return SessionHeartbeatResponse() + @app.get('/twinkle/training_runs', response_model=TrainingRunsResponse) async def get_training_runs(request: Request, limit: int = 20, offset: int = 0) -> TrainingRunsResponse: token = get_token_from_request(request) diff --git a/src/twinkle_client/__init__.py b/src/twinkle_client/__init__.py index 58c43a37..a87eba5d 100644 --- a/src/twinkle_client/__init__.py +++ b/src/twinkle_client/__init__.py @@ -1,10 +1,13 @@ # Copyright (c) ModelScope Contributors. All rights reserved. from __future__ import annotations +from typing import Optional + def init_tinker_client(**kwargs) -> None: """Initialize Tinker client with Twinkle-specific headers. - After calling this function, users can directly use: + After calling this function, users can directly use:: + from tinker import ServiceClient client = ServiceClient(base_url='...', api_key='...') @@ -13,39 +16,57 @@ def init_tinker_client(**kwargs) -> None: Args: **kwargs: Additional keyword arguments (currently unused, reserved for future) - Example: - >>> from twinkle import init_tinker_client + Example:: + + >>> from twinkle_client import init_tinker_client >>> init_tinker_client() >>> from tinker import ServiceClient >>> client = ServiceClient(base_url='http://localhost:8000', api_key='your_token') """ from twinkle.utils import requires - + requires('tinker') from twinkle_client.utils.patch_tinker import patch_tinker - # Apply patches to tinker library (includes header injection) patch_tinker() -def init_twinkle_client(base_url: str | None = None, api_key: str | None = None, **kwargs) -> TwinkleClient: +def init_twinkle_client( + base_url: Optional[str] = None, + api_key: Optional[str] = None, + session_heartbeat_interval: int = 30, + **kwargs, +) -> 'TwinkleClient': """ - Initialize a Twinkle client and setup context variables. + Initialize a Twinkle client. + + This function: + + * Resolves ``base_url`` and ``api_key`` (env-vars as fallbacks). + * Sets both values into the shared context so that all other client objects + (``MultiLoraTransformersModel``, ``vLLMSampler``, processor clients) created + afterwards automatically inherit the same server configuration. + * Creates a server-side session and stores the ``session_id`` in context so + every subsequent HTTP request carries it in ``X-Twinkle-Session-Id``. + * Starts a background thread that touches the session every + ``session_heartbeat_interval`` seconds. + + Args: + base_url: Twinkle server base URL. Falls back to ``TWINKLE_SERVER_URL``. + api_key: Authentication token. Falls back to ``TWINKLE_SERVER_TOKEN``. + session_heartbeat_interval: Seconds between session touch calls (default: 30). + **kwargs: Additional keyword arguments forwarded to :class:`TwinkleClient`. + + Returns: + An initialised :class:`~twinkle_client.manager.TwinkleClient` instance. """ - from .http.utils import get_api_key, get_base_url, set_api_key, set_base_url - from .manager import TwinkleClient, TwinkleClientError - - if base_url is not None: - set_base_url(base_url) - else: - base_url = get_base_url() - - if api_key is not None: - set_api_key(api_key) - else: - api_key = get_api_key() - - return TwinkleClient(base_url=base_url, api_key=api_key, **kwargs) + from .manager import TwinkleClient + return TwinkleClient( + base_url=base_url, + api_key=api_key, + session_heartbeat_interval=session_heartbeat_interval, + **kwargs, + ) __all__ = ['init_tinker_client', 'init_twinkle_client'] diff --git a/src/twinkle_client/dataloader/dataloader.py b/src/twinkle_client/dataloader/dataloader.py index 477149c5..0a067ddd 100644 --- a/src/twinkle_client/dataloader/dataloader.py +++ b/src/twinkle_client/dataloader/dataloader.py @@ -10,7 +10,7 @@ # ============================================================================ from typing import Callable, Type, Union -from twinkle_client.http import http_post, heartbeat_manager +from twinkle_client.http import http_post from twinkle.dataset import Dataset from twinkle.processor import InputProcessor @@ -31,13 +31,6 @@ def __init__(self, dataset: Union[Dataset, Callable], **kwargs): ) response.raise_for_status() self.processor_id = response.json()['processor_id'] - heartbeat_manager.register_processor(self.processor_id) - - def __del__(self): - try: - heartbeat_manager.unregister_processor(self.processor_id) - except: - pass def __len__(self): diff --git a/src/twinkle_client/dataset/base.py b/src/twinkle_client/dataset/base.py index d2b57ab4..0487f733 100644 --- a/src/twinkle_client/dataset/base.py +++ b/src/twinkle_client/dataset/base.py @@ -10,7 +10,7 @@ # ============================================================================ from typing import Any, Callable, Dict, Type, Union -from twinkle_client.http import http_post, heartbeat_manager +from twinkle_client.http import http_post from twinkle.dataset import Dataset from twinkle.dataset import DatasetMeta from twinkle.preprocessor import DataFilter @@ -34,13 +34,6 @@ def __init__(self, dataset_meta: DatasetMeta, **kwargs): ) response.raise_for_status() self.processor_id = response.json()['processor_id'] - heartbeat_manager.register_processor(self.processor_id) - - def __del__(self): - try: - heartbeat_manager.unregister_processor(self.processor_id) - except: - pass def set_template(self, template_func: Union[Template, Type[Template], str], **kwargs): diff --git a/src/twinkle_client/dataset/iterable_dataset.py b/src/twinkle_client/dataset/iterable_dataset.py index 1aa0b006..25c48919 100644 --- a/src/twinkle_client/dataset/iterable_dataset.py +++ b/src/twinkle_client/dataset/iterable_dataset.py @@ -9,7 +9,7 @@ # 2. Run: python client_tools/client_generator.py # ============================================================================ -from twinkle_client.http import http_post, heartbeat_manager +from twinkle_client.http import http_post from twinkle.dataset import Dataset from twinkle.dataset import DatasetMeta from torch.utils.data import IterableDataset @@ -31,13 +31,6 @@ def __init__(self, dataset_meta: DatasetMeta, **kwargs): ) response.raise_for_status() self.processor_id = response.json()['processor_id'] - heartbeat_manager.register_processor(self.processor_id) - - def __del__(self): - try: - heartbeat_manager.unregister_processor(self.processor_id) - except: - pass def add_dataset(self, dataset_meta: DatasetMeta, **kwargs): diff --git a/src/twinkle_client/dataset/iterable_packing_dataset.py b/src/twinkle_client/dataset/iterable_packing_dataset.py index d0dd1582..12a958a4 100644 --- a/src/twinkle_client/dataset/iterable_packing_dataset.py +++ b/src/twinkle_client/dataset/iterable_packing_dataset.py @@ -10,7 +10,7 @@ # ============================================================================ from typing import Type, Union -from twinkle_client.http import http_post, heartbeat_manager +from twinkle_client.http import http_post from twinkle.dataset import Dataset from twinkle.dataset import DatasetMeta from twinkle.template import Template @@ -33,13 +33,6 @@ def __init__(self, dataset_meta: DatasetMeta, packing_interval: int = 128, packi ) response.raise_for_status() self.processor_id = response.json()['processor_id'] - heartbeat_manager.register_processor(self.processor_id) - - def __del__(self): - try: - heartbeat_manager.unregister_processor(self.processor_id) - except: - pass def set_template(self, template_cls: Union[Type[Template], str, Template], **kwargs): diff --git a/src/twinkle_client/dataset/lazy_dataset.py b/src/twinkle_client/dataset/lazy_dataset.py index 52109a8e..62b13dea 100644 --- a/src/twinkle_client/dataset/lazy_dataset.py +++ b/src/twinkle_client/dataset/lazy_dataset.py @@ -9,7 +9,7 @@ # 2. Run: python client_tools/client_generator.py # ============================================================================ -from twinkle_client.http import http_post, heartbeat_manager +from twinkle_client.http import http_post from twinkle.dataset import Dataset from twinkle.dataset import DatasetMeta from .base import Dataset @@ -31,13 +31,6 @@ def __init__(self, dataset_meta: DatasetMeta, **kwargs): ) response.raise_for_status() self.processor_id = response.json()['processor_id'] - heartbeat_manager.register_processor(self.processor_id) - - def __del__(self): - try: - heartbeat_manager.unregister_processor(self.processor_id) - except: - pass def encode(self, **kwargs): diff --git a/src/twinkle_client/dataset/packing_dataset.py b/src/twinkle_client/dataset/packing_dataset.py index 175b8cd1..dd901d1d 100644 --- a/src/twinkle_client/dataset/packing_dataset.py +++ b/src/twinkle_client/dataset/packing_dataset.py @@ -9,7 +9,7 @@ # 2. Run: python client_tools/client_generator.py # ============================================================================ -from twinkle_client.http import http_post, heartbeat_manager +from twinkle_client.http import http_post from twinkle.dataset import Dataset from twinkle.dataset import DatasetMeta from .base import Dataset @@ -31,13 +31,6 @@ def __init__(self, dataset_meta: DatasetMeta, packing_num_proc: int = 1, **kwarg ) response.raise_for_status() self.processor_id = response.json()['processor_id'] - heartbeat_manager.register_processor(self.processor_id) - - def __del__(self): - try: - heartbeat_manager.unregister_processor(self.processor_id) - except: - pass def pack_dataset(self): diff --git a/src/twinkle_client/http/__init__.py b/src/twinkle_client/http/__init__.py index 39bedf71..2e6388b7 100644 --- a/src/twinkle_client/http/__init__.py +++ b/src/twinkle_client/http/__init__.py @@ -1,7 +1,8 @@ from .heartbeat import heartbeat_manager from .http_utils import http_delete, http_get, http_post from .utils import (TWINKLE_SERVER_TOKEN, TWINKLE_SERVER_URL, clear_api_key, clear_base_url, clear_request_id, - get_api_key, get_base_url, get_request_id, set_api_key, set_base_url, set_request_id) + clear_session_id, get_api_key, get_base_url, get_request_id, get_session_id, set_api_key, + set_base_url, set_request_id, set_session_id) __all__ = [ 'http_get', @@ -16,6 +17,9 @@ 'set_api_key', 'get_api_key', 'clear_api_key', + 'set_session_id', + 'get_session_id', + 'clear_session_id', 'set_request_id', 'get_request_id', 'clear_request_id', diff --git a/src/twinkle_client/http/heartbeat.py b/src/twinkle_client/http/heartbeat.py index c348a4f6..5194d75b 100644 --- a/src/twinkle_client/http/heartbeat.py +++ b/src/twinkle_client/http/heartbeat.py @@ -51,7 +51,7 @@ def __init__(self): def processor_heartbeat_func(self, processor_id_list: str): response = http_post( - url=f'{get_base_url()}/processors/twinkle/heartbeat', json_data={'processor_id': processor_id_list}) + url=f'{get_base_url()}/processor/twinkle/heartbeat', json_data={'processor_id': processor_id_list}) response.raise_for_status() def register_processor(self, processor_id: str): diff --git a/src/twinkle_client/http/http_utils.py b/src/twinkle_client/http/http_utils.py index 74f3c3bd..490337bb 100644 --- a/src/twinkle_client/http/http_utils.py +++ b/src/twinkle_client/http/http_utils.py @@ -1,7 +1,7 @@ import requests from typing import Any, Callable, Dict, Mapping, Optional -from .utils import get_api_key, get_base_url, get_request_id +from .utils import get_api_key, get_base_url, get_request_id, get_session_id def _build_headers(additional_headers: Optional[Dict[str, str]] = None) -> Dict[str, str]: @@ -16,10 +16,14 @@ def _build_headers(additional_headers: Optional[Dict[str, str]] = None) -> Dict[ """ headers = { 'X-Ray-Serve-Request-Id': get_request_id(), + 'serve_multiplexed_model_id': get_request_id(), # For model multiplexing 'Authorization': 'Bearer ' + get_api_key(), 'Twinkle-Authorization': 'Bearer ' + get_api_key(), # For server compatibility } + if session_id := get_session_id(): + headers['X-Twinkle-Session-Id'] = session_id + if additional_headers: headers.update(additional_headers) diff --git a/src/twinkle_client/http/utils.py b/src/twinkle_client/http/utils.py index 4ec1230b..e45b0360 100644 --- a/src/twinkle_client/http/utils.py +++ b/src/twinkle_client/http/utils.py @@ -10,6 +10,7 @@ # Context variables for flexible configuration _base_url_context: ContextVar[Optional[str]] = ContextVar('base_url', default=None) _api_key_context: ContextVar[Optional[str]] = ContextVar('api_key', default=None) +_session_id_context: ContextVar[Optional[str]] = ContextVar('session_id', default=None) # Global static request ID shared across all threads # This ensures heartbeat threads use the same request ID as the main training thread @@ -50,6 +51,21 @@ def clear_api_key(): _api_key_context.set(None) +def set_session_id(session_id: str): + """Set the session ID for the current context.""" + _session_id_context.set(session_id) + + +def get_session_id() -> Optional[str]: + """Get the current session ID from context.""" + return _session_id_context.get() + + +def clear_session_id(): + """Clear the session ID context.""" + _session_id_context.set(None) + + def set_request_id(request_id: str): """Set the global request ID for HTTP requests (shared across all threads).""" global _global_request_id diff --git a/src/twinkle_client/manager.py b/src/twinkle_client/manager.py index 2c3c0e96..394714e8 100644 --- a/src/twinkle_client/manager.py +++ b/src/twinkle_client/manager.py @@ -1,11 +1,16 @@ # Copyright (c) ModelScope Contributors. All rights reserved. from __future__ import annotations -from typing import Any, Dict, List, Optional +import atexit +import threading +from typing import Any, Dict, List, Optional, Tuple -# Shared Pydantic models -from twinkle_client.types.training import Checkpoint, Cursor, TrainingRun -from .http import http_get, http_post, get_base_url +from twinkle_client.types.server import DeleteCheckpointResponse +from twinkle_client.types.session import (CreateSessionRequest, CreateSessionResponse, SessionHeartbeatRequest, + SessionHeartbeatResponse) +from twinkle_client.types.training import (Checkpoint, Cursor, ParsedCheckpointTwinklePath, TrainingRun, + TrainingRunsResponse, WeightsInfoResponse) +from .http import get_api_key, get_base_url, http_delete, http_get, http_post, set_api_key, set_base_url, set_session_id, clear_session_id class TwinkleClientError(Exception): @@ -17,28 +22,68 @@ class TwinkleClient: """ Client manager for interacting with Twinkle REST API. - This manager provides methods to: - - List training runs owned by the current user - - Get details of specific training runs - - List checkpoints for a training run - - Get checkpoint file paths for resume training - - Delete checkpoints - - All operations respect user permissions - users can only access - and modify their own resources. + On initialization this client: + - Sets the base_url and api_key into the shared context so that all other + client objects (MultiLoraTransformersModel, vLLMSampler, processor clients) + automatically pick up the same configuration. + - Creates a server-side session and stores the session_id in context so that + every outgoing HTTP request carries it in the ``X-Twinkle-Session-Id`` header. + - Starts a lightweight background thread that touches the session every + ``session_heartbeat_interval`` seconds to keep it alive. Args: - base_url: Base URL of the Twinkle server (e.g., "http://localhost:8000"). - api_key: API key for authentication. If not provided, uses - TWINKLE_SERVER_TOKEN environment variable - route_prefix: API route prefix (default: "/twinkle") + base_url: Base URL of the Twinkle server (e.g. "http://localhost:8000"). + Falls back to the ``TWINKLE_SERVER_URL`` environment variable. + api_key: API key for authentication. Falls back to the + ``TWINKLE_SERVER_TOKEN`` environment variable. + route_prefix: API route prefix (default: "/twinkle"). + session_heartbeat_interval: Seconds between session touch calls (default: 30). + session_metadata: Optional metadata dict stored with the session on the server. """ - def __init__(self, base_url: str = None, api_key: str = None, route_prefix: str | None = '/twinkle'): - self.base_url = base_url or get_base_url() - self.api_key = api_key + def __init__( + self, + base_url: Optional[str] = None, + api_key: Optional[str] = None, + route_prefix: Optional[str] = '/twinkle', + session_heartbeat_interval: int = 30, + session_metadata: Optional[Dict[str, Any]] = None, + ): + # Resolve and store config, then propagate to context so all generated + # client objects that call get_base_url() / get_api_key() get these values. + if base_url: + set_base_url(base_url) + if api_key: + set_api_key(api_key) + + self.base_url = get_base_url() + self.api_key = get_api_key() self.route_prefix = route_prefix.rstrip('/') if route_prefix else '' + # Create a server-side session. + resp = http_post( + self._get_url('/create_session'), + json_data=CreateSessionRequest(metadata=session_metadata).model_dump(), + ) + resp.raise_for_status() + self._session_id: str = CreateSessionResponse(**resp.json()).session_id + set_session_id(self._session_id) + + # Start background session-touch thread. + self._heartbeat_interval = session_heartbeat_interval + self._stop_event = threading.Event() + self._heartbeat_thread = threading.Thread( + target=self._touch_session_loop, + daemon=True, + name='TwinkleSessionHeartbeat', + ) + self._heartbeat_thread.start() + atexit.register(self.close) + + # ------------------------------------------------------------------ + # Internal helpers + # ------------------------------------------------------------------ + def _get_url(self, endpoint: str) -> str: """Construct full URL for an endpoint.""" return f'{self.base_url}{self.route_prefix}{endpoint}' @@ -54,14 +99,36 @@ def _handle_response(self, response, expected_code: int = 200) -> dict[str, Any] raise TwinkleClientError(f'Request failed with status {response.status_code}: {detail}') return response.json() - # ----- Health Check ----- + def _touch_session_loop(self) -> None: + """Background loop: touch the session every N seconds.""" + while not self._stop_event.wait(timeout=self._heartbeat_interval): + try: + resp = http_post( + self._get_url('/session_heartbeat'), + json_data=SessionHeartbeatRequest(session_id=self._session_id).model_dump(), + ) + resp.raise_for_status() + except Exception as e: + # Do not crash the background thread on transient errors. + print(f'[TwinkleClient] Session heartbeat error: {e}') + + def close(self) -> None: + """Stop the background heartbeat thread and clear session context.""" + self._stop_event.set() + if self._heartbeat_thread.is_alive(): + self._heartbeat_thread.join(timeout=2) + clear_session_id() + + # ------------------------------------------------------------------ + # Health Check + # ------------------------------------------------------------------ def health_check(self) -> bool: """ Check if the Twinkle server is healthy. Returns: - True if server is healthy, False otherwise + True if server is healthy, False otherwise. """ try: response = http_get(self._get_url('/healthz')) @@ -69,66 +136,64 @@ def health_check(self) -> bool: except Exception: return False - # ----- Training Runs ----- + # ------------------------------------------------------------------ + # Training Runs + # ------------------------------------------------------------------ - def list_training_runs(self, limit: int = 20, offset: int = 0, all_users: bool = False) -> list[TrainingRun]: + def list_training_runs(self, limit: int = 20, offset: int = 0, all_users: bool = False) -> List[TrainingRun]: """ List training runs. By default, only returns training runs owned by the current user. Args: - limit: Maximum number of results (default: 20) - offset: Offset for pagination (default: 0) - all_users: If True, return all runs (if permission allows) + limit: Maximum number of results (default: 20). + offset: Offset for pagination (default: 0). + all_users: If True, return all runs (if permission allows). Returns: - List of TrainingRun objects + List of :class:`~twinkle_client.types.training.TrainingRun` objects. Raises: - TwinkleManagerError: If the request fails + TwinkleClientError: If the request fails. """ - params = {'limit': limit, 'offset': offset} + params: Dict[str, Any] = {'limit': limit, 'offset': offset} if all_users: params['all_users'] = 'true' response = http_get(self._get_url('/training_runs'), params=params) data = self._handle_response(response) - runs = [] - for run_data in data.get('training_runs', []): - runs.append(TrainingRun(**run_data)) - return runs + return [TrainingRun(**r) for r in data.get('training_runs', [])] - def list_training_runs_with_cursor(self, - limit: int = 20, - offset: int = 0, - all_users: bool = False) -> tuple[list[TrainingRun], Cursor]: + def list_training_runs_with_cursor( + self, + limit: int = 20, + offset: int = 0, + all_users: bool = False, + ) -> Tuple[List[TrainingRun], Cursor]: """ List training runs with pagination info. Args: - limit: Maximum number of results (default: 20) - offset: Offset for pagination (default: 0) - all_users: If True, return all runs (if permission allows) + limit: Maximum number of results (default: 20). + offset: Offset for pagination (default: 0). + all_users: If True, return all runs (if permission allows). Returns: - Tuple of (list of TrainingRun, Cursor with pagination info) + Tuple of (list of TrainingRun, Cursor with pagination info). Raises: - TwinkleManagerError: If the request fails + TwinkleClientError: If the request fails. """ - params = {'limit': limit, 'offset': offset} + params: Dict[str, Any] = {'limit': limit, 'offset': offset} if all_users: params['all_users'] = 'true' response = http_get(self._get_url('/training_runs'), params=params) data = self._handle_response(response) - runs = [] - for run_data in data.get('training_runs', []): - runs.append(TrainingRun(**run_data)) - + runs = [TrainingRun(**r) for r in data.get('training_runs', [])] cursor = Cursor(**data.get('cursor', {})) return runs, cursor @@ -137,158 +202,156 @@ def get_training_run(self, run_id: str) -> TrainingRun: Get details of a specific training run. Args: - run_id: The training run identifier + run_id: The training run identifier. Returns: - TrainingRun object with run details + :class:`~twinkle_client.types.training.TrainingRun` object with run details. Raises: - TwinkleManagerError: If run not found or access denied + TwinkleClientError: If run not found or access denied. """ response = http_get(self._get_url(f'/training_runs/{run_id}')) data = self._handle_response(response) return TrainingRun(**data) - # ----- Checkpoints ----- + # ------------------------------------------------------------------ + # Checkpoints + # ------------------------------------------------------------------ - def list_checkpoints(self, run_id: str) -> list[Checkpoint]: + def list_checkpoints(self, run_id: str) -> List[Checkpoint]: """ List checkpoints for a training run. Args: - run_id: The training run identifier + run_id: The training run identifier. Returns: - List of Checkpoint objects + List of :class:`~twinkle_client.types.training.Checkpoint` objects. Raises: - TwinkleManagerError: If run not found or access denied + TwinkleClientError: If run not found or access denied. """ response = http_get(self._get_url(f'/training_runs/{run_id}/checkpoints')) data = self._handle_response(response) + return [Checkpoint(**c) for c in data.get('checkpoints', [])] - checkpoints = [] - for ckpt_data in data.get('checkpoints', []): - checkpoints.append(Checkpoint(**ckpt_data)) - return checkpoints - - def get_checkpoint_path(self, run_id: str, checkpoint_id: str) -> str: + def get_checkpoint_path(self, run_id: str, checkpoint_id: str) -> ParsedCheckpointTwinklePath: """ - Get the filesystem path for a checkpoint. - - This path can be used to load weights for resume training. + Get the filesystem path and twinkle:// path for a checkpoint. Args: - run_id: The training run identifier - checkpoint_id: The checkpoint identifier (e.g., "weights/20240101_120000") + run_id: The training run identifier. + checkpoint_id: The checkpoint identifier (e.g. "weights/20240101_120000"). Returns: - Filesystem path to the checkpoint directory + :class:`~twinkle_client.types.training.ParsedCheckpointTwinklePath` with + ``path`` (filesystem) and ``twinkle_path`` fields. Raises: - TwinkleManagerError: If checkpoint not found or access denied + TwinkleClientError: If checkpoint not found or access denied. """ response = http_get(self._get_url(f'/checkpoint_path/{run_id}/{checkpoint_id}')) data = self._handle_response(response) - return data.get('path', '') + return ParsedCheckpointTwinklePath( + path=data.get('path', ''), + twinkle_path=data.get('twinkle_path', ''), + training_run_id=run_id, + checkpoint_type=checkpoint_id.split('/')[0] if '/' in checkpoint_id else '', + checkpoint_id=checkpoint_id, + ) def get_checkpoint_twinkle_path(self, run_id: str, checkpoint_id: str) -> str: """ Get the twinkle:// path for a checkpoint. Args: - run_id: The training run identifier - checkpoint_id: The checkpoint identifier + run_id: The training run identifier. + checkpoint_id: The checkpoint identifier. Returns: - Twinkle path (e.g., "twinkle://run_id/weights/checkpoint_name") + Twinkle path string (e.g. "twinkle://run_id/weights/checkpoint_name"). Raises: - TwinkleManagerError: If checkpoint not found or access denied + TwinkleClientError: If checkpoint not found or access denied. """ - response = http_get(self._get_url(f'/checkpoint_path/{run_id}/{checkpoint_id}')) - data = self._handle_response(response) - return data.get('twinkle_path', '') + return self.get_checkpoint_path(run_id, checkpoint_id).twinkle_path - def delete_checkpoint(self, run_id: str, checkpoint_id: str) -> bool: + def delete_checkpoint(self, run_id: str, checkpoint_id: str) -> DeleteCheckpointResponse: """ Delete a checkpoint. Args: - run_id: The training run identifier - checkpoint_id: The checkpoint identifier + run_id: The training run identifier. + checkpoint_id: The checkpoint identifier. Returns: - True if deletion was successful + :class:`~twinkle_client.types.server.DeleteCheckpointResponse` indicating success. Raises: - TwinkleManagerError: If checkpoint not found or access denied + TwinkleClientError: If checkpoint not found or access denied. """ - from .http import http_delete - url = self._get_url(f'/training_runs/{run_id}/checkpoints/{checkpoint_id}') response = http_delete(url) data = self._handle_response(response) - return data.get('success', False) + return DeleteCheckpointResponse(**data) - # ----- Weights Info ----- + # ------------------------------------------------------------------ + # Weights Info + # ------------------------------------------------------------------ - def get_weights_info(self, twinkle_path: str) -> dict[str, Any]: + def get_weights_info(self, twinkle_path: str) -> WeightsInfoResponse: """ Get information about saved weights. Args: - twinkle_path: The twinkle:// path to the weights + twinkle_path: The twinkle:// path to the weights. Returns: - Dictionary with weight information including: - - training_run_id - - base_model - - model_owner - - is_lora - - lora_rank + :class:`~twinkle_client.types.training.WeightsInfoResponse` with fields: + ``training_run_id``, ``base_model``, ``model_owner``, ``is_lora``, ``lora_rank``. Raises: - TwinkleManagerError: If weights not found or access denied + TwinkleClientError: If weights not found or access denied. """ response = http_post(self._get_url('/weights_info'), json_data={'twinkle_path': twinkle_path}) - return self._handle_response(response) + data = self._handle_response(response) + return WeightsInfoResponse(**data) - # ----- Convenience Methods for Resume Training ----- + # ------------------------------------------------------------------ + # Convenience Methods + # ------------------------------------------------------------------ - def get_latest_checkpoint_path(self, run_id: str) -> str | None: + def get_latest_checkpoint_path(self, run_id: str) -> Optional[str]: """ - Get the path to the latest checkpoint for a training run. + Get the filesystem path to the latest checkpoint for a training run. - This is useful for resume training - it returns the path to the - most recent checkpoint that can be loaded. + Useful for resume training — returns the path to the most recent checkpoint. Args: - run_id: The training run identifier + run_id: The training run identifier. Returns: - Filesystem path to the latest checkpoint, or None if no checkpoints exist + Filesystem path string to the latest checkpoint, or ``None`` if none exist. Raises: - TwinkleManagerError: If run not found or access denied + TwinkleClientError: If run not found or access denied. """ checkpoints = self.list_checkpoints(run_id) if not checkpoints: return None - - # Checkpoints are sorted by time, so last one is the latest latest = checkpoints[-1] - return self.get_checkpoint_path(run_id, latest.checkpoint_id) + return self.get_checkpoint_path(run_id, latest.checkpoint_id).path - def find_training_run_by_model(self, base_model: str) -> list[TrainingRun]: + def find_training_run_by_model(self, base_model: str) -> List[TrainingRun]: """ Find training runs for a specific base model. Args: - base_model: The base model name to search for + base_model: The base model name to search for. Returns: - List of TrainingRun objects matching the base model + List of :class:`~twinkle_client.types.training.TrainingRun` objects + matching the base model. """ all_runs = self.list_training_runs(limit=100) return [run for run in all_runs if run.base_model == base_model] diff --git a/src/twinkle_client/model/multi_lora_transformers.py b/src/twinkle_client/model/multi_lora_transformers.py index bb91ad3a..992cce64 100644 --- a/src/twinkle_client/model/multi_lora_transformers.py +++ b/src/twinkle_client/model/multi_lora_transformers.py @@ -8,18 +8,36 @@ # 1. Modify the source files in src/twinkle/ # 2. Run: python client_tools/client_generator.py # ============================================================================ -from typing import Any, Optional, Union, Type, Dict, Literal, List -import uuid -from twinkle_client.http import http_post, heartbeat_manager -from twinkle import DeviceMesh -from twinkle.data_format import InputFeature, Trajectory +from typing import Any, Dict, Optional +from twinkle_client.http import http_post +from twinkle_client.types.model import ( + BackwardResponse, + CalculateLossResponse, + CalculateMetricResponse, + ClipGradNormResponse, + ForwardBackwardResponse, + ForwardResponse, + GetStateDictResponse, + GetTrainConfigsResponse, + LoadResponse, + LrStepResponse, + SaveResponse, + SetLossResponse, + SetLrSchedulerResponse, + SetOptimizerResponse, + SetProcessorResponse, + SetTemplateResponse, + StepResponse, + UploadToHubResponse, + ZeroGradResponse, +) class MultiLoraTransformersModel: """Client wrapper for TwinkleModel that calls server HTTP endpoints. This client manages adapters and sends training/inference requests to the model server. - Each adapter has its own lifecycle managed through automatic heartbeats. + The server-side session (managed by TwinkleClient) keeps the model alive. """ def __init__(self, model_id: str, **kwargs): @@ -37,208 +55,193 @@ def __init__(self, model_id: str, **kwargs): ) response.raise_for_status() - def _send_adapter_heartbeat(self): - """Internal method to send adapter heartbeat.""" - response = http_post( - url=f'{self.server_url}/heartbeat', - json_data={'adapter_name': self.adapter_name} - ) - response.raise_for_status() - - def add_adapter_to_model(self, adapter_name: str, config: Dict[str, Any], **kwargs): - """Add a new adapter to the model and start automatic heartbeat.""" + def add_adapter_to_model(self, adapter_name: str, config: Dict[str, Any], **kwargs) -> None: + """Add a new adapter to the model.""" response = http_post( url=f'{self.server_url}/add_adapter_to_model', json_data={'adapter_name': adapter_name, 'config': config, **kwargs} ) response.raise_for_status() - - # Register adapter for automatic heartbeat after successful creation self.adapter_name = adapter_name - heartbeat_manager.register_adapter( - self.adapter_name, - self._send_adapter_heartbeat - ) - def __del__(self): - """Cleanup: unregister adapter from heartbeat manager.""" - try: - heartbeat_manager.unregister_adapter(self.adapter_name) - except: - pass - - def forward(self, inputs: Any, **kwargs): + def forward(self, inputs: Any, **kwargs) -> ForwardResponse: """Execute forward pass on the model.""" response = http_post( url=f'{self.server_url}/forward', json_data={'inputs': inputs, 'adapter_name': self.adapter_name, **kwargs} ) response.raise_for_status() - return response.json()['result'] + return ForwardResponse(**response.json()) - def forward_only(self, inputs: Any, **kwargs): + def forward_only(self, inputs: Any, **kwargs) -> ForwardResponse: """Execute forward pass without gradient computation.""" response = http_post( url=f'{self.server_url}/forward_only', json_data={'inputs': inputs, 'adapter_name': self.adapter_name, **kwargs} ) response.raise_for_status() - return response.json()['result'] + return ForwardResponse(**response.json()) - def calculate_loss(self, **kwargs): + def calculate_loss(self, **kwargs) -> CalculateLossResponse: """Calculate loss from model outputs.""" response = http_post( url=f'{self.server_url}/calculate_loss', json_data={'adapter_name': self.adapter_name, **kwargs} ) response.raise_for_status() - return response.json()['result'] + return CalculateLossResponse(**response.json()) - def get_train_configs(self, **kwargs): - """Get training configs""" + def get_train_configs(self, **kwargs) -> GetTrainConfigsResponse: + """Get training configs.""" response = http_post( url=f'{self.server_url}/get_train_configs', json_data={'adapter_name': self.adapter_name, **kwargs} ) response.raise_for_status() - return response.json()['result'] + return GetTrainConfigsResponse(**response.json()) - def backward(self, **kwargs): + def backward(self, **kwargs) -> BackwardResponse: """Execute backward pass.""" response = http_post( url=f'{self.server_url}/backward', json_data={'adapter_name': self.adapter_name, **kwargs} ) response.raise_for_status() - return response.json()['result'] + return BackwardResponse(**response.json()) - def forward_backward(self, inputs: Any, **kwargs): + def forward_backward(self, inputs: Any, **kwargs) -> ForwardBackwardResponse: """Execute combined forward and backward pass.""" response = http_post( url=f'{self.server_url}/forward_backward', json_data={'inputs': inputs, 'adapter_name': self.adapter_name, **kwargs} ) response.raise_for_status() - return response.json()['result'] + return ForwardBackwardResponse(**response.json()) - def step(self, **kwargs): + def step(self, **kwargs) -> StepResponse: """Execute optimizer step.""" response = http_post( url=f'{self.server_url}/step', json_data={'adapter_name': self.adapter_name, **kwargs} ) response.raise_for_status() - return response.json()['result'] + return StepResponse(**response.json()) - def zero_grad(self, **kwargs): + def zero_grad(self, **kwargs) -> ZeroGradResponse: """Zero out gradients.""" response = http_post( url=f'{self.server_url}/zero_grad', json_data={'adapter_name': self.adapter_name, **kwargs} ) response.raise_for_status() - return response.json()['result'] + return ZeroGradResponse(**response.json()) - def lr_step(self, **kwargs): + def lr_step(self, **kwargs) -> LrStepResponse: """Execute learning rate scheduler step.""" response = http_post( url=f'{self.server_url}/lr_step', json_data={'adapter_name': self.adapter_name, **kwargs} ) response.raise_for_status() - return response.json()['result'] + return LrStepResponse(**response.json()) - def set_loss(self, loss_cls: str, **kwargs): + def set_loss(self, loss_cls: str, **kwargs) -> SetLossResponse: """Set the loss function.""" response = http_post( url=f'{self.server_url}/set_loss', json_data={'loss_cls': loss_cls, 'adapter_name': self.adapter_name, **kwargs} ) response.raise_for_status() - return response.json()['result'] + return SetLossResponse(**response.json()) - def clip_grad_norm(self, max_grad_norm: float=1.0, norm_type=2, **kwargs): - """Set the loss function.""" + def clip_grad_norm(self, max_grad_norm: float = 1.0, norm_type: int = 2, **kwargs) -> ClipGradNormResponse: + """Clip gradient norm.""" response = http_post( url=f'{self.server_url}/clip_grad_norm', json_data={'max_grad_norm': max_grad_norm, 'norm_type': norm_type, 'adapter_name': self.adapter_name, **kwargs} ) response.raise_for_status() - return response.json()['result'] + return ClipGradNormResponse(**response.json()) - def set_optimizer(self, optimizer_cls: str, **kwargs): + def set_optimizer(self, optimizer_cls: str, **kwargs) -> SetOptimizerResponse: """Set the optimizer.""" response = http_post( url=f'{self.server_url}/set_optimizer', json_data={'optimizer_cls': optimizer_cls, 'adapter_name': self.adapter_name, **kwargs} ) response.raise_for_status() - return response.json()['result'] + return SetOptimizerResponse(**response.json()) - def set_lr_scheduler(self, scheduler_cls: str, **kwargs): + def set_lr_scheduler(self, scheduler_cls: str, **kwargs) -> SetLrSchedulerResponse: """Set the learning rate scheduler.""" response = http_post( url=f'{self.server_url}/set_lr_scheduler', json_data={'scheduler_cls': scheduler_cls, 'adapter_name': self.adapter_name, **kwargs} ) response.raise_for_status() - return response.json()['result'] + return SetLrSchedulerResponse(**response.json()) - def save(self, name: str, **kwargs): + def save(self, name: str, **kwargs) -> SaveResponse: """Save model checkpoint.""" response = http_post( url=f'{self.server_url}/save', json_data={'name': name, 'adapter_name': self.adapter_name, **kwargs} ) response.raise_for_status() - return response.json()['result'] + return SaveResponse(**response.json()) - def load(self, name: str, **kwargs): + def load(self, name: str, **kwargs) -> LoadResponse: """Load model checkpoint.""" response = http_post( url=f'{self.server_url}/load', json_data={'name': name, 'adapter_name': self.adapter_name, **kwargs} ) response.raise_for_status() - return response.json()['result'] + return LoadResponse(**response.json()) - def set_template(self, template_cls: str, **kwargs): + def set_template(self, template_cls: str, **kwargs) -> SetTemplateResponse: """Set the template for data processing.""" response = http_post( url=f'{self.server_url}/set_template', json_data={'template_cls': template_cls, 'adapter_name': self.adapter_name, 'model_id': self.model_id, **kwargs} ) response.raise_for_status() - return response.json()['result'] + return SetTemplateResponse(**response.json()) - def set_processor(self, processor_cls: str, **kwargs): + def set_processor(self, processor_cls: str, **kwargs) -> SetProcessorResponse: """Set the input processor.""" response = http_post( url=f'{self.server_url}/set_processor', json_data={'processor_cls': processor_cls, 'adapter_name': self.adapter_name, **kwargs} ) response.raise_for_status() - return response.json()['result'] + return SetProcessorResponse(**response.json()) - def calculate_metric(self, is_training: bool = True, **kwargs): + def calculate_metric(self, is_training: bool = True, **kwargs) -> CalculateMetricResponse: """Calculate metrics from model outputs.""" response = http_post( url=f'{self.server_url}/calculate_metric', json_data={'is_training': is_training, 'adapter_name': self.adapter_name, **kwargs} ) response.raise_for_status() - return response.json()['result'] + return CalculateMetricResponse(**response.json()) - def get_state_dict(self, **kwargs): + def get_state_dict(self, **kwargs) -> GetStateDictResponse: """Get model state dictionary.""" response = http_post( url=f'{self.server_url}/get_state_dict', json_data={'adapter_name': self.adapter_name, **kwargs} ) response.raise_for_status() - return response.json()['result'] + return GetStateDictResponse(**response.json()) - def upload_to_hub(self, checkpoint_dir: str, hub_model_id: str, hub_token: Optional[str] = None, async_upload: bool = True): + def upload_to_hub( + self, + checkpoint_dir: str, + hub_model_id: str, + hub_token: Optional[str] = None, + async_upload: bool = True, + ) -> UploadToHubResponse: """Upload model checkpoint to hub. Args: @@ -253,8 +256,8 @@ def upload_to_hub(self, checkpoint_dir: str, hub_model_id: str, hub_token: Optio 'checkpoint_dir': checkpoint_dir, 'hub_model_id': hub_model_id, 'hub_token': hub_token, - 'async_upload': async_upload + 'async_upload': async_upload, } ) response.raise_for_status() - return response.json() + return UploadToHubResponse(**response.json()) diff --git a/src/twinkle_client/processor/base.py b/src/twinkle_client/processor/base.py index 63aaa9c4..048ace5e 100644 --- a/src/twinkle_client/processor/base.py +++ b/src/twinkle_client/processor/base.py @@ -10,7 +10,7 @@ # ============================================================================ from typing import List, Literal, Optional, Union -from twinkle_client.http import http_post, heartbeat_manager +from twinkle_client.http import http_post from twinkle import DeviceMesh from twinkle.data_format import InputFeature @@ -31,13 +31,6 @@ def __init__(self, device_mesh: Optional[DeviceMesh] = None, padding_free: bool ) response.raise_for_status() self.processor_id = response.json()['processor_id'] - heartbeat_manager.register_processor(self.processor_id) - - def __del__(self): - try: - heartbeat_manager.unregister_processor(self.processor_id) - except: - pass def __call__(self, inputs: Union[InputFeature, List[InputFeature]], **kwargs): diff --git a/src/twinkle_client/sampler/vllm_sampler.py b/src/twinkle_client/sampler/vllm_sampler.py index 3c703569..a19984c3 100644 --- a/src/twinkle_client/sampler/vllm_sampler.py +++ b/src/twinkle_client/sampler/vllm_sampler.py @@ -8,9 +8,10 @@ # 1. Modify the source files in src/twinkle/ # 2. Run: python client_tools/client_generator.py # ============================================================================ -from typing import Any, Optional, List, Dict, Union -from twinkle_client.http import http_post, heartbeat_manager +from typing import Any, Dict, List, Optional, Union +from twinkle_client.http import http_post from twinkle.sampler.base import Sampler +from twinkle_client.types.sampler import AddAdapterResponse, SampleResponseModel, SetTemplateResponse from peft import PeftConfig from twinkle.data_format import Trajectory, InputFeature @@ -19,7 +20,7 @@ class vLLMSampler(Sampler): """Client wrapper for Sampler that calls server HTTP endpoints. This client manages sampling operations and adapter synchronization with the sampler server. - Each adapter has its own lifecycle managed through automatic heartbeats. + The server-side session (managed by TwinkleClient) keeps the sampler alive. """ def __init__(self, model_id: str, **kwargs): @@ -37,18 +38,8 @@ def __init__(self, model_id: str, **kwargs): ) response.raise_for_status() - def _send_adapter_heartbeat(self): - """Internal method to send adapter heartbeat.""" - if not self.adapter_name: - return - response = http_post( - url=f'{self.server_url}/heartbeat', - json_data={'adapter_name': self.adapter_name} - ) - response.raise_for_status() - - def add_adapter_to_sampler(self, adapter_name: str, config: PeftConfig, **kwargs): - """Add a new adapter to the sampler and start automatic heartbeat.""" + def add_adapter_to_sampler(self, adapter_name: str, config: PeftConfig, **kwargs) -> AddAdapterResponse: + """Add a new adapter to the sampler.""" if isinstance(config, PeftConfig): config = config.__dict__ response = http_post( @@ -56,23 +47,8 @@ def add_adapter_to_sampler(self, adapter_name: str, config: PeftConfig, **kwargs json_data={'adapter_name': adapter_name, 'config': config, **kwargs} ) response.raise_for_status() - - # Register adapter for automatic heartbeat after successful creation self.adapter_name = adapter_name - heartbeat_manager.register_adapter( - self.adapter_name, - self._send_adapter_heartbeat - ) - - return response.json() - - def __del__(self): - """Cleanup: unregister adapter from heartbeat manager.""" - try: - if self.adapter_name: - heartbeat_manager.unregister_adapter(self.adapter_name) - except: - pass + return AddAdapterResponse(**response.json()) def sample( self, @@ -81,7 +57,7 @@ def sample( adapter_name: str = '', adapter_uri: Optional[str] = None, num_samples: int = 1, - ) -> Dict[str, Any]: + ) -> SampleResponseModel: """Sample from the model. Args: @@ -92,7 +68,7 @@ def sample( num_samples: Number of completions to generate per prompt. Returns: - Dict with 'sequences' list, each containing tokens, logprobs, stop_reason. + SampleResponseModel with 'sequences' list, each containing tokens, logprobs, stop_reason. """ json_data = { 'inputs': inputs, @@ -108,13 +84,13 @@ def sample( json_data=json_data ) response.raise_for_status() - return response.json() + return SampleResponseModel(**response.json()) - def set_template(self, template_cls: str, adapter_name: str = '', **kwargs): + def set_template(self, template_cls: str, adapter_name: str = '', **kwargs) -> SetTemplateResponse: """Set the template for encoding trajectories.""" response = http_post( url=f'{self.server_url}/set_template', json_data={'template_cls': template_cls, 'adapter_name': adapter_name, **kwargs} ) response.raise_for_status() - return response.json() + return SetTemplateResponse(**response.json()) diff --git a/src/twinkle_client/types/__init__.py b/src/twinkle_client/types/__init__.py index 85b3e739..c72ca1dd 100644 --- a/src/twinkle_client/types/__init__.py +++ b/src/twinkle_client/types/__init__.py @@ -1 +1,37 @@ # Copyright (c) ModelScope Contributors. All rights reserved. +from .model import ( + BackwardResponse, + CalculateLossResponse, + CalculateMetricResponse, + ClipGradNormResponse, + ForwardBackwardResponse, + ForwardResponse, + GetStateDictResponse, + GetTrainConfigsResponse, + LoadResponse, + LrStepResponse, + ModelResult, + SaveResponse, + SetLossResponse, + SetLrSchedulerResponse, + SetOptimizerResponse, + SetProcessorResponse, + SetTemplateResponse, + StepResponse, + UploadToHubResponse, + ZeroGradResponse, +) +from .sampler import AddAdapterResponse, SampleResponseModel, SetTemplateResponse as SamplerSetTemplateResponse +from .server import DeleteCheckpointResponse, ErrorResponse, HealthResponse, WeightsInfoRequest +from .session import CreateSessionRequest, CreateSessionResponse, SessionHeartbeatRequest, SessionHeartbeatResponse +from .training import ( + Checkpoint, + CheckpointsListResponse, + CreateModelRequest, + Cursor, + LoraConfig, + ParsedCheckpointTwinklePath, + TrainingRun, + TrainingRunsResponse, + WeightsInfoResponse, +) diff --git a/src/twinkle_client/types/model.py b/src/twinkle_client/types/model.py index b3b9b6c4..7ae2e923 100644 --- a/src/twinkle_client/types/model.py +++ b/src/twinkle_client/types/model.py @@ -5,7 +5,7 @@ These models are used by both the server-side handler and the twinkle client. """ from pydantic import BaseModel -from typing import Any, Optional +from typing import Any, Dict, List, Optional class CreateRequest(BaseModel): @@ -130,3 +130,112 @@ class GetStateDictRequest(BaseModel): class Config: extra = 'allow' + + +# --------------------------------------------------------------------------- +# Response models +# --------------------------------------------------------------------------- + + +class ModelResult(BaseModel): + """Generic single-value result wrapper returned by most model endpoints.""" + result: Any + + +class ForwardResponse(ModelResult): + """Response for /forward and /forward_only endpoints.""" + pass + + +class ForwardBackwardResponse(ModelResult): + """Response for /forward_backward endpoint.""" + pass + + +class BackwardResponse(ModelResult): + """Response for /backward endpoint.""" + pass + + +class StepResponse(ModelResult): + """Response for /step (optimizer step) endpoint.""" + pass + + +class ZeroGradResponse(ModelResult): + """Response for /zero_grad endpoint.""" + pass + + +class LrStepResponse(ModelResult): + """Response for /lr_step endpoint.""" + pass + + +class SetLossResponse(ModelResult): + """Response for /set_loss endpoint.""" + pass + + +class ClipGradNormResponse(ModelResult): + """Response for /clip_grad_norm endpoint.""" + pass + + +class SetOptimizerResponse(ModelResult): + """Response for /set_optimizer endpoint.""" + pass + + +class SetLrSchedulerResponse(ModelResult): + """Response for /set_lr_scheduler endpoint.""" + pass + + +class SaveResponse(ModelResult): + """Response for /save endpoint.""" + pass + + +class LoadResponse(ModelResult): + """Response for /load endpoint.""" + pass + + +class SetTemplateResponse(ModelResult): + """Response for /set_template endpoint.""" + pass + + +class SetProcessorResponse(ModelResult): + """Response for /set_processor endpoint.""" + pass + + +class CalculateLossResponse(ModelResult): + """Response for /calculate_loss endpoint.""" + pass + + +class CalculateMetricResponse(ModelResult): + """Response for /calculate_metric endpoint.""" + pass + + +class GetTrainConfigsResponse(ModelResult): + """Response for /get_train_configs endpoint.""" + pass + + +class GetStateDictResponse(ModelResult): + """Response for /get_state_dict endpoint.""" + pass + + +class UploadToHubResponse(BaseModel): + """Response for /upload_to_hub endpoint.""" + status: Optional[str] = None + message: Optional[str] = None + + class Config: + extra = 'allow' diff --git a/src/twinkle_client/types/session.py b/src/twinkle_client/types/session.py new file mode 100644 index 00000000..f6b1adb7 --- /dev/null +++ b/src/twinkle_client/types/session.py @@ -0,0 +1,24 @@ +# Copyright (c) ModelScope Contributors. All rights reserved. +"""Pydantic models for twinkle session management endpoints.""" +from pydantic import BaseModel +from typing import Any, Dict, Optional + + +class CreateSessionRequest(BaseModel): + """Request body for POST /twinkle/create_session.""" + metadata: Optional[Dict[str, Any]] = None + + +class CreateSessionResponse(BaseModel): + """Response body for POST /twinkle/create_session.""" + session_id: str + + +class SessionHeartbeatRequest(BaseModel): + """Request body for POST /twinkle/session_heartbeat.""" + session_id: str + + +class SessionHeartbeatResponse(BaseModel): + """Response body for POST /twinkle/session_heartbeat.""" + pass From 1a51335709c29ec7e5d6caa222b6a24a289ad780 Mon Sep 17 00:00:00 2001 From: Yunnglin Date: Wed, 11 Mar 2026 20:57:31 +0800 Subject: [PATCH 10/24] update twinkle --- .../gateway/twinkle_gateway_handlers.py | 56 +++++----- src/twinkle/server/model/twinkle_handlers.py | 104 +++++++++--------- src/twinkle/server/processor/app.py | 24 ++-- .../server/sampler/twinkle_handlers.py | 42 ++++--- src/twinkle_client/types/__init__.py | 48 +++++++- src/twinkle_client/types/model.py | 16 +++ src/twinkle_client/types/processor.py | 16 +++ src/twinkle_client/types/server.py | 12 ++ 8 files changed, 198 insertions(+), 120 deletions(-) diff --git a/src/twinkle/server/gateway/twinkle_gateway_handlers.py b/src/twinkle/server/gateway/twinkle_gateway_handlers.py index 3d6ca314..8a159be3 100644 --- a/src/twinkle/server/gateway/twinkle_gateway_handlers.py +++ b/src/twinkle/server/gateway/twinkle_gateway_handlers.py @@ -16,11 +16,7 @@ from twinkle.server.utils.checkpoint_base import validate_user_path from twinkle.server.utils.validation import get_token_from_request from twinkle.utils.logger import get_logger -from twinkle_client.types.server import DeleteCheckpointResponse, HealthResponse, WeightsInfoRequest -from twinkle_client.types.session import (CreateSessionRequest, CreateSessionResponse, SessionHeartbeatRequest, - SessionHeartbeatResponse) -from twinkle_client.types.training import (CheckpointsListResponse, TrainingRun, TrainingRunsResponse, - WeightsInfoResponse) +import twinkle_client.types as types logger = get_logger() @@ -28,38 +24,38 @@ def _register_twinkle_routes(app: FastAPI, self_fn: Callable[[], GatewayServer]) -> None: """Register all /twinkle/* routes on the given FastAPI app.""" - @app.get('/twinkle/healthz', response_model=HealthResponse) - async def healthz(request: Request) -> HealthResponse: - return HealthResponse(status='ok') + @app.get('/twinkle/healthz', response_model=types.HealthResponse) + async def healthz(request: Request) -> types.HealthResponse: + return types.HealthResponse(status='ok') - @app.post('/twinkle/create_session', response_model=CreateSessionResponse) + @app.post('/twinkle/create_session', response_model=types.CreateSessionResponse) async def create_session( request: Request, - body: CreateSessionRequest, + body: types.CreateSessionRequest, self: GatewayServer = Depends(self_fn), - ) -> CreateSessionResponse: + ) -> types.CreateSessionResponse: session_id = self.state.create_session(body.model_dump()) - return CreateSessionResponse(session_id=session_id) + return types.CreateSessionResponse(session_id=session_id) - @app.post('/twinkle/session_heartbeat', response_model=SessionHeartbeatResponse) + @app.post('/twinkle/session_heartbeat', response_model=types.SessionHeartbeatResponse) async def session_heartbeat( request: Request, - body: SessionHeartbeatRequest, + body: types.SessionHeartbeatRequest, self: GatewayServer = Depends(self_fn), - ) -> SessionHeartbeatResponse: + ) -> types.SessionHeartbeatResponse: alive = self.state.touch_session(body.session_id) if not alive: raise HTTPException(status_code=404, detail='Unknown session') - return SessionHeartbeatResponse() + return types.SessionHeartbeatResponse() - @app.get('/twinkle/training_runs', response_model=TrainingRunsResponse) - async def get_training_runs(request: Request, limit: int = 20, offset: int = 0) -> TrainingRunsResponse: + @app.get('/twinkle/training_runs', response_model=types.TrainingRunsResponse) + async def get_training_runs(request: Request, limit: int = 20, offset: int = 0) -> types.TrainingRunsResponse: token = get_token_from_request(request) training_run_manager = create_training_run_manager(token, client_type='twinkle') return training_run_manager.list_runs(limit=limit, offset=offset) - @app.get('/twinkle/training_runs/{run_id}', response_model=TrainingRun) - async def get_training_run(request: Request, run_id: str) -> TrainingRun: + @app.get('/twinkle/training_runs/{run_id}', response_model=types.TrainingRun) + async def get_training_run(request: Request, run_id: str) -> types.TrainingRun: token = get_token_from_request(request) training_run_manager = create_training_run_manager(token, client_type='twinkle') run = training_run_manager.get_with_permission(run_id) @@ -67,8 +63,8 @@ async def get_training_run(request: Request, run_id: str) -> TrainingRun: raise HTTPException(status_code=404, detail=f'Training run {run_id} not found or access denied') return run - @app.get('/twinkle/training_runs/{run_id}/checkpoints', response_model=CheckpointsListResponse) - async def get_run_checkpoints(request: Request, run_id: str) -> CheckpointsListResponse: + @app.get('/twinkle/training_runs/{run_id}/checkpoints', response_model=types.CheckpointsListResponse) + async def get_run_checkpoints(request: Request, run_id: str) -> types.CheckpointsListResponse: token = get_token_from_request(request) checkpoint_manager = create_checkpoint_manager(token, client_type='twinkle') response = checkpoint_manager.list_checkpoints(run_id) @@ -76,8 +72,8 @@ async def get_run_checkpoints(request: Request, run_id: str) -> CheckpointsListR raise HTTPException(status_code=404, detail=f'Training run {run_id} not found or access denied') return response - @app.delete('/twinkle/training_runs/{run_id}/checkpoints/{checkpoint_id:path}') - async def delete_run_checkpoint(request: Request, run_id: str, checkpoint_id: str) -> DeleteCheckpointResponse: + @app.delete('/twinkle/training_runs/{run_id}/checkpoints/{checkpoint_id:path}', response_model=types.DeleteCheckpointResponse) + async def delete_run_checkpoint(request: Request, run_id: str, checkpoint_id: str) -> types.DeleteCheckpointResponse: token = get_token_from_request(request) if not validate_user_path(token, checkpoint_id): @@ -88,10 +84,10 @@ async def delete_run_checkpoint(request: Request, run_id: str, checkpoint_id: st if not success: raise HTTPException(status_code=404, detail=f'Checkpoint {checkpoint_id} not found or access denied') - return DeleteCheckpointResponse(success=True, message=f'Checkpoint {checkpoint_id} deleted successfully') + return types.DeleteCheckpointResponse(success=True, message=f'Checkpoint {checkpoint_id} deleted successfully') - @app.post('/twinkle/weights_info', response_model=WeightsInfoResponse) - async def weights_info(request: Request, body: WeightsInfoRequest) -> WeightsInfoResponse: + @app.post('/twinkle/weights_info', response_model=types.WeightsInfoResponse) + async def weights_info(request: Request, body: types.WeightsInfoRequest) -> types.WeightsInfoResponse: token = get_token_from_request(request) checkpoint_manager = create_checkpoint_manager(token, client_type='twinkle') response = checkpoint_manager.get_weights_info(body.twinkle_path) @@ -99,8 +95,8 @@ async def weights_info(request: Request, body: WeightsInfoRequest) -> WeightsInf raise HTTPException(status_code=404, detail=f'Weights at {body.twinkle_path} not found or access denied') return response - @app.get('/twinkle/checkpoint_path/{run_id}/{checkpoint_id:path}') - async def get_checkpoint_path(request: Request, run_id: str, checkpoint_id: str) -> dict[str, str]: + @app.get('/twinkle/checkpoint_path/{run_id}/{checkpoint_id:path}', response_model=types.CheckpointPathResponse) + async def get_checkpoint_path(request: Request, run_id: str, checkpoint_id: str) -> types.CheckpointPathResponse: token = get_token_from_request(request) if not validate_user_path(token, checkpoint_id): @@ -118,4 +114,4 @@ async def get_checkpoint_path(request: Request, run_id: str, checkpoint_id: str) raise HTTPException(status_code=404, detail=f'Checkpoint {checkpoint_id} not found') ckpt_dir = checkpoint_manager.get_ckpt_dir(run_id, checkpoint_id) - return {'path': str(ckpt_dir), 'twinkle_path': checkpoint.twinkle_path} + return types.CheckpointPathResponse(path=str(ckpt_dir), twinkle_path=checkpoint.twinkle_path) diff --git a/src/twinkle/server/model/twinkle_handlers.py b/src/twinkle/server/model/twinkle_handlers.py index 43dc8c73..4aad734a 100644 --- a/src/twinkle/server/model/twinkle_handlers.py +++ b/src/twinkle/server/model/twinkle_handlers.py @@ -20,11 +20,7 @@ from twinkle.server.common.checkpoint_factory import create_checkpoint_manager, create_training_run_manager from twinkle.server.common.serialize import deserialize_object from twinkle.utils.logger import get_logger -from twinkle_client.types.model import (AdapterRequest, AddAdapterRequest, CalculateMetricRequest, CreateRequest, - ForwardOnlyRequest, ForwardRequest, GetStateDictRequest, HeartbeatRequest, - LoadRequest, SaveRequest, SetLossRequest, SetLrSchedulerRequest, - SetOptimizerRequest, SetProcessorRequest, SetTemplateRequest, - UploadToHubRequest) +import twinkle_client.types as types logger = get_logger() @@ -59,12 +55,12 @@ def _register_twinkle_routes(app: FastAPI, self_fn: Callable[[], ModelManagement replica instance. It is wired in via Depends so it is resolved lazily at request time. """ - @app.post('/twinkle/create') - async def create(request: Request, body: CreateRequest, self: ModelManagement = Depends(self_fn)): - return {'status': 'ok'} + @app.post('/twinkle/create', response_model=types.CreateResponse) + async def create(request: Request, body: types.CreateRequest, self: ModelManagement = Depends(self_fn)) -> types.CreateResponse: + return types.CreateResponse() - @app.post('/twinkle/forward') - async def forward(request: Request, body: ForwardRequest, self: ModelManagement = Depends(self_fn)): + @app.post('/twinkle/forward', response_model=types.ForwardResponse) + async def forward(request: Request, body: types.ForwardRequest, self: ModelManagement = Depends(self_fn)) -> types.ForwardResponse: adapter_name = _get_twinkle_adapter_name(request, body.adapter_name) async def _task(): @@ -76,8 +72,8 @@ async def _task(): return await self.schedule_task_and_wait(_task, task_type='forward') - @app.post('/twinkle/forward_only') - async def forward_only(request: Request, body: ForwardOnlyRequest, self: ModelManagement = Depends(self_fn)): + @app.post('/twinkle/forward_only', response_model=types.ForwardResponse) + async def forward_only(request: Request, body: types.ForwardOnlyRequest, self: ModelManagement = Depends(self_fn)) -> types.ForwardResponse: adapter_name = _get_twinkle_adapter_name(request, body.adapter_name) async def _task(): @@ -89,8 +85,8 @@ async def _task(): return await self.schedule_task_and_wait(_task, task_type='forward_only') - @app.post('/twinkle/calculate_loss') - async def calculate_loss(request: Request, body: AdapterRequest, self: ModelManagement = Depends(self_fn)): + @app.post('/twinkle/calculate_loss', response_model=types.CalculateLossResponse) + async def calculate_loss(request: Request, body: types.AdapterRequest, self: ModelManagement = Depends(self_fn)) -> types.CalculateLossResponse: adapter_name = _get_twinkle_adapter_name(request, body.adapter_name) async def _task(): @@ -101,8 +97,8 @@ async def _task(): return await self.schedule_task_and_wait(_task, task_type='calculate_loss') - @app.post('/twinkle/backward') - async def backward(request: Request, body: AdapterRequest, self: ModelManagement = Depends(self_fn)): + @app.post('/twinkle/backward', response_model=types.BackwardResponse) + async def backward(request: Request, body: types.AdapterRequest, self: ModelManagement = Depends(self_fn)) -> types.BackwardResponse: adapter_name = _get_twinkle_adapter_name(request, body.adapter_name) async def _task(): @@ -113,8 +109,8 @@ async def _task(): return await self.schedule_task_and_wait(_task, task_type='backward') - @app.post('/twinkle/forward_backward') - async def forward_backward(request: Request, body: ForwardRequest, self: ModelManagement = Depends(self_fn)): + @app.post('/twinkle/forward_backward', response_model=types.ForwardBackwardResponse) + async def forward_backward(request: Request, body: types.ForwardRequest, self: ModelManagement = Depends(self_fn)) -> types.ForwardBackwardResponse: adapter_name = _get_twinkle_adapter_name(request, body.adapter_name) async def _task(): @@ -126,8 +122,8 @@ async def _task(): return await self.schedule_task_and_wait(_task, task_type='forward_backward') - @app.post('/twinkle/clip_grad_norm') - async def clip_grad_norm(request: Request, body: AdapterRequest, self: ModelManagement = Depends(self_fn)): + @app.post('/twinkle/clip_grad_norm', response_model=types.ClipGradNormResponse) + async def clip_grad_norm(request: Request, body: types.AdapterRequest, self: ModelManagement = Depends(self_fn)) -> types.ClipGradNormResponse: adapter_name = _get_twinkle_adapter_name(request, body.adapter_name) async def _task(): @@ -138,8 +134,8 @@ async def _task(): return await self.schedule_task_and_wait(_task, task_type='clip_grad_norm') - @app.post('/twinkle/step') - async def step(request: Request, body: AdapterRequest, self: ModelManagement = Depends(self_fn)): + @app.post('/twinkle/step', response_model=types.StepResponse) + async def step(request: Request, body: types.AdapterRequest, self: ModelManagement = Depends(self_fn)) -> types.StepResponse: adapter_name = _get_twinkle_adapter_name(request, body.adapter_name) async def _task(): @@ -150,8 +146,8 @@ async def _task(): return await self.schedule_task_and_wait(_task, task_type='step') - @app.post('/twinkle/zero_grad') - async def zero_grad(request: Request, body: AdapterRequest, self: ModelManagement = Depends(self_fn)): + @app.post('/twinkle/zero_grad', response_model=types.ZeroGradResponse) + async def zero_grad(request: Request, body: types.AdapterRequest, self: ModelManagement = Depends(self_fn)) -> types.ZeroGradResponse: adapter_name = _get_twinkle_adapter_name(request, body.adapter_name) async def _task(): @@ -162,8 +158,8 @@ async def _task(): return await self.schedule_task_and_wait(_task, task_type='zero_grad') - @app.post('/twinkle/lr_step') - async def lr_step(request: Request, body: AdapterRequest, self: ModelManagement = Depends(self_fn)): + @app.post('/twinkle/lr_step', response_model=types.LrStepResponse) + async def lr_step(request: Request, body: types.AdapterRequest, self: ModelManagement = Depends(self_fn)) -> types.LrStepResponse: adapter_name = _get_twinkle_adapter_name(request, body.adapter_name) async def _task(): @@ -174,8 +170,8 @@ async def _task(): return await self.schedule_task_and_wait(_task, task_type='lr_step') - @app.post('/twinkle/get_train_configs') - async def get_train_configs(request: Request, body: AdapterRequest, self: ModelManagement = Depends(self_fn)): + @app.post('/twinkle/get_train_configs', response_model=types.GetTrainConfigsResponse) + async def get_train_configs(request: Request, body: types.AdapterRequest, self: ModelManagement = Depends(self_fn)) -> types.GetTrainConfigsResponse: adapter_name = _get_twinkle_adapter_name(request, body.adapter_name) async def _task(): @@ -186,8 +182,8 @@ async def _task(): return await self.schedule_task_and_wait(_task, task_type='get_train_configs') - @app.post('/twinkle/set_loss') - async def set_loss(request: Request, body: SetLossRequest, self: ModelManagement = Depends(self_fn)): + @app.post('/twinkle/set_loss', response_model=types.SetLossResponse) + async def set_loss(request: Request, body: types.SetLossRequest, self: ModelManagement = Depends(self_fn)) -> types.SetLossResponse: adapter_name = _get_twinkle_adapter_name(request, body.adapter_name) async def _task(): @@ -198,8 +194,8 @@ async def _task(): return await self.schedule_task_and_wait(_task, task_type='set_loss') - @app.post('/twinkle/set_optimizer') - async def set_optimizer(request: Request, body: SetOptimizerRequest, self: ModelManagement = Depends(self_fn)): + @app.post('/twinkle/set_optimizer', response_model=types.SetOptimizerResponse) + async def set_optimizer(request: Request, body: types.SetOptimizerRequest, self: ModelManagement = Depends(self_fn)) -> types.SetOptimizerResponse: adapter_name = _get_twinkle_adapter_name(request, body.adapter_name) async def _task(): @@ -210,8 +206,8 @@ async def _task(): return await self.schedule_task_and_wait(_task, task_type='set_optimizer') - @app.post('/twinkle/set_lr_scheduler') - async def set_lr_scheduler(request: Request, body: SetLrSchedulerRequest, self: ModelManagement = Depends(self_fn)): + @app.post('/twinkle/set_lr_scheduler', response_model=types.SetLrSchedulerResponse) + async def set_lr_scheduler(request: Request, body: types.SetLrSchedulerRequest, self: ModelManagement = Depends(self_fn)) -> types.SetLrSchedulerResponse: adapter_name = _get_twinkle_adapter_name(request, body.adapter_name) async def _task(): @@ -222,8 +218,8 @@ async def _task(): return await self.schedule_task_and_wait(_task, task_type='set_lr_scheduler') - @app.post('/twinkle/save') - async def save(request: Request, body: SaveRequest, self: ModelManagement = Depends(self_fn)): + @app.post('/twinkle/save', response_model=types.SaveResponse) + async def save(request: Request, body: types.SaveRequest, self: ModelManagement = Depends(self_fn)) -> types.SaveResponse: token = await self._on_request_start(request) adapter_name = _get_twinkle_adapter_name(request, body.adapter_name) @@ -244,8 +240,8 @@ async def _task(): return await self.schedule_task_and_wait(_task, task_type='save') - @app.post('/twinkle/load') - async def load(request: Request, body: LoadRequest, self: ModelManagement = Depends(self_fn)): + @app.post('/twinkle/load', response_model=types.LoadResponse) + async def load(request: Request, body: types.LoadRequest, self: ModelManagement = Depends(self_fn)) -> types.LoadResponse: token = await self._on_request_start(request) adapter_name = _get_twinkle_adapter_name(request, body.adapter_name) @@ -265,8 +261,8 @@ async def _task(): return await self.schedule_task_and_wait(_task, task_type='load') - @app.post('/twinkle/upload_to_hub') - async def upload_to_hub(request: Request, body: UploadToHubRequest, self: ModelManagement = Depends(self_fn)): + @app.post('/twinkle/upload_to_hub', response_model=types.UploadToHubResponse) + async def upload_to_hub(request: Request, body: types.UploadToHubRequest, self: ModelManagement = Depends(self_fn)) -> types.UploadToHubResponse: token = await self._on_request_start(request) async def _task(): @@ -293,8 +289,8 @@ async def _task(): return await self.schedule_task_and_wait(_task, task_type='upload_to_hub') - @app.post('/twinkle/add_adapter_to_model') - async def add_adapter_to_model(request: Request, body: AddAdapterRequest, self: ModelManagement = Depends(self_fn)): + @app.post('/twinkle/add_adapter_to_model', response_model=types.AddAdapterResponse) + async def add_adapter_to_model(request: Request, body: types.AddAdapterRequest, self: ModelManagement = Depends(self_fn)) -> types.AddAdapterResponse: assert body.adapter_name, 'You need to specify a valid `adapter_name`' token = await self._on_request_start(request) adapter_name = _get_twinkle_adapter_name(request, body.adapter_name) @@ -317,8 +313,8 @@ async def _task(): return await self.schedule_task_and_wait(_task, task_type='add_adapter_to_model') - @app.post('/twinkle/set_template') - async def set_template(request: Request, body: SetTemplateRequest, self: ModelManagement = Depends(self_fn)): + @app.post('/twinkle/set_template', response_model=types.SetTemplateResponse) + async def set_template(request: Request, body: types.SetTemplateRequest, self: ModelManagement = Depends(self_fn)) -> types.SetTemplateResponse: adapter_name = _get_twinkle_adapter_name(request, body.adapter_name) async def _task(): @@ -329,8 +325,8 @@ async def _task(): return await self.schedule_task_and_wait(_task, task_type='set_template') - @app.post('/twinkle/set_processor') - async def set_processor(request: Request, body: SetProcessorRequest, self: ModelManagement = Depends(self_fn)): + @app.post('/twinkle/set_processor', response_model=types.SetProcessorResponse) + async def set_processor(request: Request, body: types.SetProcessorRequest, self: ModelManagement = Depends(self_fn)) -> types.SetProcessorResponse: adapter_name = _get_twinkle_adapter_name(request, body.adapter_name) async def _task(): @@ -341,16 +337,16 @@ async def _task(): return await self.schedule_task_and_wait(_task, task_type='set_processor') - @app.post('/twinkle/heartbeat') - async def heartbeat(request: Request, body: HeartbeatRequest, self: ModelManagement = Depends(self_fn)): + @app.post('/twinkle/heartbeat', response_model=types.HeartbeatResponse) + async def heartbeat(request: Request, body: types.HeartbeatRequest, self: ModelManagement = Depends(self_fn)) -> types.HeartbeatResponse: adapter_name = _get_twinkle_adapter_name(request, body.adapter_name) self.assert_adapter_exists(adapter_name=adapter_name) self.touch_adapter(adapter_name) - return {'status': 'ok'} + return types.HeartbeatResponse() - @app.post('/twinkle/calculate_metric') - async def calculate_metric(request: Request, body: CalculateMetricRequest, - self: ModelManagement = Depends(self_fn)): + @app.post('/twinkle/calculate_metric', response_model=types.CalculateMetricResponse) + async def calculate_metric(request: Request, body: types.CalculateMetricRequest, + self: ModelManagement = Depends(self_fn)) -> types.CalculateMetricResponse: adapter_name = _get_twinkle_adapter_name(request, body.adapter_name) async def _task(): @@ -361,8 +357,8 @@ async def _task(): return await self.schedule_task_and_wait(_task, task_type='calculate_metric') - @app.post('/twinkle/get_state_dict') - async def get_state_dict(request: Request, body: GetStateDictRequest, self: ModelManagement = Depends(self_fn)): + @app.post('/twinkle/get_state_dict', response_model=types.GetStateDictResponse) + async def get_state_dict(request: Request, body: types.GetStateDictRequest, self: ModelManagement = Depends(self_fn)) -> types.GetStateDictResponse: adapter_name = _get_twinkle_adapter_name(request, body.adapter_name) async def _task(): diff --git a/src/twinkle/server/processor/app.py b/src/twinkle/server/processor/app.py index fb55b453..14c79556 100644 --- a/src/twinkle/server/processor/app.py +++ b/src/twinkle/server/processor/app.py @@ -18,7 +18,7 @@ from twinkle.server.common.serialize import deserialize_object from twinkle.server.utils.state import ServerStateProxy, get_server_state from twinkle.server.utils.validation import verify_request_token -from twinkle_client.types.processor import ProcessorCallRequest, ProcessorCreateRequest, ProcessorHeartbeatRequest +import twinkle_client.types as types logger = get_logger() @@ -108,8 +108,8 @@ def handle_processor_count(self, token: str, add: bool): if cur_count <= 0: self.state.pop_config(user_key) - @app.post('/twinkle/create') - def create(self, request: Request, body: ProcessorCreateRequest): + @app.post('/twinkle/create', response_model=types.ProcessorCreateResponse) + def create(self, request: Request, body: types.ProcessorCreateRequest) -> types.ProcessorCreateResponse: processor_type_name = body.processor_type class_type = body.class_type _kwargs = body.model_extra or {} @@ -140,18 +140,18 @@ def create(self, request: Request, body: ProcessorCreateRequest): **resolved_kwargs) self.resource_dict[processor_id] = processor self.resource_records[processor_id] = 0 - return {'processor_id': 'pid:' + processor_id} + return types.ProcessorCreateResponse(processor_id='pid:' + processor_id) - @app.post('/twinkle/heartbeat') - def heartbeat(self, body: ProcessorHeartbeatRequest): + @app.post('/twinkle/heartbeat', response_model=types.ProcessorHeartbeatResponse) + def heartbeat(self, body: types.ProcessorHeartbeatRequest) -> types.ProcessorHeartbeatResponse: processor_ids = body.processor_id.split(',') for _id in processor_ids: if _id and _id in self.resource_dict: self.resource_records[_id] = 0 - return {'status': 'ok'} + return types.ProcessorHeartbeatResponse() - @app.post('/twinkle/call') - def call(self, body: ProcessorCallRequest): + @app.post('/twinkle/call', response_model=types.ProcessorCallResponse) + def call(self, body: types.ProcessorCallRequest) -> types.ProcessorCallResponse: processor_id = body.processor_id function_name = body.function _kwargs = body.model_extra or {} @@ -176,16 +176,16 @@ def call(self, body: ProcessorCallRequest): if function_name == '__next__': try: result = function(**resolved_kwargs) - return {'result': result} + return types.ProcessorCallResponse(result=result) except StopIteration: # HTTP 410 Gone signals iterator exhausted raise HTTPException(status_code=410, detail='Iterator exhausted') result = function(**resolved_kwargs) if function_name == '__iter__': - return {'result': 'ok'} + return types.ProcessorCallResponse(result='ok') else: - return {'result': result} + return types.ProcessorCallResponse(result=result) return ProcessorManagement.options(**deploy_options).bind(nproc_per_node, ncpu_proc_per_node, device_group, device_mesh) diff --git a/src/twinkle/server/sampler/twinkle_handlers.py b/src/twinkle/server/sampler/twinkle_handlers.py index f2aea1cb..fed9361d 100644 --- a/src/twinkle/server/sampler/twinkle_handlers.py +++ b/src/twinkle/server/sampler/twinkle_handlers.py @@ -15,9 +15,7 @@ from twinkle.data_format import InputFeature, SamplingParams, Trajectory from twinkle.utils.logger import get_logger -from twinkle_client.types.sampler import (AddAdapterRequest, AddAdapterResponse, CreateResponse, HeartbeatRequest, - HeartbeatResponse, SampleRequest, SampleResponseModel, SetTemplateRequest, - SetTemplateResponse) +import twinkle_client.types as types logger = get_logger() @@ -36,14 +34,14 @@ def _register_twinkle_sampler_routes(app: FastAPI, self_fn: Callable[[], Sampler It is wired in via Depends so it is resolved lazily at request time. """ - @app.post('/twinkle/create', response_model=CreateResponse) - def create(request: Request, self: SamplerManagement = Depends(self_fn)) -> CreateResponse: + @app.post('/twinkle/create', response_model=types.CreateResponse) + def create(request: Request, self: SamplerManagement = Depends(self_fn)) -> types.CreateResponse: """Health check / session creation endpoint.""" - return CreateResponse() + return types.CreateResponse() - @app.post('/twinkle/sample', response_model=SampleResponseModel) - def sample(request: Request, body: SampleRequest, - self: SamplerManagement = Depends(self_fn)) -> SampleResponseModel: + @app.post('/twinkle/sample', response_model=types.SampleResponseModel) + def sample(request: Request, body: types.SampleRequest, + self: SamplerManagement = Depends(self_fn)) -> types.SampleResponseModel: """Sample completions from the model. Supports Trajectory or InputFeature inputs, with optional LoRA adapter. @@ -99,7 +97,7 @@ def sample(request: Request, body: SampleRequest, 'logprobs': list(seq.logprobs) if seq.logprobs is not None else None, }) - return SampleResponseModel( + return types.SampleResponseModel( sequences=sequences, prompt_logprobs=response.prompt_logprobs, topk_prompt_logprobs=response.topk_prompt_logprobs, @@ -108,20 +106,20 @@ def sample(request: Request, body: SampleRequest, logger.error(traceback.format_exc()) raise - @app.post('/twinkle/set_template', response_model=SetTemplateResponse) - def set_template(request: Request, body: SetTemplateRequest, - self: SamplerManagement = Depends(self_fn)) -> SetTemplateResponse: + @app.post('/twinkle/set_template', response_model=types.SetTemplateResponse) + def set_template(request: Request, body: types.SetTemplateRequest, + self: SamplerManagement = Depends(self_fn)) -> types.SetTemplateResponse: """Set the chat template for encoding Trajectory inputs.""" extra_kwargs = body.model_extra or {} self.sampler.set_template(body.template_cls, **extra_kwargs) - return SetTemplateResponse() + return types.SetTemplateResponse() - @app.post('/twinkle/add_adapter_to_sampler', response_model=AddAdapterResponse) + @app.post('/twinkle/add_adapter_to_sampler', response_model=types.AddAdapterResponse) def add_adapter_to_sampler( request: Request, - body: AddAdapterRequest, + body: types.AddAdapterRequest, self: SamplerManagement = Depends(self_fn), - ) -> AddAdapterResponse: + ) -> types.AddAdapterResponse: """Add a LoRA adapter to the sampler.""" assert body.adapter_name, 'You need to specify a valid `adapter_name`' full_adapter_name = _get_twinkle_sampler_adapter_name(request, body.adapter_name) @@ -134,13 +132,13 @@ def add_adapter_to_sampler( self.register_adapter(full_adapter_name, token) self.sampler.add_adapter_to_sampler(full_adapter_name, config) - return AddAdapterResponse(adapter_name=full_adapter_name) + return types.AddAdapterResponse(adapter_name=full_adapter_name) - @app.post('/twinkle/heartbeat', response_model=HeartbeatResponse) - def heartbeat(request: Request, body: HeartbeatRequest, - self: SamplerManagement = Depends(self_fn)) -> HeartbeatResponse: + @app.post('/twinkle/heartbeat', response_model=types.HeartbeatResponse) + def heartbeat(request: Request, body: types.HeartbeatRequest, + self: SamplerManagement = Depends(self_fn)) -> types.HeartbeatResponse: """Keep an adapter alive by resetting its inactivity timer.""" full_adapter_name = _get_twinkle_sampler_adapter_name(request, body.adapter_name) self.assert_adapter_exists(adapter_name=full_adapter_name) self.touch_adapter(full_adapter_name) - return HeartbeatResponse() + return types.HeartbeatResponse() diff --git a/src/twinkle_client/types/__init__.py b/src/twinkle_client/types/__init__.py index c72ca1dd..7a38edbf 100644 --- a/src/twinkle_client/types/__init__.py +++ b/src/twinkle_client/types/__init__.py @@ -1,28 +1,72 @@ # Copyright (c) ModelScope Contributors. All rights reserved. from .model import ( + AddAdapterRequest, + AddAdapterResponse, + AdapterRequest, BackwardResponse, CalculateLossResponse, + CalculateMetricRequest, CalculateMetricResponse, ClipGradNormResponse, + CreateRequest, + CreateResponse, ForwardBackwardResponse, + ForwardOnlyRequest, + ForwardRequest, ForwardResponse, + GetStateDictRequest, GetStateDictResponse, GetTrainConfigsResponse, + HeartbeatRequest, + HeartbeatResponse, + LoadRequest, LoadResponse, LrStepResponse, ModelResult, + SaveRequest, SaveResponse, + SetLossRequest, SetLossResponse, + SetLrSchedulerRequest, SetLrSchedulerResponse, + SetOptimizerRequest, SetOptimizerResponse, + SetProcessorRequest, SetProcessorResponse, + SetTemplateRequest, SetTemplateResponse, StepResponse, + UploadToHubRequest, UploadToHubResponse, ZeroGradResponse, ) -from .sampler import AddAdapterResponse, SampleResponseModel, SetTemplateResponse as SamplerSetTemplateResponse -from .server import DeleteCheckpointResponse, ErrorResponse, HealthResponse, WeightsInfoRequest +from .processor import ( + ProcessorCallRequest, + ProcessorCallResponse, + ProcessorCreateRequest, + ProcessorCreateResponse, + ProcessorHeartbeatRequest, + ProcessorHeartbeatResponse, +) +from .sampler import ( + AddAdapterRequest as SamplerAddAdapterRequest, + AddAdapterResponse, + CreateResponse as SamplerCreateResponse, + HeartbeatRequest as SamplerHeartbeatRequest, + HeartbeatResponse as SamplerHeartbeatResponse, + SampleRequest, + SampleResponseModel, + SetTemplateRequest as SamplerSetTemplateRequest, + SetTemplateResponse as SamplerSetTemplateResponse, +) +from .server import ( + CheckpointPathResponse, + DeleteCheckpointResponse, + ErrorResponse, + HealthResponse, + WeightsInfoRequest, + WeightsInfoResponse as ServerWeightsInfoResponse, +) from .session import CreateSessionRequest, CreateSessionResponse, SessionHeartbeatRequest, SessionHeartbeatResponse from .training import ( Checkpoint, diff --git a/src/twinkle_client/types/model.py b/src/twinkle_client/types/model.py index 7ae2e923..e10479c3 100644 --- a/src/twinkle_client/types/model.py +++ b/src/twinkle_client/types/model.py @@ -239,3 +239,19 @@ class UploadToHubResponse(BaseModel): class Config: extra = 'allow' + + +class CreateResponse(BaseModel): + """Response for /create endpoint.""" + status: str = 'ok' + + +class AddAdapterResponse(BaseModel): + """Response for /add_adapter_to_model endpoint.""" + status: str = 'ok' + adapter_name: str + + +class HeartbeatResponse(BaseModel): + """Response for /heartbeat endpoint.""" + status: str = 'ok' diff --git a/src/twinkle_client/types/processor.py b/src/twinkle_client/types/processor.py index feac393e..fe8674ce 100644 --- a/src/twinkle_client/types/processor.py +++ b/src/twinkle_client/types/processor.py @@ -8,6 +8,7 @@ importing from twinkle_client.types alongside model.py classes. """ from pydantic import BaseModel +from typing import Any class ProcessorCreateRequest(BaseModel): @@ -28,3 +29,18 @@ class ProcessorCallRequest(BaseModel): class Config: extra = 'allow' + + +class ProcessorCreateResponse(BaseModel): + """Response body for the /create endpoint.""" + processor_id: str + + +class ProcessorHeartbeatResponse(BaseModel): + """Response body for the /heartbeat endpoint.""" + status: str = 'ok' + + +class ProcessorCallResponse(BaseModel): + """Response body for the /call endpoint.""" + result: Any diff --git a/src/twinkle_client/types/server.py b/src/twinkle_client/types/server.py index 058da8d8..df7ed58a 100644 --- a/src/twinkle_client/types/server.py +++ b/src/twinkle_client/types/server.py @@ -1,6 +1,7 @@ # Copyright (c) ModelScope Contributors. All rights reserved. """Shared Pydantic response models for the twinkle server health/error endpoints.""" from pydantic import BaseModel +from typing import Any class HealthResponse(BaseModel): @@ -18,3 +19,14 @@ class ErrorResponse(BaseModel): class WeightsInfoRequest(BaseModel): twinkle_path: str + + +class WeightsInfoResponse(BaseModel): + """Response body for the /weights_info endpoint.""" + weights_info: Any + + +class CheckpointPathResponse(BaseModel): + """Response body for the /checkpoint_path endpoint.""" + path: str + twinkle_path: str From 6a4d6ed40962ff44b382f5534e02d2de2ed8288d Mon Sep 17 00:00:00 2001 From: Yunnglin Date: Thu, 12 Mar 2026 10:00:14 +0800 Subject: [PATCH 11/24] update twinkle --- .../gateway/twinkle_gateway_handlers.py | 9 +- src/twinkle/server/model/twinkle_handlers.py | 114 ++++++++++++++---- src/twinkle/server/processor/app.py | 2 +- .../server/sampler/twinkle_handlers.py | 9 +- 4 files changed, 102 insertions(+), 32 deletions(-) diff --git a/src/twinkle/server/gateway/twinkle_gateway_handlers.py b/src/twinkle/server/gateway/twinkle_gateway_handlers.py index 8a159be3..a7323446 100644 --- a/src/twinkle/server/gateway/twinkle_gateway_handlers.py +++ b/src/twinkle/server/gateway/twinkle_gateway_handlers.py @@ -12,11 +12,11 @@ if TYPE_CHECKING: from .server import GatewayServer +import twinkle_client.types as types from twinkle.server.common.checkpoint_factory import create_checkpoint_manager, create_training_run_manager from twinkle.server.utils.checkpoint_base import validate_user_path from twinkle.server.utils.validation import get_token_from_request from twinkle.utils.logger import get_logger -import twinkle_client.types as types logger = get_logger() @@ -72,8 +72,11 @@ async def get_run_checkpoints(request: Request, run_id: str) -> types.Checkpoint raise HTTPException(status_code=404, detail=f'Training run {run_id} not found or access denied') return response - @app.delete('/twinkle/training_runs/{run_id}/checkpoints/{checkpoint_id:path}', response_model=types.DeleteCheckpointResponse) - async def delete_run_checkpoint(request: Request, run_id: str, checkpoint_id: str) -> types.DeleteCheckpointResponse: + @app.delete( + '/twinkle/training_runs/{run_id}/checkpoints/{checkpoint_id:path}', + response_model=types.DeleteCheckpointResponse) + async def delete_run_checkpoint(request: Request, run_id: str, + checkpoint_id: str) -> types.DeleteCheckpointResponse: token = get_token_from_request(request) if not validate_user_path(token, checkpoint_id): diff --git a/src/twinkle/server/model/twinkle_handlers.py b/src/twinkle/server/model/twinkle_handlers.py index 4aad734a..660d68b2 100644 --- a/src/twinkle/server/model/twinkle_handlers.py +++ b/src/twinkle/server/model/twinkle_handlers.py @@ -16,11 +16,11 @@ if TYPE_CHECKING: from .app import ModelManagement +import twinkle_client.types as types from twinkle.data_format import InputFeature, Trajectory from twinkle.server.common.checkpoint_factory import create_checkpoint_manager, create_training_run_manager from twinkle.server.common.serialize import deserialize_object from twinkle.utils.logger import get_logger -import twinkle_client.types as types logger = get_logger() @@ -56,11 +56,13 @@ def _register_twinkle_routes(app: FastAPI, self_fn: Callable[[], ModelManagement """ @app.post('/twinkle/create', response_model=types.CreateResponse) - async def create(request: Request, body: types.CreateRequest, self: ModelManagement = Depends(self_fn)) -> types.CreateResponse: + async def create(request: Request, body: types.CreateRequest, + self: ModelManagement = Depends(self_fn)) -> types.CreateResponse: return types.CreateResponse() @app.post('/twinkle/forward', response_model=types.ForwardResponse) - async def forward(request: Request, body: types.ForwardRequest, self: ModelManagement = Depends(self_fn)) -> types.ForwardResponse: + async def forward(request: Request, body: types.ForwardRequest, + self: ModelManagement = Depends(self_fn)) -> types.ForwardResponse: adapter_name = _get_twinkle_adapter_name(request, body.adapter_name) async def _task(): @@ -73,7 +75,11 @@ async def _task(): return await self.schedule_task_and_wait(_task, task_type='forward') @app.post('/twinkle/forward_only', response_model=types.ForwardResponse) - async def forward_only(request: Request, body: types.ForwardOnlyRequest, self: ModelManagement = Depends(self_fn)) -> types.ForwardResponse: + async def forward_only( + request: Request, + body: types.ForwardOnlyRequest, + self: ModelManagement = Depends(self_fn), + ) -> types.ForwardResponse: adapter_name = _get_twinkle_adapter_name(request, body.adapter_name) async def _task(): @@ -86,7 +92,11 @@ async def _task(): return await self.schedule_task_and_wait(_task, task_type='forward_only') @app.post('/twinkle/calculate_loss', response_model=types.CalculateLossResponse) - async def calculate_loss(request: Request, body: types.AdapterRequest, self: ModelManagement = Depends(self_fn)) -> types.CalculateLossResponse: + async def calculate_loss( + request: Request, + body: types.AdapterRequest, + self: ModelManagement = Depends(self_fn), + ) -> types.CalculateLossResponse: adapter_name = _get_twinkle_adapter_name(request, body.adapter_name) async def _task(): @@ -98,7 +108,8 @@ async def _task(): return await self.schedule_task_and_wait(_task, task_type='calculate_loss') @app.post('/twinkle/backward', response_model=types.BackwardResponse) - async def backward(request: Request, body: types.AdapterRequest, self: ModelManagement = Depends(self_fn)) -> types.BackwardResponse: + async def backward(request: Request, body: types.AdapterRequest, + self: ModelManagement = Depends(self_fn)) -> types.BackwardResponse: adapter_name = _get_twinkle_adapter_name(request, body.adapter_name) async def _task(): @@ -110,7 +121,11 @@ async def _task(): return await self.schedule_task_and_wait(_task, task_type='backward') @app.post('/twinkle/forward_backward', response_model=types.ForwardBackwardResponse) - async def forward_backward(request: Request, body: types.ForwardRequest, self: ModelManagement = Depends(self_fn)) -> types.ForwardBackwardResponse: + async def forward_backward( + request: Request, + body: types.ForwardRequest, + self: ModelManagement = Depends(self_fn), + ) -> types.ForwardBackwardResponse: adapter_name = _get_twinkle_adapter_name(request, body.adapter_name) async def _task(): @@ -123,7 +138,11 @@ async def _task(): return await self.schedule_task_and_wait(_task, task_type='forward_backward') @app.post('/twinkle/clip_grad_norm', response_model=types.ClipGradNormResponse) - async def clip_grad_norm(request: Request, body: types.AdapterRequest, self: ModelManagement = Depends(self_fn)) -> types.ClipGradNormResponse: + async def clip_grad_norm( + request: Request, + body: types.AdapterRequest, + self: ModelManagement = Depends(self_fn), + ) -> types.ClipGradNormResponse: adapter_name = _get_twinkle_adapter_name(request, body.adapter_name) async def _task(): @@ -135,7 +154,8 @@ async def _task(): return await self.schedule_task_and_wait(_task, task_type='clip_grad_norm') @app.post('/twinkle/step', response_model=types.StepResponse) - async def step(request: Request, body: types.AdapterRequest, self: ModelManagement = Depends(self_fn)) -> types.StepResponse: + async def step(request: Request, body: types.AdapterRequest, + self: ModelManagement = Depends(self_fn)) -> types.StepResponse: adapter_name = _get_twinkle_adapter_name(request, body.adapter_name) async def _task(): @@ -147,7 +167,8 @@ async def _task(): return await self.schedule_task_and_wait(_task, task_type='step') @app.post('/twinkle/zero_grad', response_model=types.ZeroGradResponse) - async def zero_grad(request: Request, body: types.AdapterRequest, self: ModelManagement = Depends(self_fn)) -> types.ZeroGradResponse: + async def zero_grad(request: Request, body: types.AdapterRequest, + self: ModelManagement = Depends(self_fn)) -> types.ZeroGradResponse: adapter_name = _get_twinkle_adapter_name(request, body.adapter_name) async def _task(): @@ -159,7 +180,8 @@ async def _task(): return await self.schedule_task_and_wait(_task, task_type='zero_grad') @app.post('/twinkle/lr_step', response_model=types.LrStepResponse) - async def lr_step(request: Request, body: types.AdapterRequest, self: ModelManagement = Depends(self_fn)) -> types.LrStepResponse: + async def lr_step(request: Request, body: types.AdapterRequest, + self: ModelManagement = Depends(self_fn)) -> types.LrStepResponse: adapter_name = _get_twinkle_adapter_name(request, body.adapter_name) async def _task(): @@ -171,7 +193,11 @@ async def _task(): return await self.schedule_task_and_wait(_task, task_type='lr_step') @app.post('/twinkle/get_train_configs', response_model=types.GetTrainConfigsResponse) - async def get_train_configs(request: Request, body: types.AdapterRequest, self: ModelManagement = Depends(self_fn)) -> types.GetTrainConfigsResponse: + async def get_train_configs( + request: Request, + body: types.AdapterRequest, + self: ModelManagement = Depends(self_fn), + ) -> types.GetTrainConfigsResponse: adapter_name = _get_twinkle_adapter_name(request, body.adapter_name) async def _task(): @@ -183,7 +209,8 @@ async def _task(): return await self.schedule_task_and_wait(_task, task_type='get_train_configs') @app.post('/twinkle/set_loss', response_model=types.SetLossResponse) - async def set_loss(request: Request, body: types.SetLossRequest, self: ModelManagement = Depends(self_fn)) -> types.SetLossResponse: + async def set_loss(request: Request, body: types.SetLossRequest, + self: ModelManagement = Depends(self_fn)) -> types.SetLossResponse: adapter_name = _get_twinkle_adapter_name(request, body.adapter_name) async def _task(): @@ -195,7 +222,11 @@ async def _task(): return await self.schedule_task_and_wait(_task, task_type='set_loss') @app.post('/twinkle/set_optimizer', response_model=types.SetOptimizerResponse) - async def set_optimizer(request: Request, body: types.SetOptimizerRequest, self: ModelManagement = Depends(self_fn)) -> types.SetOptimizerResponse: + async def set_optimizer( + request: Request, + body: types.SetOptimizerRequest, + self: ModelManagement = Depends(self_fn), + ) -> types.SetOptimizerResponse: adapter_name = _get_twinkle_adapter_name(request, body.adapter_name) async def _task(): @@ -207,7 +238,11 @@ async def _task(): return await self.schedule_task_and_wait(_task, task_type='set_optimizer') @app.post('/twinkle/set_lr_scheduler', response_model=types.SetLrSchedulerResponse) - async def set_lr_scheduler(request: Request, body: types.SetLrSchedulerRequest, self: ModelManagement = Depends(self_fn)) -> types.SetLrSchedulerResponse: + async def set_lr_scheduler( + request: Request, + body: types.SetLrSchedulerRequest, + self: ModelManagement = Depends(self_fn), + ) -> types.SetLrSchedulerResponse: adapter_name = _get_twinkle_adapter_name(request, body.adapter_name) async def _task(): @@ -219,7 +254,8 @@ async def _task(): return await self.schedule_task_and_wait(_task, task_type='set_lr_scheduler') @app.post('/twinkle/save', response_model=types.SaveResponse) - async def save(request: Request, body: types.SaveRequest, self: ModelManagement = Depends(self_fn)) -> types.SaveResponse: + async def save(request: Request, body: types.SaveRequest, + self: ModelManagement = Depends(self_fn)) -> types.SaveResponse: token = await self._on_request_start(request) adapter_name = _get_twinkle_adapter_name(request, body.adapter_name) @@ -241,7 +277,8 @@ async def _task(): return await self.schedule_task_and_wait(_task, task_type='save') @app.post('/twinkle/load', response_model=types.LoadResponse) - async def load(request: Request, body: types.LoadRequest, self: ModelManagement = Depends(self_fn)) -> types.LoadResponse: + async def load(request: Request, body: types.LoadRequest, + self: ModelManagement = Depends(self_fn)) -> types.LoadResponse: token = await self._on_request_start(request) adapter_name = _get_twinkle_adapter_name(request, body.adapter_name) @@ -262,7 +299,11 @@ async def _task(): return await self.schedule_task_and_wait(_task, task_type='load') @app.post('/twinkle/upload_to_hub', response_model=types.UploadToHubResponse) - async def upload_to_hub(request: Request, body: types.UploadToHubRequest, self: ModelManagement = Depends(self_fn)) -> types.UploadToHubResponse: + async def upload_to_hub( + request: Request, + body: types.UploadToHubRequest, + self: ModelManagement = Depends(self_fn), + ) -> types.UploadToHubResponse: token = await self._on_request_start(request) async def _task(): @@ -290,7 +331,11 @@ async def _task(): return await self.schedule_task_and_wait(_task, task_type='upload_to_hub') @app.post('/twinkle/add_adapter_to_model', response_model=types.AddAdapterResponse) - async def add_adapter_to_model(request: Request, body: types.AddAdapterRequest, self: ModelManagement = Depends(self_fn)) -> types.AddAdapterResponse: + async def add_adapter_to_model( + request: Request, + body: types.AddAdapterRequest, + self: ModelManagement = Depends(self_fn), + ) -> types.AddAdapterResponse: assert body.adapter_name, 'You need to specify a valid `adapter_name`' token = await self._on_request_start(request) adapter_name = _get_twinkle_adapter_name(request, body.adapter_name) @@ -314,7 +359,11 @@ async def _task(): return await self.schedule_task_and_wait(_task, task_type='add_adapter_to_model') @app.post('/twinkle/set_template', response_model=types.SetTemplateResponse) - async def set_template(request: Request, body: types.SetTemplateRequest, self: ModelManagement = Depends(self_fn)) -> types.SetTemplateResponse: + async def set_template( + request: Request, + body: types.SetTemplateRequest, + self: ModelManagement = Depends(self_fn), + ) -> types.SetTemplateResponse: adapter_name = _get_twinkle_adapter_name(request, body.adapter_name) async def _task(): @@ -326,7 +375,11 @@ async def _task(): return await self.schedule_task_and_wait(_task, task_type='set_template') @app.post('/twinkle/set_processor', response_model=types.SetProcessorResponse) - async def set_processor(request: Request, body: types.SetProcessorRequest, self: ModelManagement = Depends(self_fn)) -> types.SetProcessorResponse: + async def set_processor( + request: Request, + body: types.SetProcessorRequest, + self: ModelManagement = Depends(self_fn), + ) -> types.SetProcessorResponse: adapter_name = _get_twinkle_adapter_name(request, body.adapter_name) async def _task(): @@ -338,15 +391,22 @@ async def _task(): return await self.schedule_task_and_wait(_task, task_type='set_processor') @app.post('/twinkle/heartbeat', response_model=types.HeartbeatResponse) - async def heartbeat(request: Request, body: types.HeartbeatRequest, self: ModelManagement = Depends(self_fn)) -> types.HeartbeatResponse: + async def heartbeat( + request: Request, + body: types.HeartbeatRequest, + self: ModelManagement = Depends(self_fn), + ) -> types.HeartbeatResponse: adapter_name = _get_twinkle_adapter_name(request, body.adapter_name) self.assert_adapter_exists(adapter_name=adapter_name) self.touch_adapter(adapter_name) return types.HeartbeatResponse() @app.post('/twinkle/calculate_metric', response_model=types.CalculateMetricResponse) - async def calculate_metric(request: Request, body: types.CalculateMetricRequest, - self: ModelManagement = Depends(self_fn)) -> types.CalculateMetricResponse: + async def calculate_metric( + request: Request, + body: types.CalculateMetricRequest, + self: ModelManagement = Depends(self_fn), + ) -> types.CalculateMetricResponse: adapter_name = _get_twinkle_adapter_name(request, body.adapter_name) async def _task(): @@ -358,7 +418,11 @@ async def _task(): return await self.schedule_task_and_wait(_task, task_type='calculate_metric') @app.post('/twinkle/get_state_dict', response_model=types.GetStateDictResponse) - async def get_state_dict(request: Request, body: types.GetStateDictRequest, self: ModelManagement = Depends(self_fn)) -> types.GetStateDictResponse: + async def get_state_dict( + request: Request, + body: types.GetStateDictRequest, + self: ModelManagement = Depends(self_fn), + ) -> types.GetStateDictResponse: adapter_name = _get_twinkle_adapter_name(request, body.adapter_name) async def _task(): diff --git a/src/twinkle/server/processor/app.py b/src/twinkle/server/processor/app.py index 14c79556..64adb571 100644 --- a/src/twinkle/server/processor/app.py +++ b/src/twinkle/server/processor/app.py @@ -14,11 +14,11 @@ from typing import Any, Dict import twinkle +import twinkle_client.types as types from twinkle import DeviceGroup, DeviceMesh, get_logger from twinkle.server.common.serialize import deserialize_object from twinkle.server.utils.state import ServerStateProxy, get_server_state from twinkle.server.utils.validation import verify_request_token -import twinkle_client.types as types logger = get_logger() diff --git a/src/twinkle/server/sampler/twinkle_handlers.py b/src/twinkle/server/sampler/twinkle_handlers.py index fed9361d..7f63fc08 100644 --- a/src/twinkle/server/sampler/twinkle_handlers.py +++ b/src/twinkle/server/sampler/twinkle_handlers.py @@ -13,9 +13,9 @@ if TYPE_CHECKING: from .app import SamplerManagement +import twinkle_client.types as types from twinkle.data_format import InputFeature, SamplingParams, Trajectory from twinkle.utils.logger import get_logger -import twinkle_client.types as types logger = get_logger() @@ -107,8 +107,11 @@ def sample(request: Request, body: types.SampleRequest, raise @app.post('/twinkle/set_template', response_model=types.SetTemplateResponse) - def set_template(request: Request, body: types.SetTemplateRequest, - self: SamplerManagement = Depends(self_fn)) -> types.SetTemplateResponse: + def set_template( + request: Request, + body: types.SetTemplateRequest, + self: SamplerManagement = Depends(self_fn), + ) -> types.SetTemplateResponse: """Set the chat template for encoding Trajectory inputs.""" extra_kwargs = body.model_extra or {} self.sampler.set_template(body.template_cls, **extra_kwargs) From bb4586a55205b0abe28efd2f6137663eebd48f00 Mon Sep 17 00:00:00 2001 From: Yunnglin Date: Thu, 12 Mar 2026 12:46:41 +0800 Subject: [PATCH 12/24] update twinkle processor --- .../server/transformer/server_config.yaml | 63 +++++++-------- src/twinkle/server/processor/app.py | 48 ++++------- src/twinkle_client/http/__init__.py | 9 +-- src/twinkle_client/http/utils.py | 80 +++++++------------ src/twinkle_client/manager.py | 34 +++++--- 5 files changed, 99 insertions(+), 135 deletions(-) diff --git a/cookbook/client/server/transformer/server_config.yaml b/cookbook/client/server/transformer/server_config.yaml index c5c584ef..f507f9d4 100644 --- a/cookbook/client/server/transformer/server_config.yaml +++ b/cookbook/client/server/transformer/server_config.yaml @@ -33,37 +33,37 @@ applications: # 2. Model Service (commented out) - Would host the base model for training. # Uncomment and configure if you need a training model worker. - - name: models-Qwen3.5-4B - route_prefix: /api/v1/model/Qwen/Qwen3.5-4B - import_path: model - args: - use_megatron: false # Use HuggingFace Transformers backend - model_id: "ms://Qwen/Qwen3.5-4B" # ModelScope model identifier - max_length: 10240 - nproc_per_node: 2 # Number of GPU processes per node - device_group: - name: model - ranks: 2 - device_type: cuda - device_mesh: - device_type: cuda - dp_size: 2 - queue_config: - rps_limit: 100 # Max requests per second - tps_limit: 100000 # Max tokens per second - adapter_config: - adapter_timeout: 30 # Seconds before idle adapter unload - deployments: - - name: ModelManagement - autoscaling_config: - min_replicas: 1 - max_replicas: 1 - target_ongoing_requests: 16 - ray_actor_options: - num_cpus: 0.1 - runtime_env: - env_vars: - TWINKLE_TRUST_REMOTE_CODE: "0" + # - name: models-Qwen3.5-4B + # route_prefix: /api/v1/model/Qwen/Qwen3.5-4B + # import_path: model + # args: + # use_megatron: false # Use HuggingFace Transformers backend + # model_id: "ms://Qwen/Qwen3.5-4B" # ModelScope model identifier + # max_length: 10240 + # nproc_per_node: 1 # Number of GPU processes per node + # device_group: + # name: model + # ranks: 1 + # device_type: cuda + # device_mesh: + # device_type: cuda + # dp_size: 1 + # queue_config: + # rps_limit: 100 # Max requests per second + # tps_limit: 100000 # Max tokens per second + # adapter_config: + # adapter_timeout: 30 # Seconds before idle adapter unload + # deployments: + # - name: ModelManagement + # autoscaling_config: + # min_replicas: 1 + # max_replicas: 1 + # target_ongoing_requests: 16 + # ray_actor_options: + # num_cpus: 0.1 + # runtime_env: + # env_vars: + # TWINKLE_TRUST_REMOTE_CODE: "0" # 3. Sampler Service - Runs inference / sampling using vLLM engine # Used for generating text from the model (e.g., evaluating LoRA results). @@ -106,7 +106,6 @@ applications: route_prefix: /api/v1/processor import_path: processor args: - nproc_per_node: 2 # 每节点处理器 worker 数 ncpu_proc_per_node: 2 # 每节点 CPU 进程数 device_group: name: model diff --git a/src/twinkle/server/processor/app.py b/src/twinkle/server/processor/app.py index 64adb571..649b3768 100644 --- a/src/twinkle/server/processor/app.py +++ b/src/twinkle/server/processor/app.py @@ -7,7 +7,6 @@ """ import importlib import os -import threading import uuid from fastapi import FastAPI, HTTPException, Request from ray import serve @@ -23,16 +22,20 @@ logger = get_logger() -def build_processor_app(nproc_per_node: int, ncpu_proc_per_node: int, device_group: Dict[str, Any], - device_mesh: Dict[str, Any], deploy_options: Dict[str, Any], **kwargs): +def build_processor_app(ncpu_proc_per_node: int, + device_group: Dict[str, Any], + device_mesh: Dict[str, Any], + deploy_options: Dict[str, Any], + nproc_per_node: int = 1, + **kwargs): """Build the processor management application. Args: - nproc_per_node: Number of GPU processes per node ncpu_proc_per_node: Number of CPU processes per node device_group: Device group configuration dict device_mesh: Device mesh configuration dict deploy_options: Ray Serve deployment options + nproc_per_node: Number of GPU processes per node (default 1, not used for CPU-only tasks) **kwargs: Additional arguments Returns: @@ -55,10 +58,11 @@ class ProcessorManagement: (datasets, dataloaders, rewards, templates, etc.). """ - COUNT_DOWN = 60 * 30 - - def __init__(self, nproc_per_node: int, ncpu_proc_per_node: int, device_group: Dict[str, Any], - device_mesh: Dict[str, Any]): + def __init__(self, + ncpu_proc_per_node: int, + device_group: Dict[str, Any], + device_mesh: Dict[str, Any], + nproc_per_node: int = 1): self.device_group = DeviceGroup(**device_group) twinkle.initialize( mode='ray', @@ -71,25 +75,10 @@ def __init__(self, nproc_per_node: int, ncpu_proc_per_node: int, device_group: D else: self.device_mesh = DeviceMesh.from_sizes(**device_mesh) self.resource_dict = {} - self.resource_records: Dict[str, int] = {} - self.hb_thread = threading.Thread(target=self.countdown, daemon=True) - self.hb_thread.start() self.state: ServerStateProxy = get_server_state() self.per_token_processor_limit = int(os.environ.get('TWINKLE_PER_USER_PROCESSOR_LIMIT', 20)) self.key_token_dict = {} - def countdown(self): - import time - while True: - time.sleep(1) - for key in list(self.resource_records.keys()): - self.resource_records[key] += 1 - if self.resource_records[key] > self.COUNT_DOWN: - self.resource_records.pop(key, None) - self.resource_dict.pop(key, None) - if key in self.key_token_dict: - self.handle_processor_count(self.key_token_dict.pop(key), False) - def assert_processor_exists(self, processor_id: str): assert processor_id and processor_id in self.resource_dict, f'Processor {processor_id} not found' @@ -139,17 +128,8 @@ def create(self, request: Request, body: types.ProcessorCreateRequest) -> types. instance_id=processor_id, **resolved_kwargs) self.resource_dict[processor_id] = processor - self.resource_records[processor_id] = 0 return types.ProcessorCreateResponse(processor_id='pid:' + processor_id) - @app.post('/twinkle/heartbeat', response_model=types.ProcessorHeartbeatResponse) - def heartbeat(self, body: types.ProcessorHeartbeatRequest) -> types.ProcessorHeartbeatResponse: - processor_ids = body.processor_id.split(',') - for _id in processor_ids: - if _id and _id in self.resource_dict: - self.resource_records[_id] = 0 - return types.ProcessorHeartbeatResponse() - @app.post('/twinkle/call', response_model=types.ProcessorCallResponse) def call(self, body: types.ProcessorCallRequest) -> types.ProcessorCallResponse: processor_id = body.processor_id @@ -187,5 +167,5 @@ def call(self, body: types.ProcessorCallRequest) -> types.ProcessorCallResponse: else: return types.ProcessorCallResponse(result=result) - return ProcessorManagement.options(**deploy_options).bind(nproc_per_node, ncpu_proc_per_node, device_group, - device_mesh) + return ProcessorManagement.options(**deploy_options).bind( + ncpu_proc_per_node, device_group, device_mesh, nproc_per_node=nproc_per_node) diff --git a/src/twinkle_client/http/__init__.py b/src/twinkle_client/http/__init__.py index 2e6388b7..63880a7f 100644 --- a/src/twinkle_client/http/__init__.py +++ b/src/twinkle_client/http/__init__.py @@ -1,8 +1,7 @@ from .heartbeat import heartbeat_manager from .http_utils import http_delete, http_get, http_post -from .utils import (TWINKLE_SERVER_TOKEN, TWINKLE_SERVER_URL, clear_api_key, clear_base_url, clear_request_id, - clear_session_id, get_api_key, get_base_url, get_request_id, get_session_id, set_api_key, - set_base_url, set_request_id, set_session_id) +from .utils import (TWINKLE_SERVER_TOKEN, TWINKLE_SERVER_URL, get_api_key, get_base_url, get_request_id, + get_session_id, set_api_key, set_base_url, set_request_id, set_session_id) __all__ = [ 'http_get', @@ -13,14 +12,10 @@ 'TWINKLE_SERVER_TOKEN', 'set_base_url', 'get_base_url', - 'clear_base_url', 'set_api_key', 'get_api_key', - 'clear_api_key', 'set_session_id', 'get_session_id', - 'clear_session_id', 'set_request_id', 'get_request_id', - 'clear_request_id', ] diff --git a/src/twinkle_client/http/utils.py b/src/twinkle_client/http/utils.py index e45b0360..f5b34835 100644 --- a/src/twinkle_client/http/utils.py +++ b/src/twinkle_client/http/utils.py @@ -1,88 +1,64 @@ import os import uuid -from contextvars import ContextVar from datetime import datetime from typing import Optional TWINKLE_SERVER_URL = os.environ.get('TWINKLE_SERVER_URL', 'http://127.0.0.1:8000') TWINKLE_SERVER_TOKEN = os.environ.get('TWINKLE_SERVER_TOKEN', 'EMPTY_TOKEN') -# Context variables for flexible configuration -_base_url_context: ContextVar[Optional[str]] = ContextVar('base_url', default=None) -_api_key_context: ContextVar[Optional[str]] = ContextVar('api_key', default=None) -_session_id_context: ContextVar[Optional[str]] = ContextVar('session_id', default=None) - -# Global static request ID shared across all threads -# This ensures heartbeat threads use the same request ID as the main training thread -_global_request_id: Optional[str] = None +# Global variables for configuration +_base_url: Optional[str] = None +_api_key: Optional[str] = None +_session_id: Optional[str] = None +_request_id: Optional[str] = None def set_base_url(url: str): - """Set the base URL for HTTP requests in the current context.""" - _base_url_context.set(url.rstrip('/')) + """Set the base URL for HTTP requests.""" + global _base_url + _base_url = url.rstrip('/') -def get_base_url() -> Optional[str]: - """Get the current base URL from context or environment variable.""" - base_url = _base_url_context.get() or TWINKLE_SERVER_URL - # if not ends with '/api/v1' then append it +def get_base_url() -> str: + """Get the current base URL.""" + base_url = _base_url or TWINKLE_SERVER_URL if not base_url.endswith('/api/v1'): base_url += '/api/v1' return base_url -def clear_base_url(): - """Clear the base URL context, falling back to environment variable.""" - _base_url_context.set(None) - - def set_api_key(api_key: str): - """Set the API key for HTTP requests in the current context.""" - _api_key_context.set(api_key) + """Set the API key for HTTP requests.""" + global _api_key + _api_key = api_key def get_api_key() -> str: - """Get the current API key from context or environment variable.""" - return _api_key_context.get() or TWINKLE_SERVER_TOKEN - - -def clear_api_key(): - """Clear the API key context, falling back to environment variable.""" - _api_key_context.set(None) + """Get the current API key.""" + return _api_key or TWINKLE_SERVER_TOKEN def set_session_id(session_id: str): - """Set the session ID for the current context.""" - _session_id_context.set(session_id) + """Set the session ID.""" + global _session_id + _session_id = session_id def get_session_id() -> Optional[str]: - """Get the current session ID from context.""" - return _session_id_context.get() - - -def clear_session_id(): - """Clear the session ID context.""" - _session_id_context.set(None) + """Get the current session ID.""" + return _session_id def set_request_id(request_id: str): """Set the global request ID for HTTP requests (shared across all threads).""" - global _global_request_id - _global_request_id = request_id + global _request_id + _request_id = request_id def get_request_id() -> str: """Get the global request ID or generate and cache a new one.""" - global _global_request_id - if _global_request_id is not None: - return _global_request_id - # Generate a new request ID and cache it globally for consistency across threads - _global_request_id = datetime.now().strftime('%Y%m%d_%H%M%S') + '-' + str(uuid.uuid4().hex)[0:8] - return _global_request_id - - -def clear_request_id(): - """Clear the global request ID.""" - global _global_request_id - _global_request_id = None + global _request_id + if _request_id is not None: + return _request_id + _request_id = datetime.now().strftime('%Y%m%d_%H%M%S') + '-' + str(uuid.uuid4().hex)[0:8] + return _request_id diff --git a/src/twinkle_client/manager.py b/src/twinkle_client/manager.py index 394714e8..df097083 100644 --- a/src/twinkle_client/manager.py +++ b/src/twinkle_client/manager.py @@ -10,7 +10,7 @@ SessionHeartbeatResponse) from twinkle_client.types.training import (Checkpoint, Cursor, ParsedCheckpointTwinklePath, TrainingRun, TrainingRunsResponse, WeightsInfoResponse) -from .http import get_api_key, get_base_url, http_delete, http_get, http_post, set_api_key, set_base_url, set_session_id, clear_session_id +from .http import get_api_key, get_base_url, http_delete, http_get, http_post, set_api_key, set_base_url, set_session_id class TwinkleClientError(Exception): @@ -46,7 +46,7 @@ def __init__( base_url: Optional[str] = None, api_key: Optional[str] = None, route_prefix: Optional[str] = '/twinkle', - session_heartbeat_interval: int = 30, + session_heartbeat_interval: int = 10, session_metadata: Optional[Dict[str, Any]] = None, ): # Resolve and store config, then propagate to context so all generated @@ -55,18 +55,13 @@ def __init__( set_base_url(base_url) if api_key: set_api_key(api_key) - + self.base_url = get_base_url() self.api_key = get_api_key() self.route_prefix = route_prefix.rstrip('/') if route_prefix else '' # Create a server-side session. - resp = http_post( - self._get_url('/create_session'), - json_data=CreateSessionRequest(metadata=session_metadata).model_dump(), - ) - resp.raise_for_status() - self._session_id: str = CreateSessionResponse(**resp.json()).session_id + self._session_id: str = self.create_session(session_metadata) set_session_id(self._session_id) # Start background session-touch thread. @@ -99,6 +94,26 @@ def _handle_response(self, response, expected_code: int = 200) -> dict[str, Any] raise TwinkleClientError(f'Request failed with status {response.status_code}: {detail}') return response.json() + def create_session(self, metadata: Optional[Dict[str, Any]] = None) -> str: + """ + Create a server-side session. + + Args: + metadata: Optional metadata dict stored with the session on the server. + + Returns: + The session ID string. + + Raises: + TwinkleClientError: If the session creation request fails. + """ + resp = http_post( + self._get_url('/create_session'), + json_data=CreateSessionRequest(metadata=metadata).model_dump(), + ) + resp.raise_for_status() + return CreateSessionResponse(**resp.json()).session_id + def _touch_session_loop(self) -> None: """Background loop: touch the session every N seconds.""" while not self._stop_event.wait(timeout=self._heartbeat_interval): @@ -117,7 +132,6 @@ def close(self) -> None: self._stop_event.set() if self._heartbeat_thread.is_alive(): self._heartbeat_thread.join(timeout=2) - clear_session_id() # ------------------------------------------------------------------ # Health Check From 22f0f1739617bda972b4aa9da2eb79e17bb704d5 Mon Sep 17 00:00:00 2001 From: Yunnglin Date: Thu, 12 Mar 2026 14:20:00 +0800 Subject: [PATCH 13/24] update twinkle model --- client_tools/client_generator.py | 82 ++++----- .../server/transformer/server_config.yaml | 62 +++---- .../twinkle/self_host/self_congnition.py | 31 ++-- .../server/model/backends/megatron_model.py | 10 +- .../model/backends/transformers_model.py | 14 +- src/twinkle/server/model/tinker_handlers.py | 10 +- src/twinkle/server/model/twinkle_handlers.py | 165 +++++++++++------- src/twinkle/server/types/__init__.py | 1 - src/twinkle/server/utils/checkpoint_base.py | 2 +- .../model/multi_lora_transformers.py | 82 ++++----- src/twinkle_client/types/__init__.py | 9 + .../types/checkpoint.py | 0 src/twinkle_client/types/model.py | 135 +++++++++----- 13 files changed, 349 insertions(+), 254 deletions(-) delete mode 100644 src/twinkle/server/types/__init__.py rename src/{twinkle/server => twinkle_client}/types/checkpoint.py (100%) diff --git a/client_tools/client_generator.py b/client_tools/client_generator.py index 60276a4e..c0df54d3 100644 --- a/client_tools/client_generator.py +++ b/client_tools/client_generator.py @@ -440,7 +440,6 @@ def generate_models(): model_code = AUTO_GEN_WARNING + '''from typing import Any, Dict, Optional from twinkle_client.http import http_post from twinkle_client.types.model import ( - BackwardResponse, CalculateLossResponse, CalculateMetricResponse, ClipGradNormResponse, @@ -448,17 +447,7 @@ def generate_models(): ForwardResponse, GetStateDictResponse, GetTrainConfigsResponse, - LoadResponse, - LrStepResponse, SaveResponse, - SetLossResponse, - SetLrSchedulerResponse, - SetOptimizerResponse, - SetProcessorResponse, - SetTemplateResponse, - StepResponse, - UploadToHubResponse, - ZeroGradResponse, ) @@ -529,14 +518,13 @@ def get_train_configs(self, **kwargs) -> GetTrainConfigsResponse: response.raise_for_status() return GetTrainConfigsResponse(**response.json()) - def backward(self, **kwargs) -> BackwardResponse: + def backward(self, **kwargs) -> None: """Execute backward pass.""" response = http_post( url=f'{self.server_url}/backward', json_data={'adapter_name': self.adapter_name, **kwargs} ) response.raise_for_status() - return BackwardResponse(**response.json()) def forward_backward(self, inputs: Any, **kwargs) -> ForwardBackwardResponse: """Execute combined forward and backward pass.""" @@ -547,41 +535,29 @@ def forward_backward(self, inputs: Any, **kwargs) -> ForwardBackwardResponse: response.raise_for_status() return ForwardBackwardResponse(**response.json()) - def step(self, **kwargs) -> StepResponse: + def step(self, **kwargs) -> None: """Execute optimizer step.""" response = http_post( url=f'{self.server_url}/step', json_data={'adapter_name': self.adapter_name, **kwargs} ) response.raise_for_status() - return StepResponse(**response.json()) - def zero_grad(self, **kwargs) -> ZeroGradResponse: + def zero_grad(self, **kwargs) -> None: """Zero out gradients.""" response = http_post( url=f'{self.server_url}/zero_grad', json_data={'adapter_name': self.adapter_name, **kwargs} ) response.raise_for_status() - return ZeroGradResponse(**response.json()) - def lr_step(self, **kwargs) -> LrStepResponse: + def lr_step(self, **kwargs) -> None: """Execute learning rate scheduler step.""" response = http_post( url=f'{self.server_url}/lr_step', json_data={'adapter_name': self.adapter_name, **kwargs} ) response.raise_for_status() - return LrStepResponse(**response.json()) - - def set_loss(self, loss_cls: str, **kwargs) -> SetLossResponse: - """Set the loss function.""" - response = http_post( - url=f'{self.server_url}/set_loss', - json_data={'loss_cls': loss_cls, 'adapter_name': self.adapter_name, **kwargs} - ) - response.raise_for_status() - return SetLossResponse(**response.json()) def clip_grad_norm(self, max_grad_norm: float = 1.0, norm_type: int = 2, **kwargs) -> ClipGradNormResponse: """Clip gradient norm.""" @@ -592,23 +568,37 @@ def clip_grad_norm(self, max_grad_norm: float = 1.0, norm_type: int = 2, **kwarg response.raise_for_status() return ClipGradNormResponse(**response.json()) - def set_optimizer(self, optimizer_cls: str, **kwargs) -> SetOptimizerResponse: + def clip_grad_and_step(self, max_grad_norm: float = 1.0, norm_type: int = 2, **kwargs) -> None: + """Clip gradient norm and execute optimizer step in one call.""" + response = http_post( + url=f'{self.server_url}/clip_grad_and_step', + json_data={'max_grad_norm': max_grad_norm, 'norm_type': norm_type, 'adapter_name': self.adapter_name, **kwargs} + ) + response.raise_for_status() + + def set_loss(self, loss_cls: str, **kwargs) -> None: + """Set the loss function.""" + response = http_post( + url=f'{self.server_url}/set_loss', + json_data={'loss_cls': loss_cls, 'adapter_name': self.adapter_name, **kwargs} + ) + response.raise_for_status() + + def set_optimizer(self, optimizer_cls: str, **kwargs) -> None: """Set the optimizer.""" response = http_post( url=f'{self.server_url}/set_optimizer', json_data={'optimizer_cls': optimizer_cls, 'adapter_name': self.adapter_name, **kwargs} ) response.raise_for_status() - return SetOptimizerResponse(**response.json()) - def set_lr_scheduler(self, scheduler_cls: str, **kwargs) -> SetLrSchedulerResponse: + def set_lr_scheduler(self, scheduler_cls: str, **kwargs) -> None: """Set the learning rate scheduler.""" response = http_post( url=f'{self.server_url}/set_lr_scheduler', json_data={'scheduler_cls': scheduler_cls, 'adapter_name': self.adapter_name, **kwargs} ) response.raise_for_status() - return SetLrSchedulerResponse(**response.json()) def save(self, name: str, **kwargs) -> SaveResponse: """Save model checkpoint.""" @@ -619,32 +609,45 @@ def save(self, name: str, **kwargs) -> SaveResponse: response.raise_for_status() return SaveResponse(**response.json()) - def load(self, name: str, **kwargs) -> LoadResponse: + def load(self, name: str, **kwargs) -> None: """Load model checkpoint.""" response = http_post( url=f'{self.server_url}/load', json_data={'name': name, 'adapter_name': self.adapter_name, **kwargs} ) response.raise_for_status() - return LoadResponse(**response.json()) - def set_template(self, template_cls: str, **kwargs) -> SetTemplateResponse: + def apply_patch(self, patch_cls: str, **kwargs) -> None: + """Apply a patch to the model.""" + response = http_post( + url=f'{self.server_url}/apply_patch', + json_data={'patch_cls': patch_cls, 'adapter_name': self.adapter_name, **kwargs} + ) + response.raise_for_status() + + def add_metric(self, metric_cls: str, is_training: Optional[bool] = None, **kwargs) -> None: + """Add a metric to the model.""" + response = http_post( + url=f'{self.server_url}/add_metric', + json_data={'metric_cls': metric_cls, 'is_training': is_training, 'adapter_name': self.adapter_name, **kwargs} + ) + response.raise_for_status() + + def set_template(self, template_cls: str, **kwargs) -> None: """Set the template for data processing.""" response = http_post( url=f'{self.server_url}/set_template', json_data={'template_cls': template_cls, 'adapter_name': self.adapter_name, 'model_id': self.model_id, **kwargs} ) response.raise_for_status() - return SetTemplateResponse(**response.json()) - def set_processor(self, processor_cls: str, **kwargs) -> SetProcessorResponse: + def set_processor(self, processor_cls: str, **kwargs) -> None: """Set the input processor.""" response = http_post( url=f'{self.server_url}/set_processor', json_data={'processor_cls': processor_cls, 'adapter_name': self.adapter_name, **kwargs} ) response.raise_for_status() - return SetProcessorResponse(**response.json()) def calculate_metric(self, is_training: bool = True, **kwargs) -> CalculateMetricResponse: """Calculate metrics from model outputs.""" @@ -670,7 +673,7 @@ def upload_to_hub( hub_model_id: str, hub_token: Optional[str] = None, async_upload: bool = True, - ) -> UploadToHubResponse: + ) -> None: """Upload model checkpoint to hub. Args: @@ -689,7 +692,6 @@ def upload_to_hub( } ) response.raise_for_status() - return UploadToHubResponse(**response.json()) ''' # Write the model client file diff --git a/cookbook/client/server/transformer/server_config.yaml b/cookbook/client/server/transformer/server_config.yaml index f507f9d4..ea310520 100644 --- a/cookbook/client/server/transformer/server_config.yaml +++ b/cookbook/client/server/transformer/server_config.yaml @@ -33,37 +33,37 @@ applications: # 2. Model Service (commented out) - Would host the base model for training. # Uncomment and configure if you need a training model worker. - # - name: models-Qwen3.5-4B - # route_prefix: /api/v1/model/Qwen/Qwen3.5-4B - # import_path: model - # args: - # use_megatron: false # Use HuggingFace Transformers backend - # model_id: "ms://Qwen/Qwen3.5-4B" # ModelScope model identifier - # max_length: 10240 - # nproc_per_node: 1 # Number of GPU processes per node - # device_group: - # name: model - # ranks: 1 - # device_type: cuda - # device_mesh: - # device_type: cuda - # dp_size: 1 - # queue_config: - # rps_limit: 100 # Max requests per second - # tps_limit: 100000 # Max tokens per second - # adapter_config: - # adapter_timeout: 30 # Seconds before idle adapter unload - # deployments: - # - name: ModelManagement - # autoscaling_config: - # min_replicas: 1 - # max_replicas: 1 - # target_ongoing_requests: 16 - # ray_actor_options: - # num_cpus: 0.1 - # runtime_env: - # env_vars: - # TWINKLE_TRUST_REMOTE_CODE: "0" + - name: models-Qwen3.5-4B + route_prefix: /api/v1/model/Qwen/Qwen3.5-4B + import_path: model + args: + use_megatron: false # Use HuggingFace Transformers backend + model_id: "ms://Qwen/Qwen3.5-4B" # ModelScope model identifier + max_length: 10240 + nproc_per_node: 1 # Number of GPU processes per node + device_group: + name: model + ranks: 1 + device_type: cuda + device_mesh: + device_type: cuda + dp_size: 1 + queue_config: + rps_limit: 100 # Max requests per second + tps_limit: 100000 # Max tokens per second + adapter_config: + adapter_timeout: 30 # Seconds before idle adapter unload + deployments: + - name: ModelManagement + autoscaling_config: + min_replicas: 1 + max_replicas: 1 + target_ongoing_requests: 16 + ray_actor_options: + num_cpus: 0.1 + runtime_env: + env_vars: + TWINKLE_TRUST_REMOTE_CODE: "0" # 3. Sampler Service - Runs inference / sampling using vLLM engine # Used for generating text from the model (e.g., evaluating LoRA results). diff --git a/cookbook/client/twinkle/self_host/self_congnition.py b/cookbook/client/twinkle/self_host/self_congnition.py index f9e56dd1..7aa56e0f 100644 --- a/cookbook/client/twinkle/self_host/self_congnition.py +++ b/cookbook/client/twinkle/self_host/self_congnition.py @@ -101,24 +101,25 @@ def train(): logger.info(f'Starting epoch {epoch}') for step, batch in enumerate(dataloader): # Forward pass + backward pass (computes gradients) - output = model.forward_backward(inputs=batch) - loss=output.get('loss', 'N/A') + model.forward_backward(inputs=batch) + + # Step + model.clip_grad_and_step() + # Equal to the following steps: + # # Clip gradients to prevent exploding gradients (max norm = 1.0) + # model.clip_grad_norm(1.0) + # # Perform one optimizer step (update model weights) + # model.step() + # # Reset gradients to zero for the next iteration + # model.zero_grad() + # # Advance the learning rate scheduler by one step + # model.lr_step() # Log the loss every 2 steps (aligned with gradient accumulation) if step % 2 == 0: - logger.info(f'Current is step {step // 2}, loss: {loss}') - - # Clip gradients to prevent exploding gradients (max norm = 1.0) - model.clip_grad_norm(1.0) - - # Perform one optimizer step (update model weights) - model.step() - - # Reset gradients to zero for the next iteration - model.zero_grad() - - # Advance the learning rate scheduler by one step - model.lr_step() + # Print metric + metric = model.calculate_metric(is_training=True) + logger.info(f'Current is step {step} of {len(dataloader)}, metric: {metric.model_dump()}') # Step 8: Save the trained checkpoint twinkle_path = model.save(name=f'twinkle-epoch-{epoch}', save_optimizer=True) diff --git a/src/twinkle/server/model/backends/megatron_model.py b/src/twinkle/server/model/backends/megatron_model.py index ae47cfc1..432bd2ba 100644 --- a/src/twinkle/server/model/backends/megatron_model.py +++ b/src/twinkle/server/model/backends/megatron_model.py @@ -33,7 +33,7 @@ class TwinkleCompatMegatronModel(_MegatronBase, TwinkleCompatModelBase): """ @remote_function(dispatch='slice_dp', collect=collect_forward_backward_results, sync=True) - def forward_backward(self, *, inputs: List[types.Datum], adapter_name: str, loss_fn: str, **kwargs): + def tinker_forward_backward(self, *, inputs: List[types.Datum], adapter_name: str, loss_fn: str, **kwargs): """Combined forward and backward pass.""" if loss_fn == 'importance_sampling': super().set_loss('GRPOLoss', adapter_name=adapter_name, epsilon=0.2, beta=0.0) @@ -66,7 +66,7 @@ def forward_backward(self, *, inputs: List[types.Datum], adapter_name: str, loss return [results, loss] @remote_function(dispatch='slice_dp', collect='flatten') - def forward_only(self, *, inputs: List[types.Datum], **kwargs): + def tinker_forward_only(self, *, inputs: List[types.Datum], **kwargs): """Forward pass without gradient computation.""" template = self.get_template(**kwargs) input_features = datum_to_input_feature(inputs, template) @@ -86,7 +86,7 @@ def forward_only(self, *, inputs: List[types.Datum], **kwargs): return results @remote_function(dispatch='all') - def step(self, *, adam_params: types.AdamParams, **kwargs): + def tinker_step(self, *, adam_params: types.AdamParams, **kwargs): """Optimizer step with AdamParams configuration.""" adapter_name = kwargs.get('adapter_name') optimizer_config = self.optimizer_group.get(adapter_name) @@ -108,12 +108,12 @@ def step(self, *, adam_params: types.AdamParams, **kwargs): super().zero_grad(**kwargs) @remote_function(collect='first', lazy_collect=False) - def calculate_metric(self, is_training, **kwargs): + def tinker_calculate_metric(self, is_training, **kwargs): metric = super().calculate_metric(is_training, **kwargs) return clean_metrics(metric) @remote_function(dispatch='all', sync=True) - def load(self, checkpoint_dir: str, **kwargs): + def tinker_load(self, checkpoint_dir: str, **kwargs): """Load checkpoint with token-based isolation support.""" token = kwargs.pop('token', None) if not token: diff --git a/src/twinkle/server/model/backends/transformers_model.py b/src/twinkle/server/model/backends/transformers_model.py index 6cbc401b..20d6b75b 100644 --- a/src/twinkle/server/model/backends/transformers_model.py +++ b/src/twinkle/server/model/backends/transformers_model.py @@ -59,7 +59,7 @@ def _to_cpu_safe_output(obj: Any) -> Any: # ------------------------------------------------------------------ @remote_function(dispatch='slice_dp', collect='flatten') - def forward_only(self, *, inputs: List[types.Datum], **kwargs): + def tinker_forward_only(self, *, inputs: List[types.Datum], **kwargs): template = self.get_template(**kwargs) input_features = datum_to_input_feature(inputs, template) outputs = super().forward_only(inputs=input_features, **kwargs) @@ -71,7 +71,7 @@ def forward_only(self, *, inputs: List[types.Datum], **kwargs): return results @remote_function(dispatch='slice_dp', collect=collect_forward_backward_results) - def forward_backward(self, *, inputs: List[types.Datum], adapter_name: str, loss_fn: str, **kwargs): + def tinker_forward_backward(self, *, inputs: List[types.Datum], adapter_name: str, loss_fn: str, **kwargs): if loss_fn == 'cross_entropy': super().set_loss('CrossEntropyLoss', adapter_name=adapter_name) elif loss_fn == 'importance_sampling': @@ -94,7 +94,7 @@ def forward_backward(self, *, inputs: List[types.Datum], adapter_name: str, loss return [results, loss] @remote_function() - def step(self, *, adam_params: types.AdamParams, **kwargs): + def tinker_step(self, *, adam_params: types.AdamParams, **kwargs): grad_clip_norm = adam_params.grad_clip_norm if grad_clip_norm > 0.0: self.clip_grad_norm(max_grad_norm=grad_clip_norm, norm_type=2, **kwargs) @@ -108,12 +108,12 @@ def step(self, *, adam_params: types.AdamParams, **kwargs): super().zero_grad(**kwargs) @remote_function(collect='first', lazy_collect=False) - def calculate_metric(self, is_training, **kwargs): + def tinker_calculate_metric(self, is_training, **kwargs): metric = super().calculate_metric(is_training, **kwargs) return clean_metrics(metric) @remote_function() - def load(self, checkpoint_dir: str, **kwargs): + def tinker_load(self, checkpoint_dir: str, **kwargs): """Load checkpoint with token-based isolation support.""" token = kwargs.pop('token', None) if not token: @@ -131,8 +131,8 @@ def load(self, checkpoint_dir: str, **kwargs): # ------------------------------------------------------------------ @remote_function(dispatch='slice_dp', collect='mean') - def twinkle_forward_backward(self, *, inputs: Union[InputFeature, List[InputFeature], Trajectory, List[Trajectory]], - **kwargs): + def forward_backward(self, *, inputs: Union[InputFeature, List[InputFeature], Trajectory, List[Trajectory]], + **kwargs): """Forward+backward for twinkle-native clients (InputFeature/Trajectory I/O).""" output = super().forward_backward(inputs=inputs, **kwargs) return self._to_cpu_safe_output(output) diff --git a/src/twinkle/server/model/tinker_handlers.py b/src/twinkle/server/model/tinker_handlers.py index b89a5936..cdb55669 100644 --- a/src/twinkle/server/model/tinker_handlers.py +++ b/src/twinkle/server/model/tinker_handlers.py @@ -112,7 +112,7 @@ async def _do_forward(): self.touch_adapter(adapter_name) datum_list = body.forward_input.data loss_fn_config = body.forward_input.loss_fn_config or {} - output = self.model.forward_only(inputs=datum_list, adapter_name=adapter_name) + output = self.model.tinker_forward_only(inputs=datum_list, adapter_name=adapter_name) loss = self.model.calculate_loss(adapter_name=adapter_name, **loss_fn_config) return types.ForwardBackwardOutput( loss_fn_output_type='CrossEntropyLossReturn', @@ -155,7 +155,7 @@ async def _do_forward_backward(): datum_list = body.forward_backward_input.data loss_fn = body.forward_backward_input.loss_fn loss_fn_config = body.forward_backward_input.loss_fn_config or {} - output, loss = self.model.forward_backward( + output, loss = self.model.tinker_forward_backward( inputs=datum_list, adapter_name=adapter_name, loss_fn=loss_fn, **loss_fn_config) output_type = ('ImportanceSamplingLossReturn' if loss_fn == 'importance_sampling' else 'CrossEntropyLossReturn') @@ -201,9 +201,9 @@ async def _do_optim(): raise RuntimeError(f'No accumulated gradients for adapter={adapter_name}; ' 'call forward_backward before optim_step') self.touch_adapter(adapter_name) - self.model.step(adam_params=body.adam_params, adapter_name=adapter_name) + self.model.tinker_step(adam_params=body.adam_params, adapter_name=adapter_name) self.set_adapter_state(adapter_name, 'grad_ready', False) - metrics = self.model.calculate_metric(is_training=True, adapter_name=adapter_name) + metrics = self.model.tinker_calculate_metric(is_training=True, adapter_name=adapter_name) return types.OptimStepResponse(metrics=metrics) except Exception: logger.error(traceback.format_exc()) @@ -294,7 +294,7 @@ async def _do_load(): adapter_name = self.get_adapter_name(adapter_name=body.model_id) self.assert_adapter_exists(adapter_name=adapter_name) self.touch_adapter(adapter_name) - self.model.load( + self.model.tinker_load( checkpoint_dir=body.path, load_optimizer=body.optimizer, adapter_name=adapter_name, token=token) self.set_adapter_state(adapter_name, 'grad_ready', False) return types.LoadWeightsResponse(path=body.path, type='load_weights') diff --git a/src/twinkle/server/model/twinkle_handlers.py b/src/twinkle/server/model/twinkle_handlers.py index 660d68b2..814bf9f7 100644 --- a/src/twinkle/server/model/twinkle_handlers.py +++ b/src/twinkle/server/model/twinkle_handlers.py @@ -8,7 +8,6 @@ """ from __future__ import annotations -import traceback from fastapi import Depends, FastAPI, Request from peft import LoraConfig from typing import TYPE_CHECKING, Any, Callable, Optional @@ -107,18 +106,16 @@ async def _task(): return await self.schedule_task_and_wait(_task, task_type='calculate_loss') - @app.post('/twinkle/backward', response_model=types.BackwardResponse) - async def backward(request: Request, body: types.AdapterRequest, - self: ModelManagement = Depends(self_fn)) -> types.BackwardResponse: + @app.post('/twinkle/backward') + async def backward(request: Request, body: types.AdapterRequest, self: ModelManagement = Depends(self_fn)) -> None: adapter_name = _get_twinkle_adapter_name(request, body.adapter_name) async def _task(): self.assert_adapter_exists(adapter_name=adapter_name) extra_kwargs = body.model_extra or {} - ret = self.model.backward(adapter_name=adapter_name, **extra_kwargs) - return {'result': ret} + self.model.backward(adapter_name=adapter_name, **extra_kwargs) - return await self.schedule_task_and_wait(_task, task_type='backward') + await self.schedule_task_and_wait(_task, task_type='backward') @app.post('/twinkle/forward_backward', response_model=types.ForwardBackwardResponse) async def forward_backward( @@ -132,7 +129,7 @@ async def _task(): self.assert_adapter_exists(adapter_name=adapter_name) extra_kwargs = body.model_extra or {} inputs = _parse_inputs(body.inputs) - ret = self.model.twinkle_forward_backward(inputs=inputs, adapter_name=adapter_name, **extra_kwargs) + ret = self.model.forward_backward(inputs=inputs, adapter_name=adapter_name, **extra_kwargs) return {'result': ret} return await self.schedule_task_and_wait(_task, task_type='forward_backward') @@ -153,44 +150,58 @@ async def _task(): return await self.schedule_task_and_wait(_task, task_type='clip_grad_norm') - @app.post('/twinkle/step', response_model=types.StepResponse) - async def step(request: Request, body: types.AdapterRequest, - self: ModelManagement = Depends(self_fn)) -> types.StepResponse: + @app.post('/twinkle/step') + async def step(request: Request, body: types.AdapterRequest, self: ModelManagement = Depends(self_fn)) -> None: adapter_name = _get_twinkle_adapter_name(request, body.adapter_name) async def _task(): self.assert_adapter_exists(adapter_name=adapter_name) extra_kwargs = body.model_extra or {} - ret = self.model.step(adapter_name=adapter_name, **extra_kwargs) - return {'result': ret} + self.model.step(adapter_name=adapter_name, **extra_kwargs) - return await self.schedule_task_and_wait(_task, task_type='step') + await self.schedule_task_and_wait(_task, task_type='step') - @app.post('/twinkle/zero_grad', response_model=types.ZeroGradResponse) - async def zero_grad(request: Request, body: types.AdapterRequest, - self: ModelManagement = Depends(self_fn)) -> types.ZeroGradResponse: + @app.post('/twinkle/zero_grad') + async def zero_grad(request: Request, body: types.AdapterRequest, self: ModelManagement = Depends(self_fn)) -> None: adapter_name = _get_twinkle_adapter_name(request, body.adapter_name) async def _task(): self.assert_adapter_exists(adapter_name=adapter_name) extra_kwargs = body.model_extra or {} - ret = self.model.zero_grad(adapter_name=adapter_name, **extra_kwargs) - return {'result': ret} + self.model.zero_grad(adapter_name=adapter_name, **extra_kwargs) - return await self.schedule_task_and_wait(_task, task_type='zero_grad') + await self.schedule_task_and_wait(_task, task_type='zero_grad') - @app.post('/twinkle/lr_step', response_model=types.LrStepResponse) - async def lr_step(request: Request, body: types.AdapterRequest, - self: ModelManagement = Depends(self_fn)) -> types.LrStepResponse: + @app.post('/twinkle/lr_step') + async def lr_step(request: Request, body: types.AdapterRequest, self: ModelManagement = Depends(self_fn)) -> None: adapter_name = _get_twinkle_adapter_name(request, body.adapter_name) async def _task(): self.assert_adapter_exists(adapter_name=adapter_name) extra_kwargs = body.model_extra or {} - ret = self.model.lr_step(adapter_name=adapter_name, **extra_kwargs) - return {'result': ret} + self.model.lr_step(adapter_name=adapter_name, **extra_kwargs) + + await self.schedule_task_and_wait(_task, task_type='lr_step') + + @app.post('/twinkle/clip_grad_and_step') + async def clip_grad_and_step( + request: Request, + body: types.ClipGradAndStepRequest, + self: ModelManagement = Depends(self_fn), + ) -> None: + adapter_name = _get_twinkle_adapter_name(request, body.adapter_name) + + async def _task(): + self.assert_adapter_exists(adapter_name=adapter_name) + extra_kwargs = body.model_extra or {} + self.model.clip_grad_and_step( + max_grad_norm=body.max_grad_norm, + norm_type=body.norm_type, + adapter_name=adapter_name, + **extra_kwargs, + ) - return await self.schedule_task_and_wait(_task, task_type='lr_step') + await self.schedule_task_and_wait(_task, task_type='clip_grad_and_step') @app.post('/twinkle/get_train_configs', response_model=types.GetTrainConfigsResponse) async def get_train_configs( @@ -208,50 +219,46 @@ async def _task(): return await self.schedule_task_and_wait(_task, task_type='get_train_configs') - @app.post('/twinkle/set_loss', response_model=types.SetLossResponse) - async def set_loss(request: Request, body: types.SetLossRequest, - self: ModelManagement = Depends(self_fn)) -> types.SetLossResponse: + @app.post('/twinkle/set_loss') + async def set_loss(request: Request, body: types.SetLossRequest, self: ModelManagement = Depends(self_fn)) -> None: adapter_name = _get_twinkle_adapter_name(request, body.adapter_name) async def _task(): self.assert_adapter_exists(adapter_name=adapter_name) extra_kwargs = body.model_extra or {} - ret = self.model.set_loss(body.loss_cls, adapter_name=adapter_name, **extra_kwargs) - return {'result': ret} + self.model.set_loss(body.loss_cls, adapter_name=adapter_name, **extra_kwargs) - return await self.schedule_task_and_wait(_task, task_type='set_loss') + await self.schedule_task_and_wait(_task, task_type='set_loss') - @app.post('/twinkle/set_optimizer', response_model=types.SetOptimizerResponse) + @app.post('/twinkle/set_optimizer') async def set_optimizer( request: Request, body: types.SetOptimizerRequest, self: ModelManagement = Depends(self_fn), - ) -> types.SetOptimizerResponse: + ) -> None: adapter_name = _get_twinkle_adapter_name(request, body.adapter_name) async def _task(): self.assert_adapter_exists(adapter_name=adapter_name) extra_kwargs = body.model_extra or {} - ret = self.model.set_optimizer(body.optimizer_cls, adapter_name=adapter_name, **extra_kwargs) - return {'result': ret} + self.model.set_optimizer(body.optimizer_cls, adapter_name=adapter_name, **extra_kwargs) - return await self.schedule_task_and_wait(_task, task_type='set_optimizer') + await self.schedule_task_and_wait(_task, task_type='set_optimizer') - @app.post('/twinkle/set_lr_scheduler', response_model=types.SetLrSchedulerResponse) + @app.post('/twinkle/set_lr_scheduler') async def set_lr_scheduler( request: Request, body: types.SetLrSchedulerRequest, self: ModelManagement = Depends(self_fn), - ) -> types.SetLrSchedulerResponse: + ) -> None: adapter_name = _get_twinkle_adapter_name(request, body.adapter_name) async def _task(): self.assert_adapter_exists(adapter_name=adapter_name) extra_kwargs = body.model_extra or {} - ret = self.model.set_lr_scheduler(body.scheduler_cls, adapter_name=adapter_name, **extra_kwargs) - return {'result': ret} + self.model.set_lr_scheduler(body.scheduler_cls, adapter_name=adapter_name, **extra_kwargs) - return await self.schedule_task_and_wait(_task, task_type='set_lr_scheduler') + await self.schedule_task_and_wait(_task, task_type='set_lr_scheduler') @app.post('/twinkle/save', response_model=types.SaveResponse) async def save(request: Request, body: types.SaveRequest, @@ -276,9 +283,8 @@ async def _task(): return await self.schedule_task_and_wait(_task, task_type='save') - @app.post('/twinkle/load', response_model=types.LoadResponse) - async def load(request: Request, body: types.LoadRequest, - self: ModelManagement = Depends(self_fn)) -> types.LoadResponse: + @app.post('/twinkle/load') + async def load(request: Request, body: types.LoadRequest, self: ModelManagement = Depends(self_fn)) -> None: token = await self._on_request_start(request) adapter_name = _get_twinkle_adapter_name(request, body.adapter_name) @@ -287,23 +293,22 @@ async def _task(): extra_kwargs = body.model_extra or {} checkpoint_manager = create_checkpoint_manager(token, client_type='twinkle') resolved = checkpoint_manager.resolve_load_path(body.name) - ret = self.model.load( + self.model.load( name=resolved.checkpoint_name, output_dir=resolved.checkpoint_dir, adapter_name=adapter_name, load_optimizer=body.load_optimizer, token=token, **extra_kwargs) - return {'result': ret} - return await self.schedule_task_and_wait(_task, task_type='load') + await self.schedule_task_and_wait(_task, task_type='load') - @app.post('/twinkle/upload_to_hub', response_model=types.UploadToHubResponse) + @app.post('/twinkle/upload_to_hub') async def upload_to_hub( request: Request, body: types.UploadToHubRequest, self: ModelManagement = Depends(self_fn), - ) -> types.UploadToHubResponse: + ) -> None: token = await self._on_request_start(request) async def _task(): @@ -326,9 +331,8 @@ async def _task(): hub_model_id=body.hub_model_id, hub_token=body.hub_token or token, async_upload=body.async_upload) - return {'result': body.hub_model_id} - return await self.schedule_task_and_wait(_task, task_type='upload_to_hub') + await self.schedule_task_and_wait(_task, task_type='upload_to_hub') @app.post('/twinkle/add_adapter_to_model', response_model=types.AddAdapterResponse) async def add_adapter_to_model( @@ -346,49 +350,78 @@ async def _task(): training_run_manager = create_training_run_manager(token, client_type='twinkle') self.register_adapter(adapter_name, token) self.model.add_adapter_to_model(adapter_name, config, **extra_kwargs) - from twinkle_client.types.training import CreateModelRequest - from twinkle_client.types.training import LoraConfig as IoLoraConfig + lora_config = None if isinstance(config, LoraConfig): - lora_config = IoLoraConfig(rank=config.r, train_unembed=False, train_mlp=True, train_attn=True) - run_config = CreateModelRequest( + lora_config = types.LoraConfig(rank=config.r, train_unembed=False, train_mlp=True, train_attn=True) + run_config = types.CreateModelRequest( base_model=self.base_model, lora_config=lora_config, user_metadata={'adapter_name': body.adapter_name}) training_run_manager.save(adapter_name, run_config) return {'status': 'ok', 'adapter_name': adapter_name} return await self.schedule_task_and_wait(_task, task_type='add_adapter_to_model') - @app.post('/twinkle/set_template', response_model=types.SetTemplateResponse) + @app.post('/twinkle/apply_patch') + async def apply_patch( + request: Request, + body: types.ApplyPatchRequest, + self: ModelManagement = Depends(self_fn), + ) -> None: + adapter_name = _get_twinkle_adapter_name(request, body.adapter_name) + + async def _task(): + self.assert_adapter_exists(adapter_name=adapter_name) + extra_kwargs = body.model_extra or {} + patch_cls = deserialize_object(body.patch_cls) + self.model.apply_patch(patch_cls, adapter_name=adapter_name, **extra_kwargs) + + await self.schedule_task_and_wait(_task, task_type='apply_patch') + + @app.post('/twinkle/add_metric') + async def add_metric( + request: Request, + body: types.AddMetricRequest, + self: ModelManagement = Depends(self_fn), + ) -> None: + adapter_name = _get_twinkle_adapter_name(request, body.adapter_name) + + async def _task(): + self.assert_adapter_exists(adapter_name=adapter_name) + extra_kwargs = body.model_extra or {} + metric_cls = deserialize_object(body.metric_cls) + self.model.add_metric(metric_cls, is_training=body.is_training, adapter_name=adapter_name, **extra_kwargs) + + await self.schedule_task_and_wait(_task, task_type='add_metric') + + @app.post('/twinkle/set_template') async def set_template( request: Request, body: types.SetTemplateRequest, self: ModelManagement = Depends(self_fn), - ) -> types.SetTemplateResponse: + ) -> None: adapter_name = _get_twinkle_adapter_name(request, body.adapter_name) async def _task(): self.assert_adapter_exists(adapter_name=adapter_name) extra_kwargs = body.model_extra or {} - ret = self.model.set_template(body.template_cls, adapter_name=adapter_name, **extra_kwargs) - return {'result': ret} + self.model.set_template(body.template_cls, adapter_name=adapter_name, **extra_kwargs) - return await self.schedule_task_and_wait(_task, task_type='set_template') + await self.schedule_task_and_wait(_task, task_type='set_template') - @app.post('/twinkle/set_processor', response_model=types.SetProcessorResponse) + @app.post('/twinkle/set_processor') async def set_processor( request: Request, body: types.SetProcessorRequest, self: ModelManagement = Depends(self_fn), - ) -> types.SetProcessorResponse: + ) -> None: adapter_name = _get_twinkle_adapter_name(request, body.adapter_name) async def _task(): self.assert_adapter_exists(adapter_name=adapter_name) extra_kwargs = body.model_extra or {} - ret = self.model.set_processor(body.processor_cls, adapter_name=adapter_name, **extra_kwargs) - return {'result': ret} + self.model.set_processor(body.processor_cls, adapter_name=adapter_name, **extra_kwargs) - return await self.schedule_task_and_wait(_task, task_type='set_processor') + await self.schedule_task_and_wait(_task, task_type='set_processor') @app.post('/twinkle/heartbeat', response_model=types.HeartbeatResponse) async def heartbeat( diff --git a/src/twinkle/server/types/__init__.py b/src/twinkle/server/types/__init__.py deleted file mode 100644 index 85b3e739..00000000 --- a/src/twinkle/server/types/__init__.py +++ /dev/null @@ -1 +0,0 @@ -# Copyright (c) ModelScope Contributors. All rights reserved. diff --git a/src/twinkle/server/utils/checkpoint_base.py b/src/twinkle/server/utils/checkpoint_base.py index 1cf49b97..cbe05602 100644 --- a/src/twinkle/server/utils/checkpoint_base.py +++ b/src/twinkle/server/utils/checkpoint_base.py @@ -26,7 +26,7 @@ from twinkle import get_logger from twinkle.hub import HubOperation -from twinkle.server.types.checkpoint import ResolvedLoadPath +from twinkle_client.types import ResolvedLoadPath logger = get_logger() diff --git a/src/twinkle_client/model/multi_lora_transformers.py b/src/twinkle_client/model/multi_lora_transformers.py index 992cce64..743125d9 100644 --- a/src/twinkle_client/model/multi_lora_transformers.py +++ b/src/twinkle_client/model/multi_lora_transformers.py @@ -11,7 +11,6 @@ from typing import Any, Dict, Optional from twinkle_client.http import http_post from twinkle_client.types.model import ( - BackwardResponse, CalculateLossResponse, CalculateMetricResponse, ClipGradNormResponse, @@ -19,17 +18,7 @@ ForwardResponse, GetStateDictResponse, GetTrainConfigsResponse, - LoadResponse, - LrStepResponse, SaveResponse, - SetLossResponse, - SetLrSchedulerResponse, - SetOptimizerResponse, - SetProcessorResponse, - SetTemplateResponse, - StepResponse, - UploadToHubResponse, - ZeroGradResponse, ) @@ -100,14 +89,13 @@ def get_train_configs(self, **kwargs) -> GetTrainConfigsResponse: response.raise_for_status() return GetTrainConfigsResponse(**response.json()) - def backward(self, **kwargs) -> BackwardResponse: + def backward(self, **kwargs) -> None: """Execute backward pass.""" response = http_post( url=f'{self.server_url}/backward', json_data={'adapter_name': self.adapter_name, **kwargs} ) response.raise_for_status() - return BackwardResponse(**response.json()) def forward_backward(self, inputs: Any, **kwargs) -> ForwardBackwardResponse: """Execute combined forward and backward pass.""" @@ -118,41 +106,29 @@ def forward_backward(self, inputs: Any, **kwargs) -> ForwardBackwardResponse: response.raise_for_status() return ForwardBackwardResponse(**response.json()) - def step(self, **kwargs) -> StepResponse: + def step(self, **kwargs) -> None: """Execute optimizer step.""" response = http_post( url=f'{self.server_url}/step', json_data={'adapter_name': self.adapter_name, **kwargs} ) response.raise_for_status() - return StepResponse(**response.json()) - def zero_grad(self, **kwargs) -> ZeroGradResponse: + def zero_grad(self, **kwargs) -> None: """Zero out gradients.""" response = http_post( url=f'{self.server_url}/zero_grad', json_data={'adapter_name': self.adapter_name, **kwargs} ) response.raise_for_status() - return ZeroGradResponse(**response.json()) - def lr_step(self, **kwargs) -> LrStepResponse: + def lr_step(self, **kwargs) -> None: """Execute learning rate scheduler step.""" response = http_post( url=f'{self.server_url}/lr_step', json_data={'adapter_name': self.adapter_name, **kwargs} ) response.raise_for_status() - return LrStepResponse(**response.json()) - - def set_loss(self, loss_cls: str, **kwargs) -> SetLossResponse: - """Set the loss function.""" - response = http_post( - url=f'{self.server_url}/set_loss', - json_data={'loss_cls': loss_cls, 'adapter_name': self.adapter_name, **kwargs} - ) - response.raise_for_status() - return SetLossResponse(**response.json()) def clip_grad_norm(self, max_grad_norm: float = 1.0, norm_type: int = 2, **kwargs) -> ClipGradNormResponse: """Clip gradient norm.""" @@ -163,23 +139,37 @@ def clip_grad_norm(self, max_grad_norm: float = 1.0, norm_type: int = 2, **kwarg response.raise_for_status() return ClipGradNormResponse(**response.json()) - def set_optimizer(self, optimizer_cls: str, **kwargs) -> SetOptimizerResponse: + def clip_grad_and_step(self, max_grad_norm: float = 1.0, norm_type: int = 2, **kwargs) -> None: + """Clip gradient norm and execute optimizer step in one call.""" + response = http_post( + url=f'{self.server_url}/clip_grad_and_step', + json_data={'max_grad_norm': max_grad_norm, 'norm_type': norm_type, 'adapter_name': self.adapter_name, **kwargs} + ) + response.raise_for_status() + + def set_loss(self, loss_cls: str, **kwargs) -> None: + """Set the loss function.""" + response = http_post( + url=f'{self.server_url}/set_loss', + json_data={'loss_cls': loss_cls, 'adapter_name': self.adapter_name, **kwargs} + ) + response.raise_for_status() + + def set_optimizer(self, optimizer_cls: str, **kwargs) -> None: """Set the optimizer.""" response = http_post( url=f'{self.server_url}/set_optimizer', json_data={'optimizer_cls': optimizer_cls, 'adapter_name': self.adapter_name, **kwargs} ) response.raise_for_status() - return SetOptimizerResponse(**response.json()) - def set_lr_scheduler(self, scheduler_cls: str, **kwargs) -> SetLrSchedulerResponse: + def set_lr_scheduler(self, scheduler_cls: str, **kwargs) -> None: """Set the learning rate scheduler.""" response = http_post( url=f'{self.server_url}/set_lr_scheduler', json_data={'scheduler_cls': scheduler_cls, 'adapter_name': self.adapter_name, **kwargs} ) response.raise_for_status() - return SetLrSchedulerResponse(**response.json()) def save(self, name: str, **kwargs) -> SaveResponse: """Save model checkpoint.""" @@ -190,32 +180,45 @@ def save(self, name: str, **kwargs) -> SaveResponse: response.raise_for_status() return SaveResponse(**response.json()) - def load(self, name: str, **kwargs) -> LoadResponse: + def load(self, name: str, **kwargs) -> None: """Load model checkpoint.""" response = http_post( url=f'{self.server_url}/load', json_data={'name': name, 'adapter_name': self.adapter_name, **kwargs} ) response.raise_for_status() - return LoadResponse(**response.json()) - def set_template(self, template_cls: str, **kwargs) -> SetTemplateResponse: + def apply_patch(self, patch_cls: str, **kwargs) -> None: + """Apply a patch to the model.""" + response = http_post( + url=f'{self.server_url}/apply_patch', + json_data={'patch_cls': patch_cls, 'adapter_name': self.adapter_name, **kwargs} + ) + response.raise_for_status() + + def add_metric(self, metric_cls: str, is_training: Optional[bool] = None, **kwargs) -> None: + """Add a metric to the model.""" + response = http_post( + url=f'{self.server_url}/add_metric', + json_data={'metric_cls': metric_cls, 'is_training': is_training, 'adapter_name': self.adapter_name, **kwargs} + ) + response.raise_for_status() + + def set_template(self, template_cls: str, **kwargs) -> None: """Set the template for data processing.""" response = http_post( url=f'{self.server_url}/set_template', json_data={'template_cls': template_cls, 'adapter_name': self.adapter_name, 'model_id': self.model_id, **kwargs} ) response.raise_for_status() - return SetTemplateResponse(**response.json()) - def set_processor(self, processor_cls: str, **kwargs) -> SetProcessorResponse: + def set_processor(self, processor_cls: str, **kwargs) -> None: """Set the input processor.""" response = http_post( url=f'{self.server_url}/set_processor', json_data={'processor_cls': processor_cls, 'adapter_name': self.adapter_name, **kwargs} ) response.raise_for_status() - return SetProcessorResponse(**response.json()) def calculate_metric(self, is_training: bool = True, **kwargs) -> CalculateMetricResponse: """Calculate metrics from model outputs.""" @@ -241,7 +244,7 @@ def upload_to_hub( hub_model_id: str, hub_token: Optional[str] = None, async_upload: bool = True, - ) -> UploadToHubResponse: + ) -> None: """Upload model checkpoint to hub. Args: @@ -260,4 +263,3 @@ def upload_to_hub( } ) response.raise_for_status() - return UploadToHubResponse(**response.json()) diff --git a/src/twinkle_client/types/__init__.py b/src/twinkle_client/types/__init__.py index 7a38edbf..d91d8240 100644 --- a/src/twinkle_client/types/__init__.py +++ b/src/twinkle_client/types/__init__.py @@ -2,11 +2,17 @@ from .model import ( AddAdapterRequest, AddAdapterResponse, + AddMetricRequest, + AddMetricResponse, AdapterRequest, + ApplyPatchRequest, + ApplyPatchResponse, BackwardResponse, CalculateLossResponse, CalculateMetricRequest, CalculateMetricResponse, + ClipGradAndStepRequest, + ClipGradAndStepResponse, ClipGradNormResponse, CreateRequest, CreateResponse, @@ -23,6 +29,7 @@ LoadResponse, LrStepResponse, ModelResult, + OkResponse, SaveRequest, SaveResponse, SetLossRequest, @@ -79,3 +86,5 @@ TrainingRunsResponse, WeightsInfoResponse, ) + +from .checkpoint import ResolvedLoadPath diff --git a/src/twinkle/server/types/checkpoint.py b/src/twinkle_client/types/checkpoint.py similarity index 100% rename from src/twinkle/server/types/checkpoint.py rename to src/twinkle_client/types/checkpoint.py diff --git a/src/twinkle_client/types/model.py b/src/twinkle_client/types/model.py index e10479c3..c5abf641 100644 --- a/src/twinkle_client/types/model.py +++ b/src/twinkle_client/types/model.py @@ -132,114 +132,163 @@ class Config: extra = 'allow' +class ClipGradAndStepRequest(BaseModel): + adapter_name: str + max_grad_norm: float = 1.0 + norm_type: int = 2 + + class Config: + extra = 'allow' + + +class ApplyPatchRequest(BaseModel): + patch_cls: str + adapter_name: str + + class Config: + extra = 'allow' + + +class AddMetricRequest(BaseModel): + metric_cls: str + adapter_name: str + is_training: Optional[bool] = None + + class Config: + extra = 'allow' + + # --------------------------------------------------------------------------- # Response models # --------------------------------------------------------------------------- +class OkResponse(BaseModel): + """Response for endpoints whose underlying method returns None.""" + status: str = 'ok' + + class ModelResult(BaseModel): - """Generic single-value result wrapper returned by most model endpoints.""" + """Generic single-value result wrapper returned by result-bearing endpoints.""" result: Any -class ForwardResponse(ModelResult): - """Response for /forward and /forward_only endpoints.""" - pass +# --- Result-bearing responses --- + +class ForwardResponse(BaseModel): + """Response for /forward and /forward_only endpoints (returns ModelOutput).""" + result: Any -class ForwardBackwardResponse(ModelResult): - """Response for /forward_backward endpoint.""" - pass +class ForwardBackwardResponse(BaseModel): + """Response for /forward_backward endpoint (returns ModelOutput).""" + result: Any -class BackwardResponse(ModelResult): +class CalculateLossResponse(BaseModel): + """Response for /calculate_loss endpoint (returns float).""" + result: float + + +class ClipGradNormResponse(BaseModel): + """Response for /clip_grad_norm endpoint (returns float as str).""" + result: str + + +class GetTrainConfigsResponse(BaseModel): + """Response for /get_train_configs endpoint (returns str).""" + result: str + + +class GetStateDictResponse(BaseModel): + """Response for /get_state_dict endpoint (returns Dict).""" + result: Dict[str, Any] + + +class CalculateMetricResponse(BaseModel): + """Response for /calculate_metric endpoint (returns Dict).""" + result: Dict[str, Any] + + +class SaveResponse(BaseModel): + """Response for /save endpoint (returns twinkle path + checkpoint dir).""" + result: str + checkpoint_dir: Optional[str] = None + + +# --- Void responses (return None → OkResponse) --- + +class BackwardResponse(OkResponse): """Response for /backward endpoint.""" pass -class StepResponse(ModelResult): +class StepResponse(OkResponse): """Response for /step (optimizer step) endpoint.""" pass -class ZeroGradResponse(ModelResult): +class ZeroGradResponse(OkResponse): """Response for /zero_grad endpoint.""" pass -class LrStepResponse(ModelResult): +class LrStepResponse(OkResponse): """Response for /lr_step endpoint.""" pass -class SetLossResponse(ModelResult): +class SetLossResponse(OkResponse): """Response for /set_loss endpoint.""" pass -class ClipGradNormResponse(ModelResult): - """Response for /clip_grad_norm endpoint.""" - pass - - -class SetOptimizerResponse(ModelResult): +class SetOptimizerResponse(OkResponse): """Response for /set_optimizer endpoint.""" pass -class SetLrSchedulerResponse(ModelResult): +class SetLrSchedulerResponse(OkResponse): """Response for /set_lr_scheduler endpoint.""" pass -class SaveResponse(ModelResult): - """Response for /save endpoint.""" - pass - - -class LoadResponse(ModelResult): +class LoadResponse(OkResponse): """Response for /load endpoint.""" pass -class SetTemplateResponse(ModelResult): +class SetTemplateResponse(OkResponse): """Response for /set_template endpoint.""" pass -class SetProcessorResponse(ModelResult): +class SetProcessorResponse(OkResponse): """Response for /set_processor endpoint.""" pass -class CalculateLossResponse(ModelResult): - """Response for /calculate_loss endpoint.""" +class UploadToHubResponse(OkResponse): + """Response for /upload_to_hub endpoint.""" pass -class CalculateMetricResponse(ModelResult): - """Response for /calculate_metric endpoint.""" +class ClipGradAndStepResponse(OkResponse): + """Response for /clip_grad_and_step endpoint.""" pass -class GetTrainConfigsResponse(ModelResult): - """Response for /get_train_configs endpoint.""" +class ApplyPatchResponse(OkResponse): + """Response for /apply_patch endpoint.""" pass -class GetStateDictResponse(ModelResult): - """Response for /get_state_dict endpoint.""" +class AddMetricResponse(OkResponse): + """Response for /add_metric endpoint.""" pass -class UploadToHubResponse(BaseModel): - """Response for /upload_to_hub endpoint.""" - status: Optional[str] = None - message: Optional[str] = None - - class Config: - extra = 'allow' - +# --- Other responses --- class CreateResponse(BaseModel): """Response for /create endpoint.""" From 2f7006210fa679fac4e32bdb1ddd8969914d1c8f Mon Sep 17 00:00:00 2001 From: Yunnglin Date: Thu, 12 Mar 2026 14:48:27 +0800 Subject: [PATCH 14/24] update twinkle model --- src/twinkle/server/model/tinker_handlers.py | 6 - src/twinkle/server/model/twinkle_handlers.py | 15 +-- .../server/sampler/twinkle_handlers.py | 9 -- src/twinkle/server/utils/adapter_manager.py | 110 +++++------------- src/twinkle/server/utils/validation.py | 14 +++ src/twinkle_client/types/__init__.py | 4 - src/twinkle_client/types/model.py | 9 -- src/twinkle_client/types/sampler.py | 10 -- 8 files changed, 47 insertions(+), 130 deletions(-) diff --git a/src/twinkle/server/model/tinker_handlers.py b/src/twinkle/server/model/tinker_handlers.py index cdb55669..ccd0daed 100644 --- a/src/twinkle/server/model/tinker_handlers.py +++ b/src/twinkle/server/model/tinker_handlers.py @@ -109,7 +109,6 @@ async def _do_forward(): try: adapter_name = self.get_adapter_name(adapter_name=body.model_id) self.assert_adapter_exists(adapter_name=adapter_name) - self.touch_adapter(adapter_name) datum_list = body.forward_input.data loss_fn_config = body.forward_input.loss_fn_config or {} output = self.model.tinker_forward_only(inputs=datum_list, adapter_name=adapter_name) @@ -151,7 +150,6 @@ async def _do_forward_backward(): try: adapter_name = self.get_adapter_name(adapter_name=body.model_id) self.assert_adapter_exists(adapter_name=adapter_name) - self.touch_adapter(adapter_name) datum_list = body.forward_backward_input.data loss_fn = body.forward_backward_input.loss_fn loss_fn_config = body.forward_backward_input.loss_fn_config or {} @@ -200,7 +198,6 @@ async def _do_optim(): if not self.get_adapter_state(adapter_name, 'grad_ready', False): raise RuntimeError(f'No accumulated gradients for adapter={adapter_name}; ' 'call forward_backward before optim_step') - self.touch_adapter(adapter_name) self.model.tinker_step(adam_params=body.adam_params, adapter_name=adapter_name) self.set_adapter_state(adapter_name, 'grad_ready', False) metrics = self.model.tinker_calculate_metric(is_training=True, adapter_name=adapter_name) @@ -226,7 +223,6 @@ async def _do_save(): try: adapter_name = self.get_adapter_name(adapter_name=body.model_id) self.assert_adapter_exists(adapter_name=adapter_name) - self.touch_adapter(adapter_name) checkpoint_manager = create_checkpoint_manager(token, client_type='tinker') checkpoint_name = checkpoint_manager.get_ckpt_name(body.path) save_dir = checkpoint_manager.get_save_dir(model_id=body.model_id, is_sampler=False) @@ -255,7 +251,6 @@ async def _do_save_for_sampler(): try: adapter_name = self.get_adapter_name(adapter_name=body.model_id) self.assert_adapter_exists(adapter_name=adapter_name) - self.touch_adapter(adapter_name) checkpoint_manager = create_checkpoint_manager(token, client_type='tinker') checkpoint_name = checkpoint_manager.get_ckpt_name(body.path) save_dir = checkpoint_manager.get_save_dir(model_id=body.model_id, is_sampler=True) @@ -293,7 +288,6 @@ async def _do_load(): assert self.model is not None, 'Model not loaded, please load model first' adapter_name = self.get_adapter_name(adapter_name=body.model_id) self.assert_adapter_exists(adapter_name=adapter_name) - self.touch_adapter(adapter_name) self.model.tinker_load( checkpoint_dir=body.path, load_optimizer=body.optimizer, adapter_name=adapter_name, token=token) self.set_adapter_state(adapter_name, 'grad_ready', False) diff --git a/src/twinkle/server/model/twinkle_handlers.py b/src/twinkle/server/model/twinkle_handlers.py index 814bf9f7..eda462a6 100644 --- a/src/twinkle/server/model/twinkle_handlers.py +++ b/src/twinkle/server/model/twinkle_handlers.py @@ -19,6 +19,7 @@ from twinkle.data_format import InputFeature, Trajectory from twinkle.server.common.checkpoint_factory import create_checkpoint_manager, create_training_run_manager from twinkle.server.common.serialize import deserialize_object +from twinkle.server.utils.validation import get_session_id_from_request from twinkle.utils.logger import get_logger logger = get_logger() @@ -343,12 +344,13 @@ async def add_adapter_to_model( assert body.adapter_name, 'You need to specify a valid `adapter_name`' token = await self._on_request_start(request) adapter_name = _get_twinkle_adapter_name(request, body.adapter_name) + session_id = get_session_id_from_request(request) async def _task(): config = deserialize_object(body.config) extra_kwargs = body.model_extra or {} training_run_manager = create_training_run_manager(token, client_type='twinkle') - self.register_adapter(adapter_name, token) + self.register_adapter(adapter_name, token, session_id=session_id) self.model.add_adapter_to_model(adapter_name, config, **extra_kwargs) lora_config = None @@ -423,17 +425,6 @@ async def _task(): await self.schedule_task_and_wait(_task, task_type='set_processor') - @app.post('/twinkle/heartbeat', response_model=types.HeartbeatResponse) - async def heartbeat( - request: Request, - body: types.HeartbeatRequest, - self: ModelManagement = Depends(self_fn), - ) -> types.HeartbeatResponse: - adapter_name = _get_twinkle_adapter_name(request, body.adapter_name) - self.assert_adapter_exists(adapter_name=adapter_name) - self.touch_adapter(adapter_name) - return types.HeartbeatResponse() - @app.post('/twinkle/calculate_metric', response_model=types.CalculateMetricResponse) async def calculate_metric( request: Request, diff --git a/src/twinkle/server/sampler/twinkle_handlers.py b/src/twinkle/server/sampler/twinkle_handlers.py index 7f63fc08..2e477736 100644 --- a/src/twinkle/server/sampler/twinkle_handlers.py +++ b/src/twinkle/server/sampler/twinkle_handlers.py @@ -136,12 +136,3 @@ def add_adapter_to_sampler( self.sampler.add_adapter_to_sampler(full_adapter_name, config) return types.AddAdapterResponse(adapter_name=full_adapter_name) - - @app.post('/twinkle/heartbeat', response_model=types.HeartbeatResponse) - def heartbeat(request: Request, body: types.HeartbeatRequest, - self: SamplerManagement = Depends(self_fn)) -> types.HeartbeatResponse: - """Keep an adapter alive by resetting its inactivity timer.""" - full_adapter_name = _get_twinkle_sampler_adapter_name(request, body.adapter_name) - self.assert_adapter_exists(adapter_name=full_adapter_name) - self.touch_adapter(full_adapter_name) - return types.HeartbeatResponse() diff --git a/src/twinkle/server/utils/adapter_manager.py b/src/twinkle/server/utils/adapter_manager.py index 8337ed6b..c1706b65 100644 --- a/src/twinkle/server/utils/adapter_manager.py +++ b/src/twinkle/server/utils/adapter_manager.py @@ -24,17 +24,17 @@ class AdapterManagerMixin: - """Mixin for adapter lifecycle management with automatic timeout. + """Mixin for adapter lifecycle management with session-based expiration. This mixin tracks adapter activity and automatically expires adapters - that have been inactive for longer than the configured timeout period. + when their associated session expires. Inheriting classes should: 1. Call _init_adapter_manager() in __init__ 2. Override _on_adapter_expired() to customize expiration handling Attributes: - _adapter_timeout: Timeout in seconds for inactive adapters. + _adapter_timeout: Session inactivity timeout in seconds used to determine if a session is alive. """ # Type hint for state attribute that inheriting classes must provide @@ -43,52 +43,51 @@ class AdapterManagerMixin: def _init_adapter_manager( self, adapter_timeout: float = 1800.0, - adapter_max_lifetime: float = 12 * 60 * 60, ) -> None: """Initialize the adapter manager. This should be called in the __init__ of the inheriting class. Args: - adapter_timeout: Timeout in seconds for inactive adapters and session-based expiration. - Default is 1800.0 (30 minutes). Adapters linked to sessions will expire - when their session hasn't been touched for this duration. - adapter_max_lifetime: Maximum lifetime in seconds for an adapter since creation. - Default is 43200.0 (12 hours). If <= 0, lifetime enforcement is disabled. + adapter_timeout: Timeout in seconds used to check whether a session is still alive. + Default is 1800.0 (30 minutes). """ self._adapter_timeout = adapter_timeout - self._adapter_max_lifetime = adapter_max_lifetime # Adapter lifecycle tracking # Dict mapping adapter_name -> - # {'token': str, 'session_id': str, 'last_activity': float, 'created_at': float, 'inactivity_counter': int} + # {'token': str, 'session_id': str, 'created_at': float, 'state': dict, 'expiring': bool} self._adapter_records: dict[str, dict[str, Any]] = {} # Countdown thread self._adapter_countdown_thread: threading.Thread | None = None self._adapter_countdown_running = False - def register_adapter(self, adapter_name: str, token: str, session_id: str | None = None) -> None: + def register_adapter(self, adapter_name: str, token: str, session_id: str) -> None: """Register a new adapter for lifecycle tracking. + The adapter will expire when its associated session expires. + Args: adapter_name: Name of the adapter to register. token: User token that owns this adapter. - session_id: Optional session ID to associate with this adapter. - If provided, adapter will expire when the session expires. + session_id: Session ID to associate with this adapter. Must be a non-empty string. + + Raises: + ValueError: If session_id is None or empty. """ + if not session_id: + raise ValueError(f'session_id must be provided when registering adapter {adapter_name}') current_time = time.time() self._adapter_records[adapter_name] = { 'token': token, 'session_id': session_id, - 'last_activity': current_time, 'created_at': current_time, - 'inactivity_counter': 0, 'state': {}, 'expiring': False, } - logger.debug(f'[AdapterManager] Registered adapter {adapter_name} for token {token[:8]}...' - + (f' (session: {session_id})' if session_id else '')) + logger.debug( + f'[AdapterManager] Registered adapter {adapter_name} for token {token[:8]}... (session: {session_id})') def _is_session_alive(self, session_id: str) -> bool: """Check if a session is still alive via state proxy. @@ -166,24 +165,6 @@ def clear_adapter_state(self, adapter_name: str) -> None: return info['state'] = {} - def touch_adapter(self, adapter_name: str) -> bool: - """Update adapter activity timestamp to prevent timeout. - - Args: - adapter_name: Name of the adapter to touch. - - Returns: - True if adapter was found and touched, False otherwise. - """ - info = self._adapter_records.get(adapter_name) - if not info: - return False - if info.get('expiring'): - return False - info['last_activity'] = time.time() - info['inactivity_counter'] = 0 - return True - def get_adapter_info(self, adapter_name: str) -> dict[str, Any] | None: """Get information about a registered adapter. @@ -230,18 +211,17 @@ def assert_adapter_exists(self, adapter_name: str) -> None: f'Adapter {adapter_name} not found' def _adapter_countdown_loop(self) -> None: - """Background thread that monitors and handles inactive adapters. + """Background thread that monitors and handles adapters whose session has expired. This thread runs continuously and: - 1. Increments inactivity counters for all adapters every second - 2. Calls _on_adapter_expired() for adapters that exceed timeout + 1. Checks session liveness for all registered adapters every second + 2. Calls _on_adapter_expired() for adapters whose session has expired 3. Removes expired adapters from tracking """ - logger.debug(f'[AdapterManager] Countdown thread started (timeout={self._adapter_timeout}s)') + logger.debug(f'[AdapterManager] Countdown thread started (session_timeout={self._adapter_timeout}s)') while self._adapter_countdown_running: try: time.sleep(1) - now = time.time() expired_adapters: list[tuple[str, str | None]] = [] # Create snapshot to avoid modification during iteration @@ -251,54 +231,24 @@ def _adapter_countdown_loop(self) -> None: continue session_id = info.get('session_id') - created_at = info.get('created_at') - - # Check TTL for both cases - exceeded_ttl = ( - self._adapter_max_lifetime and self._adapter_max_lifetime > 0 - and (now - created_at) > self._adapter_max_lifetime) - - # Different logic based on session association - if session_id: - # Has session: check session expiration and TTL - session_expired = not self._is_session_alive(session_id) - should_expire = session_expired or exceeded_ttl - logger.debug( - f'[AdapterManager] Adapter {adapter_name} session expiration check ' - f'(session_id={session_id}, session_alive={not session_expired}, should_expire={should_expire})' # noqa:E501 - ) - expiration_reasons = [] - if exceeded_ttl: - expiration_reasons.append('ttl_exceeded') - if session_expired: - expiration_reasons.append('session_expired') - else: - # No session: check inactivity timeout and TTL - info['inactivity_counter'] = info.get('inactivity_counter', 0) + 1 - exceeded_inactivity = info['inactivity_counter'] > self._adapter_timeout - should_expire = exceeded_ttl or exceeded_inactivity - logger.debug( - f'[AdapterManager] Adapter {adapter_name} inactivity check ' - f'(inactivity_counter={info["inactivity_counter"]}, timeout={self._adapter_timeout}, should_expire={should_expire})' # noqa:E501 - ) - expiration_reasons = [] - if exceeded_ttl: - expiration_reasons.append('ttl_exceeded') - if exceeded_inactivity: - expiration_reasons.append('inactivity_timeout') - - if should_expire: + session_expired = not self._is_session_alive(session_id) + logger.debug(f'[AdapterManager] Adapter {adapter_name} session check ' + f'(session_id={session_id}, session_alive={not session_expired})') + + if session_expired: info['expiring'] = True info['state'] = {} # best-effort clear token = info.get('token') expired_adapters.append((adapter_name, token)) - for adapter_name, token in expired_adapters: + for adapter_name, _token in expired_adapters: success = False try: self._on_adapter_expired(adapter_name) + info = self._adapter_records.get(adapter_name, {}) + session_id = info.get('session_id') logger.info(f'[AdapterManager] Adapter {adapter_name} expired ' - f"(reasons={','.join(expiration_reasons)}, session={session_id})") + f'(reason=session_expired, session={session_id})') success = True except Exception as e: logger.warning(f'[AdapterManager] Error while expiring adapter {adapter_name}: {e}') diff --git a/src/twinkle/server/utils/validation.py b/src/twinkle/server/utils/validation.py index 23539ed8..96a1f33a 100644 --- a/src/twinkle/server/utils/validation.py +++ b/src/twinkle/server/utils/validation.py @@ -32,6 +32,7 @@ async def verify_request_token(request: Request, call_next): status_code=400, content={'detail': 'Missing X-Ray-Serve-Request-Id header, required for sticky session'}) request.state.request_id = request_id request.state.token = token + request.state.session_id = request.headers.get('X-Twinkle-Session-Id') or '' response = await call_next(request) return response @@ -63,3 +64,16 @@ def get_token_from_request(request: Request) -> str: The extracted token or empty string if not found """ return getattr(request.state, 'token', '') or '' + + +def get_session_id_from_request(request: Request) -> str: + """ + Extract session ID from request. + + Args: + request: The FastAPI Request object + + Returns: + The extracted session ID or empty string if not found + """ + return getattr(request.state, 'session_id', '') or '' diff --git a/src/twinkle_client/types/__init__.py b/src/twinkle_client/types/__init__.py index d91d8240..57485dfd 100644 --- a/src/twinkle_client/types/__init__.py +++ b/src/twinkle_client/types/__init__.py @@ -23,8 +23,6 @@ GetStateDictRequest, GetStateDictResponse, GetTrainConfigsResponse, - HeartbeatRequest, - HeartbeatResponse, LoadRequest, LoadResponse, LrStepResponse, @@ -59,8 +57,6 @@ AddAdapterRequest as SamplerAddAdapterRequest, AddAdapterResponse, CreateResponse as SamplerCreateResponse, - HeartbeatRequest as SamplerHeartbeatRequest, - HeartbeatResponse as SamplerHeartbeatResponse, SampleRequest, SampleResponseModel, SetTemplateRequest as SamplerSetTemplateRequest, diff --git a/src/twinkle_client/types/model.py b/src/twinkle_client/types/model.py index c5abf641..a349e024 100644 --- a/src/twinkle_client/types/model.py +++ b/src/twinkle_client/types/model.py @@ -113,10 +113,6 @@ class Config: extra = 'allow' -class HeartbeatRequest(BaseModel): - adapter_name: str - - class CalculateMetricRequest(BaseModel): adapter_name: str is_training: bool = True @@ -299,8 +295,3 @@ class AddAdapterResponse(BaseModel): """Response for /add_adapter_to_model endpoint.""" status: str = 'ok' adapter_name: str - - -class HeartbeatResponse(BaseModel): - """Response for /heartbeat endpoint.""" - status: str = 'ok' diff --git a/src/twinkle_client/types/sampler.py b/src/twinkle_client/types/sampler.py index 303316a9..c78f5d55 100644 --- a/src/twinkle_client/types/sampler.py +++ b/src/twinkle_client/types/sampler.py @@ -53,16 +53,6 @@ class AddAdapterResponse(BaseModel): adapter_name: str -class HeartbeatRequest(BaseModel): - """Request body for the /heartbeat endpoint.""" - adapter_name: str = Field(..., description='Adapter name to keep alive') - - -class HeartbeatResponse(BaseModel): - """Response body for the /heartbeat endpoint.""" - status: str = 'ok' - - class CreateResponse(BaseModel): """Response body for the /create endpoint.""" status: str = 'ok' From ba05fc2444f9759724fad8caff1a056fdb901ae3 Mon Sep 17 00:00:00 2001 From: Yunnglin Date: Thu, 12 Mar 2026 15:45:24 +0800 Subject: [PATCH 15/24] update twinkle model --- src/twinkle/server/model/twinkle_handlers.py | 61 ++++++++++++-------- src/twinkle/server/utils/adapter_manager.py | 6 +- src/twinkle_client/http/http_utils.py | 14 +++++ src/twinkle_client/manager.py | 24 +++++++- 4 files changed, 73 insertions(+), 32 deletions(-) diff --git a/src/twinkle/server/model/twinkle_handlers.py b/src/twinkle/server/model/twinkle_handlers.py index eda462a6..7265f2b8 100644 --- a/src/twinkle/server/model/twinkle_handlers.py +++ b/src/twinkle/server/model/twinkle_handlers.py @@ -8,7 +8,8 @@ """ from __future__ import annotations -from fastapi import Depends, FastAPI, Request +import traceback +from fastapi import Depends, FastAPI, HTTPException, Request from peft import LoraConfig from typing import TYPE_CHECKING, Any, Callable, Optional @@ -55,6 +56,16 @@ def _register_twinkle_routes(app: FastAPI, self_fn: Callable[[], ModelManagement replica instance. It is wired in via Depends so it is resolved lazily at request time. """ + async def run_task(coro): + """Await a schedule_task_and_wait coroutine and surface any exception as a + structured HTTP 500 response so the client receives the full traceback instead + of an opaque connection-level error.""" + try: + return await coro + except Exception: + logger.error(traceback.format_exc()) + raise HTTPException(status_code=500, detail=traceback.format_exc()) + @app.post('/twinkle/create', response_model=types.CreateResponse) async def create(request: Request, body: types.CreateRequest, self: ModelManagement = Depends(self_fn)) -> types.CreateResponse: @@ -72,7 +83,7 @@ async def _task(): ret = self.model.forward(inputs=inputs, adapter_name=adapter_name, **extra_kwargs) return {'result': ret} - return await self.schedule_task_and_wait(_task, task_type='forward') + return await run_task(self.schedule_task_and_wait(_task, task_type='forward')) @app.post('/twinkle/forward_only', response_model=types.ForwardResponse) async def forward_only( @@ -89,7 +100,7 @@ async def _task(): ret = self.model.forward_only(inputs=inputs, adapter_name=adapter_name, **extra_kwargs) return {'result': ret} - return await self.schedule_task_and_wait(_task, task_type='forward_only') + return await run_task(self.schedule_task_and_wait(_task, task_type='forward_only')) @app.post('/twinkle/calculate_loss', response_model=types.CalculateLossResponse) async def calculate_loss( @@ -105,7 +116,7 @@ async def _task(): ret = self.model.calculate_loss(adapter_name=adapter_name, **extra_kwargs) return {'result': ret} - return await self.schedule_task_and_wait(_task, task_type='calculate_loss') + return await run_task(self.schedule_task_and_wait(_task, task_type='calculate_loss')) @app.post('/twinkle/backward') async def backward(request: Request, body: types.AdapterRequest, self: ModelManagement = Depends(self_fn)) -> None: @@ -116,7 +127,7 @@ async def _task(): extra_kwargs = body.model_extra or {} self.model.backward(adapter_name=adapter_name, **extra_kwargs) - await self.schedule_task_and_wait(_task, task_type='backward') + await run_task(self.schedule_task_and_wait(_task, task_type='backward')) @app.post('/twinkle/forward_backward', response_model=types.ForwardBackwardResponse) async def forward_backward( @@ -133,7 +144,7 @@ async def _task(): ret = self.model.forward_backward(inputs=inputs, adapter_name=adapter_name, **extra_kwargs) return {'result': ret} - return await self.schedule_task_and_wait(_task, task_type='forward_backward') + return await run_task(self.schedule_task_and_wait(_task, task_type='forward_backward')) @app.post('/twinkle/clip_grad_norm', response_model=types.ClipGradNormResponse) async def clip_grad_norm( @@ -149,7 +160,7 @@ async def _task(): ret = self.model.clip_grad_norm(adapter_name=adapter_name, **extra_kwargs) return {'result': str(ret)} - return await self.schedule_task_and_wait(_task, task_type='clip_grad_norm') + return await run_task(self.schedule_task_and_wait(_task, task_type='clip_grad_norm')) @app.post('/twinkle/step') async def step(request: Request, body: types.AdapterRequest, self: ModelManagement = Depends(self_fn)) -> None: @@ -160,7 +171,7 @@ async def _task(): extra_kwargs = body.model_extra or {} self.model.step(adapter_name=adapter_name, **extra_kwargs) - await self.schedule_task_and_wait(_task, task_type='step') + await run_task(self.schedule_task_and_wait(_task, task_type='step')) @app.post('/twinkle/zero_grad') async def zero_grad(request: Request, body: types.AdapterRequest, self: ModelManagement = Depends(self_fn)) -> None: @@ -171,7 +182,7 @@ async def _task(): extra_kwargs = body.model_extra or {} self.model.zero_grad(adapter_name=adapter_name, **extra_kwargs) - await self.schedule_task_and_wait(_task, task_type='zero_grad') + await run_task(self.schedule_task_and_wait(_task, task_type='zero_grad')) @app.post('/twinkle/lr_step') async def lr_step(request: Request, body: types.AdapterRequest, self: ModelManagement = Depends(self_fn)) -> None: @@ -182,7 +193,7 @@ async def _task(): extra_kwargs = body.model_extra or {} self.model.lr_step(adapter_name=adapter_name, **extra_kwargs) - await self.schedule_task_and_wait(_task, task_type='lr_step') + await run_task(self.schedule_task_and_wait(_task, task_type='lr_step')) @app.post('/twinkle/clip_grad_and_step') async def clip_grad_and_step( @@ -202,7 +213,7 @@ async def _task(): **extra_kwargs, ) - await self.schedule_task_and_wait(_task, task_type='clip_grad_and_step') + await run_task(self.schedule_task_and_wait(_task, task_type='clip_grad_and_step')) @app.post('/twinkle/get_train_configs', response_model=types.GetTrainConfigsResponse) async def get_train_configs( @@ -218,7 +229,7 @@ async def _task(): ret = self.model.get_train_configs(adapter_name=adapter_name, **extra_kwargs) return {'result': ret} - return await self.schedule_task_and_wait(_task, task_type='get_train_configs') + return await run_task(self.schedule_task_and_wait(_task, task_type='get_train_configs')) @app.post('/twinkle/set_loss') async def set_loss(request: Request, body: types.SetLossRequest, self: ModelManagement = Depends(self_fn)) -> None: @@ -229,7 +240,7 @@ async def _task(): extra_kwargs = body.model_extra or {} self.model.set_loss(body.loss_cls, adapter_name=adapter_name, **extra_kwargs) - await self.schedule_task_and_wait(_task, task_type='set_loss') + await run_task(self.schedule_task_and_wait(_task, task_type='set_loss')) @app.post('/twinkle/set_optimizer') async def set_optimizer( @@ -244,7 +255,7 @@ async def _task(): extra_kwargs = body.model_extra or {} self.model.set_optimizer(body.optimizer_cls, adapter_name=adapter_name, **extra_kwargs) - await self.schedule_task_and_wait(_task, task_type='set_optimizer') + await run_task(self.schedule_task_and_wait(_task, task_type='set_optimizer')) @app.post('/twinkle/set_lr_scheduler') async def set_lr_scheduler( @@ -259,7 +270,7 @@ async def _task(): extra_kwargs = body.model_extra or {} self.model.set_lr_scheduler(body.scheduler_cls, adapter_name=adapter_name, **extra_kwargs) - await self.schedule_task_and_wait(_task, task_type='set_lr_scheduler') + await run_task(self.schedule_task_and_wait(_task, task_type='set_lr_scheduler')) @app.post('/twinkle/save', response_model=types.SaveResponse) async def save(request: Request, body: types.SaveRequest, @@ -282,7 +293,7 @@ async def _task(): twinkle_path = checkpoint_manager.save(model_id=adapter_name, name=checkpoint_name, is_sampler=False) return {'result': twinkle_path, 'checkpoint_dir': checkpoint_dir} - return await self.schedule_task_and_wait(_task, task_type='save') + return await run_task(self.schedule_task_and_wait(_task, task_type='save')) @app.post('/twinkle/load') async def load(request: Request, body: types.LoadRequest, self: ModelManagement = Depends(self_fn)) -> None: @@ -302,7 +313,7 @@ async def _task(): token=token, **extra_kwargs) - await self.schedule_task_and_wait(_task, task_type='load') + await run_task(self.schedule_task_and_wait(_task, task_type='load')) @app.post('/twinkle/upload_to_hub') async def upload_to_hub( @@ -333,7 +344,7 @@ async def _task(): hub_token=body.hub_token or token, async_upload=body.async_upload) - await self.schedule_task_and_wait(_task, task_type='upload_to_hub') + await run_task(self.schedule_task_and_wait(_task, task_type='upload_to_hub')) @app.post('/twinkle/add_adapter_to_model', response_model=types.AddAdapterResponse) async def add_adapter_to_model( @@ -361,7 +372,7 @@ async def _task(): training_run_manager.save(adapter_name, run_config) return {'status': 'ok', 'adapter_name': adapter_name} - return await self.schedule_task_and_wait(_task, task_type='add_adapter_to_model') + return await run_task(self.schedule_task_and_wait(_task, task_type='add_adapter_to_model')) @app.post('/twinkle/apply_patch') async def apply_patch( @@ -377,7 +388,7 @@ async def _task(): patch_cls = deserialize_object(body.patch_cls) self.model.apply_patch(patch_cls, adapter_name=adapter_name, **extra_kwargs) - await self.schedule_task_and_wait(_task, task_type='apply_patch') + await run_task(self.schedule_task_and_wait(_task, task_type='apply_patch')) @app.post('/twinkle/add_metric') async def add_metric( @@ -393,7 +404,7 @@ async def _task(): metric_cls = deserialize_object(body.metric_cls) self.model.add_metric(metric_cls, is_training=body.is_training, adapter_name=adapter_name, **extra_kwargs) - await self.schedule_task_and_wait(_task, task_type='add_metric') + await run_task(self.schedule_task_and_wait(_task, task_type='add_metric')) @app.post('/twinkle/set_template') async def set_template( @@ -408,7 +419,7 @@ async def _task(): extra_kwargs = body.model_extra or {} self.model.set_template(body.template_cls, adapter_name=adapter_name, **extra_kwargs) - await self.schedule_task_and_wait(_task, task_type='set_template') + await run_task(self.schedule_task_and_wait(_task, task_type='set_template')) @app.post('/twinkle/set_processor') async def set_processor( @@ -423,7 +434,7 @@ async def _task(): extra_kwargs = body.model_extra or {} self.model.set_processor(body.processor_cls, adapter_name=adapter_name, **extra_kwargs) - await self.schedule_task_and_wait(_task, task_type='set_processor') + await run_task(self.schedule_task_and_wait(_task, task_type='set_processor')) @app.post('/twinkle/calculate_metric', response_model=types.CalculateMetricResponse) async def calculate_metric( @@ -439,7 +450,7 @@ async def _task(): ret = self.model.calculate_metric(is_training=body.is_training, adapter_name=adapter_name, **extra_kwargs) return {'result': ret} - return await self.schedule_task_and_wait(_task, task_type='calculate_metric') + return await run_task(self.schedule_task_and_wait(_task, task_type='calculate_metric')) @app.post('/twinkle/get_state_dict', response_model=types.GetStateDictResponse) async def get_state_dict( @@ -455,4 +466,4 @@ async def _task(): ret = self.model.get_state_dict(adapter_name=adapter_name, **extra_kwargs) return {'result': ret} - return await self.schedule_task_and_wait(_task, task_type='get_state_dict') + return await run_task(self.schedule_task_and_wait(_task, task_type='get_state_dict')) diff --git a/src/twinkle/server/utils/adapter_manager.py b/src/twinkle/server/utils/adapter_manager.py index c1706b65..9a461b32 100644 --- a/src/twinkle/server/utils/adapter_manager.py +++ b/src/twinkle/server/utils/adapter_manager.py @@ -239,14 +239,12 @@ def _adapter_countdown_loop(self) -> None: info['expiring'] = True info['state'] = {} # best-effort clear token = info.get('token') - expired_adapters.append((adapter_name, token)) + expired_adapters.append((adapter_name, token, session_id)) - for adapter_name, _token in expired_adapters: + for adapter_name, _token, session_id in expired_adapters: success = False try: self._on_adapter_expired(adapter_name) - info = self._adapter_records.get(adapter_name, {}) - session_id = info.get('session_id') logger.info(f'[AdapterManager] Adapter {adapter_name} expired ' f'(reason=session_expired, session={session_id})') success = True diff --git a/src/twinkle_client/http/http_utils.py b/src/twinkle_client/http/http_utils.py index 490337bb..70001f7e 100644 --- a/src/twinkle_client/http/http_utils.py +++ b/src/twinkle_client/http/http_utils.py @@ -64,12 +64,26 @@ def _handle_response(response: requests.Response) -> requests.Response: Raises: StopIteration: When server returns HTTP 410 (iterator exhausted) + requests.HTTPError: When server returns a 4xx/5xx error, with the + server-side ``detail`` field (full traceback) included in the + exception message so callers don't need to inspect the response body. """ # Convert HTTP 410 Gone to StopIteration # This indicates an iterator has been exhausted if response.status_code == 410: raise StopIteration(response.json().get('detail', 'Iterator exhausted')) + if not response.ok: + try: + detail = response.json().get('detail', response.text) + except Exception: + detail = response.text + http_error_msg = ( + f'{response.status_code} Error for url: {response.url}\n' + f'Server detail:\n{detail}' + ) + raise requests.HTTPError(http_error_msg, response=response) + return response diff --git a/src/twinkle_client/manager.py b/src/twinkle_client/manager.py index df097083..9a5b41ff 100644 --- a/src/twinkle_client/manager.py +++ b/src/twinkle_client/manager.py @@ -4,7 +4,7 @@ import atexit import threading from typing import Any, Dict, List, Optional, Tuple - +from twinkle import get_logger from twinkle_client.types.server import DeleteCheckpointResponse from twinkle_client.types.session import (CreateSessionRequest, CreateSessionResponse, SessionHeartbeatRequest, SessionHeartbeatResponse) @@ -12,6 +12,7 @@ TrainingRunsResponse, WeightsInfoResponse) from .http import get_api_key, get_base_url, http_delete, http_get, http_post, set_api_key, set_base_url, set_session_id +logger = get_logger() class TwinkleClientError(Exception): """Base exception for TwinkleManager errors.""" @@ -115,17 +116,34 @@ def create_session(self, metadata: Optional[Dict[str, Any]] = None) -> str: return CreateSessionResponse(**resp.json()).session_id def _touch_session_loop(self) -> None: - """Background loop: touch the session every N seconds.""" + """Background loop: touch the session every N seconds. + + The sleep interval is corrected for the time spent in the request so + that the + server-visible gap stays close to ``_heartbeat_interval`` + even when the HTTP call itself takes a few seconds. + """ + import time while not self._stop_event.wait(timeout=self._heartbeat_interval): + t0 = time.monotonic() try: + logger.debug('[TwinkleClient] Touching session...') resp = http_post( self._get_url('/session_heartbeat'), json_data=SessionHeartbeatRequest(session_id=self._session_id).model_dump(), + timeout=min(self._heartbeat_interval, 10), ) + logger.debug(f'[TwinkleClient] Session heartbeat response: {resp.status_code}') resp.raise_for_status() except Exception as e: # Do not crash the background thread on transient errors. - print(f'[TwinkleClient] Session heartbeat error: {e}') + logger.error(f'[TwinkleClient] Session heartbeat error: {e}') + elapsed = time.monotonic() - t0 + # If the request consumed most of the interval, skip the residual + # wait so we don't fall further behind; next full sleep follows. + residual = self._heartbeat_interval - elapsed + if residual > 0: + self._stop_event.wait(timeout=residual) def close(self) -> None: """Stop the background heartbeat thread and clear session context.""" From 641302efe576ac16ce26efb14c291a7cd2bfae66 Mon Sep 17 00:00:00 2001 From: Yunnglin Date: Thu, 12 Mar 2026 16:40:05 +0800 Subject: [PATCH 16/24] fix heartbeat --- .../server/transformer/server_config.yaml | 7 ++-- .../twinkle/self_host/self_congnition.py | 2 +- src/twinkle/server/processor/app.py | 16 -------- .../server/utils/state/server_state.py | 39 ------------------- src/twinkle_client/__init__.py | 4 +- src/twinkle_client/manager.py | 27 +++++++------ 6 files changed, 20 insertions(+), 75 deletions(-) diff --git a/cookbook/client/server/transformer/server_config.yaml b/cookbook/client/server/transformer/server_config.yaml index ea310520..9bf6dc03 100644 --- a/cookbook/client/server/transformer/server_config.yaml +++ b/cookbook/client/server/transformer/server_config.yaml @@ -24,6 +24,7 @@ applications: - Qwen/Qwen3.5-4B deployments: - name: TinkerCompatServer + max_ongoing_requests: 10 autoscaling_config: min_replicas: 1 # Minimum number of replicas max_replicas: 1 # Maximum number of replicas @@ -106,14 +107,14 @@ applications: route_prefix: /api/v1/processor import_path: processor args: - ncpu_proc_per_node: 2 # 每节点 CPU 进程数 + ncpu_proc_per_node: 1 # 每节点 CPU 进程数 device_group: name: model - ranks: 2 + ranks: 1 device_type: CPU device_mesh: device_type: CPU - dp_size: 2 # 数据并行大小 + dp_size: 1 # 数据并行大小 deployments: - name: ProcessorManagement autoscaling_config: diff --git a/cookbook/client/twinkle/self_host/self_congnition.py b/cookbook/client/twinkle/self_host/self_congnition.py index 7aa56e0f..b70000ed 100644 --- a/cookbook/client/twinkle/self_host/self_congnition.py +++ b/cookbook/client/twinkle/self_host/self_congnition.py @@ -95,7 +95,7 @@ def train(): model.load(resume_path, load_optimizer=True) # Step 7: Run the training loop - logger.info(model.get_train_configs()) + logger.info(model.get_train_configs().model_dump()) for epoch in range(3): logger.info(f'Starting epoch {epoch}') diff --git a/src/twinkle/server/processor/app.py b/src/twinkle/server/processor/app.py index 649b3768..764b21a3 100644 --- a/src/twinkle/server/processor/app.py +++ b/src/twinkle/server/processor/app.py @@ -82,21 +82,6 @@ def __init__(self, def assert_processor_exists(self, processor_id: str): assert processor_id and processor_id in self.resource_dict, f'Processor {processor_id} not found' - def handle_processor_count(self, token: str, add: bool): - user_key = token + '_' + 'processor' - cur_count = self.state.get_config(user_key) or 0 - if add: - if cur_count < self.per_token_processor_limit: - self.state.add_config(user_key, cur_count + 1) - else: - raise RuntimeError(f'Processor count limitation reached: {self.per_token_processor_limit}') - else: - if cur_count > 0: - cur_count -= 1 - self.state.add_config(user_key, cur_count) - if cur_count <= 0: - self.state.pop_config(user_key) - @app.post('/twinkle/create', response_model=types.ProcessorCreateResponse) def create(self, request: Request, body: types.ProcessorCreateRequest) -> types.ProcessorCreateResponse: processor_type_name = body.processor_type @@ -106,7 +91,6 @@ def create(self, request: Request, body: types.ProcessorCreateRequest) -> types. assert processor_type_name in processors, f'Invalid processor type: {processor_type_name}' processor_module = importlib.import_module(f'twinkle.{processor_type_name}') assert hasattr(processor_module, class_type), f'Class {class_type} not found in {processor_type_name}' - self.handle_processor_count(request.state.token, True) processor_id = str(uuid.uuid4().hex) self.key_token_dict[processor_id] = request.state.token diff --git a/src/twinkle/server/utils/state/server_state.py b/src/twinkle/server/utils/state/server_state.py index 7588c65d..70dcfe9c 100644 --- a/src/twinkle/server/utils/state/server_state.py +++ b/src/twinkle/server/utils/state/server_state.py @@ -239,28 +239,6 @@ def store_future_status( queue_state_reason=queue_state_reason, ) - # ----- Config Management ----- - - def add_config(self, key: str, value: Any) -> None: - """Add or update a configuration value.""" - self._config_mgr.add(key, value) - - def add_or_get(self, key: str, value: Any) -> Any: - """Add a config value if the key does not exist; otherwise return the existing value.""" - return self._config_mgr.add_or_get(key, value) - - def get_config(self, key: str) -> Any | None: - """Get a configuration value by key.""" - return self._config_mgr.get(key) - - def pop_config(self, key: str) -> Any | None: - """Remove and return a configuration value.""" - return self._config_mgr.pop(key) - - def clear_config(self) -> None: - """Clear all configuration values.""" - self._config_mgr.clear() - # ----- Resource Cleanup ----- def cleanup_expired_resources(self) -> dict[str, int]: @@ -432,23 +410,6 @@ def store_future_status( self._actor.store_future_status.remote(request_id, status, model_id, reason, result, queue_state, queue_state_reason)) - # ----- Config Management ----- - - def add_config(self, key: str, value: Any): - return ray.get(self._actor.add_config.remote(key, value)) - - def add_or_get(self, key: str, value: Any) -> Any: - return ray.get(self._actor.add_or_get.remote(key, value)) - - def get_config(self, key: str) -> Any | None: - return ray.get(self._actor.get_config.remote(key)) - - def pop_config(self, key: str) -> Any | None: - return ray.get(self._actor.pop_config.remote(key)) - - def clear_config(self): - return ray.get(self._actor.clear_config.remote()) - # ----- Resource Cleanup ----- def cleanup_expired_resources(self) -> dict[str, int]: diff --git a/src/twinkle_client/__init__.py b/src/twinkle_client/__init__.py index a87eba5d..f41a83ce 100644 --- a/src/twinkle_client/__init__.py +++ b/src/twinkle_client/__init__.py @@ -34,7 +34,7 @@ def init_tinker_client(**kwargs) -> None: def init_twinkle_client( base_url: Optional[str] = None, api_key: Optional[str] = None, - session_heartbeat_interval: int = 30, + session_heartbeat_interval: int = 10, **kwargs, ) -> 'TwinkleClient': """ @@ -54,7 +54,7 @@ def init_twinkle_client( Args: base_url: Twinkle server base URL. Falls back to ``TWINKLE_SERVER_URL``. api_key: Authentication token. Falls back to ``TWINKLE_SERVER_TOKEN``. - session_heartbeat_interval: Seconds between session touch calls (default: 30). + session_heartbeat_interval: Seconds between session touch calls (default: 10). **kwargs: Additional keyword arguments forwarded to :class:`TwinkleClient`. Returns: diff --git a/src/twinkle_client/manager.py b/src/twinkle_client/manager.py index 9a5b41ff..108465ec 100644 --- a/src/twinkle_client/manager.py +++ b/src/twinkle_client/manager.py @@ -116,34 +116,33 @@ def create_session(self, metadata: Optional[Dict[str, Any]] = None) -> str: return CreateSessionResponse(**resp.json()).session_id def _touch_session_loop(self) -> None: - """Background loop: touch the session every N seconds. + """Background loop: touch the session every ``_heartbeat_interval`` seconds. - The sleep interval is corrected for the time spent in the request so - that the - server-visible gap stays close to ``_heartbeat_interval`` - even when the HTTP call itself takes a few seconds. + Uses a fixed-rate design: the wall-clock period between successive + server-side heartbeats stays close to ``_heartbeat_interval`` regardless + of how long the HTTP call takes, by subtracting elapsed time from the + subsequent sleep. """ import time - while not self._stop_event.wait(timeout=self._heartbeat_interval): + while not self._stop_event.is_set(): t0 = time.monotonic() + success = False try: - logger.debug('[TwinkleClient] Touching session...') + logger.debug(f'[TwinkleClient] Touching session (session={self._session_id})...') resp = http_post( self._get_url('/session_heartbeat'), json_data=SessionHeartbeatRequest(session_id=self._session_id).model_dump(), timeout=min(self._heartbeat_interval, 10), ) - logger.debug(f'[TwinkleClient] Session heartbeat response: {resp.status_code}') resp.raise_for_status() + success = True except Exception as e: - # Do not crash the background thread on transient errors. logger.error(f'[TwinkleClient] Session heartbeat error: {e}') elapsed = time.monotonic() - t0 - # If the request consumed most of the interval, skip the residual - # wait so we don't fall further behind; next full sleep follows. - residual = self._heartbeat_interval - elapsed - if residual > 0: - self._stop_event.wait(timeout=residual) + if success: + logger.debug(f'[TwinkleClient] Session heartbeat OK (elapsed={elapsed:.2f}s)') + sleep_time = max(0.0, self._heartbeat_interval - elapsed) + self._stop_event.wait(timeout=sleep_time) def close(self) -> None: """Stop the background heartbeat thread and clear session context.""" From 415da4e24dd6e0384eca8ccff13044787481b49d Mon Sep 17 00:00:00 2001 From: Yunnglin Date: Thu, 12 Mar 2026 18:26:07 +0800 Subject: [PATCH 17/24] fix processor --- .../server/transformer/server_config.yaml | 8 +- setup.cfg | 2 +- src/twinkle/hub/hub.py | 4 +- src/twinkle/server/common/datum.py | 1 - src/twinkle/server/common/router.py | 2 +- .../server/gateway/tinker_gateway_handlers.py | 8 +- .../gateway/twinkle_gateway_handlers.py | 2 +- src/twinkle/server/processor/app.py | 224 ++++++++---------- .../server/processor/twinkle_handlers.py | 130 ++++++++++ src/twinkle/server/utils/__init__.py | 1 + src/twinkle/server/utils/adapter_manager.py | 27 ++- src/twinkle/server/utils/processor_manager.py | 209 ++++++++++++++++ .../server/utils/state/server_state.py | 33 ++- src/twinkle/server/utils/task_queue.py | 22 +- src/twinkle_client/types/model.py | 2 +- 15 files changed, 507 insertions(+), 168 deletions(-) create mode 100644 src/twinkle/server/processor/twinkle_handlers.py create mode 100644 src/twinkle/server/utils/processor_manager.py diff --git a/cookbook/client/server/transformer/server_config.yaml b/cookbook/client/server/transformer/server_config.yaml index 9bf6dc03..91dff9b2 100644 --- a/cookbook/client/server/transformer/server_config.yaml +++ b/cookbook/client/server/transformer/server_config.yaml @@ -24,7 +24,7 @@ applications: - Qwen/Qwen3.5-4B deployments: - name: TinkerCompatServer - max_ongoing_requests: 10 + max_ongoing_requests: 50 autoscaling_config: min_replicas: 1 # Minimum number of replicas max_replicas: 1 # Maximum number of replicas @@ -107,14 +107,14 @@ applications: route_prefix: /api/v1/processor import_path: processor args: - ncpu_proc_per_node: 1 # 每节点 CPU 进程数 + ncpu_proc_per_node: 2 # 每节点 CPU 进程数 device_group: name: model - ranks: 1 + ranks: 2 device_type: CPU device_mesh: device_type: CPU - dp_size: 1 # 数据并行大小 + dp_size: 2 # 数据并行大小 deployments: - name: ProcessorManagement autoscaling_config: diff --git a/setup.cfg b/setup.cfg index 3ca70ce3..811fd55c 100644 --- a/setup.cfg +++ b/setup.cfg @@ -22,7 +22,7 @@ ignore-words-list = patten,nd,ty,mot,hist,formating,winn,gool,datas,wan,confids [flake8] max-line-length = 120 select = B,E,F,P,T4,W,B9 -ignore = F401,F403,F405,F821,W503,E251,W504,E126 +ignore = F401,F403,F405,F821,W503,E251,W504,E126,E125 exclude = docs/src,*.pyi,.git,peft.py [darglint] diff --git a/src/twinkle/hub/hub.py b/src/twinkle/hub/hub.py index 6e1653e1..916a42b2 100644 --- a/src/twinkle/hub/hub.py +++ b/src/twinkle/hub/hub.py @@ -374,7 +374,7 @@ def push_to_hub(cls, ignore_patterns = [] if revision is None or revision == 'main': revision = 'master' - return push_to_hub( + result = push_to_hub( repo_id, folder_path, token or cls.ms_token, @@ -383,6 +383,8 @@ def push_to_hub(cls, ignore_file_pattern=ignore_patterns, revision=revision, tag=path_in_repo) + if not result: + raise Exception('Failed to push to hub') @classmethod def load_dataset(cls, diff --git a/src/twinkle/server/common/datum.py b/src/twinkle/server/common/datum.py index 7dd0ae1c..9091f388 100644 --- a/src/twinkle/server/common/datum.py +++ b/src/twinkle/server/common/datum.py @@ -1,5 +1,4 @@ # Copyright (c) ModelScope Contributors. All rights reserved. -# Moved from tinker/common/datum.py — logic unchanged. from __future__ import annotations import numpy as np diff --git a/src/twinkle/server/common/router.py b/src/twinkle/server/common/router.py index 27abbfbd..dee1bd36 100644 --- a/src/twinkle/server/common/router.py +++ b/src/twinkle/server/common/router.py @@ -56,7 +56,7 @@ async def choose_replicas( # Filter out replicas that exceed max lora count (query from server state) candidate_ids = [r.replica_id.unique_id for r in top_ranked_replicas.values()] - available_ids = set(self.state.get_available_replica_ids(candidate_ids)) + available_ids = set(await self.state.get_available_replica_ids(candidate_ids)) if available_ids: top_ranked_replicas = { rid: r diff --git a/src/twinkle/server/gateway/tinker_gateway_handlers.py b/src/twinkle/server/gateway/tinker_gateway_handlers.py index b545ad63..71c0654f 100644 --- a/src/twinkle/server/gateway/tinker_gateway_handlers.py +++ b/src/twinkle/server/gateway/tinker_gateway_handlers.py @@ -61,7 +61,7 @@ async def create_session( async def session_heartbeat( request: Request, body: types.SessionHeartbeatRequest, self: GatewayServer = Depends(self_fn) ) -> types.SessionHeartbeatResponse: # noqa: E125 - alive = self.state.touch_session(body.session_id) + alive = await self.state.touch_session(body.session_id) if not alive: raise HTTPException(status_code=404, detail='Unknown session') return types.SessionHeartbeatResponse() @@ -84,7 +84,7 @@ async def retrieve_future(request: Request, start = asyncio.get_event_loop().time() while True: - record = self.state.get_future(request_id) + record = await self.state.get_future(request_id) if record is None: return {'type': 'try_again'} @@ -103,7 +103,7 @@ async def retrieve_future(request: Request, await asyncio.sleep(poll_interval) - record = self.state.get_future(request_id) + record = await self.state.get_future(request_id) if not record: return {'type': 'try_again'} @@ -207,7 +207,7 @@ async def publish_checkpoint(request: Request, checkpoint_name = checkpoint_id.split('/')[-1] hub_model_id = f'{username}/{run_id}_{checkpoint_name}' - HubOperation.async_push_to_hub(repo_id=hub_model_id, folder_path=checkpoint_dir, token=token, private=True) + HubOperation.push_to_hub(repo_id=hub_model_id, folder_path=checkpoint_dir, token=token, private=True) return Response(status_code=204) diff --git a/src/twinkle/server/gateway/twinkle_gateway_handlers.py b/src/twinkle/server/gateway/twinkle_gateway_handlers.py index a7323446..9c0a3ba7 100644 --- a/src/twinkle/server/gateway/twinkle_gateway_handlers.py +++ b/src/twinkle/server/gateway/twinkle_gateway_handlers.py @@ -43,7 +43,7 @@ async def session_heartbeat( body: types.SessionHeartbeatRequest, self: GatewayServer = Depends(self_fn), ) -> types.SessionHeartbeatResponse: - alive = self.state.touch_session(body.session_id) + alive = await self.state.touch_session(body.session_id) if not alive: raise HTTPException(status_code=404, detail='Unknown session') return types.SessionHeartbeatResponse() diff --git a/src/twinkle/server/processor/app.py b/src/twinkle/server/processor/app.py index 764b21a3..4b03af86 100644 --- a/src/twinkle/server/processor/app.py +++ b/src/twinkle/server/processor/app.py @@ -1,155 +1,133 @@ # Copyright (c) ModelScope Contributors. All rights reserved. """ -Processor management application (moved from twinkle/processor.py). +Processor management application. Provides a Ray Serve deployment for managing distributed processors (datasets, dataloaders, preprocessors, rewards, templates, weight loaders, etc.). + +Follows the same structural pattern as model/app.py: +- ProcessorManagement is a top-level class inheriting ProcessorManagerMixin +- Routes are registered in build_processor_app() via _register_processor_routes() +- serve.ingress(app)(ProcessorManagement) applied before deployment +- Sticky session routing via @serve.multiplexed keyed on session ID """ -import importlib +from __future__ import annotations + import os -import uuid -from fastapi import FastAPI, HTTPException, Request +from fastapi import FastAPI, Request from ray import serve -from typing import Any, Dict +from typing import Any, Dict, Optional import twinkle -import twinkle_client.types as types from twinkle import DeviceGroup, DeviceMesh, get_logger -from twinkle.server.common.serialize import deserialize_object +from twinkle.server.utils.processor_manager import ProcessorManagerMixin from twinkle.server.utils.state import ServerStateProxy, get_server_state from twinkle.server.utils.validation import verify_request_token +from .twinkle_handlers import _register_processor_routes logger = get_logger() +class ProcessorManagement(ProcessorManagerMixin): + """Processor management service. + + Manages lifecycle and invocation of distributed processor objects + (datasets, dataloaders, rewards, templates, etc.). + + Lifecycle is handled by ProcessorManagerMixin: + - Processors are registered with a session ID on creation. + - A background thread expires processors whose session has timed out. + - Per-user processor limit is enforced at registration. + - Sticky session routing ensures session requests hit the same replica. + """ + + def __init__(self, + ncpu_proc_per_node: int, + device_group: dict[str, Any], + device_mesh: dict[str, Any], + nproc_per_node: int = 1, + processor_config: dict[str, Any] | None = None): + self.device_group = DeviceGroup(**device_group) + twinkle.initialize( + mode='ray', + nproc_per_node=nproc_per_node, + groups=[self.device_group], + lazy_collect=False, + ncpu_proc_per_node=ncpu_proc_per_node) + if 'mesh_dim_names' in device_mesh: + self.device_mesh = DeviceMesh(**device_mesh) + else: + self.device_mesh = DeviceMesh.from_sizes(**device_mesh) + + # processor objects keyed by processor_id + self.resource_dict: dict[str, Any] = {} + self.state: ServerStateProxy = get_server_state() + + _cfg = processor_config or {} + _env_limit = int(os.environ.get('TWINKLE_PER_USER_PROCESSOR_LIMIT', 20)) + self._init_processor_manager( + processor_timeout=float(_cfg.get('processor_timeout', 1800.0)), + per_token_processor_limit=int(_cfg.get('per_token_processor_limit', _env_limit)), + ) + self.start_processor_countdown() + + @serve.multiplexed(max_num_models_per_replica=100) + async def _sticky_entry(self, sticky_key: str): + return sticky_key + + async def _ensure_sticky(self): + sticky_key = serve.get_multiplexed_model_id() + await self._sticky_entry(sticky_key) + + def _on_processor_expired(self, processor_id: str) -> None: + """Called by the countdown thread when a processor's session expires.""" + self.resource_dict.pop(processor_id, None) + self.unregister_processor(processor_id) + + def build_processor_app(ncpu_proc_per_node: int, - device_group: Dict[str, Any], - device_mesh: Dict[str, Any], - deploy_options: Dict[str, Any], + device_group: dict[str, Any], + device_mesh: dict[str, Any], + deploy_options: dict[str, Any], nproc_per_node: int = 1, + processor_config: dict[str, Any] | None = None, **kwargs): """Build the processor management application. + Follows the same pattern as build_model_app(): FastAPI app and routes are + built here BEFORE serve.ingress so that the frozen app contains the full + route table visible to ProxyActor. + Args: - ncpu_proc_per_node: Number of CPU processes per node - device_group: Device group configuration dict - device_mesh: Device mesh configuration dict - deploy_options: Ray Serve deployment options - nproc_per_node: Number of GPU processes per node (default 1, not used for CPU-only tasks) - **kwargs: Additional arguments + ncpu_proc_per_node: Number of CPU processes per node. + device_group: Device group configuration dict. + device_mesh: Device mesh configuration dict. + deploy_options: Ray Serve deployment options. + nproc_per_node: Number of GPU processes per node (default 1). + processor_config: Optional lifecycle configuration dict. + Supported keys: + - ``processor_timeout`` (float): Session inactivity timeout seconds. Default 1800.0. + - ``per_token_processor_limit`` (int): Max processors per user. + Overrides ``TWINKLE_PER_USER_PROCESSOR_LIMIT`` env var when provided. + **kwargs: Additional arguments. Returns: - Ray Serve deployment bound with configuration + Ray Serve deployment bound with configuration. """ + # Build the FastAPI app and register all routes BEFORE serve.ingress so that + # the frozen app contains the complete route table (visible to ProxyActor). app = FastAPI() @app.middleware('http') async def verify_token(request: Request, call_next): return await verify_request_token(request=request, call_next=call_next) - processors = ['dataset', 'dataloader', 'preprocessor', 'processor', 'reward', 'template', 'weight_loader'] - - @serve.deployment(name='ProcessorManagement') - @serve.ingress(app) - class ProcessorManagement: - """Processor management service. - - Manages lifecycle and invocation of distributed processor objects - (datasets, dataloaders, rewards, templates, etc.). - """ - - def __init__(self, - ncpu_proc_per_node: int, - device_group: Dict[str, Any], - device_mesh: Dict[str, Any], - nproc_per_node: int = 1): - self.device_group = DeviceGroup(**device_group) - twinkle.initialize( - mode='ray', - nproc_per_node=nproc_per_node, - groups=[self.device_group], - lazy_collect=False, - ncpu_proc_per_node=ncpu_proc_per_node) - if 'mesh_dim_names' in device_mesh: - self.device_mesh = DeviceMesh(**device_mesh) - else: - self.device_mesh = DeviceMesh.from_sizes(**device_mesh) - self.resource_dict = {} - self.state: ServerStateProxy = get_server_state() - self.per_token_processor_limit = int(os.environ.get('TWINKLE_PER_USER_PROCESSOR_LIMIT', 20)) - self.key_token_dict = {} - - def assert_processor_exists(self, processor_id: str): - assert processor_id and processor_id in self.resource_dict, f'Processor {processor_id} not found' - - @app.post('/twinkle/create', response_model=types.ProcessorCreateResponse) - def create(self, request: Request, body: types.ProcessorCreateRequest) -> types.ProcessorCreateResponse: - processor_type_name = body.processor_type - class_type = body.class_type - _kwargs = body.model_extra or {} - - assert processor_type_name in processors, f'Invalid processor type: {processor_type_name}' - processor_module = importlib.import_module(f'twinkle.{processor_type_name}') - assert hasattr(processor_module, class_type), f'Class {class_type} not found in {processor_type_name}' - processor_id = str(uuid.uuid4().hex) - self.key_token_dict[processor_id] = request.state.token - - _kwargs.pop('remote_group', None) - _kwargs.pop('device_mesh', None) - - resolved_kwargs = {} - for key, value in _kwargs.items(): - if isinstance(value, str) and value.startswith('pid:'): - ref_id = value[4:] - resolved_kwargs[key] = self.resource_dict[ref_id] - else: - value = deserialize_object(value) - resolved_kwargs[key] = value - - processor = getattr(processor_module, class_type)( - remote_group=self.device_group.name, - device_mesh=self.device_mesh, - instance_id=processor_id, - **resolved_kwargs) - self.resource_dict[processor_id] = processor - return types.ProcessorCreateResponse(processor_id='pid:' + processor_id) - - @app.post('/twinkle/call', response_model=types.ProcessorCallResponse) - def call(self, body: types.ProcessorCallRequest) -> types.ProcessorCallResponse: - processor_id = body.processor_id - function_name = body.function - _kwargs = body.model_extra or {} - processor_id = processor_id[4:] - self.assert_processor_exists(processor_id=processor_id) - processor = self.resource_dict.get(processor_id) - function = getattr(processor, function_name, None) - - assert function is not None, f'`{function_name}` not found in {processor.__class__}' - assert hasattr(function, '_execute'), f'Cannot call inner method of {processor.__class__}' - - resolved_kwargs = {} - for key, value in _kwargs.items(): - if isinstance(value, str) and value.startswith('pid:'): - ref_id = value[4:] - resolved_kwargs[key] = self.resource_dict[ref_id] - else: - value = deserialize_object(value) - resolved_kwargs[key] = value - - # Special handling for __next__ to catch StopIteration - if function_name == '__next__': - try: - result = function(**resolved_kwargs) - return types.ProcessorCallResponse(result=result) - except StopIteration: - # HTTP 410 Gone signals iterator exhausted - raise HTTPException(status_code=410, detail='Iterator exhausted') - - result = function(**resolved_kwargs) - if function_name == '__iter__': - return types.ProcessorCallResponse(result='ok') - else: - return types.ProcessorCallResponse(result=result) - - return ProcessorManagement.options(**deploy_options).bind( - ncpu_proc_per_node, device_group, device_mesh, nproc_per_node=nproc_per_node) + def get_self() -> ProcessorManagement: + return serve.get_replica_context().servable_object + + _register_processor_routes(app, get_self) + + ProcessorManagementWithIngress = serve.ingress(app)(ProcessorManagement) + DeploymentClass = serve.deployment(name='ProcessorManagement')(ProcessorManagementWithIngress) + return DeploymentClass.options(**deploy_options).bind(ncpu_proc_per_node, device_group, device_mesh, nproc_per_node, + processor_config) diff --git a/src/twinkle/server/processor/twinkle_handlers.py b/src/twinkle/server/processor/twinkle_handlers.py new file mode 100644 index 00000000..86e35f86 --- /dev/null +++ b/src/twinkle/server/processor/twinkle_handlers.py @@ -0,0 +1,130 @@ +# Copyright (c) ModelScope Contributors. All rights reserved. +""" +Processor management handler mixin. + +All endpoints are prefixed /twinkle/... and handle processor lifecycle +(create, call). self_fn is injected via FastAPI Depends to obtain the +ProcessorManagement instance at request time. +""" +from __future__ import annotations + +import asyncio +import importlib +import uuid +from fastapi import Depends, FastAPI, HTTPException, Request +from typing import TYPE_CHECKING, Callable + +if TYPE_CHECKING: + from .app import ProcessorManagement + +import twinkle_client.types as types +from twinkle.server.common.serialize import deserialize_object +from twinkle.server.utils.validation import get_session_id_from_request, get_token_from_request +from twinkle.utils.logger import get_logger + +logger = get_logger() + +_PROCESSOR_TYPES = ['dataset', 'dataloader', 'preprocessor', 'processor', 'reward', 'template', 'weight_loader'] + + +def _register_processor_routes(app: FastAPI, self_fn: Callable[[], ProcessorManagement]) -> None: + """Register all /twinkle/* processor routes on the given FastAPI app. + + self_fn is a zero-argument callable that returns the current ProcessorManagement + replica instance. It is wired in via Depends so it is resolved lazily at request time. + """ + + @app.post('/twinkle/create', response_model=types.ProcessorCreateResponse) + async def create( + request: Request, body: types.ProcessorCreateRequest, + self: ProcessorManagement = Depends(self_fn)) -> types.ProcessorCreateResponse: + await self._ensure_sticky() + + processor_type_name = body.processor_type + class_type = body.class_type + _kwargs = body.model_extra or {} + + assert processor_type_name in _PROCESSOR_TYPES, f'Invalid processor type: {processor_type_name}' + processor_module = importlib.import_module(f'twinkle.{processor_type_name}') + assert hasattr(processor_module, class_type), f'Class {class_type} not found in {processor_type_name}' + + token = get_token_from_request(request) + session_id = get_session_id_from_request(request) + processor_id = str(uuid.uuid4().hex) + + # Register for lifecycle tracking (enforces per-user limit) + self.register_processor(processor_id, token, session_id) + + _kwargs.pop('remote_group', None) + _kwargs.pop('device_mesh', None) + + resolved_kwargs = {} + for key, value in _kwargs.items(): + if isinstance(value, str) and value.startswith('pid:'): + ref_id = value[4:] + resolved_kwargs[key] = self.resource_dict[ref_id] + else: + value = deserialize_object(value) + resolved_kwargs[key] = value + + # Run processor instantiation in a thread to avoid blocking the event loop, + # which would starve the session-liveness coroutines submitted by the + # countdown thread via asyncio.run_coroutine_threadsafe. + _remote_group = self.device_group.name + _device_mesh = self.device_mesh + + def _do_create(): + return getattr(processor_module, class_type)( + remote_group=_remote_group, device_mesh=_device_mesh, instance_id=processor_id, **resolved_kwargs) + + processor = await asyncio.get_event_loop().run_in_executor(None, _do_create) + self.resource_dict[processor_id] = processor + return types.ProcessorCreateResponse(processor_id='pid:' + processor_id) + + @app.post('/twinkle/call', response_model=types.ProcessorCallResponse) + async def call( + request: Request, body: types.ProcessorCallRequest, + self: ProcessorManagement = Depends(self_fn)) -> types.ProcessorCallResponse: + await self._ensure_sticky() + + processor_id = body.processor_id + function_name = body.function + _kwargs = body.model_extra or {} + processor_id = processor_id[4:] + self.assert_processor_exists(processor_id=processor_id) + processor = self.resource_dict.get(processor_id) + function = getattr(processor, function_name, None) + + assert function is not None, f'`{function_name}` not found in {processor.__class__}' + assert hasattr(function, '_execute'), f'Cannot call inner method of {processor.__class__}' + + resolved_kwargs = {} + for key, value in _kwargs.items(): + if isinstance(value, str) and value.startswith('pid:'): + ref_id = value[4:] + resolved_kwargs[key] = self.resource_dict[ref_id] + else: + value = deserialize_object(value) + resolved_kwargs[key] = value + + # Run the processor function in a thread to avoid blocking the event loop. + # StopIteration cannot propagate through asyncio coroutine boundaries + # (Python 3.7+ converts it to RuntimeError), so capture it as a sentinel tuple. + def _do_call(): + try: + result = function(**resolved_kwargs) + return False, result + except StopIteration: + return True, None + + is_exhausted, result = await asyncio.get_event_loop().run_in_executor(None, _do_call) + + if function_name == '__next__': + if is_exhausted: + # HTTP 410 Gone signals iterator exhausted + raise HTTPException(status_code=410, detail='Iterator exhausted') + return types.ProcessorCallResponse(result=result) + + if function_name == '__iter__': + return types.ProcessorCallResponse(result='ok') + return types.ProcessorCallResponse(result=result) diff --git a/src/twinkle/server/utils/__init__.py b/src/twinkle/server/utils/__init__.py index f9855709..d19d34d0 100644 --- a/src/twinkle/server/utils/__init__.py +++ b/src/twinkle/server/utils/__init__.py @@ -3,5 +3,6 @@ from .checkpoint_base import (TRAIN_RUN_INFO_FILENAME, TWINKLE_DEFAULT_SAVE_DIR, BaseCheckpointManager, BaseFileManager, BaseTrainingRunManager) from .device_utils import auto_fill_device_group_visible_devices, wrap_builder_with_device_group_env +from .processor_manager import ProcessorManagerMixin from .rate_limiter import RateLimiter from .task_queue import QueueState, TaskQueueConfig, TaskQueueMixin, TaskStatus diff --git a/src/twinkle/server/utils/adapter_manager.py b/src/twinkle/server/utils/adapter_manager.py index 9a461b32..ce37fe83 100644 --- a/src/twinkle/server/utils/adapter_manager.py +++ b/src/twinkle/server/utils/adapter_manager.py @@ -11,6 +11,8 @@ """ from __future__ import annotations +import asyncio +import concurrent.futures import threading import time from typing import TYPE_CHECKING, Any @@ -63,6 +65,9 @@ def _init_adapter_manager( self._adapter_countdown_thread: threading.Thread | None = None self._adapter_countdown_running = False + # Event loop reference used to bridge sync thread → async state calls + self._adapter_event_loop: asyncio.AbstractEventLoop | None = None + def register_adapter(self, adapter_name: str, token: str, session_id: str) -> None: """Register a new adapter for lifecycle tracking. @@ -89,7 +94,7 @@ def register_adapter(self, adapter_name: str, token: str, session_id: str) -> No logger.debug( f'[AdapterManager] Registered adapter {adapter_name} for token {token[:8]}... (session: {session_id})') - def _is_session_alive(self, session_id: str) -> bool: + async def _is_session_alive(self, session_id: str) -> bool: """Check if a session is still alive via state proxy. Args: @@ -102,7 +107,7 @@ def _is_session_alive(self, session_id: str) -> bool: return True # No session association means always alive # Get session last heartbeat through proxy - last_heartbeat = self.state.get_session_last_heartbeat(session_id) + last_heartbeat = await self.state.get_session_last_heartbeat(session_id) if last_heartbeat is None: return False # Session doesn't exist @@ -231,7 +236,19 @@ def _adapter_countdown_loop(self) -> None: continue session_id = info.get('session_id') - session_expired = not self._is_session_alive(session_id) + if self._adapter_event_loop is not None and self._adapter_event_loop.is_running(): + future: concurrent.futures.Future = asyncio.run_coroutine_threadsafe( + self._is_session_alive(session_id), self._adapter_event_loop) + try: + session_alive = future.result(timeout=5.0) + except Exception as e: + logger.warning(f'[AdapterManager] Failed to check session liveness for {adapter_name}: {e}') + continue + else: + logger.warning( + f'[AdapterManager] No event loop available to check session {session_id}, skipping') + continue + session_expired = not session_alive logger.debug(f'[AdapterManager] Adapter {adapter_name} session check ' f'(session_id={session_id}, session_alive={not session_expired})') @@ -272,6 +289,10 @@ def start_adapter_countdown(self) -> None: """ if not self._adapter_countdown_running: self._adapter_countdown_running = True + try: + self._adapter_event_loop = asyncio.get_running_loop() + except RuntimeError: + self._adapter_event_loop = None self._adapter_countdown_thread = threading.Thread(target=self._adapter_countdown_loop, daemon=True) self._adapter_countdown_thread.start() logger.debug('[AdapterManager] Countdown thread started') diff --git a/src/twinkle/server/utils/processor_manager.py b/src/twinkle/server/utils/processor_manager.py new file mode 100644 index 00000000..9e1cef42 --- /dev/null +++ b/src/twinkle/server/utils/processor_manager.py @@ -0,0 +1,209 @@ +# Copyright (c) ModelScope Contributors. All rights reserved. +""" +Processor Lifecycle Manager Mixin for Twinkle Server. + +Mirrors AdapterManagerMixin but adds a global per-token processor limit. +Sessions are tracked via session ID; processors expire when their session expires. +""" +from __future__ import annotations + +import asyncio +import concurrent.futures +import threading +import time +from typing import TYPE_CHECKING, Any + +if TYPE_CHECKING: + from twinkle.server.utils.state import ServerStateProxy + +from twinkle.utils.logger import get_logger + +logger = get_logger() + + +class ProcessorManagerMixin: + """Mixin for processor lifecycle management with session-based expiration. + + Mirrors AdapterManagerMixin with an additional per-token processor limit. + + Inheriting classes should: + 1. Call _init_processor_manager() in __init__ + 2. Override _on_processor_expired() to handle cleanup + + Attributes: + _processor_timeout: Session inactivity timeout in seconds. + _per_token_processor_limit: Maximum active processors per user token. + """ + + # Type hint for state attribute that inheriting classes must provide + state: ServerStateProxy + + def _init_processor_manager( + self, + processor_timeout: float = 1800.0, + per_token_processor_limit: int = 20, + ) -> None: + """Initialize the processor manager. + + Args: + processor_timeout: Timeout in seconds to determine if a session is alive. + Default is 1800.0 (30 minutes). + per_token_processor_limit: Maximum active processors per user token. + Default is 20. + """ + self._processor_timeout = processor_timeout + self._per_token_processor_limit = per_token_processor_limit + + # processor_id -> {'token': str, 'session_id': str, 'created_at': float, 'expiring': bool} + self._processor_records: dict[str, dict[str, Any]] = {} + + self._processor_countdown_thread: threading.Thread | None = None + self._processor_countdown_running = False + self._processor_event_loop: asyncio.AbstractEventLoop | None = None + + def register_processor(self, processor_id: str, token: str, session_id: str) -> None: + """Register a new processor for lifecycle tracking. + + Args: + processor_id: Unique identifier of the processor. + token: User token that owns this processor. + session_id: Session ID to associate with this processor. Must be non-empty. + + Raises: + ValueError: If session_id is None or empty. + RuntimeError: If the per-token processor limit has been reached. + """ + if not session_id: + raise ValueError(f'session_id must be provided when registering processor {processor_id}') + + current_count = sum(1 for info in self._processor_records.values() if info.get('token') == token) + if current_count >= self._per_token_processor_limit: + raise RuntimeError(f'Per-user processor limit ({self._per_token_processor_limit}) reached ' + f'for token {token[:8]}...') + + self._processor_records[processor_id] = { + 'token': token, + 'session_id': session_id, + 'created_at': time.time(), + 'expiring': False, + } + logger.debug(f'[ProcessorManager] Registered processor {processor_id} ' + f'for token {token[:8]}... (session: {session_id})') + + def unregister_processor(self, processor_id: str) -> bool: + """Unregister a processor from lifecycle tracking. + + Returns: + True if found and removed, False otherwise. + """ + if processor_id in self._processor_records: + info = self._processor_records.pop(processor_id) + token = info.get('token', '') + logger.debug(f'[ProcessorManager] Unregistered processor {processor_id} ' + f'for token {token[:8] if token else "unknown"}...') + return True + return False + + def get_processor_info(self, processor_id: str) -> dict[str, Any] | None: + """Get tracking info for a registered processor, or None if not found.""" + return self._processor_records.get(processor_id) + + def assert_processor_exists(self, processor_id: str) -> None: + """Assert a processor exists and is not expiring.""" + info = self._processor_records.get(processor_id) + assert processor_id and info is not None and not info.get('expiring'), \ + f'Processor {processor_id} not found' + + def _on_processor_expired(self, processor_id: str) -> None: + """Hook called when a processor's session expires. + + Must be overridden by inheriting classes. + + Raises: + NotImplementedError: If not overridden. + """ + raise NotImplementedError(f'_on_processor_expired must be implemented by {self.__class__.__name__}') + + async def _is_session_alive(self, session_id: str) -> bool: + """Check if a session is still alive via state proxy.""" + if not session_id: + return True + last_heartbeat = await self.state.get_session_last_heartbeat(session_id) + if last_heartbeat is None: + return False + return (time.time() - last_heartbeat) < self._processor_timeout + + def _processor_countdown_loop(self) -> None: + """Background thread: checks session liveness and expires stale processors.""" + logger.debug(f'[ProcessorManager] Countdown thread started (session_timeout={self._processor_timeout}s)') + while self._processor_countdown_running: + try: + time.sleep(1) + + expired: list[tuple[str, str | None]] = [] + for processor_id, info in list(self._processor_records.items()): + if info.get('expiring'): + continue + session_id = info.get('session_id') + if self._processor_event_loop is not None and self._processor_event_loop.is_running(): + future: concurrent.futures.Future = asyncio.run_coroutine_threadsafe( + self._is_session_alive(session_id), self._processor_event_loop) + try: + session_alive = future.result(timeout=5.0) + except Exception as e: + logger.warning(f'[ProcessorManager] Failed to check session liveness ' + f'for {processor_id}: {type(e).__name__}: {e}') + continue + else: + logger.warning(f'[ProcessorManager] No event loop available to check ' + f'session {session_id}, skipping') + continue + + logger.debug(f'[ProcessorManager] Processor {processor_id} session check ' + f'(session_id={session_id}, session_alive={session_alive})') + if not session_alive: + info['expiring'] = True + expired.append((processor_id, session_id)) + + for processor_id, session_id in expired: + success = False + try: + self._on_processor_expired(processor_id) + logger.info(f'[ProcessorManager] Processor {processor_id} expired ' + f'(reason=session_expired, session={session_id})') + success = True + except Exception as e: + logger.warning(f'[ProcessorManager] Error while expiring processor {processor_id}: {e}') + finally: + if success: + self._processor_records.pop(processor_id, None) + else: + info = self._processor_records.get(processor_id) + if info is not None: + info['expiring'] = False + + except Exception as e: + logger.warning(f'[ProcessorManager] Error in countdown loop: {e}') + continue + + logger.debug('[ProcessorManager] Countdown thread stopped') + + def start_processor_countdown(self) -> None: + """Start the background countdown thread. Safe to call multiple times.""" + if not self._processor_countdown_running: + self._processor_countdown_running = True + try: + self._processor_event_loop = asyncio.get_running_loop() + except RuntimeError: + self._processor_event_loop = None + self._processor_countdown_thread = threading.Thread(target=self._processor_countdown_loop, daemon=True) + self._processor_countdown_thread.start() + logger.debug('[ProcessorManager] Countdown thread started') + + def stop_processor_countdown(self) -> None: + """Stop the background countdown thread.""" + if self._processor_countdown_running: + self._processor_countdown_running = False + if self._processor_countdown_thread: + self._processor_countdown_thread.join(timeout=2.0) + logger.debug('[ProcessorManager] Countdown thread stopped') diff --git a/src/twinkle/server/utils/state/server_state.py b/src/twinkle/server/utils/state/server_state.py index 70dcfe9c..d04c877f 100644 --- a/src/twinkle/server/utils/state/server_state.py +++ b/src/twinkle/server/utils/state/server_state.py @@ -71,7 +71,7 @@ def create_session(self, payload: dict[str, Any]) -> str: self._session_mgr.add(session_id, record) return session_id - def touch_session(self, session_id: str) -> bool: + async def touch_session(self, session_id: str) -> bool: """Update session heartbeat timestamp. Returns: @@ -79,7 +79,7 @@ def touch_session(self, session_id: str) -> bool: """ return self._session_mgr.touch(session_id) - def get_session_last_heartbeat(self, session_id: str) -> float | None: + async def get_session_last_heartbeat(self, session_id: str) -> float | None: """Get the last heartbeat timestamp for a session. Returns: @@ -154,7 +154,7 @@ def unregister_replica(self, replica_id: str) -> None: """ self._model_mgr.unregister_replica(replica_id) - def get_available_replica_ids(self, candidate_ids: list[str]) -> list[str]: + async def get_available_replica_ids(self, candidate_ids: list[str]) -> list[str]: """Return candidate replica IDs that have not reached their max_loras limit. Args: @@ -195,12 +195,12 @@ def get_sampling_session(self, sampling_session_id: str) -> dict[str, Any] | Non # ----- Future Management ----- - def get_future(self, request_id: str) -> dict[str, Any] | None: + async def get_future(self, request_id: str) -> dict[str, Any] | None: """Retrieve a stored future result as a plain dict.""" record = self._future_mgr.get(request_id) return record.model_dump() if record is not None else None - def store_future_status( + async def store_future_status( self, request_id: str, status: str, @@ -350,11 +350,11 @@ def __init__(self, actor_handle) -> None: def create_session(self, payload: dict[str, Any]) -> str: return ray.get(self._actor.create_session.remote(payload)) - def touch_session(self, session_id: str) -> bool: - return ray.get(self._actor.touch_session.remote(session_id)) + async def touch_session(self, session_id: str) -> bool: + return await self._actor.touch_session.remote(session_id) - def get_session_last_heartbeat(self, session_id: str) -> float | None: - return ray.get(self._actor.get_session_last_heartbeat.remote(session_id)) + async def get_session_last_heartbeat(self, session_id: str) -> float | None: + return await self._actor.get_session_last_heartbeat.remote(session_id) # ----- Model Registration ----- @@ -379,8 +379,8 @@ def register_replica(self, replica_id: str, max_loras: int) -> None: def unregister_replica(self, replica_id: str) -> None: ray.get(self._actor.unregister_replica.remote(replica_id)) - def get_available_replica_ids(self, candidate_ids: list[str]) -> list[str]: - return ray.get(self._actor.get_available_replica_ids.remote(candidate_ids)) + async def get_available_replica_ids(self, candidate_ids: list[str]) -> list[str]: + return await self._actor.get_available_replica_ids.remote(candidate_ids) # ----- Sampling Session Management ----- @@ -392,10 +392,10 @@ def get_sampling_session(self, sampling_session_id: str) -> dict[str, Any] | Non # ----- Future Management ----- - def get_future(self, request_id: str) -> dict[str, Any] | None: - return ray.get(self._actor.get_future.remote(request_id)) + async def get_future(self, request_id: str) -> dict[str, Any] | None: + return await self._actor.get_future.remote(request_id) - def store_future_status( + async def store_future_status( self, request_id: str, status: str, @@ -406,9 +406,8 @@ def store_future_status( queue_state_reason: str | None = None, ) -> None: """Store task status with optional result (synchronous).""" - ray.get( - self._actor.store_future_status.remote(request_id, status, model_id, reason, result, queue_state, - queue_state_reason)) + await self._actor.store_future_status.remote(request_id, status, model_id, reason, result, queue_state, + queue_state_reason) # ----- Resource Cleanup ----- diff --git a/src/twinkle/server/utils/task_queue.py b/src/twinkle/server/utils/task_queue.py index 33025324..d0985c15 100644 --- a/src/twinkle/server/utils/task_queue.py +++ b/src/twinkle/server/utils/task_queue.py @@ -255,7 +255,7 @@ async def _queue_worker(self) -> None: 'error': f'Queue timeout exceeded: waited {now - task.created_at:.2f}s', 'category': 'Server' } - self.state.store_future_status( + await self.state.store_future_status( task.request_id, TaskStatus.FAILED.value, task.model_id, @@ -270,13 +270,13 @@ async def _queue_worker(self) -> None: # Execute executed_any = True - self.state.store_future_status( + await self.state.store_future_status( task.request_id, TaskStatus.RUNNING.value, task.model_id, queue_state=QueueState.ACTIVE.value) try: coro = task.coro_factory() result = await coro - self.state.store_future_status( + await self.state.store_future_status( task.request_id, TaskStatus.COMPLETED.value, task.model_id, @@ -284,7 +284,7 @@ async def _queue_worker(self) -> None: queue_state=QueueState.ACTIVE.value) except Exception: error_payload = {'error': traceback.format_exc(), 'category': 'Server'} - self.state.store_future_status( + await self.state.store_future_status( task.request_id, TaskStatus.FAILED.value, task.model_id, @@ -321,7 +321,7 @@ async def _fail_queue_tasks_async(self, queue_key: str, reason: str) -> None: for task in drained: error_payload = {'error': reason, 'category': 'Server'} - self.state.store_future_status( + await self.state.store_future_status( task.request_id, TaskStatus.FAILED.value, task.model_id, @@ -381,7 +381,7 @@ async def _perform_preflight_checks( if input_tokens > self._task_queue_config.max_input_tokens: error_msg = f'Input tokens ({input_tokens}) exceed maximum allowed ({self._task_queue_config.max_input_tokens})' # noqa: E501 error_payload = {'error': error_msg, 'category': 'User'} - self.state.store_future_status( + await self.state.store_future_status( request_id, TaskStatus.FAILED.value, model_id, @@ -396,7 +396,7 @@ async def _perform_preflight_checks( if batch_size < data_world_size: error_msg = f'Batch size {batch_size} must be greater than or equal to data world size {data_world_size}' # noqa: E501 error_payload = {'error': error_msg, 'category': 'User'} - self.state.store_future_status( + await self.state.store_future_status( request_id, TaskStatus.FAILED.value, model_id, @@ -411,7 +411,7 @@ async def _perform_preflight_checks( if not allowed: error_msg = f'Rate limit exceeded: {reason}' error_payload = {'error': error_msg, 'category': 'User'} - self.state.store_future_status( + await self.state.store_future_status( request_id, TaskStatus.FAILED.value, model_id, @@ -475,7 +475,7 @@ async def schedule_task( ) # 2. Register PENDING status FIRST - self.state.store_future_status( + await self.state.store_future_status( request_id, TaskStatus.PENDING.value, model_id, queue_state=QueueState.ACTIVE.value) # 3. Route to per-model/per-token queue @@ -500,7 +500,7 @@ async def schedule_task( task_type=task_type, created_at=time.monotonic(), )) - self.state.store_future_status( + await self.state.store_future_status( request_id, TaskStatus.QUEUED.value, model_id, queue_state=QueueState.ACTIVE.value) logger.debug(f'[TaskQueue] Task {request_id} queued, new queue size: {q.qsize()} key={queue_key}') @@ -585,7 +585,7 @@ async def schedule_task_and_wait( raise RuntimeError(f'Task scheduling failed: {future_ref}') while True: - record = self.state.get_future(request_id) + record = await self.state.get_future(request_id) if record and record.get('status') not in ('pending', 'queued', 'running'): break await asyncio.sleep(0.05) diff --git a/src/twinkle_client/types/model.py b/src/twinkle_client/types/model.py index a349e024..e620d866 100644 --- a/src/twinkle_client/types/model.py +++ b/src/twinkle_client/types/model.py @@ -74,7 +74,7 @@ class UploadToHubRequest(BaseModel): checkpoint_dir: str hub_model_id: str hub_token: Optional[str] = None - async_upload: bool = True + async_upload: bool = False class Config: extra = 'allow' From df8af42f903640d078a6d822fc6e483a66b61b23 Mon Sep 17 00:00:00 2001 From: Yunnglin Date: Thu, 12 Mar 2026 18:31:38 +0800 Subject: [PATCH 18/24] update --- .../server/transformer/server_config.yaml | 68 +++++++++---------- 1 file changed, 34 insertions(+), 34 deletions(-) diff --git a/cookbook/client/server/transformer/server_config.yaml b/cookbook/client/server/transformer/server_config.yaml index 91dff9b2..41f7ca88 100644 --- a/cookbook/client/server/transformer/server_config.yaml +++ b/cookbook/client/server/transformer/server_config.yaml @@ -68,39 +68,39 @@ applications: # 3. Sampler Service - Runs inference / sampling using vLLM engine # Used for generating text from the model (e.g., evaluating LoRA results). - # - name: sampler-Qwen3.5-4B - # route_prefix: /api/v1/sampler/Qwen/Qwen3.5-4B - # import_path: sampler - # args: - # model_id: "ms://Qwen/Qwen3.5-4B" # ModelScope model identifier - # nproc_per_node: 2 # Number of GPU processes per node - # sampler_type: vllm # Inference engine: 'vllm' (fast) or 'torch' (TorchSampler) - # engine_args: # vLLM engine-specific settings - # max_model_len: 4096 # Maximum sequence length the engine supports - # gpu_memory_utilization: 0.5 # Fraction of GPU memory to use (0.0-1.0) - # enable_lora: true # Allow loading LoRA adapters during inference - # logprobs_mode: processed_logprobs # Logprobs mode for sampling results - # device_group: # Logical device group for the sampler - # name: sampler - # ranks: 1 # Number of GPUs to use - # device_type: cuda - # device_mesh: - # device_type: cuda - # dp_size: 1 - # queue_config: - # rps_limit: 100 # Max requests per second - # tps_limit: 100000 # Max tokens per second - # deployments: - # - name: SamplerManagement - # autoscaling_config: - # min_replicas: 1 - # max_replicas: 1 - # target_ongoing_requests: 16 - # ray_actor_options: - # num_cpus: 0.1 - # runtime_env: - # env_vars: - # TWINKLE_TRUST_REMOTE_CODE: "0" + - name: sampler-Qwen3.5-4B + route_prefix: /api/v1/sampler/Qwen/Qwen3.5-4B + import_path: sampler + args: + model_id: "ms://Qwen/Qwen3.5-4B" # ModelScope model identifier + nproc_per_node: 2 # Number of GPU processes per node + sampler_type: vllm # Inference engine: 'vllm' (fast) or 'torch' (TorchSampler) + engine_args: # vLLM engine-specific settings + max_model_len: 4096 # Maximum sequence length the engine supports + gpu_memory_utilization: 0.5 # Fraction of GPU memory to use (0.0-1.0) + enable_lora: true # Allow loading LoRA adapters during inference + logprobs_mode: processed_logprobs # Logprobs mode for sampling results + device_group: # Logical device group for the sampler + name: sampler + ranks: 1 # Number of GPUs to use + device_type: cuda + device_mesh: + device_type: cuda + dp_size: 1 + queue_config: + rps_limit: 100 # Max requests per second + tps_limit: 100000 # Max tokens per second + deployments: + - name: SamplerManagement + autoscaling_config: + min_replicas: 1 + max_replicas: 1 + target_ongoing_requests: 16 + ray_actor_options: + num_cpus: 0.1 + runtime_env: + env_vars: + TWINKLE_TRUST_REMOTE_CODE: "0" # 4. Processor Service - Runs inference / sampling using vLLM engine - name: processor @@ -114,7 +114,7 @@ applications: device_type: CPU device_mesh: device_type: CPU - dp_size: 2 # 数据并行大小 + dp_size: # 数据并行大小 deployments: - name: ProcessorManagement autoscaling_config: From 6ba417c6daf957e3a008a0a476631cbd472848e3 Mon Sep 17 00:00:00 2001 From: Yunnglin Date: Thu, 12 Mar 2026 18:56:48 +0800 Subject: [PATCH 19/24] update --- cookbook/client/server/transformer/server_config.yaml | 6 +++--- src/twinkle/infra/_ray/ray_helper.py | 2 +- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/cookbook/client/server/transformer/server_config.yaml b/cookbook/client/server/transformer/server_config.yaml index 41f7ca88..570142af 100644 --- a/cookbook/client/server/transformer/server_config.yaml +++ b/cookbook/client/server/transformer/server_config.yaml @@ -102,19 +102,19 @@ applications: env_vars: TWINKLE_TRUST_REMOTE_CODE: "0" - # 4. Processor Service - Runs inference / sampling using vLLM engine + # 4. Processor Service - name: processor route_prefix: /api/v1/processor import_path: processor args: - ncpu_proc_per_node: 2 # 每节点 CPU 进程数 + ncpu_proc_per_node: 2 device_group: name: model ranks: 2 device_type: CPU device_mesh: device_type: CPU - dp_size: # 数据并行大小 + dp_size: 2 deployments: - name: ProcessorManagement autoscaling_config: diff --git a/src/twinkle/infra/_ray/ray_helper.py b/src/twinkle/infra/_ray/ray_helper.py index f0a4011b..0d8908a3 100644 --- a/src/twinkle/infra/_ray/ray_helper.py +++ b/src/twinkle/infra/_ray/ray_helper.py @@ -157,7 +157,7 @@ def get_master_id_port(placement_group): def get_node_address(): return find_node_ip(), find_free_port() - ip, port = ray.get(get_node_address.options(placement_group=placement_group).remote()) + ip, port = ray.get(get_node_address.options(placement_group=placement_group, num_cpus=0.01).remote()) return ip, port @staticmethod From 66df30b0499cfef220f794dd4e37c8b092c29417 Mon Sep 17 00:00:00 2001 From: Yunnglin Date: Thu, 12 Mar 2026 20:41:04 +0800 Subject: [PATCH 20/24] fix processor --- .../client/server/megatron/server_config.yaml | 22 +++ .../server/megatron/server_config_4b.yaml | 72 ++++++---- .../server/transformer/server_config.yaml | 66 ++++----- cookbook/client/tinker/self_host/sample.py | 2 +- .../client/tinker/self_host/self_cognition.py | 9 +- cookbook/client/twinkle/self_host/grpo.py | 128 ++++++------------ cookbook/client/twinkle/self_host/sample.py | 11 +- .../twinkle/self_host/self_congnition.py | 6 +- cookbook/transformers/fsdp2.py | 6 +- cookbook/transformers/sp_fsdp_dense.py | 2 +- src/twinkle/preprocessor/__init__.py | 2 +- src/twinkle/server/model/twinkle_handlers.py | 2 +- .../server/sampler/twinkle_handlers.py | 41 ++++-- src/twinkle/server/utils/adapter_manager.py | 58 ++++---- src/twinkle/server/utils/processor_manager.py | 28 +--- .../server/utils/state/server_state.py | 6 +- src/twinkle_client/types/__init__.py | 1 + src/twinkle_client/types/model.py | 2 +- src/twinkle_client/types/sampler.py | 22 ++- 19 files changed, 254 insertions(+), 232 deletions(-) diff --git a/cookbook/client/server/megatron/server_config.yaml b/cookbook/client/server/megatron/server_config.yaml index 7b2c9768..2687c910 100644 --- a/cookbook/client/server/megatron/server_config.yaml +++ b/cookbook/client/server/megatron/server_config.yaml @@ -107,3 +107,25 @@ applications: runtime_env: env_vars: TWINKLE_TRUST_REMOTE_CODE: "0" + + # 4. Processor Service + - name: processor + route_prefix: /api/v1/processor + import_path: processor + args: + ncpu_proc_per_node: 2 + device_group: + name: model + ranks: 2 + device_type: CPU + device_mesh: + device_type: CPU + dp_size: 2 + deployments: + - name: ProcessorManagement + autoscaling_config: + min_replicas: 1 + max_replicas: 1 + target_ongoing_requests: 128 + ray_actor_options: + num_cpus: 0.1 diff --git a/cookbook/client/server/megatron/server_config_4b.yaml b/cookbook/client/server/megatron/server_config_4b.yaml index 12dcc68f..dae992a3 100644 --- a/cookbook/client/server/megatron/server_config_4b.yaml +++ b/cookbook/client/server/megatron/server_config_4b.yaml @@ -70,36 +70,58 @@ applications: # 3. Sampler Service - Runs inference / sampling using vLLM engine # Used for generating text from the model (e.g., evaluating LoRA results). - - name: sampler-Qwen3.5-4B - route_prefix: /api/v1/sampler/Qwen/Qwen3.5-4B - import_path: sampler + # - name: sampler-Qwen3.5-4B + # route_prefix: /api/v1/sampler/Qwen/Qwen3.5-4B + # import_path: sampler + # args: + # model_id: "ms://Qwen/Qwen3.5-4B" # ModelScope model identifier + # nproc_per_node: 2 # Number of GPU processes per node + # sampler_type: vllm # Inference engine: 'vllm' (fast) or 'torch' (TorchSampler) + # engine_args: # vLLM engine-specific settings + # max_model_len: 4096 # Maximum sequence length the engine supports + # gpu_memory_utilization: 0.5 # Fraction of GPU memory to use (0.0-1.0) + # enable_lora: true # Allow loading LoRA adapters during inference + # logprobs_mode: processed_logprobs # Logprobs mode for sampling results + # device_group: # Logical device group for the sampler + # name: sampler + # ranks: 1 # Number of GPUs to use + # device_type: cuda + # device_mesh: + # device_type: cuda + # dp_size: 1 + # queue_config: + # rps_limit: 100 # Max requests per second + # tps_limit: 100000 # Max tokens per second + # deployments: + # - name: SamplerManagement + # autoscaling_config: + # min_replicas: 1 + # max_replicas: 1 + # target_ongoing_requests: 16 + # ray_actor_options: + # num_cpus: 0.1 + # runtime_env: + # env_vars: + # TWINKLE_TRUST_REMOTE_CODE: "0" + + # 4. Processor Service + - name: processor + route_prefix: /api/v1/processor + import_path: processor args: - model_id: "ms://Qwen/Qwen3.5-4B" # ModelScope model identifier - nproc_per_node: 2 # Number of GPU processes per node - sampler_type: vllm # Inference engine: 'vllm' (fast) or 'torch' (TorchSampler) - engine_args: # vLLM engine-specific settings - max_model_len: 4096 # Maximum sequence length the engine supports - gpu_memory_utilization: 0.5 # Fraction of GPU memory to use (0.0-1.0) - enable_lora: true # Allow loading LoRA adapters during inference - logprobs_mode: processed_logprobs # Logprobs mode for sampling results - device_group: # Logical device group for the sampler - name: sampler - ranks: 1 # Number of GPUs to use - device_type: cuda + ncpu_proc_per_node: 2 + device_group: + name: model + ranks: 2 + device_type: CPU device_mesh: - device_type: cuda - dp_size: 1 - queue_config: - rps_limit: 100 # Max requests per second - tps_limit: 100000 # Max tokens per second + device_type: CPU + dp_size: 2 deployments: - - name: SamplerManagement + - name: ProcessorManagement autoscaling_config: min_replicas: 1 max_replicas: 1 - target_ongoing_requests: 16 + target_ongoing_requests: 128 ray_actor_options: num_cpus: 0.1 - runtime_env: - env_vars: - TWINKLE_TRUST_REMOTE_CODE: "0" diff --git a/cookbook/client/server/transformer/server_config.yaml b/cookbook/client/server/transformer/server_config.yaml index 570142af..77657dcb 100644 --- a/cookbook/client/server/transformer/server_config.yaml +++ b/cookbook/client/server/transformer/server_config.yaml @@ -68,39 +68,39 @@ applications: # 3. Sampler Service - Runs inference / sampling using vLLM engine # Used for generating text from the model (e.g., evaluating LoRA results). - - name: sampler-Qwen3.5-4B - route_prefix: /api/v1/sampler/Qwen/Qwen3.5-4B - import_path: sampler - args: - model_id: "ms://Qwen/Qwen3.5-4B" # ModelScope model identifier - nproc_per_node: 2 # Number of GPU processes per node - sampler_type: vllm # Inference engine: 'vllm' (fast) or 'torch' (TorchSampler) - engine_args: # vLLM engine-specific settings - max_model_len: 4096 # Maximum sequence length the engine supports - gpu_memory_utilization: 0.5 # Fraction of GPU memory to use (0.0-1.0) - enable_lora: true # Allow loading LoRA adapters during inference - logprobs_mode: processed_logprobs # Logprobs mode for sampling results - device_group: # Logical device group for the sampler - name: sampler - ranks: 1 # Number of GPUs to use - device_type: cuda - device_mesh: - device_type: cuda - dp_size: 1 - queue_config: - rps_limit: 100 # Max requests per second - tps_limit: 100000 # Max tokens per second - deployments: - - name: SamplerManagement - autoscaling_config: - min_replicas: 1 - max_replicas: 1 - target_ongoing_requests: 16 - ray_actor_options: - num_cpus: 0.1 - runtime_env: - env_vars: - TWINKLE_TRUST_REMOTE_CODE: "0" + # - name: sampler-Qwen3.5-4B + # route_prefix: /api/v1/sampler/Qwen/Qwen3.5-4B + # import_path: sampler + # args: + # model_id: "ms://Qwen/Qwen3.5-4B" # ModelScope model identifier + # nproc_per_node: 2 # Number of GPU processes per node + # sampler_type: vllm # Inference engine: 'vllm' (fast) or 'torch' (TorchSampler) + # engine_args: # vLLM engine-specific settings + # max_model_len: 4096 # Maximum sequence length the engine supports + # gpu_memory_utilization: 0.5 # Fraction of GPU memory to use (0.0-1.0) + # enable_lora: true # Allow loading LoRA adapters during inference + # logprobs_mode: processed_logprobs # Logprobs mode for sampling results + # device_group: # Logical device group for the sampler + # name: sampler + # ranks: 1 # Number of GPUs to use + # device_type: cuda + # device_mesh: + # device_type: cuda + # dp_size: 1 + # queue_config: + # rps_limit: 100 # Max requests per second + # tps_limit: 100000 # Max tokens per second + # deployments: + # - name: SamplerManagement + # autoscaling_config: + # min_replicas: 1 + # max_replicas: 1 + # target_ongoing_requests: 16 + # ray_actor_options: + # num_cpus: 0.1 + # runtime_env: + # env_vars: + # TWINKLE_TRUST_REMOTE_CODE: "0" # 4. Processor Service - name: processor diff --git a/cookbook/client/tinker/self_host/sample.py b/cookbook/client/tinker/self_host/sample.py index 132eb63a..278f24bf 100644 --- a/cookbook/client/tinker/self_host/sample.py +++ b/cookbook/client/tinker/self_host/sample.py @@ -17,7 +17,7 @@ from tinker import ServiceClient # Step 2: Define the base model and connect to the server -base_model = 'Qwen/Qwen3.5-4B' +base_model = 'Qwen/Qwen3-4B' service_client = ServiceClient( base_url='http://localhost:8000', api_key='EMPTY-TOKEN' diff --git a/cookbook/client/tinker/self_host/self_cognition.py b/cookbook/client/tinker/self_host/self_cognition.py index 81125e53..1097afd8 100644 --- a/cookbook/client/tinker/self_host/self_cognition.py +++ b/cookbook/client/tinker/self_host/self_cognition.py @@ -92,9 +92,9 @@ def eval(): # Step 1: Load the trained LoRA checkpoint for inference # Path to a previously saved LoRA checkpoint (twinkle:// URI) - weight_path = 'twinkle://20260212_174205-Qwen_Qwen2_5-7B-Instruct-51edc9ed/weights/twinkle-lora-2' + weight_path = 'twinkle://20260301_142318-Qwen_Qwen3-4B-199d2cdb/weights/twinkle-lora-0' - service_client = ServiceClient(base_url=base_url, api_key=os.environ.get('MODELSCOPE_TOKEN')) + service_client = ServiceClient(base_url=base_url, api_key=api_key) sampling_client = service_client.create_sampling_client(model_path=weight_path, base_model=base_model) # Step 2: Prepare the chat prompt @@ -119,7 +119,6 @@ def eval(): params = types.SamplingParams( max_tokens=50, # Maximum tokens to generate temperature=0.2, # Low temperature for more focused responses - stop=['\n'] # Stop at newline ) # Sample 8 independent completions @@ -134,5 +133,5 @@ def eval(): if __name__ == '__main__': - train() # Uncomment to run training - # eval() # Run evaluation / inference + # train() # Uncomment to run training + eval() # Run evaluation / inference diff --git a/cookbook/client/twinkle/self_host/grpo.py b/cookbook/client/twinkle/self_host/grpo.py index 1f7c0553..883b2323 100644 --- a/cookbook/client/twinkle/self_host/grpo.py +++ b/cookbook/client/twinkle/self_host/grpo.py @@ -22,16 +22,14 @@ import dotenv dotenv.load_dotenv('.env') -import re -from twinkle.data_format import Trajectory -from twinkle.reward.base import Reward import gc import os from peft import LoraConfig -from typing import List, Tuple +from typing import List, Tuple, Dict, Any from twinkle import get_logger +from twinkle.reward import GSM8KAccuracyReward, GSM8KFormatReward from twinkle.advantage import GRPOAdvantage from twinkle.dataset import DatasetMeta from twinkle.metric import CompletionRewardMetric @@ -40,6 +38,7 @@ from twinkle_client.dataset import Dataset from twinkle_client.model import MultiLoraTransformersModel from twinkle_client.sampler import vLLMSampler +from twinkle.preprocessor.llm import GSM8KProcessor logger = get_logger() @@ -55,62 +54,22 @@ GRADIENT_ACCUMULATION_STEPS = 4 -def create_countdown_dataset(): - """Create Countdown Game dataset for GRPO training.""" - - dataset = Dataset(dataset_meta=DatasetMeta('ms://zouxuhong/Countdown-Tasks-3to4', data_slice=range(500))) - dataset.set_template('Template', model_id=MODEL_ID, max_length=8192) - dataset.map('CountdownProcessor') - dataset.encode(add_generation_prompt=True, batched=True) +def create_gsm8k_dataset(): + dataset = Dataset(DatasetMeta('ms://modelscope/gsm8k', subset_name='main', split='train')) + dataset.set_template('Template', model_id=MODEL_ID, max_length=2048) + dataset.map('GSM8KProcessor') + dataset.encode(add_generation_prompt=True) return dataset +def compute_rewards( + trajectories: List[Dict[str, Any]], +) -> Tuple[List[float], List[float], List[float]]: + accuracy_reward_fn = GSM8KAccuracyReward() + format_reward_fn = GSM8KFormatReward() -class CountDownAccuracy(Reward): - - @staticmethod - def countdown_accuracy_reward(completion: str, target: int, nums: List[int]) -> float: - """Accuracy reward: checks if equation is correct.""" - try: - match = re.search(r'(.*?)<\/answer>', completion) - if match is None: - return 0.0 - equation = match.group(1).strip() - if '=' in equation: - equation = equation.split('=')[0] - used_numbers = [int(n) for n in re.findall(r'\d+', equation)] - if sorted(used_numbers) != sorted(nums): - return 0.0 - if not re.match(r'^[\d+\-*/().\s]+$', equation): - return 0.0 - result = eval(equation, {'__builtins__': None}, {}) - return 1.0 if abs(float(result) - float(target)) < 1e-5 else 0.0 - except Exception: # noqa - return 0.0 - - def __call__(self, trajectories: List[Trajectory], ground_truths: List[Trajectory]): - rewards = [] - for trajectory in trajectories: - messages = trajectory.get('messages', []) - completion = '' - for msg in reversed(messages): - if msg.get('role') == 'assistant': - completion = msg.get('content', '') - break - user_data = trajectory.get('user_data', [{}]) - data = user_data[0] if isinstance(user_data, list) and user_data else {} - target = data.get('target', 0) - nums = data.get('nums', []) - acc_reward = self.countdown_accuracy_reward(completion, target, nums) - rewards.append(acc_reward) - return rewards - - -def compute_rewards(trajectories: List[dict], ) -> Tuple[List[float], List[float], List[float]]: - """Compute format and accuracy rewards for Countdown game.""" - from twinkle.reward import FormatReward - format_rewards = FormatReward()(trajectories, []) - accuracy_rewards = CountDownAccuracy()(trajectories, []) - total_rewards = [a + b for a, b in zip(accuracy_rewards, format_rewards)] + accuracy_rewards = accuracy_reward_fn(trajectories) + format_rewards = format_reward_fn(trajectories) + total_rewards = [a + f for a, f in zip(accuracy_rewards, format_rewards)] return total_rewards, format_rewards, accuracy_rewards @@ -122,7 +81,7 @@ def train(): ) # Step 2: Prepare dataset and dataloader - dataset = create_countdown_dataset() + dataset = create_gsm8k_dataset() dataloader = DataLoader(dataset=dataset, batch_size=BATCH_SIZE) # Step 3: Configure the training model @@ -185,11 +144,11 @@ def train(): # the resulting path to the sampler as adapter_uri if step % SYNC_INTERVAL == 0: logger.info(f'Step {step}: Saving weights for sampler...') - twinkle_path = model.save( + result = model.save( name=f'grpo-sampler-step-{step}', save_optimizer=False, ) - current_adapter_uri = twinkle_path + current_adapter_uri = result.twinkle_path logger.info(f'Step {step}: Saved weights to {current_adapter_uri}') # ========== 2. Sample completions ========== @@ -200,32 +159,29 @@ def train(): num_samples=NUM_GENERATIONS, ) - input_features = [] - old_logps_list = [] - completion_lengths = [] + all_input_data: List[Dict[str, Any]] = [] + all_old_logps: List[List[float]] = [] + all_completion_lengths: List[int] = [] - sequences = sample_response.get('sequences', []) - for seq in sequences: - input_features.append(seq.get('new_input_feature', seq)) - old_logps_list.append(seq.get('logprobs', [])) - completion_lengths.append(len(seq.get('tokens', []))) - - if not input_features: - logger.warning(f'Step {step}: No valid samples, skipping') - step += 1 - continue + for sequence in sample_response.sequences: + all_input_data.append(sequence.new_input_feature) + all_old_logps.append(sequence.logprobs) + all_completion_lengths.append(len(sequence.tokens)) # ========== 3. Compute rewards ========== - total_rewards, format_rewards, accuracy_rewards = compute_rewards(input_features) + + total_rewards, format_rewards, accuracy_rewards = compute_rewards( + all_input_data + ) metrics.accumulate( - None, - None, - completion_lengths=completion_lengths, + completion_lengths=all_completion_lengths, rewards={ 'total': total_rewards, 'format': format_rewards, 'accuracy': accuracy_rewards, - }) + }, + ) + # ========== 4. Compute advantages ========== advantages = advantage_fn( @@ -244,29 +200,27 @@ def train(): # forward_backward with GRPO loss: passes advantages and old_logps # to the server-side GRPOLoss for proper policy optimization model.forward_backward( - inputs=input_features, + inputs=all_input_data, advantages=advantages, - old_logps=old_logps_list, + old_logps=all_old_logps, ) # Gradient clipping and optimizer step - model.clip_grad_norm(1.0) - model.step() - model.zero_grad() - model.lr_step() + model.clip_grad_and_step() gc.collect() # ========== 6. Log ========== log_dict = metrics.calculate() - log_dict.update(model.calculate_metric()) + log_dict.update(model.calculate_metric(is_training=True).result) log_dict['train/frac_reward_zero_std'] = frac_zero_std logger.info(f'Step {step}: {log_dict}') step += 1 + metrics.reset() # Save final checkpoint - twinkle_path = model.save(name='grpo-countdown-final', save_optimizer=True) - logger.info(f'Saved final checkpoint: {twinkle_path}') + result = model.save(name='grpo-countdown-final', save_optimizer=True) + logger.info(f'Saved final checkpoint: {result}') if __name__ == '__main__': diff --git a/cookbook/client/twinkle/self_host/sample.py b/cookbook/client/twinkle/self_host/sample.py index 9437bb36..d800b635 100644 --- a/cookbook/client/twinkle/self_host/sample.py +++ b/cookbook/client/twinkle/self_host/sample.py @@ -29,14 +29,13 @@ # or None to use the base model # ADAPTER_URI = None # Example: -ADAPTER_URI = 'twinkle://20260208_224851-fa3cdd11-default/weights/twinkle-epoch-2' - +ADAPTER_URI = 'twinkle://20260301_142318-Qwen_Qwen3-4B-199d2cdb/weights/twinkle-lora-0' def sample(): # Step 2: Initialize the Twinkle client to communicate with the remote server. client = init_twinkle_client( base_url='http://127.0.0.1:8000', - api_key=os.environ.get('MODELSCOPE_TOKEN'), + api_key='EMPTY_API_KEY', ) # Step 3: Create the sampler client pointing to the model on the server @@ -84,11 +83,11 @@ def sample(): # Step 8: Decode and print the results tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, trust_remote_code=True) - logger.info(f"Generated {len(response['sequences'])} sequences " + logger.info(f'Generated {len(response.sequences)} sequences ' f'({num_prompts} prompts x {num_samples} samples)') - for i, seq in enumerate(response['sequences']): - text = tokenizer.decode(seq['tokens'], skip_special_tokens=True) + for i, seq in enumerate(response.sequences): + text = tokenizer.decode(seq.tokens, skip_special_tokens=True) logger.info(f'Sequence {i}:\n {text}\n') diff --git a/cookbook/client/twinkle/self_host/self_congnition.py b/cookbook/client/twinkle/self_host/self_congnition.py index b70000ed..0b853944 100644 --- a/cookbook/client/twinkle/self_host/self_congnition.py +++ b/cookbook/client/twinkle/self_host/self_congnition.py @@ -50,7 +50,7 @@ def train(): dataset = Dataset(dataset_meta=DatasetMeta('ms://swift/self-cognition', data_slice=range(500))) # Apply a chat template so the data matches the model's expected input format - dataset.set_template('Template', model_id='ms://Qwen/Qwen3.5-4B', max_length=512) + dataset.set_template('Template', model_id='ms://Qwen/Qwen3-4B', max_length=512) # Replace placeholder names in the dataset with custom model/author names dataset.map('SelfCognitionProcessor', init_args={'model_name': 'twinkle模型', 'model_author': 'ModelScope社区'}) @@ -64,7 +64,7 @@ def train(): # Step 5: Configure the model # Create a multi-LoRA Transformers model pointing to the base model on ModelScope - model = MultiLoraTransformersModel(model_id='ms://Qwen/Qwen3.5-4B') + model = MultiLoraTransformersModel(model_id='ms://Qwen/Qwen3-4B') # Define LoRA configuration: apply low-rank adapters to all linear layers lora_config = LoraConfig(target_modules='all-linear') @@ -119,7 +119,7 @@ def train(): if step % 2 == 0: # Print metric metric = model.calculate_metric(is_training=True) - logger.info(f'Current is step {step} of {len(dataloader)}, metric: {metric.model_dump()}') + logger.info(f'Current is step {step} of {len(dataloader)}, metric: {metric.result}') # Step 8: Save the trained checkpoint twinkle_path = model.save(name=f'twinkle-epoch-{epoch}', save_optimizer=True) diff --git a/cookbook/transformers/fsdp2.py b/cookbook/transformers/fsdp2.py index ca37d724..a9c60c82 100644 --- a/cookbook/transformers/fsdp2.py +++ b/cookbook/transformers/fsdp2.py @@ -20,7 +20,7 @@ def eval(model): # 100 Samples dataset = Dataset(dataset_meta=DatasetMeta('ms://swift/self-cognition', data_slice=range(100))) - dataset.set_template('Template', model_id='ms://Qwen/Qwen3-4B') + dataset.set_template('Template', model_id='ms://Qwen/Qwen3.5-4B') dataset.map(SelfCognitionProcessor('twinkle大模型', 'ModelScope社区')) dataset.encode() dataloader = DataLoader(dataset=dataset, batch_size=8) @@ -35,7 +35,7 @@ def train(): # 1000 samples dataset = Dataset(dataset_meta=DatasetMeta('ms://swift/self-cognition', data_slice=range(1000))) # Set template to prepare encoding - dataset.set_template('Template', model_id='ms://Qwen/Qwen3-4B') + dataset.set_template('Template', model_id='ms://Qwen/Qwen3.5-4B') # Preprocess the dataset to standard format dataset.map(SelfCognitionProcessor('twinkle大模型', 'ModelScope社区')) # Encode dataset @@ -43,7 +43,7 @@ def train(): # Global batch size = 8, for GPUs, so 1 sample per GPU dataloader = DataLoader(dataset=dataset, batch_size=8) # Use a TransformersModel - model = TransformersModel(model_id='ms://Qwen/Qwen3-4B') + model = TransformersModel(model_id='ms://Qwen/Qwen3.5-4B') lora_config = LoraConfig(r=8, lora_alpha=32, target_modules='all-linear') diff --git a/cookbook/transformers/sp_fsdp_dense.py b/cookbook/transformers/sp_fsdp_dense.py index da6e2d28..868b61c0 100644 --- a/cookbook/transformers/sp_fsdp_dense.py +++ b/cookbook/transformers/sp_fsdp_dense.py @@ -10,7 +10,7 @@ from twinkle.preprocessor import SelfCognitionProcessor logger = get_logger() -MODEL_ID = 'ms://Qwen/Qwen3-4B' +MODEL_ID = 'ms://Qwen/Qwen3.5-4B' DATASETS = 'ms://swift/self-cognition' device_group = [DeviceGroup( diff --git a/src/twinkle/preprocessor/__init__.py b/src/twinkle/preprocessor/__init__.py index 1c19815e..13b52d99 100644 --- a/src/twinkle/preprocessor/__init__.py +++ b/src/twinkle/preprocessor/__init__.py @@ -1,4 +1,4 @@ # Copyright (c) ModelScope Contributors. All rights reserved. from .base import DataFilter, Preprocessor from .llm import (AlpacaProcessor, CompetitionMathGRPOProcessor, CompetitionMathProcessor, CountdownProcessor, - SelfCognitionProcessor) + GSM8KProcessor, SelfCognitionProcessor) diff --git a/src/twinkle/server/model/twinkle_handlers.py b/src/twinkle/server/model/twinkle_handlers.py index 7265f2b8..35c87441 100644 --- a/src/twinkle/server/model/twinkle_handlers.py +++ b/src/twinkle/server/model/twinkle_handlers.py @@ -291,7 +291,7 @@ async def _task(): save_optimizer=body.save_optimizer, **extra_kwargs) twinkle_path = checkpoint_manager.save(model_id=adapter_name, name=checkpoint_name, is_sampler=False) - return {'result': twinkle_path, 'checkpoint_dir': checkpoint_dir} + return {'twinkle_path': twinkle_path, 'checkpoint_dir': checkpoint_dir} return await run_task(self.schedule_task_and_wait(_task, task_type='save')) diff --git a/src/twinkle/server/sampler/twinkle_handlers.py b/src/twinkle/server/sampler/twinkle_handlers.py index 2e477736..a31f4046 100644 --- a/src/twinkle/server/sampler/twinkle_handlers.py +++ b/src/twinkle/server/sampler/twinkle_handlers.py @@ -7,12 +7,14 @@ from __future__ import annotations import traceback -from fastapi import Depends, FastAPI, Request +from fastapi import Depends, FastAPI, HTTPException, Request from typing import TYPE_CHECKING, Callable, Optional if TYPE_CHECKING: from .app import SamplerManagement +import numpy as np + import twinkle_client.types as types from twinkle.data_format import InputFeature, SamplingParams, Trajectory from twinkle.utils.logger import get_logger @@ -20,6 +22,24 @@ logger = get_logger() +def _serialize_input_feature(feature: dict) -> dict: + """Convert numpy arrays / torch tensors in an InputFeature to plain Python lists.""" + result = {} + for k, v in feature.items(): + if isinstance(v, np.ndarray): + result[k] = v.tolist() + else: + try: + import torch + if isinstance(v, torch.Tensor): + result[k] = v.tolist() + continue + except ImportError: + pass + result[k] = v + return result + + def _get_twinkle_sampler_adapter_name(request: Request, adapter_name: str | None) -> str | None: """Prefix the adapter name with the request ID for per-request isolation.""" if adapter_name is None or adapter_name == '': @@ -89,13 +109,16 @@ def sample(request: Request, body: types.SampleRequest, if callable(response): response = response() - sequences = [] - for seq in response.sequences: - sequences.append({ - 'stop_reason': seq.stop_reason, - 'tokens': list(seq.tokens), - 'logprobs': list(seq.logprobs) if seq.logprobs is not None else None, - }) + sequences = [ + types.SampledSequenceModel( + stop_reason=seq.stop_reason, + tokens=list(seq.tokens), + logprobs=list(seq.logprobs) if seq.logprobs is not None else None, + decoded=seq.decoded, + new_input_feature=_serialize_input_feature(seq.new_input_feature) + if seq.new_input_feature is not None else None, + ) for seq in response.sequences + ] return types.SampleResponseModel( sequences=sequences, @@ -104,7 +127,7 @@ def sample(request: Request, body: types.SampleRequest, ) except Exception: logger.error(traceback.format_exc()) - raise + raise HTTPException(status_code=500, detail=traceback.format_exc()) @app.post('/twinkle/set_template', response_model=types.SetTemplateResponse) def set_template( diff --git a/src/twinkle/server/utils/adapter_manager.py b/src/twinkle/server/utils/adapter_manager.py index ce37fe83..844ccfd1 100644 --- a/src/twinkle/server/utils/adapter_manager.py +++ b/src/twinkle/server/utils/adapter_manager.py @@ -11,8 +11,6 @@ """ from __future__ import annotations -import asyncio -import concurrent.futures import threading import time from typing import TYPE_CHECKING, Any @@ -37,6 +35,7 @@ class AdapterManagerMixin: Attributes: _adapter_timeout: Session inactivity timeout in seconds used to determine if a session is alive. + _adapter_max_lifetime: Maximum lifetime in seconds for any adapter, regardless of session liveness. """ # Type hint for state attribute that inheriting classes must provide @@ -45,6 +44,7 @@ class AdapterManagerMixin: def _init_adapter_manager( self, adapter_timeout: float = 1800.0, + adapter_max_lifetime: float = 36000.0, ) -> None: """Initialize the adapter manager. @@ -53,8 +53,11 @@ def _init_adapter_manager( Args: adapter_timeout: Timeout in seconds used to check whether a session is still alive. Default is 1800.0 (30 minutes). + adapter_max_lifetime: Maximum lifetime in seconds for an adapter regardless of session + liveness. Adapters older than this are treated as expired. Default is 36000.0 (10 hours). """ self._adapter_timeout = adapter_timeout + self._adapter_max_lifetime = adapter_max_lifetime # Adapter lifecycle tracking # Dict mapping adapter_name -> @@ -65,9 +68,6 @@ def _init_adapter_manager( self._adapter_countdown_thread: threading.Thread | None = None self._adapter_countdown_running = False - # Event loop reference used to bridge sync thread → async state calls - self._adapter_event_loop: asyncio.AbstractEventLoop | None = None - def register_adapter(self, adapter_name: str, token: str, session_id: str) -> None: """Register a new adapter for lifecycle tracking. @@ -94,7 +94,7 @@ def register_adapter(self, adapter_name: str, token: str, session_id: str) -> No logger.debug( f'[AdapterManager] Registered adapter {adapter_name} for token {token[:8]}... (session: {session_id})') - async def _is_session_alive(self, session_id: str) -> bool: + def _is_session_alive(self, session_id: str) -> bool: """Check if a session is still alive via state proxy. Args: @@ -107,7 +107,7 @@ async def _is_session_alive(self, session_id: str) -> bool: return True # No session association means always alive # Get session last heartbeat through proxy - last_heartbeat = await self.state.get_session_last_heartbeat(session_id) + last_heartbeat = self.state.get_session_last_heartbeat(session_id) if last_heartbeat is None: return False # Session doesn't exist @@ -216,17 +216,18 @@ def assert_adapter_exists(self, adapter_name: str) -> None: f'Adapter {adapter_name} not found' def _adapter_countdown_loop(self) -> None: - """Background thread that monitors and handles adapters whose session has expired. + """Background thread that monitors and handles adapters whose session has expired or exceeded max lifetime. This thread runs continuously and: - 1. Checks session liveness for all registered adapters every second - 2. Calls _on_adapter_expired() for adapters whose session has expired - 3. Removes expired adapters from tracking + 1. Checks whether an adapter has exceeded `_adapter_max_lifetime` (sync, no async call) + 2. Checks session liveness for remaining adapters every second + 3. Calls _on_adapter_expired() for adapters that have expired + 4. Removes expired adapters from tracking """ logger.debug(f'[AdapterManager] Countdown thread started (session_timeout={self._adapter_timeout}s)') while self._adapter_countdown_running: try: - time.sleep(1) + time.sleep(10) expired_adapters: list[tuple[str, str | None]] = [] # Create snapshot to avoid modification during iteration @@ -236,17 +237,24 @@ def _adapter_countdown_loop(self) -> None: continue session_id = info.get('session_id') - if self._adapter_event_loop is not None and self._adapter_event_loop.is_running(): - future: concurrent.futures.Future = asyncio.run_coroutine_threadsafe( - self._is_session_alive(session_id), self._adapter_event_loop) - try: - session_alive = future.result(timeout=5.0) - except Exception as e: - logger.warning(f'[AdapterManager] Failed to check session liveness for {adapter_name}: {e}') - continue - else: - logger.warning( - f'[AdapterManager] No event loop available to check session {session_id}, skipping') + created_at = info.get('created_at', 0.0) + now = time.time() + + # Check max lifetime first (no async call needed) + if now - created_at >= self._adapter_max_lifetime: + logger.debug(f'[AdapterManager] Adapter {adapter_name} exceeded max lifetime ' + f'({self._adapter_max_lifetime}s), marking as expired') + info['expiring'] = True + info['state'] = {} + token = info.get('token') + expired_adapters.append((adapter_name, token, session_id)) + continue + + try: + session_alive = self._is_session_alive(session_id) + except Exception as e: + logger.warning(f'[AdapterManager] Failed to check session liveness for {adapter_name}: ' + f'{type(e).__name__}: {e}') continue session_expired = not session_alive logger.debug(f'[AdapterManager] Adapter {adapter_name} session check ' @@ -289,10 +297,6 @@ def start_adapter_countdown(self) -> None: """ if not self._adapter_countdown_running: self._adapter_countdown_running = True - try: - self._adapter_event_loop = asyncio.get_running_loop() - except RuntimeError: - self._adapter_event_loop = None self._adapter_countdown_thread = threading.Thread(target=self._adapter_countdown_loop, daemon=True) self._adapter_countdown_thread.start() logger.debug('[AdapterManager] Countdown thread started') diff --git a/src/twinkle/server/utils/processor_manager.py b/src/twinkle/server/utils/processor_manager.py index 9e1cef42..df289b39 100644 --- a/src/twinkle/server/utils/processor_manager.py +++ b/src/twinkle/server/utils/processor_manager.py @@ -7,8 +7,6 @@ """ from __future__ import annotations -import asyncio -import concurrent.futures import threading import time from typing import TYPE_CHECKING, Any @@ -59,7 +57,6 @@ def _init_processor_manager( self._processor_countdown_thread: threading.Thread | None = None self._processor_countdown_running = False - self._processor_event_loop: asyncio.AbstractEventLoop | None = None def register_processor(self, processor_id: str, token: str, session_id: str) -> None: """Register a new processor for lifecycle tracking. @@ -124,11 +121,11 @@ def _on_processor_expired(self, processor_id: str) -> None: """ raise NotImplementedError(f'_on_processor_expired must be implemented by {self.__class__.__name__}') - async def _is_session_alive(self, session_id: str) -> bool: + def _is_session_alive(self, session_id: str) -> bool: """Check if a session is still alive via state proxy.""" if not session_id: return True - last_heartbeat = await self.state.get_session_last_heartbeat(session_id) + last_heartbeat = self.state.get_session_last_heartbeat(session_id) if last_heartbeat is None: return False return (time.time() - last_heartbeat) < self._processor_timeout @@ -145,18 +142,11 @@ def _processor_countdown_loop(self) -> None: if info.get('expiring'): continue session_id = info.get('session_id') - if self._processor_event_loop is not None and self._processor_event_loop.is_running(): - future: concurrent.futures.Future = asyncio.run_coroutine_threadsafe( - self._is_session_alive(session_id), self._processor_event_loop) - try: - session_alive = future.result(timeout=5.0) - except Exception as e: - logger.warning(f'[ProcessorManager] Failed to check session liveness ' - f'for {processor_id}: {type(e).__name__}: {e}') - continue - else: - logger.warning(f'[ProcessorManager] No event loop available to check ' - f'session {session_id}, skipping') + try: + session_alive = self._is_session_alive(session_id) + except Exception as e: + logger.warning(f'[ProcessorManager] Failed to check session liveness ' + f'for {processor_id}: {type(e).__name__}: {e}') continue logger.debug(f'[ProcessorManager] Processor {processor_id} session check ' @@ -192,10 +182,6 @@ def start_processor_countdown(self) -> None: """Start the background countdown thread. Safe to call multiple times.""" if not self._processor_countdown_running: self._processor_countdown_running = True - try: - self._processor_event_loop = asyncio.get_running_loop() - except RuntimeError: - self._processor_event_loop = None self._processor_countdown_thread = threading.Thread(target=self._processor_countdown_loop, daemon=True) self._processor_countdown_thread.start() logger.debug('[ProcessorManager] Countdown thread started') diff --git a/src/twinkle/server/utils/state/server_state.py b/src/twinkle/server/utils/state/server_state.py index d04c877f..a70fdac5 100644 --- a/src/twinkle/server/utils/state/server_state.py +++ b/src/twinkle/server/utils/state/server_state.py @@ -79,7 +79,7 @@ async def touch_session(self, session_id: str) -> bool: """ return self._session_mgr.touch(session_id) - async def get_session_last_heartbeat(self, session_id: str) -> float | None: + def get_session_last_heartbeat(self, session_id: str) -> float | None: """Get the last heartbeat timestamp for a session. Returns: @@ -353,8 +353,8 @@ def create_session(self, payload: dict[str, Any]) -> str: async def touch_session(self, session_id: str) -> bool: return await self._actor.touch_session.remote(session_id) - async def get_session_last_heartbeat(self, session_id: str) -> float | None: - return await self._actor.get_session_last_heartbeat.remote(session_id) + def get_session_last_heartbeat(self, session_id: str) -> float | None: + return ray.get(self._actor.get_session_last_heartbeat.remote(session_id)) # ----- Model Registration ----- diff --git a/src/twinkle_client/types/__init__.py b/src/twinkle_client/types/__init__.py index 57485dfd..b6650a28 100644 --- a/src/twinkle_client/types/__init__.py +++ b/src/twinkle_client/types/__init__.py @@ -57,6 +57,7 @@ AddAdapterRequest as SamplerAddAdapterRequest, AddAdapterResponse, CreateResponse as SamplerCreateResponse, + SampledSequenceModel, SampleRequest, SampleResponseModel, SetTemplateRequest as SamplerSetTemplateRequest, diff --git a/src/twinkle_client/types/model.py b/src/twinkle_client/types/model.py index e620d866..e594bae4 100644 --- a/src/twinkle_client/types/model.py +++ b/src/twinkle_client/types/model.py @@ -208,7 +208,7 @@ class CalculateMetricResponse(BaseModel): class SaveResponse(BaseModel): """Response for /save endpoint (returns twinkle path + checkpoint dir).""" - result: str + twinkle_path: str checkpoint_dir: Optional[str] = None diff --git a/src/twinkle_client/types/sampler.py b/src/twinkle_client/types/sampler.py index c78f5d55..cf370330 100644 --- a/src/twinkle_client/types/sampler.py +++ b/src/twinkle_client/types/sampler.py @@ -5,7 +5,9 @@ These models are used by both the server-side handler and the twinkle client. """ from pydantic import BaseModel, Field -from typing import Any, Dict, List, Optional +from typing import Any, Dict, List, Literal, Optional, Tuple + +StopReason = Literal['length', 'stop'] class SampleRequest(BaseModel): @@ -19,12 +21,22 @@ class SampleRequest(BaseModel): num_samples: int = Field(1, description='Number of completions to generate per prompt') +class SampledSequenceModel(BaseModel): + """A single sampled sequence, mirroring twinkle.data_format.SampledSequence.""" + stop_reason: StopReason = Field(..., description="Stop reason: 'length' or 'stop'") + tokens: List[int] = Field(..., description='Token IDs of the sampled sequence') + logprobs: Optional[List[float]] = Field(None, description='Per-token log-probabilities') + decoded: Optional[str] = Field(None, description='Decoded text of the sampled sequence') + new_input_feature: Optional[Dict[str, Any]] = Field( + None, description='Updated InputFeature after sampling (input_ids, labels, etc.)') + + class SampleResponseModel(BaseModel): - """Response body for the /sample endpoint.""" - sequences: List[Dict[str, Any]] = Field( - ..., description='List of sampled sequences, each with tokens, logprobs, stop_reason') + """Response body for the /sample endpoint, mirroring twinkle.data_format.SampleResponse.""" + sequences: List[SampledSequenceModel] = Field( + ..., description='List of sampled sequences') prompt_logprobs: Optional[List[Optional[float]]] = None - topk_prompt_logprobs: Optional[List[Optional[List]]] = None + topk_prompt_logprobs: Optional[List[Optional[List[Tuple[int, float]]]]] = None class SetTemplateRequest(BaseModel): From 8a2e68144dd7d81c9d8dcb96ff4fef4c11d485b3 Mon Sep 17 00:00:00 2001 From: Yunnglin Date: Thu, 12 Mar 2026 21:20:48 +0800 Subject: [PATCH 21/24] fix processor --- .../client/server/megatron/server_config.yaml | 1 + .../server/megatron/server_config_4b.yaml | 67 ++++++++++--------- .../server/transformer/server_config.yaml | 66 +++++++++--------- cookbook/client/tinker/self_host/sample.py | 2 +- .../client/tinker/self_host/self_cognition.py | 4 +- .../twinkle/self_host/self_congnition.py | 6 +- src/twinkle/server/model/app.py | 1 + .../server/model/backends/megatron_model.py | 17 +---- src/twinkle/server/model/tinker_handlers.py | 1 + 9 files changed, 78 insertions(+), 87 deletions(-) diff --git a/cookbook/client/server/megatron/server_config.yaml b/cookbook/client/server/megatron/server_config.yaml index 2687c910..becda8b0 100644 --- a/cookbook/client/server/megatron/server_config.yaml +++ b/cookbook/client/server/megatron/server_config.yaml @@ -23,6 +23,7 @@ applications: deployments: - name: TinkerCompatServer + max_ongoing_requests: 50 autoscaling_config: min_replicas: 1 # Minimum number of replicas max_replicas: 1 # Maximum number of replicas diff --git a/cookbook/client/server/megatron/server_config_4b.yaml b/cookbook/client/server/megatron/server_config_4b.yaml index dae992a3..0ea99551 100644 --- a/cookbook/client/server/megatron/server_config_4b.yaml +++ b/cookbook/client/server/megatron/server_config_4b.yaml @@ -24,6 +24,7 @@ applications: - Qwen/Qwen3.5-4B deployments: - name: TinkerCompatServer + max_ongoing_requests: 50 autoscaling_config: min_replicas: 1 # Minimum number of replicas max_replicas: 1 # Maximum number of replicas @@ -70,39 +71,39 @@ applications: # 3. Sampler Service - Runs inference / sampling using vLLM engine # Used for generating text from the model (e.g., evaluating LoRA results). - # - name: sampler-Qwen3.5-4B - # route_prefix: /api/v1/sampler/Qwen/Qwen3.5-4B - # import_path: sampler - # args: - # model_id: "ms://Qwen/Qwen3.5-4B" # ModelScope model identifier - # nproc_per_node: 2 # Number of GPU processes per node - # sampler_type: vllm # Inference engine: 'vllm' (fast) or 'torch' (TorchSampler) - # engine_args: # vLLM engine-specific settings - # max_model_len: 4096 # Maximum sequence length the engine supports - # gpu_memory_utilization: 0.5 # Fraction of GPU memory to use (0.0-1.0) - # enable_lora: true # Allow loading LoRA adapters during inference - # logprobs_mode: processed_logprobs # Logprobs mode for sampling results - # device_group: # Logical device group for the sampler - # name: sampler - # ranks: 1 # Number of GPUs to use - # device_type: cuda - # device_mesh: - # device_type: cuda - # dp_size: 1 - # queue_config: - # rps_limit: 100 # Max requests per second - # tps_limit: 100000 # Max tokens per second - # deployments: - # - name: SamplerManagement - # autoscaling_config: - # min_replicas: 1 - # max_replicas: 1 - # target_ongoing_requests: 16 - # ray_actor_options: - # num_cpus: 0.1 - # runtime_env: - # env_vars: - # TWINKLE_TRUST_REMOTE_CODE: "0" + - name: sampler-Qwen3.5-4B + route_prefix: /api/v1/sampler/Qwen/Qwen3.5-4B + import_path: sampler + args: + model_id: "ms://Qwen/Qwen3.5-4B" # ModelScope model identifier + nproc_per_node: 2 # Number of GPU processes per node + sampler_type: vllm # Inference engine: 'vllm' (fast) or 'torch' (TorchSampler) + engine_args: # vLLM engine-specific settings + max_model_len: 4096 # Maximum sequence length the engine supports + gpu_memory_utilization: 0.5 # Fraction of GPU memory to use (0.0-1.0) + enable_lora: true # Allow loading LoRA adapters during inference + logprobs_mode: processed_logprobs # Logprobs mode for sampling results + device_group: # Logical device group for the sampler + name: sampler + ranks: 1 # Number of GPUs to use + device_type: cuda + device_mesh: + device_type: cuda + dp_size: 1 + queue_config: + rps_limit: 100 # Max requests per second + tps_limit: 100000 # Max tokens per second + deployments: + - name: SamplerManagement + autoscaling_config: + min_replicas: 1 + max_replicas: 1 + target_ongoing_requests: 16 + ray_actor_options: + num_cpus: 0.1 + runtime_env: + env_vars: + TWINKLE_TRUST_REMOTE_CODE: "0" # 4. Processor Service - name: processor diff --git a/cookbook/client/server/transformer/server_config.yaml b/cookbook/client/server/transformer/server_config.yaml index 77657dcb..570142af 100644 --- a/cookbook/client/server/transformer/server_config.yaml +++ b/cookbook/client/server/transformer/server_config.yaml @@ -68,39 +68,39 @@ applications: # 3. Sampler Service - Runs inference / sampling using vLLM engine # Used for generating text from the model (e.g., evaluating LoRA results). - # - name: sampler-Qwen3.5-4B - # route_prefix: /api/v1/sampler/Qwen/Qwen3.5-4B - # import_path: sampler - # args: - # model_id: "ms://Qwen/Qwen3.5-4B" # ModelScope model identifier - # nproc_per_node: 2 # Number of GPU processes per node - # sampler_type: vllm # Inference engine: 'vllm' (fast) or 'torch' (TorchSampler) - # engine_args: # vLLM engine-specific settings - # max_model_len: 4096 # Maximum sequence length the engine supports - # gpu_memory_utilization: 0.5 # Fraction of GPU memory to use (0.0-1.0) - # enable_lora: true # Allow loading LoRA adapters during inference - # logprobs_mode: processed_logprobs # Logprobs mode for sampling results - # device_group: # Logical device group for the sampler - # name: sampler - # ranks: 1 # Number of GPUs to use - # device_type: cuda - # device_mesh: - # device_type: cuda - # dp_size: 1 - # queue_config: - # rps_limit: 100 # Max requests per second - # tps_limit: 100000 # Max tokens per second - # deployments: - # - name: SamplerManagement - # autoscaling_config: - # min_replicas: 1 - # max_replicas: 1 - # target_ongoing_requests: 16 - # ray_actor_options: - # num_cpus: 0.1 - # runtime_env: - # env_vars: - # TWINKLE_TRUST_REMOTE_CODE: "0" + - name: sampler-Qwen3.5-4B + route_prefix: /api/v1/sampler/Qwen/Qwen3.5-4B + import_path: sampler + args: + model_id: "ms://Qwen/Qwen3.5-4B" # ModelScope model identifier + nproc_per_node: 2 # Number of GPU processes per node + sampler_type: vllm # Inference engine: 'vllm' (fast) or 'torch' (TorchSampler) + engine_args: # vLLM engine-specific settings + max_model_len: 4096 # Maximum sequence length the engine supports + gpu_memory_utilization: 0.5 # Fraction of GPU memory to use (0.0-1.0) + enable_lora: true # Allow loading LoRA adapters during inference + logprobs_mode: processed_logprobs # Logprobs mode for sampling results + device_group: # Logical device group for the sampler + name: sampler + ranks: 1 # Number of GPUs to use + device_type: cuda + device_mesh: + device_type: cuda + dp_size: 1 + queue_config: + rps_limit: 100 # Max requests per second + tps_limit: 100000 # Max tokens per second + deployments: + - name: SamplerManagement + autoscaling_config: + min_replicas: 1 + max_replicas: 1 + target_ongoing_requests: 16 + ray_actor_options: + num_cpus: 0.1 + runtime_env: + env_vars: + TWINKLE_TRUST_REMOTE_CODE: "0" # 4. Processor Service - name: processor diff --git a/cookbook/client/tinker/self_host/sample.py b/cookbook/client/tinker/self_host/sample.py index 278f24bf..132eb63a 100644 --- a/cookbook/client/tinker/self_host/sample.py +++ b/cookbook/client/tinker/self_host/sample.py @@ -17,7 +17,7 @@ from tinker import ServiceClient # Step 2: Define the base model and connect to the server -base_model = 'Qwen/Qwen3-4B' +base_model = 'Qwen/Qwen3.5-4B' service_client = ServiceClient( base_url='http://localhost:8000', api_key='EMPTY-TOKEN' diff --git a/cookbook/client/tinker/self_host/self_cognition.py b/cookbook/client/tinker/self_host/self_cognition.py index 1097afd8..6951760d 100644 --- a/cookbook/client/tinker/self_host/self_cognition.py +++ b/cookbook/client/tinker/self_host/self_cognition.py @@ -133,5 +133,5 @@ def eval(): if __name__ == '__main__': - # train() # Uncomment to run training - eval() # Run evaluation / inference + train() # Uncomment to run training + # eval() # Run evaluation / inference diff --git a/cookbook/client/twinkle/self_host/self_congnition.py b/cookbook/client/twinkle/self_host/self_congnition.py index 0b853944..6bf6afce 100644 --- a/cookbook/client/twinkle/self_host/self_congnition.py +++ b/cookbook/client/twinkle/self_host/self_congnition.py @@ -50,7 +50,7 @@ def train(): dataset = Dataset(dataset_meta=DatasetMeta('ms://swift/self-cognition', data_slice=range(500))) # Apply a chat template so the data matches the model's expected input format - dataset.set_template('Template', model_id='ms://Qwen/Qwen3-4B', max_length=512) + dataset.set_template('Template', model_id='ms://Qwen/Qwen3.5-4B', max_length=512) # Replace placeholder names in the dataset with custom model/author names dataset.map('SelfCognitionProcessor', init_args={'model_name': 'twinkle模型', 'model_author': 'ModelScope社区'}) @@ -64,7 +64,7 @@ def train(): # Step 5: Configure the model # Create a multi-LoRA Transformers model pointing to the base model on ModelScope - model = MultiLoraTransformersModel(model_id='ms://Qwen/Qwen3-4B') + model = MultiLoraTransformersModel(model_id='ms://Qwen/Qwen3.5-4B') # Define LoRA configuration: apply low-rank adapters to all linear layers lora_config = LoraConfig(target_modules='all-linear') @@ -87,7 +87,7 @@ def train(): model.set_optimizer('Adam', lr=1e-4) # Use a linear learning rate scheduler (Do not support LR scheduler if server use megatron) - model.set_lr_scheduler('LinearLR') + # model.set_lr_scheduler('LinearLR') # Step 6: Optionally resume from a previous checkpoint if resume_path: diff --git a/src/twinkle/server/model/app.py b/src/twinkle/server/model/app.py index b9294fd6..49692e7f 100644 --- a/src/twinkle/server/model/app.py +++ b/src/twinkle/server/model/app.py @@ -61,6 +61,7 @@ def __init__(self, # Choose model backend if use_megatron: from ..model.backends.megatron_model import TwinkleCompatMegatronModel + self.model = TwinkleCompatMegatronModel( model_id=model_id, device_mesh=self.device_mesh, diff --git a/src/twinkle/server/model/backends/megatron_model.py b/src/twinkle/server/model/backends/megatron_model.py index 432bd2ba..b471cdb1 100644 --- a/src/twinkle/server/model/backends/megatron_model.py +++ b/src/twinkle/server/model/backends/megatron_model.py @@ -1,32 +1,19 @@ # Copyright (c) ModelScope Contributors. All rights reserved. """ Megatron backend model for the unified model deployment. -Moved from tinker/common/megatron_model.py — imports updated. """ import torch from tinker import types from typing import TYPE_CHECKING, Any, List, Optional, Tuple from twinkle import remote_class, remote_function +from twinkle.model.megatron import MultiLoraMegatronModel from twinkle.server.common.datum import datum_to_input_feature, extract_rl_feature from twinkle.server.model.backends.common import TwinkleCompatModelBase, clean_metrics, collect_forward_backward_results -from twinkle.utils import exists, requires - -if TYPE_CHECKING: - from twinkle.model.megatron import MultiLoraMegatronModel as _MegatronBase -elif exists('megatron_core'): - import twinkle.model.megatron as megatron_module - _MegatronBase = megatron_module.MultiLoraMegatronModel -else: - - class _MegatronBase: - - def __init__(self, *args, **kwargs): - requires('megatron_core') @remote_class(execute='all') -class TwinkleCompatMegatronModel(_MegatronBase, TwinkleCompatModelBase): +class TwinkleCompatMegatronModel(MultiLoraMegatronModel, TwinkleCompatModelBase): """Compatibility wrapper around MultiLoraMegatronModel for Twinkle/Tinker. Moved from tinker/common/megatron_model.py — logic unchanged. diff --git a/src/twinkle/server/model/tinker_handlers.py b/src/twinkle/server/model/tinker_handlers.py index ccd0daed..6f458d8f 100644 --- a/src/twinkle/server/model/tinker_handlers.py +++ b/src/twinkle/server/model/tinker_handlers.py @@ -40,6 +40,7 @@ async def create_model( async def _create_adapter(): _model_id = None try: + _model_id = self.state.register_model(body.model_dump(), token=token, replica_id=self.replica_id) if body.lora_config: lora_cfg = LoraConfig(r=body.lora_config.rank, target_modules='all-linear') From 09ca8f4385a088286406f42acf3a14162ce3af0b Mon Sep 17 00:00:00 2001 From: Yunnglin Date: Fri, 13 Mar 2026 23:52:21 +0800 Subject: [PATCH 22/24] update --- src/twinkle_client/http/__init__.py | 2 - src/twinkle_client/http/heartbeat.py | 176 ----------------------- src/twinkle_client/processor/grpo.py | 48 ------- src/twinkle_client/reward/__init__.py | 11 -- src/twinkle_client/reward/math_reward.py | 56 -------- 5 files changed, 293 deletions(-) delete mode 100644 src/twinkle_client/http/heartbeat.py delete mode 100644 src/twinkle_client/processor/grpo.py delete mode 100644 src/twinkle_client/reward/__init__.py delete mode 100644 src/twinkle_client/reward/math_reward.py diff --git a/src/twinkle_client/http/__init__.py b/src/twinkle_client/http/__init__.py index 63880a7f..e36ce1e2 100644 --- a/src/twinkle_client/http/__init__.py +++ b/src/twinkle_client/http/__init__.py @@ -1,4 +1,3 @@ -from .heartbeat import heartbeat_manager from .http_utils import http_delete, http_get, http_post from .utils import (TWINKLE_SERVER_TOKEN, TWINKLE_SERVER_URL, get_api_key, get_base_url, get_request_id, get_session_id, set_api_key, set_base_url, set_request_id, set_session_id) @@ -7,7 +6,6 @@ 'http_get', 'http_post', 'http_delete', - 'heartbeat_manager', 'TWINKLE_SERVER_URL', 'TWINKLE_SERVER_TOKEN', 'set_base_url', diff --git a/src/twinkle_client/http/heartbeat.py b/src/twinkle_client/http/heartbeat.py deleted file mode 100644 index 5194d75b..00000000 --- a/src/twinkle_client/http/heartbeat.py +++ /dev/null @@ -1,176 +0,0 @@ -import atexit -import threading -from threading import Lock -from typing import Callable, Dict, Optional, Set - -from .http_utils import http_post -from .utils import get_base_url - - -class HeartbeatManager: - """Manages heartbeat threads for processors, models, and samplers. - - This class provides automatic heartbeat management with these features: - - Global thread for processor heartbeats (sent every 30 seconds) - - Per-adapter threads for model/sampler heartbeats (sent every 30 seconds) - - Batch processor heartbeats to reduce network load - - Automatic cleanup on object destruction - """ - - _instance = None - _lock = Lock() - - def __new__(cls): - if cls._instance is None: - with cls._lock: - if cls._instance is None: - cls._instance = super().__new__(cls) - cls._instance._initialized = False - return cls._instance - - def __init__(self): - if self._initialized: - return - - self._initialized = True - - # Processor heartbeat management - self.processor_ids: Set[str] = set() - self.processor_lock = Lock() - self.processor_thread: Optional[threading.Thread] = None - self.processor_stop_event = threading.Event() - - # Adapter heartbeat management (for models/samplers) - self.adapter_threads: Dict[str, threading.Thread] = {} - self.adapter_stop_events: Dict[str, threading.Event] = {} - self.adapter_heartbeat_funcs: Dict[str, Callable] = {} - self.adapter_lock = Lock() - - # Register cleanup on exit - atexit.register(self.shutdown_all) - - def processor_heartbeat_func(self, processor_id_list: str): - response = http_post( - url=f'{get_base_url()}/processor/twinkle/heartbeat', json_data={'processor_id': processor_id_list}) - response.raise_for_status() - - def register_processor(self, processor_id: str): - """Register a processor for heartbeat monitoring. - - Args: - processor_id: The processor ID to monitor - """ - with self.processor_lock: - self.processor_ids.add(processor_id) - - # Start processor heartbeat thread if not running - if self.processor_thread is None or not self.processor_thread.is_alive(): - self.processor_stop_event.clear() - self.processor_thread = threading.Thread( - target=self._processor_heartbeat_loop, daemon=True, name='ProcessorHeartbeatThread') - self.processor_thread.start() - - def unregister_processor(self, processor_id: str): - """Unregister a processor from heartbeat monitoring. - - Args: - processor_id: The processor ID to remove - """ - with self.processor_lock: - self.processor_ids.discard(processor_id) - - # Stop thread if no more processors - if not self.processor_ids and self.processor_thread: - self.processor_stop_event.set() - - def register_adapter(self, adapter_key: str, heartbeat_func: Callable): - """Register an adapter for heartbeat monitoring. - - Args: - adapter_key: Unique key for the adapter (e.g., "model:adapter_name") - heartbeat_func: Function to call for heartbeat (no arguments) - """ - with self.adapter_lock: - # Stop existing thread if any - if adapter_key in self.adapter_threads: - self.adapter_stop_events[adapter_key].set() - self.adapter_threads[adapter_key].join(timeout=1) - - # Create new thread - self.adapter_heartbeat_funcs[adapter_key] = heartbeat_func - stop_event = threading.Event() - self.adapter_stop_events[adapter_key] = stop_event - - thread = threading.Thread( - target=self._adapter_heartbeat_loop, - args=(adapter_key, stop_event), - daemon=True, - name=f'AdapterHeartbeat-{adapter_key}') - self.adapter_threads[adapter_key] = thread - thread.start() - - def unregister_adapter(self, adapter_key: str): - """Unregister an adapter from heartbeat monitoring. - - Args: - adapter_key: Unique key for the adapter - """ - with self.adapter_lock: - if adapter_key in self.adapter_stop_events: - self.adapter_stop_events[adapter_key].set() - - if adapter_key in self.adapter_threads: - self.adapter_threads[adapter_key].join(timeout=1) - del self.adapter_threads[adapter_key] - - self.adapter_stop_events.pop(adapter_key, None) - self.adapter_heartbeat_funcs.pop(adapter_key, None) - - def _processor_heartbeat_loop(self): - """Heartbeat loop for processors (runs every 30 seconds).""" - while not self.processor_stop_event.wait(timeout=30): - with self.processor_lock: - if not self.processor_ids or not self.processor_heartbeat_func: - continue - - # Batch send processor IDs as comma-separated string - processor_id_list = ','.join(self.processor_ids) - - try: - self.processor_heartbeat_func(processor_id_list) - except Exception as e: - print(f'Processor heartbeat error: {e}') - - def _adapter_heartbeat_loop(self, adapter_key: str, stop_event: threading.Event): - """Heartbeat loop for a specific adapter (runs every 30 seconds). - - Args: - adapter_key: Unique key for the adapter - stop_event: Event to signal thread shutdown - """ - while not stop_event.wait(timeout=30): - heartbeat_func = self.adapter_heartbeat_funcs.get(adapter_key) - if heartbeat_func: - try: - heartbeat_func() - except Exception as e: - print(f'Adapter heartbeat error for {adapter_key}: {e}') - - def shutdown_all(self): - """Shutdown all heartbeat threads.""" - # Stop processor thread - if self.processor_thread: - self.processor_stop_event.set() - self.processor_thread.join(timeout=1) - - # Stop all adapter threads - with self.adapter_lock: - for stop_event in self.adapter_stop_events.values(): - stop_event.set() - - for thread in self.adapter_threads.values(): - thread.join(timeout=1) - - -# Global heartbeat manager instance -heartbeat_manager = HeartbeatManager() diff --git a/src/twinkle_client/processor/grpo.py b/src/twinkle_client/processor/grpo.py deleted file mode 100644 index 10d32d20..00000000 --- a/src/twinkle_client/processor/grpo.py +++ /dev/null @@ -1,48 +0,0 @@ -from typing import Optional - -from twinkle import DeviceMesh -from twinkle.data_format import InputFeature -from twinkle_client.http import TWINKLE_SERVER_URL, heartbeat_manager, http_post -from .base import InputProcessor - - -class GRPOLossProcessor(InputProcessor): - """Client wrapper for GRPOLossProcessor that calls server HTTP endpoints.""" - - def __init__(self, device_mesh: Optional[DeviceMesh] = None, ignore_index: int = -100, **kwargs): - from twinkle_client.http import get_base_url - self.server_url = get_base_url() - - response = http_post( - url=f'{self.server_url}/processors/create', - json_data={ - 'processor_type': 'processor', - 'class_type': 'GRPOLossProcessor', - **{ - 'device_mesh': device_mesh, - 'ignore_index': ignore_index - }, - **kwargs - }) - response.raise_for_status() - self.processor_id = response.json()['processor_id'] - heartbeat_manager.register_processor(self.processor_id) - - def __del__(self): - try: - heartbeat_manager.unregister_processor(self.processor_id) - except: - pass - - def prepare_inputs(self, inputs: InputFeature): - response = http_post( - url=f'{self.server_url}/processors/call', - json_data={ - 'processor_id': self.processor_id, - 'function': 'prepare_inputs', - **{ - 'inputs': inputs - }, - }) - response.raise_for_status() - return response.json()['result'] diff --git a/src/twinkle_client/reward/__init__.py b/src/twinkle_client/reward/__init__.py deleted file mode 100644 index e632b263..00000000 --- a/src/twinkle_client/reward/__init__.py +++ /dev/null @@ -1,11 +0,0 @@ -# ============================================================================ -# WARNING: AUTO-GENERATED FILE - DO NOT MODIFY MANUALLY! -# ============================================================================ -# This file is automatically generated by client_tools/client_generator.py -# Any manual changes will be overwritten when the generator runs again. -# -# To update this file: -# 1. Modify the source files in src/twinkle/ -# 2. Run: python client_tools/client_generator.py -# ============================================================================ -from .math_reward import MathReward diff --git a/src/twinkle_client/reward/math_reward.py b/src/twinkle_client/reward/math_reward.py deleted file mode 100644 index f0a8e180..00000000 --- a/src/twinkle_client/reward/math_reward.py +++ /dev/null @@ -1,56 +0,0 @@ -# ============================================================================ -# WARNING: AUTO-GENERATED FILE - DO NOT MODIFY MANUALLY! -# ============================================================================ -# This file is automatically generated by client_tools/client_generator.py -# Any manual changes will be overwritten when the generator runs again. -# -# To update this file: -# 1. Modify the source files in src/twinkle/ -# 2. Run: python client_tools/client_generator.py -# ============================================================================ - -from typing import List - -from twinkle.data_format import Trajectory -from twinkle_client.http import TWINKLE_SERVER_URL, heartbeat_manager, http_post - - -class MathReward: - """Client wrapper for MathReward that calls server HTTP endpoints.""" - - def __init__(self, ground_truth_key: str = 'solution'): - from twinkle_client.http import get_base_url - self.server_url = get_base_url() - - response = http_post( - url=f'{self.server_url}/processors/create', - json_data={ - 'processor_type': 'reward', - 'class_type': 'MathReward', - **{ - 'ground_truth_key': ground_truth_key - } - }) - response.raise_for_status() - self.processor_id = response.json()['processor_id'] - heartbeat_manager.register_processor(self.processor_id) - - def __del__(self): - try: - heartbeat_manager.unregister_processor(self.processor_id) - except: - pass - - def __call__(self, trajectories: List[Trajectory], ground_truths: List[Trajectory]): - response = http_post( - url=f'{self.server_url}/processors/call', - json_data={ - 'processor_id': self.processor_id, - 'function': '__call__', - **{ - 'trajectories': trajectories, - 'ground_truths': ground_truths - }, - }) - response.raise_for_status() - return response.json()['result'] From 2ca4e60091fb074bf57be95664ac6a18bdd1f17f Mon Sep 17 00:00:00 2001 From: Yunnglin Date: Sat, 14 Mar 2026 00:43:16 +0800 Subject: [PATCH 23/24] update --- src/twinkle/model/megatron/multi_lora_megatron.py | 2 +- src/twinkle/model/transformers/multi_lora_transformers.py | 2 +- src/twinkle/model/transformers/transformers.py | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/src/twinkle/model/megatron/multi_lora_megatron.py b/src/twinkle/model/megatron/multi_lora_megatron.py index 531f5acd..9afbc056 100644 --- a/src/twinkle/model/megatron/multi_lora_megatron.py +++ b/src/twinkle/model/megatron/multi_lora_megatron.py @@ -170,7 +170,7 @@ def set_lr_scheduler(self, scheduler_cls: Union[LRScheduler, Type[LRScheduler], self._check_adapter_valid(kwargs.get('adapter_name')) return super().set_lr_scheduler(scheduler_cls, **kwargs) - @remote_function(dispatch='all', sync=True) + @remote_function(dispatch='all', collect='first', sync=True) def save(self, name, output_dir: Optional[str] = None, interval=1, **kwargs): self._check_adapter_valid(kwargs.get('adapter_name')) optimizer_config = self.optimizer_group[kwargs.get('adapter_name')] diff --git a/src/twinkle/model/transformers/multi_lora_transformers.py b/src/twinkle/model/transformers/multi_lora_transformers.py index f4bccc53..6033d943 100644 --- a/src/twinkle/model/transformers/multi_lora_transformers.py +++ b/src/twinkle/model/transformers/multi_lora_transformers.py @@ -209,7 +209,7 @@ def get_state_dict(self, **kwargs): self._check_adapter_valid(kwargs.get('adapter_name')) return self.multi_adapter.get_state_dict(kwargs.get('adapter_name')) - @remote_function() + @remote_function(collect='first') def save(self, name, output_dir: Optional[str] = None, interval=1, **kwargs): self._check_adapter_valid(kwargs.get('adapter_name')) with self.multi_adapter.save_context(kwargs.get('adapter_name')): diff --git a/src/twinkle/model/transformers/transformers.py b/src/twinkle/model/transformers/transformers.py index 0d92bd24..7f73efe6 100644 --- a/src/twinkle/model/transformers/transformers.py +++ b/src/twinkle/model/transformers/transformers.py @@ -781,7 +781,7 @@ def apply_patch(self, patch_cls: Union[Patch, Type[Patch], str], **kwargs): def __del__(self): HubOperation.wait_for() - @remote_function() + @remote_function(collect='first') def save(self, name: Optional[str] = None, output_dir: Optional[str] = None, interval: int = 1, **kwargs): """Save model. From 51ebbbf8a734b43708bb51eb8bc439ff205ec862 Mon Sep 17 00:00:00 2001 From: Yunnglin Date: Sat, 14 Mar 2026 01:03:53 +0800 Subject: [PATCH 24/24] update --- .../tinker/self_host/short_math_grpo.py | 2 +- cookbook/client/twinkle/self_host/grpo.py | 13 +++---- src/twinkle/server/model/backends/common.py | 30 +++++++++++++++- .../server/model/backends/megatron_model.py | 17 +++++++-- .../model/backends/transformers_model.py | 36 +++---------------- 5 files changed, 56 insertions(+), 42 deletions(-) diff --git a/cookbook/client/tinker/self_host/short_math_grpo.py b/cookbook/client/tinker/self_host/short_math_grpo.py index 6e34f899..35b4d96d 100644 --- a/cookbook/client/tinker/self_host/short_math_grpo.py +++ b/cookbook/client/tinker/self_host/short_math_grpo.py @@ -217,7 +217,7 @@ def main(): from tinker import ServiceClient service_client = ServiceClient( base_url='http://localhost:8000', - api_key=os.environ.get('MODELSCOPE_TOKEN') + api_key='EMPTY_TOKEN' ) logger.info('Creating LoRA training client...') diff --git a/cookbook/client/twinkle/self_host/grpo.py b/cookbook/client/twinkle/self_host/grpo.py index 883b2323..8291fb91 100644 --- a/cookbook/client/twinkle/self_host/grpo.py +++ b/cookbook/client/twinkle/self_host/grpo.py @@ -103,12 +103,13 @@ def train(): model.set_loss('GRPOLoss', epsilon=0.2, beta=0.0) # Set optimizer and LR scheduler - model.set_optimizer('AdamW', lr=LEARNING_RATE) - model.set_lr_scheduler( - 'CosineWarmupScheduler', - num_warmup_steps=500, - num_training_steps=MAX_STEPS, - ) + model.set_optimizer('Adam', lr=LEARNING_RATE) + # Set LR scheduler (if server use megatron, don't support set self.lr_scheduler) + # model.set_lr_scheduler( + # 'CosineWarmupScheduler', + # num_warmup_steps=500, + # num_training_steps=MAX_STEPS, + # ) # Set processor and template for encoding inputs model.set_processor('InputProcessor') diff --git a/src/twinkle/server/model/backends/common.py b/src/twinkle/server/model/backends/common.py index e1f62e23..6b8da701 100644 --- a/src/twinkle/server/model/backends/common.py +++ b/src/twinkle/server/model/backends/common.py @@ -5,9 +5,10 @@ import numpy as np import re import torch +from collections.abc import Mapping from numbers import Number from tinker import types -from typing import List +from typing import Any, List from twinkle import DeviceMesh from twinkle.template import Template @@ -58,6 +59,33 @@ def collect_forward_backward_results(results, device_mesh: DeviceMesh): return [all_outputs, avg_loss] +def to_cpu_safe_output(obj: Any) -> Any: + """Convert nested model outputs into CPU-safe Python objects for HTTP transport. + + Recursively walks tensors, numpy arrays, mappings and sequences, + converting each tensor/array to a plain Python scalar or list so + Ray can serialise the result without requiring CUDA on the driver. + """ + from twinkle.utils import torch_util + + if isinstance(obj, torch.Tensor): + tensor = torch_util.to_local_tensor(obj).detach().cpu() + if tensor.numel() == 1: + return tensor.item() + return tensor.tolist() + if isinstance(obj, np.ndarray): + if obj.size == 1: + return obj.item() + return obj.tolist() + if isinstance(obj, np.generic): + return obj.item() + if isinstance(obj, Mapping): + return {key: to_cpu_safe_output(value) for key, value in obj.items()} + if isinstance(obj, (list, tuple)): + return [to_cpu_safe_output(value) for value in obj] + return obj + + def clean_metrics(metrics: dict) -> dict: def _to_float(v): diff --git a/src/twinkle/server/model/backends/megatron_model.py b/src/twinkle/server/model/backends/megatron_model.py index b471cdb1..831b9468 100644 --- a/src/twinkle/server/model/backends/megatron_model.py +++ b/src/twinkle/server/model/backends/megatron_model.py @@ -4,12 +4,14 @@ """ import torch from tinker import types -from typing import TYPE_CHECKING, Any, List, Optional, Tuple +from typing import TYPE_CHECKING, Any, List, Optional, Tuple, Union from twinkle import remote_class, remote_function +from twinkle.data_format import InputFeature, Trajectory from twinkle.model.megatron import MultiLoraMegatronModel from twinkle.server.common.datum import datum_to_input_feature, extract_rl_feature -from twinkle.server.model.backends.common import TwinkleCompatModelBase, clean_metrics, collect_forward_backward_results +from twinkle.server.model.backends.common import (TwinkleCompatModelBase, clean_metrics, + collect_forward_backward_results, to_cpu_safe_output) @remote_class(execute='all') @@ -112,3 +114,14 @@ def tinker_load(self, checkpoint_dir: str, **kwargs): return super().load(name=resolved.checkpoint_name, output_dir=resolved.checkpoint_dir, **kwargs) else: return super().load(name=resolved.checkpoint_name, **kwargs) + + # ------------------------------------------------------------------ + # Twinkle-native methods (InputFeature/Trajectory-based I/O) + # ------------------------------------------------------------------ + + @remote_function(dispatch='slice_dp', collect='mean') + def forward_backward(self, *, inputs: Union[InputFeature, List[InputFeature], Trajectory, List[Trajectory]], + **kwargs): + """Forward+backward for twinkle-native clients (InputFeature/Trajectory I/O).""" + output = super().forward_backward(inputs=inputs, **kwargs) + return to_cpu_safe_output(output) diff --git a/src/twinkle/server/model/backends/transformers_model.py b/src/twinkle/server/model/backends/transformers_model.py index 20d6b75b..2a895e02 100644 --- a/src/twinkle/server/model/backends/transformers_model.py +++ b/src/twinkle/server/model/backends/transformers_model.py @@ -6,17 +6,15 @@ - TwinkleCompatTransformersModel: handles both tinker (Datum-based I/O) via /tinker/* endpoints and twinkle-native (InputFeature/Trajectory-based I/O) via /twinkle/* endpoints. """ -import numpy as np -import torch -from collections.abc import Mapping from tinker import types -from typing import Any, List, Union +from typing import List, Union from twinkle import remote_class, remote_function from twinkle.data_format import InputFeature, Trajectory from twinkle.model import MultiLoraTransformersModel from twinkle.server.common.datum import datum_to_input_feature, extract_rl_feature -from twinkle.server.model.backends.common import TwinkleCompatModelBase, clean_metrics, collect_forward_backward_results +from twinkle.server.model.backends.common import (TwinkleCompatModelBase, clean_metrics, + collect_forward_backward_results, to_cpu_safe_output) @remote_class() @@ -28,32 +26,6 @@ class TwinkleCompatTransformersModel(MultiLoraTransformersModel, TwinkleCompatMo - Twinkle-native I/O (InputFeature / Trajectory) via /twinkle/* endpoints. """ - # ------------------------------------------------------------------ - # Shared helper: CPU-safe serialisation for HTTP transport - # ------------------------------------------------------------------ - - @staticmethod - def _to_cpu_safe_output(obj: Any) -> Any: - """Convert nested outputs into CPU-safe Python objects for HTTP transport.""" - from twinkle.utils import torch_util - - if isinstance(obj, torch.Tensor): - tensor = torch_util.to_local_tensor(obj).detach().cpu() - if tensor.numel() == 1: - return tensor.item() - return tensor.tolist() - if isinstance(obj, np.ndarray): - if obj.size == 1: - return obj.item() - return obj.tolist() - if isinstance(obj, np.generic): - return obj.item() - if isinstance(obj, Mapping): - return {key: TwinkleCompatTransformersModel._to_cpu_safe_output(value) for key, value in obj.items()} - if isinstance(obj, (list, tuple)): - return [TwinkleCompatTransformersModel._to_cpu_safe_output(value) for value in obj] - return obj - # ------------------------------------------------------------------ # Tinker-compat methods (Datum-based I/O) # ------------------------------------------------------------------ @@ -135,4 +107,4 @@ def forward_backward(self, *, inputs: Union[InputFeature, List[InputFeature], Tr **kwargs): """Forward+backward for twinkle-native clients (InputFeature/Trajectory I/O).""" output = super().forward_backward(inputs=inputs, **kwargs) - return self._to_cpu_safe_output(output) + return to_cpu_safe_output(output)