From d409eda2ca08ce25d976f2324c76ccfa1bf756c7 Mon Sep 17 00:00:00 2001 From: qwes5s5 <1522419171@qq.com> Date: Mon, 30 Mar 2026 06:42:07 +0000 Subject: [PATCH 1/2] abort requests --- docs/online_serving/README.md | 1 + docs/online_serving/router.md | 1 + docs/zh/online_serving/README.md | 1 + docs/zh/online_serving/router.md | 1 + fastdeploy/engine/common_engine.py | 135 ++++++++++++ .../engine/sched/resource_manager_v1.py | 4 + fastdeploy/entrypoints/openai/api_server.py | 19 ++ fastdeploy/entrypoints/openai/serving_chat.py | 5 + .../entrypoints/openai/serving_completion.py | 4 + fastdeploy/router/router.py | 46 +++- tests/engine/test_common_engine.py | 202 ++++++++++++++++++ tests/entrypoints/openai/test_api_server.py | 77 +++++++ 12 files changed, 494 insertions(+), 2 deletions(-) diff --git a/docs/online_serving/README.md b/docs/online_serving/README.md index c87e9d51ec5..2b447476020 100644 --- a/docs/online_serving/README.md +++ b/docs/online_serving/README.md @@ -577,3 +577,4 @@ DeltaFunctionCall: - `/v1/pause` - Pause generation (causes denial of service). Inflight requests are aborted and cache is reset. - `/v1/resume` - Resume generation. - `/v1/is_paused` - Check if generation is paused. +- `/v1/abort_requests` - Abort inference requests to release GPU memory (KV Cache blocks) and compute resources. Accepts `req_ids` (list of request IDs) or `abort_all=true` (abort all requests). Returns the list of aborted requests with their generated token counts. diff --git a/docs/online_serving/router.md b/docs/online_serving/router.md index 367f276ec69..c3405fc361f 100644 --- a/docs/online_serving/router.md +++ b/docs/online_serving/router.md @@ -151,6 +151,7 @@ The Router exposes a set of HTTP services to provide unified request scheduling, |----------|------|------| | POST | `/v1/chat/completions` | Provide scheduling services for inference requests based on the Chat Completions API | | POST | `/v1/completions` | Provide scheduling services for general text completion inference requests | +| POST | `/v1/abort_requests` | Abort inference requests to release GPU memory and compute resources. Accepts `req_ids` or `abort_all=true`. Returns aborted requests with their generated token counts | | POST | `/register` | Allow inference instances to register their metadata with the Router for scheduling | | GET | `/registered` | Query the list of currently registered inference instances | | GET | `/registered_number` | Query the number of currently registered inference instances | diff --git a/docs/zh/online_serving/README.md b/docs/zh/online_serving/README.md index 5c734daeb62..21f16d06e32 100644 --- a/docs/zh/online_serving/README.md +++ b/docs/zh/online_serving/README.md @@ -563,3 +563,4 @@ DeltaFunctionCall: /v1/pause - 暂停推理生成(会导致服务拒绝推理请求)。正在进行中的请求会被中止,缓存会被重置。 /v1/resume - 恢复推理生成。 /v1/is_paused - 检查推理生成是否已暂停。 +/v1/abort_requests - 中断推理请求,释放 GPU 显存(KV Cache blocks)和计算资源。支持传入 `req_ids`(请求 ID 列表)或 `abort_all=true`(中断所有请求)。返回已中断请求列表及其已生成的 token 数。 diff --git a/docs/zh/online_serving/router.md b/docs/zh/online_serving/router.md index 434f55f1c6b..6202a3337e6 100644 --- a/docs/zh/online_serving/router.md +++ b/docs/zh/online_serving/router.md @@ -153,6 +153,7 @@ Router 通过 HTTP 接口对外提供统一的调度服务,同时支持运行 |----------|------|------| | POST | `/v1/chat/completions` | 对外提供基于 Chat 接口的推理请求调度服务 | | POST | `/v1/completions` | 对外提供通用文本补全请求的调度服务 | +| POST | `/v1/abort_requests` | 中断推理请求,释放 GPU 显存和计算资源。支持传入 `req_ids` 或 `abort_all=true`,返回已中断请求列表及其已生成的 token 数 | | POST | `/register` | 推理实例向 Router 注册自身信息,用于参与调度 | | GET | `/registered` | 查询当前已注册的推理实例列表 | | GET | `/registered_number` | 查询当前已注册的推理实例数量 | diff --git a/fastdeploy/engine/common_engine.py b/fastdeploy/engine/common_engine.py index 1c6452408ba..9b8717f3763 100644 --- a/fastdeploy/engine/common_engine.py +++ b/fastdeploy/engine/common_engine.py @@ -43,9 +43,11 @@ from fastdeploy.cache_manager.cache_data import CacheStatus from fastdeploy.config import FDConfig from fastdeploy.engine.request import ( + CompletionOutput, ControlRequest, ControlResponse, Request, + RequestMetrics, RequestOutput, RequestStatus, RequestType, @@ -1413,6 +1415,139 @@ def _control_update_weights(self, control_request: ControlRequest) -> Optional[d raise Exception(error_msg) return self._call_worker(control_request, 60) + def _control_abort_requests(self, control_req: ControlRequest): + if not envs.ENABLE_V1_KVCACHE_SCHEDULER: + raise Exception("abort_requests only supported in ENABLE_V1_KVCACHE_SCHEDULER") + args = control_req.get_args() + abort_all = args.get("abort_all", False) + req_ids = args.get("req_ids", []) + matched_input_ids = set() + now_reqs = list(set(self.resource_manager.requests.keys()) | set(self.scheduler.requests.keys())) + + # Step 1: Determine target request list + if abort_all: + # all requests in running + waiting + target_req_ids = now_reqs + else: + # filter out requests that actually exist + target_req_ids = [] + for rid in req_ids: + if rid in now_reqs: + target_req_ids.append(rid) + matched_input_ids.add(rid) + elif f"{rid}_0" in now_reqs: + target_req_ids.append(f"{rid}_0") + matched_input_ids.add(rid) + + if not target_req_ids: + return {"aborted": [], "not_found": req_ids if not abort_all else []} + + # Step 2: Collect partial results + aborted_info = [] + results = [] + for req_id in target_req_ids: + request = self.resource_manager.requests.get(req_id) + if request is None: + scheduled_req = self.scheduler.requests.get(req_id) + if scheduled_req is None: + continue + request = scheduled_req.raw + + partial_token_ids = list(request.output_token_ids) + + # Construct finished response with partial results + now = time.time() + abort_metrics = RequestMetrics( + arrival_time=request.metrics.arrival_time if request.metrics else now, + inference_start_time=request.metrics.inference_start_time if request.metrics else now, + engine_recv_latest_token_time=now, + engine_recv_first_token_time=request.metrics.engine_recv_first_token_time if request.metrics else now, + request_start_time=request.metrics.arrival_time if request.metrics else now, + ) + result = RequestOutput( + request_id=req_id, + finished=True, + outputs=CompletionOutput( + index=0, + send_idx=len(partial_token_ids), + token_ids=[self.data_processor.eos_token_ids[0]], + ), + metrics=abort_metrics, + error_code=200, + error_msg="Aborted", + ) + results.append(result) + aborted_info.append( + { + "request_id": req_id, + "output_token_count": len(partial_token_ids), + } + ) + + # Step 3: Execute abort — add all requests to waiting_abort_req_id_set + if envs.ENABLE_V1_KVCACHE_SCHEDULER: + for req_id in target_req_ids: + self.resource_manager.add_abort_req_ids(req_id) + time.sleep(0.0001) + if self.cfg.scheduler_config.splitwise_role != "prefill": + self._wait_abort_complete(target_req_ids) + + # Add results to scheduler, engine will have a thread calling get_results, + # then cleanup and call send_response to send to client. + # When client disconnects, send_response will automatically ignore + if self.cfg.scheduler_config.splitwise_role != "prefill": + try: + # self.send_response_server.send_response(req_id, [result]) + self.scheduler.put_results(results) + except Exception: + pass # client may have disconnected + + not_found = [rid for rid in req_ids if rid not in matched_input_ids] if not abort_all else [] + + return {"aborted": aborted_info, "not_found": not_found} + + def _wait_abort_complete(self, target_req_ids, stall_timeout=1): + """ + Wait for all abort requests to complete. + - Keep monitoring as long as remaining is not empty, which means cleanup is not done yet + - If no progress within stall_timeout seconds, force cleanup requests stuck in to_be_aborted_req_id_set, + reset progress state if any, then continue monitoring + """ + target_set = set(target_req_ids) + prev_remaining_count = len(target_set) + last_progress_time = time.time() + remaining = target_set & self.resource_manager.get_reqs_in_aborting() + while remaining: + remaining = target_set & self.resource_manager.get_reqs_in_aborting() + if not remaining: + self.llm_logger.info(f"all {len(target_set)} abort reqs cleaned") + return + + current_count = len(remaining) + if current_count < prev_remaining_count: + # progress made: recycle_abort_task was called + self.llm_logger.info(f"abort progress: {prev_remaining_count} -> {current_count}") + last_progress_time = time.time() + prev_remaining_count = current_count + + if time.time() - last_progress_time > stall_timeout: + # no progress timeout: only cleanup requests stuck in to_be_aborted (worker hasn't returned -9) + stuck = remaining & self.resource_manager.to_be_aborted_req_id_set + if stuck: + self.llm_logger.warning( + f"no abort progress for {stall_timeout}s, " + f"force cleanup {len(stuck)} stuck requests (in to_be_aborted)" + ) + for req_id in list(stuck): + self.llm_logger.warning(f"force cleanup stuck req_id:{req_id}") + self.resource_manager.recycle_abort_task(req_id) + # reset progress state + last_progress_time = time.time() + prev_remaining_count = current_count - len(stuck) + # else: remaining are all in waiting_abort_req_id_set, waiting for natural flow + + time.sleep(0.005) + def _parse_tags(self, control_request: ControlRequest): """ Parse tags from control request. diff --git a/fastdeploy/engine/sched/resource_manager_v1.py b/fastdeploy/engine/sched/resource_manager_v1.py index 6cc9363bf6b..27322473adf 100644 --- a/fastdeploy/engine/sched/resource_manager_v1.py +++ b/fastdeploy/engine/sched/resource_manager_v1.py @@ -279,6 +279,7 @@ def recycle_abort_task(self, request_id): del self.requests[request_id] del self.req_dict[request_id] self.to_be_aborted_req_id_set.remove(request_id) + self.update_metrics() def _trigger_abort(self, request_id, scheduled_reqs): if request_id in self.requests: @@ -1120,6 +1121,9 @@ def download_bos_features(bos_client, features_urls): return None inputs["audio_features"] = result + def get_reqs_in_aborting(self): + return self.waiting_abort_req_id_set | self.to_be_aborted_req_id_set + def get_available_position(self) -> int: position = 0 while position < self.max_num_seqs: diff --git a/fastdeploy/entrypoints/openai/api_server.py b/fastdeploy/entrypoints/openai/api_server.py index 2a608350db5..b5ae9287a02 100644 --- a/fastdeploy/entrypoints/openai/api_server.py +++ b/fastdeploy/entrypoints/openai/api_server.py @@ -478,6 +478,25 @@ async def update_weights(request: Request) -> Response: return control_response.to_api_json_response() +@app.post("/v1/abort_requests") +async def abort_requests(request: Request): + body = await request.json() + abort_all = body.get("abort_all", False) + req_ids = body.get("req_ids", None) + + # 参数校验 + if not abort_all and not req_ids: + return JSONResponse(status_code=400, content={"error": "must provide abort_all=true or req_ids"}) + + control_request = ControlRequest( + request_id=f"control-{uuid.uuid4()}", + method="abort_requests", + args={"abort_all": abort_all, "req_ids": req_ids or []}, + ) + control_response = await app.state.engine_client.run_control_method(control_request) + return control_response.to_api_json_response() + + def wrap_streaming_generator(original_generator: AsyncGenerator): """ Wrap an async generator to release the connection semaphore when the generator is finished. diff --git a/fastdeploy/entrypoints/openai/serving_chat.py b/fastdeploy/entrypoints/openai/serving_chat.py index f0fb528972c..a6699f1f2f9 100644 --- a/fastdeploy/entrypoints/openai/serving_chat.py +++ b/fastdeploy/entrypoints/openai/serving_chat.py @@ -463,6 +463,9 @@ async def chat_completion_stream_generator( if res.get("error_msg") is not None and "Recover" in res["error_msg"]: choice.finish_reason = "recover_stop" + if res.get("error_msg") is not None and "Aborted" in res["error_msg"]: + choice.finish_reason = "abort" + inference_start_time[idx] = 0 if request.collect_metrics: @@ -795,6 +798,8 @@ async def _create_chat_completion_choice( if data.get("error_msg", None) is not None and "Recover" in data["error_msg"]: finish_reason = "recover_stop" + if data.get("error_msg", None) is not None and "Aborted" in data["error_msg"]: + finish_reason = "abort" return ChatCompletionResponseChoice( index=idx, message=message, diff --git a/fastdeploy/entrypoints/openai/serving_completion.py b/fastdeploy/entrypoints/openai/serving_completion.py index aea39c01e7b..7bd04f4ecab 100644 --- a/fastdeploy/entrypoints/openai/serving_completion.py +++ b/fastdeploy/entrypoints/openai/serving_completion.py @@ -582,6 +582,8 @@ async def completion_stream_generator( output, tool_called[idx], ) + if res.get("error_msg") is not None and "Aborted" in res["error_msg"]: + choices[-1].finish_reason = "abort" inference_start_time[idx] = 0 send_idx = output.get("send_idx") @@ -724,6 +726,8 @@ def request_output_to_completion_response( output, False, ) + if final_res.get("error_msg", None) is not None and "Aborted" in final_res["error_msg"]: + finish_reason = "abort" choice_data = CompletionResponseChoice( token_ids=token_ids, diff --git a/fastdeploy/router/router.py b/fastdeploy/router/router.py index f73144dc535..6f830ebe157 100644 --- a/fastdeploy/router/router.py +++ b/fastdeploy/router/router.py @@ -16,8 +16,8 @@ import aiohttp import uvicorn -from fastapi import FastAPI, HTTPException -from fastapi.responses import ORJSONResponse, Response, StreamingResponse +from fastapi import FastAPI, HTTPException, Request +from fastapi.responses import JSONResponse, ORJSONResponse, Response, StreamingResponse from fastdeploy.router.utils import ( InstanceInfo, @@ -485,6 +485,48 @@ async def health_generate(): return Response(status_code=200) +@app.post("/v1/abort_requests") +async def abort_requests(request: Request): + body = await request.json() + prefill_servers = app.state.router.prefill_servers + decode_servers = app.state.router.decode_servers + all_servers = prefill_servers + decode_servers + + async with aiohttp.ClientSession() as session: + tasks = [session.post(f"{server.url()}/v1/abort_requests", json=body) for server in all_servers] + responses = await asyncio.gather(*tasks, return_exceptions=True) + + # Aggregate results from Node D only + all_aborted = [] + all_not_found = [] + errors = [] + decode_start = len(prefill_servers) + for i, (server, resp) in enumerate(zip(all_servers, responses)): + if i < decode_start: + continue + if isinstance(resp, Exception): + errors.append({"server": server.url(), "error": str(resp)}) + elif resp.status == 200: + data = await resp.json() + result = data.get("result") or {} + all_aborted.extend(result.get("aborted", [])) + all_not_found.extend(result.get("not_found", [])) + else: + errors.append({"server": server.url(), "status": resp.status}) + + return JSONResponse( + content={ + "request_id": f"router-{uuid4()}", + "status": "success" if not errors else "error", + "error_message": None if not errors else str(errors), + "result": { + "aborted": all_aborted, + "not_found": list(set(all_not_found)), + }, + } + ) + + def launch_router(router_args: RouterArgs): app.state.router_args = router_args print(f"Starting router with args: {router_args}") diff --git a/tests/engine/test_common_engine.py b/tests/engine/test_common_engine.py index f60898fb40f..0dffb8e3cd0 100644 --- a/tests/engine/test_common_engine.py +++ b/tests/engine/test_common_engine.py @@ -23,6 +23,7 @@ from fastdeploy.engine.args_utils import EngineArgs from fastdeploy.engine.common_engine import EngineService +from fastdeploy.engine.request import ControlRequest MODEL_NAME = os.getenv("MODEL_PATH", "/path/to/models") + "/ERNIE-4.5-0.3B-Paddle" @@ -872,3 +873,204 @@ def test_get_scheduler_unhandled_request_num(self): )() self.assertEqual(eng._get_scheduler_unhandled_request_num(), 0) eng.llm_logger.debug.assert_called() + + # ── _control_abort_requests / _wait_abort_complete ─────────────── + + def _make_abort_engine(self, splitwise_role="mixed"): + """Create an engine wired up for abort tests (no real init needed).""" + eng = EngineService.__new__(EngineService) + eng.cfg = MagicMock() + eng.cfg.scheduler_config.splitwise_role = splitwise_role + eng.llm_logger = MagicMock() + + # data_processor with eos token + eng.data_processor = MagicMock() + eng.data_processor.eos_token_ids = [2] + + # resource_manager with requests dict and abort sets + eng.resource_manager = MagicMock() + eng.resource_manager.requests = {} + eng.resource_manager.waiting_abort_req_id_set = set() + eng.resource_manager.to_be_aborted_req_id_set = set() + eng.resource_manager.get_reqs_in_aborting = lambda: ( + eng.resource_manager.waiting_abort_req_id_set | eng.resource_manager.to_be_aborted_req_id_set + ) + + # scheduler with requests dict and put_results + eng.scheduler = MagicMock() + eng.scheduler.requests = {} + eng.scheduler.put_results = MagicMock() + + return eng + + def _make_fake_request(self, output_token_ids=None): + """Create a fake request object for abort tests.""" + req = MagicMock() + req.output_token_ids = output_token_ids or [10, 20, 30] + req.metrics = MagicMock() + req.metrics.arrival_time = 1000.0 + req.metrics.inference_start_time = 1000.1 + req.metrics.engine_recv_first_token_time = 1000.2 + return req + + def test_control_abort_requests_not_v1_raises(self): + """abort_requests raises when ENABLE_V1_KVCACHE_SCHEDULER is off.""" + eng = self._make_abort_engine() + control_req = ControlRequest("ctrl-1", "abort_requests", {"abort_all": True, "req_ids": []}) + with patch("fastdeploy.engine.common_engine.envs.ENABLE_V1_KVCACHE_SCHEDULER", 0): + with self.assertRaises(Exception) as ctx: + eng._control_abort_requests(control_req) + self.assertIn("only supported", str(ctx.exception)) + + def test_control_abort_requests_abort_all(self): + """abort_all=True aborts all requests in resource_manager + scheduler.""" + eng = self._make_abort_engine() + eng.resource_manager.requests = {"req-1_0": self._make_fake_request([10, 20])} + eng.scheduler.requests = {"req-2_0": MagicMock(raw=self._make_fake_request([30]))} + + control_req = ControlRequest("ctrl-1", "abort_requests", {"abort_all": True, "req_ids": []}) + + def clear_abort_sets(req_id): + # Simulate immediate abort completion + eng.resource_manager.waiting_abort_req_id_set.discard(req_id) + + eng.resource_manager.add_abort_req_ids = MagicMock(side_effect=clear_abort_sets) + + with patch("fastdeploy.engine.common_engine.envs.ENABLE_V1_KVCACHE_SCHEDULER", 1): + result = eng._control_abort_requests(control_req) + + self.assertEqual(len(result["aborted"]), 2) + self.assertEqual(result["not_found"], []) + ids = {a["request_id"] for a in result["aborted"]} + self.assertEqual(ids, {"req-1_0", "req-2_0"}) + # put_results should have been called (not prefill) + eng.scheduler.put_results.assert_called_once() + + def test_control_abort_requests_by_req_ids_with_suffix_match(self): + """req_ids match both exact and _0 suffix.""" + eng = self._make_abort_engine() + eng.resource_manager.requests = { + "req-A_0": self._make_fake_request([1, 2, 3]), + "req-B": self._make_fake_request([4, 5]), + } + + control_req = ControlRequest( + "ctrl-1", + "abort_requests", + { + "abort_all": False, + "req_ids": ["req-A", "req-B", "req-C"], + }, + ) + + def clear_abort_sets(req_id): + eng.resource_manager.waiting_abort_req_id_set.discard(req_id) + + eng.resource_manager.add_abort_req_ids = MagicMock(side_effect=clear_abort_sets) + + with patch("fastdeploy.engine.common_engine.envs.ENABLE_V1_KVCACHE_SCHEDULER", 1): + result = eng._control_abort_requests(control_req) + + aborted_ids = {a["request_id"] for a in result["aborted"]} + self.assertIn("req-A_0", aborted_ids) # matched via _0 suffix + self.assertIn("req-B", aborted_ids) # exact match + self.assertEqual(result["not_found"], ["req-C"]) + + def test_control_abort_requests_no_match(self): + """No requests found returns empty aborted and all in not_found.""" + eng = self._make_abort_engine() + control_req = ControlRequest( + "ctrl-1", + "abort_requests", + { + "abort_all": False, + "req_ids": ["nonexistent"], + }, + ) + + with patch("fastdeploy.engine.common_engine.envs.ENABLE_V1_KVCACHE_SCHEDULER", 1): + result = eng._control_abort_requests(control_req) + + self.assertEqual(result["aborted"], []) + self.assertEqual(result["not_found"], ["nonexistent"]) + + def test_control_abort_requests_prefill_skips_wait_and_put(self): + """Prefill role skips _wait_abort_complete and put_results.""" + eng = self._make_abort_engine(splitwise_role="prefill") + eng.resource_manager.requests = {"req-1_0": self._make_fake_request()} + + control_req = ControlRequest("ctrl-1", "abort_requests", {"abort_all": True, "req_ids": []}) + eng.resource_manager.add_abort_req_ids = MagicMock() + + with patch("fastdeploy.engine.common_engine.envs.ENABLE_V1_KVCACHE_SCHEDULER", 1): + result = eng._control_abort_requests(control_req) + + self.assertEqual(len(result["aborted"]), 1) + eng.scheduler.put_results.assert_not_called() + + def test_control_abort_requests_output_token_count(self): + """output_token_count reflects partial_token_ids length.""" + eng = self._make_abort_engine() + eng.resource_manager.requests = {"req-1_0": self._make_fake_request([10, 20, 30, 40, 50])} + + control_req = ControlRequest("ctrl-1", "abort_requests", {"abort_all": True, "req_ids": []}) + + def clear_abort_sets(req_id): + eng.resource_manager.waiting_abort_req_id_set.discard(req_id) + + eng.resource_manager.add_abort_req_ids = MagicMock(side_effect=clear_abort_sets) + + with patch("fastdeploy.engine.common_engine.envs.ENABLE_V1_KVCACHE_SCHEDULER", 1): + result = eng._control_abort_requests(control_req) + + self.assertEqual(result["aborted"][0]["output_token_count"], 5) + + def test_wait_abort_complete_immediate(self): + """_wait_abort_complete returns immediately when all requests already cleaned.""" + eng = self._make_abort_engine() + # Empty abort sets → remaining is empty → returns immediately + eng._wait_abort_complete(["req-1_0"]) + + def test_wait_abort_complete_progress(self): + """_wait_abort_complete exits when background thread cleans up.""" + eng = self._make_abort_engine() + eng.resource_manager.waiting_abort_req_id_set = {"req-1_0"} + + call_count = [0] + + def fake_sleep(s): + call_count[0] += 1 + # Simulate background thread cleaning up after first sleep + eng.resource_manager.waiting_abort_req_id_set.discard("req-1_0") + + with patch("fastdeploy.engine.common_engine.time.sleep", fake_sleep): + eng._wait_abort_complete(["req-1_0"]) + + self.assertGreaterEqual(call_count[0], 1) + + def test_wait_abort_complete_force_cleanup_stuck_in_to_be_aborted(self): + """Stall timeout triggers force cleanup for requests in to_be_aborted_req_id_set.""" + eng = self._make_abort_engine() + eng.resource_manager.to_be_aborted_req_id_set = {"req-1_0"} + + def mock_recycle(req_id): + eng.resource_manager.to_be_aborted_req_id_set.discard(req_id) + + eng.resource_manager.recycle_abort_task = MagicMock(side_effect=mock_recycle) + + # Make time.time() advance past stall_timeout + time_values = [100.0, 100.0, 102.0, 102.0, 102.0] + time_idx = [0] + + def fake_time(): + idx = min(time_idx[0], len(time_values) - 1) + time_idx[0] += 1 + return time_values[idx] + + with ( + patch("fastdeploy.engine.common_engine.time.time", fake_time), + patch("fastdeploy.engine.common_engine.time.sleep", lambda s: None), + ): + eng._wait_abort_complete(["req-1_0"], stall_timeout=1) + + eng.resource_manager.recycle_abort_task.assert_called_with("req-1_0") diff --git a/tests/entrypoints/openai/test_api_server.py b/tests/entrypoints/openai/test_api_server.py index 0cd57421701..8136dd3035c 100644 --- a/tests/entrypoints/openai/test_api_server.py +++ b/tests/entrypoints/openai/test_api_server.py @@ -809,3 +809,80 @@ def test_config_info(): api_server = _reload_api_server(args) api_server.llm_engine = None assert api_server.config_info().status_code == 500 + + +# ── /v1/abort_requests ────────────────────────────────────────────── + + +def _mock_abort_control_response(api_server, result, status_code=200): + mock_resp = MagicMock() + mock_resp.to_api_json_response.return_value = api_server.JSONResponse( + content={"request_id": "control-test", "status": "success", "error_message": None, "result": result}, + status_code=status_code, + ) + api_server.app.state.engine_client = MagicMock() + api_server.app.state.engine_client.run_control_method = AsyncMock(return_value=mock_resp) + + +@pytest.mark.asyncio +async def test_abort_requests_with_req_ids(): + args = _build_args() + api_server = _reload_api_server(args) + _mock_abort_control_response( + api_server, + { + "aborted": [{"request_id": "req-1_0", "output_token_count": 10}], + "not_found": ["req-999"], + }, + ) + req = MagicMock() + req.json = AsyncMock(return_value={"req_ids": ["req-1", "req-999"]}) + resp = await api_server.abort_requests(req) + assert resp.status_code == 200 + control_req = api_server.app.state.engine_client.run_control_method.await_args.args[0] + assert control_req.method == "abort_requests" + assert control_req.args["req_ids"] == ["req-1", "req-999"] + assert control_req.args["abort_all"] is False + + +@pytest.mark.asyncio +async def test_abort_requests_with_abort_all(): + args = _build_args() + api_server = _reload_api_server(args) + _mock_abort_control_response( + api_server, + { + "aborted": [ + {"request_id": "req-1_0", "output_token_count": 5}, + {"request_id": "req-2_0", "output_token_count": 12}, + ], + "not_found": [], + }, + ) + req = MagicMock() + req.json = AsyncMock(return_value={"abort_all": True}) + resp = await api_server.abort_requests(req) + assert resp.status_code == 200 + control_req = api_server.app.state.engine_client.run_control_method.await_args.args[0] + assert control_req.args["abort_all"] is True + assert control_req.args["req_ids"] == [] + + +@pytest.mark.asyncio +async def test_abort_requests_missing_params(): + args = _build_args() + api_server = _reload_api_server(args) + req = MagicMock() + req.json = AsyncMock(return_value={}) + resp = await api_server.abort_requests(req) + assert resp.status_code == 400 + + +@pytest.mark.asyncio +async def test_abort_requests_empty_req_ids(): + args = _build_args() + api_server = _reload_api_server(args) + req = MagicMock() + req.json = AsyncMock(return_value={"req_ids": []}) + resp = await api_server.abort_requests(req) + assert resp.status_code == 400 From 2ef0db387951fa696d69febff30fc78ded17f122 Mon Sep 17 00:00:00 2001 From: qwes5s5 <1522419171@qq.com> Date: Tue, 31 Mar 2026 16:44:28 +0000 Subject: [PATCH 2/2] add finish_reason --- fastdeploy/entrypoints/openai/protocol.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/fastdeploy/entrypoints/openai/protocol.py b/fastdeploy/entrypoints/openai/protocol.py index deaa45a420d..d90145376f9 100644 --- a/fastdeploy/entrypoints/openai/protocol.py +++ b/fastdeploy/entrypoints/openai/protocol.py @@ -268,7 +268,7 @@ class ChatCompletionResponseChoice(BaseModel): logprobs: Optional[LogProbs] = None draft_logprobs: Optional[LogProbs] = None prompt_logprobs: Optional[PromptLogprobs] = None - finish_reason: Optional[Literal["stop", "length", "tool_calls", "recover_stop"]] + finish_reason: Optional[Literal["stop", "length", "tool_calls", "recover_stop", "abort"]] speculate_metrics: Optional[SpeculateMetrics] = None @@ -333,7 +333,7 @@ class ChatCompletionResponseStreamChoice(BaseModel): logprobs: Optional[LogProbs] = None draft_logprobs: Optional[LogProbs] = None prompt_logprobs: Optional[PromptLogprobs] = None - finish_reason: Optional[Literal["stop", "length", "tool_calls", "recover_stop"]] = None + finish_reason: Optional[Literal["stop", "length", "tool_calls", "recover_stop", "abort"]] = None arrival_time: Optional[float] = None speculate_metrics: Optional[SpeculateMetrics] = None @@ -369,7 +369,7 @@ class CompletionResponseChoice(BaseModel): draft_logprobs: Optional[CompletionLogprobs] = None prompt_logprobs: Optional[PromptLogprobs] = None reasoning_content: Optional[str] = None - finish_reason: Optional[Literal["stop", "length", "tool_calls", "recover_stop"]] = None + finish_reason: Optional[Literal["stop", "length", "tool_calls", "recover_stop", "abort"]] = None tool_calls: Optional[List[DeltaToolCall | ToolCall]] = None speculate_metrics: Optional[SpeculateMetrics] = None @@ -415,7 +415,7 @@ class CompletionResponseStreamChoice(BaseModel): prompt_tokens: Optional[str] = None completion_tokens: Optional[str] = None reasoning_content: Optional[str] = None - finish_reason: Optional[Literal["stop", "length", "tool_calls", "recover_stop"]] = None + finish_reason: Optional[Literal["stop", "length", "tool_calls", "recover_stop", "abort"]] = None tool_calls: Optional[List[DeltaToolCall | ToolCall]] = None speculate_metrics: Optional[SpeculateMetrics] = None