Skip to content

Commit b44c70a

Browse files
committed
fix: enforce tier-based generation limits
1 parent 8dcb50e commit b44c70a

4 files changed

Lines changed: 166 additions & 13 deletions

File tree

src/core/account_tiers.py

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,57 @@
1+
"""Account tier and model capability helpers."""
2+
3+
from typing import Optional
4+
5+
6+
PAYGATE_TIER_NOT_PAID = "PAYGATE_TIER_NOT_PAID"
7+
PAYGATE_TIER_ONE = "PAYGATE_TIER_ONE"
8+
PAYGATE_TIER_TWO = "PAYGATE_TIER_TWO"
9+
10+
11+
def normalize_user_paygate_tier(user_paygate_tier: Optional[str]) -> str:
12+
"""Normalize an account tier, defaulting unknown values to free tier."""
13+
normalized = (user_paygate_tier or "").strip()
14+
if normalized in {PAYGATE_TIER_NOT_PAID, PAYGATE_TIER_ONE, PAYGATE_TIER_TWO}:
15+
return normalized
16+
return PAYGATE_TIER_NOT_PAID
17+
18+
19+
def get_paygate_tier_rank(user_paygate_tier: Optional[str]) -> int:
20+
"""Map account tier to a comparable rank."""
21+
normalized = normalize_user_paygate_tier(user_paygate_tier)
22+
if normalized == PAYGATE_TIER_TWO:
23+
return 2
24+
if normalized == PAYGATE_TIER_ONE:
25+
return 1
26+
return 0
27+
28+
29+
def get_paygate_tier_label(user_paygate_tier: Optional[str]) -> str:
30+
"""Return a readable account tier label."""
31+
normalized = normalize_user_paygate_tier(user_paygate_tier)
32+
if normalized == PAYGATE_TIER_TWO:
33+
return "Ult"
34+
if normalized == PAYGATE_TIER_ONE:
35+
return "Pro"
36+
return "Normal"
37+
38+
39+
def get_required_paygate_tier_for_model(model_name: Optional[str]) -> str:
40+
"""Infer the minimum required account tier from a model name."""
41+
normalized = (model_name or "").strip().lower()
42+
if not normalized:
43+
return PAYGATE_TIER_NOT_PAID
44+
45+
if normalized.endswith("-4k") or normalized.endswith("_4k") or "_ultra" in normalized:
46+
return PAYGATE_TIER_TWO
47+
48+
if normalized.endswith("-2k") or normalized.endswith("_1080p"):
49+
return PAYGATE_TIER_ONE
50+
51+
return PAYGATE_TIER_NOT_PAID
52+
53+
54+
def supports_model_for_tier(model_name: Optional[str], user_paygate_tier: Optional[str]) -> bool:
55+
"""Check whether the current account tier can use the given model."""
56+
required_tier = get_required_paygate_tier_for_model(model_name)
57+
return get_paygate_tier_rank(user_paygate_tier) >= get_paygate_tier_rank(required_tier)

src/services/generation_handler.py

Lines changed: 97 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,13 @@
77
from ..core.logger import debug_logger
88
from ..core.config import config
99
from ..core.models import Task, RequestLog
10+
from ..core.account_tiers import (
11+
PAYGATE_TIER_NOT_PAID,
12+
get_paygate_tier_label,
13+
get_required_paygate_tier_for_model,
14+
normalize_user_paygate_tier,
15+
supports_model_for_tier,
16+
)
1017
from .file_cache import FileCache
1118

1219

@@ -680,6 +687,24 @@ def __init__(self, flow_client, token_manager, load_balancer, db, concurrency_ma
680687
self._last_generated_url = None
681688
self._last_generation_assets = None
682689

690+
def _create_generation_result(self) -> Dict[str, Any]:
691+
"""????????????????"""
692+
return dict(success=False, error_message=None, error_emitted=False)
693+
694+
def _mark_generation_failed(self, generation_result: Optional[Dict[str, Any]], error_message: str):
695+
"""????????????????????"""
696+
if isinstance(generation_result, dict):
697+
generation_result["success"] = False
698+
generation_result["error_message"] = error_message
699+
generation_result["error_emitted"] = True
700+
701+
def _mark_generation_succeeded(self, generation_result: Optional[Dict[str, Any]]):
702+
"""???????"""
703+
if isinstance(generation_result, dict):
704+
generation_result["success"] = True
705+
generation_result["error_message"] = None
706+
generation_result["error_emitted"] = False
707+
683708
async def check_token_availability(self, is_image: bool, is_video: bool) -> bool:
684709
"""检查Token可用性
685710
@@ -721,6 +746,7 @@ async def handle_generation(
721746
"model": model,
722747
"status": "processing",
723748
}
749+
generation_result = self._create_generation_result()
724750
self._last_generated_url = None
725751
self._last_generation_assets = None
726752

@@ -818,6 +844,15 @@ async def handle_generation(
818844
# 4. 确保Project存在
819845
debug_logger.log_info(f"[GENERATION] 检查/创建Project...")
820846

847+
if not supports_model_for_tier(model, token.user_paygate_tier):
848+
required_tier = get_required_paygate_tier_for_model(model)
849+
error_msg = "当前模型需要 " + get_paygate_tier_label(required_tier) + " 账号: " + model
850+
debug_logger.log_error(f"[GENERATION] {error_msg}")
851+
if stream:
852+
yield self._create_stream_chunk(f"❌ {error_msg}\n")
853+
yield self._create_error_response(error_msg)
854+
return
855+
821856
ensure_project_started_at = time.time()
822857
project_id = await self.token_manager.ensure_project_exists(token.id)
823858
perf_trace["ensure_project_ms"] = int((time.time() - ensure_project_started_at) * 1000)
@@ -830,6 +865,7 @@ async def handle_generation(
830865
async for chunk in self._handle_image_generation(
831866
token, project_id, model_config, prompt, images, stream,
832867
perf_trace=perf_trace,
868+
generation_result=generation_result,
833869
pending_token_state=pending_token_state
834870
):
835871
yield chunk
@@ -838,12 +874,39 @@ async def handle_generation(
838874
async for chunk in self._handle_video_generation(
839875
token, project_id, model_config, prompt, images, stream,
840876
perf_trace=perf_trace,
877+
generation_result=generation_result,
841878
pending_token_state=pending_token_state
842879
):
843880
yield chunk
844881
perf_trace["generation_pipeline_ms"] = int((time.time() - generation_pipeline_started_at) * 1000)
845882

846883
# 6. 记录使用
884+
if not generation_result.get("success"):
885+
error_msg = generation_result.get("error_message") or "生成未成功完成"
886+
debug_logger.log_warning(f"[GENERATION] 生成未成功,不扣次数: {error_msg}")
887+
if token:
888+
await self.token_manager.record_error(token.id)
889+
duration = time.time() - start_time
890+
perf_trace["status"] = "failed"
891+
perf_trace["total_ms"] = int(duration * 1000)
892+
perf_trace["error"] = error_msg
893+
prompt_for_log = prompt if len(prompt) <= 2000 else f"{prompt[:2000]}...(truncated)"
894+
await self._log_request(
895+
token.id if token else None,
896+
f"generate_{generation_type if model_config else 'unknown'}",
897+
{"model": model, "prompt": prompt_for_log, "has_images": images is not None and len(images) > 0},
898+
{"error": error_msg, "performance": perf_trace},
899+
500,
900+
duration
901+
)
902+
self._last_generated_url = None
903+
self._last_generation_assets = None
904+
if not generation_result.get("error_emitted"):
905+
if stream:
906+
yield self._create_stream_chunk(f"❌ {error_msg}\n")
907+
yield self._create_error_response(error_msg)
908+
return
909+
847910
is_video = (generation_type == "video")
848911
await self.token_manager.record_usage(token.id, is_video=is_video)
849912

@@ -949,6 +1012,7 @@ async def _handle_image_generation(
9491012
images: Optional[List[bytes]],
9501013
stream: bool,
9511014
perf_trace: Optional[Dict[str, Any]] = None,
1015+
generation_result: Optional[Dict[str, Any]] = None,
9521016
pending_token_state: Optional[Dict[str, bool]] = None
9531017
) -> AsyncGenerator:
9541018
"""处理图片生成 (同步返回)"""
@@ -959,6 +1023,8 @@ async def _handle_image_generation(
9591023
image_trace["input_image_count"] = len(images) if images else 0
9601024

9611025
# 不在本地等待图片硬并发槽位;请求一到就直接向上游提交。
1026+
normalized_tier = normalize_user_paygate_tier(token.user_paygate_tier)
1027+
9621028
if image_trace is not None:
9631029
image_trace["slot_wait_ms"] = 0
9641030

@@ -1014,6 +1080,7 @@ async def _handle_image_generation(
10141080
# 提取URL和mediaId
10151081
media = result.get("media", [])
10161082
if not media:
1083+
self._mark_generation_failed(generation_result, "\u751f\u6210\u7ed3\u679c\u4e3a\u7a7a")
10171084
yield self._create_error_response("生成结果为空")
10181085
return
10191086

@@ -1042,7 +1109,7 @@ async def _handle_image_generation(
10421109
project_id=project_id,
10431110
media_id=media_id,
10441111
target_resolution=upsample_resolution,
1045-
user_paygate_tier=token.user_paygate_tier or "PAYGATE_TIER_NOT_PAID",
1112+
user_paygate_tier=normalized_tier,
10461113
session_id=generation_session_id,
10471114
token_id=token.id
10481115
)
@@ -1073,6 +1140,7 @@ async def _handle_image_generation(
10731140
local_url = f"{self._get_base_url()}/tmp/{cached_filename}"
10741141
self._last_generation_assets["upscaled_image"]["local_url"] = local_url
10751142
self._last_generation_assets["upscaled_image"]["url"] = local_url
1143+
self._mark_generation_succeeded(generation_result)
10761144
if stream:
10771145
yield self._create_stream_chunk(f"✅ {resolution_name} 图片缓存成功\n")
10781146
yield self._create_stream_chunk(
@@ -1096,6 +1164,7 @@ async def _handle_image_generation(
10961164
base64_url = f"data:image/jpeg;base64,{encoded_image}"
10971165
self._last_generation_assets["upscaled_image"]["local_url"] = None
10981166
self._last_generation_assets["upscaled_image"]["url"] = base64_url
1167+
self._mark_generation_succeeded(generation_result)
10991168
if stream:
11001169
yield self._create_stream_chunk(
11011170
f"![Generated Image]({base64_url})",
@@ -1165,6 +1234,7 @@ async def _handle_image_generation(
11651234
"origin_image_url": image_url,
11661235
"final_image_url": local_url
11671236
}
1237+
self._mark_generation_succeeded(generation_result)
11681238

11691239
if stream:
11701240
yield self._create_stream_chunk(
@@ -1189,6 +1259,7 @@ async def _handle_video_generation(
11891259
images: Optional[List[bytes]],
11901260
stream: bool,
11911261
perf_trace: Optional[Dict[str, Any]] = None,
1262+
generation_result: Optional[Dict[str, Any]] = None,
11921263
pending_token_state: Optional[Dict[str, bool]] = None
11931264
) -> AsyncGenerator:
11941265
"""处理视频生成 (异步轮询)"""
@@ -1199,6 +1270,8 @@ async def _handle_video_generation(
11991270
video_trace["input_image_count"] = len(images) if images else 0
12001271

12011272
# 不在本地等待视频硬并发槽位;请求一到就直接向上游提交。
1273+
normalized_tier = normalize_user_paygate_tier(token.user_paygate_tier)
1274+
12021275
if video_trace is not None:
12031276
video_trace["slot_wait_ms"] = 0
12041277

@@ -1211,7 +1284,7 @@ async def _handle_video_generation(
12111284

12121285
# 根据账号tier自动调整模型 key
12131286
model_key = model_config["model_key"]
1214-
user_tier = token.user_paygate_tier or "PAYGATE_TIER_ONE"
1287+
user_tier = normalized_tier
12151288

12161289
# TIER_TWO 账号需要使用 ultra 版本的模型
12171290
if user_tier == "PAYGATE_TIER_TWO":
@@ -1269,6 +1342,7 @@ async def _handle_video_generation(
12691342
error_msg = f"❌ 首尾帧模型需要 {min_images}-{max_images} 张图片,当前提供了 {image_count} 张"
12701343
if stream:
12711344
yield self._create_stream_chunk(f"{error_msg}\n")
1345+
self._mark_generation_failed(generation_result, error_msg)
12721346
yield self._create_error_response(error_msg)
12731347
return
12741348

@@ -1278,6 +1352,7 @@ async def _handle_video_generation(
12781352
error_msg = f"❌ 多图视频模型最多支持 {max_images} 张参考图,当前提供了 {image_count} 张"
12791353
if stream:
12801354
yield self._create_stream_chunk(f"{error_msg}\n")
1355+
self._mark_generation_failed(generation_result, error_msg)
12811356
yield self._create_error_response(error_msg)
12821357
return
12831358

@@ -1341,7 +1416,7 @@ async def _handle_video_generation(
13411416
aspect_ratio=model_config["aspect_ratio"],
13421417
start_media_id=start_media_id,
13431418
end_media_id=end_media_id,
1344-
user_paygate_tier=token.user_paygate_tier or "PAYGATE_TIER_ONE",
1419+
user_paygate_tier=normalized_tier,
13451420
token_id=token.id,
13461421
token_video_concurrency=token.video_concurrency,
13471422
)
@@ -1360,7 +1435,7 @@ async def _handle_video_generation(
13601435
model_key=actual_model_key,
13611436
aspect_ratio=model_config["aspect_ratio"],
13621437
start_media_id=start_media_id,
1363-
user_paygate_tier=token.user_paygate_tier or "PAYGATE_TIER_ONE",
1438+
user_paygate_tier=normalized_tier,
13641439
token_id=token.id,
13651440
token_video_concurrency=token.video_concurrency,
13661441
)
@@ -1374,7 +1449,7 @@ async def _handle_video_generation(
13741449
model_key=model_config["model_key"],
13751450
aspect_ratio=model_config["aspect_ratio"],
13761451
reference_images=reference_images,
1377-
user_paygate_tier=token.user_paygate_tier or "PAYGATE_TIER_ONE",
1452+
user_paygate_tier=normalized_tier,
13781453
token_id=token.id,
13791454
token_video_concurrency=token.video_concurrency,
13801455
)
@@ -1387,7 +1462,7 @@ async def _handle_video_generation(
13871462
prompt=prompt,
13881463
model_key=model_config["model_key"],
13891464
aspect_ratio=model_config["aspect_ratio"],
1390-
user_paygate_tier=token.user_paygate_tier or "PAYGATE_TIER_ONE",
1465+
user_paygate_tier=normalized_tier,
13911466
token_id=token.id,
13921467
token_video_concurrency=token.video_concurrency,
13931468
)
@@ -1397,6 +1472,7 @@ async def _handle_video_generation(
13971472
# 获取task_id和operations
13981473
operations = result.get("operations", [])
13991474
if not operations:
1475+
self._mark_generation_failed(generation_result, "\u751f\u6210\u4efb\u52a1\u521b\u5efa\u5931\u8d25")
14001476
yield self._create_error_response("生成任务创建失败")
14011477
return
14021478

@@ -1422,7 +1498,7 @@ async def _handle_video_generation(
14221498
# 检查是否需要放大
14231499
upsample_config = model_config.get("upsample")
14241500

1425-
async for chunk in self._poll_video_result(token, project_id, operations, stream, upsample_config):
1501+
async for chunk in self._poll_video_result(token, project_id, operations, stream, upsample_config, generation_result):
14261502
yield chunk
14271503

14281504
finally:
@@ -1434,7 +1510,8 @@ async def _poll_video_result(
14341510
project_id: str,
14351511
operations: List[Dict],
14361512
stream: bool,
1437-
upsample_config: Optional[Dict] = None
1513+
upsample_config: Optional[Dict] = None,
1514+
generation_result: Optional[Dict[str, Any]] = None
14381515
) -> AsyncGenerator:
14391516
"""轮询视频生成结果
14401517
@@ -1478,6 +1555,7 @@ async def _poll_video_result(
14781555
aspect_ratio = video_info.get("aspectRatio", "VIDEO_ASPECT_RATIO_LANDSCAPE")
14791556

14801557
if not video_url:
1558+
self._mark_generation_failed(generation_result, "\u89c6\u9891URL\u4e3a\u7a7a")
14811559
yield self._create_error_response("视频URL为空")
14821560
return
14831561

@@ -1507,7 +1585,7 @@ async def _poll_video_result(
15071585

15081586
# 递归轮询放大结果(不再放大)
15091587
async for chunk in self._poll_video_result(
1510-
token, project_id, upsample_operations, stream, None
1588+
token, project_id, upsample_operations, stream, None, generation_result
15111589
):
15121590
yield chunk
15131591
return
@@ -1557,6 +1635,8 @@ async def _poll_video_result(
15571635
}
15581636

15591637
# 返回结果
1638+
self._mark_generation_succeeded(generation_result)
1639+
15601640
if stream:
15611641
yield self._create_stream_chunk(
15621642
f"<video src='{local_url}' controls style='max-width:100%'></video>",
@@ -1586,22 +1666,27 @@ async def _poll_video_result(
15861666

15871667
# 返回友好的错误消息,提示用户重试
15881668
friendly_error = f"视频生成失败: {error_message},请重试"
1669+
self._mark_generation_failed(generation_result, friendly_error)
15891670
if stream:
15901671
yield self._create_stream_chunk(f"❌ {friendly_error}\n")
15911672
yield self._create_error_response(friendly_error)
15921673
return
15931674

15941675
elif status.startswith("MEDIA_GENERATION_STATUS_ERROR"):
1595-
# 其他错误状态
1596-
yield self._create_error_response(f"视频生成失败: {status}")
1676+
# ??????
1677+
error_msg = f"\u89c6\u9891\u751f\u6210\u5931\u8d25: {status}"
1678+
self._mark_generation_failed(generation_result, error_msg)
1679+
yield self._create_error_response(error_msg)
15971680
return
15981681

15991682
except Exception as e:
16001683
debug_logger.log_error(f"Poll error: {str(e)}")
16011684
continue
16021685

16031686
# 超时
1604-
yield self._create_error_response(f"视频生成超时 (已轮询{max_attempts}次)")
1687+
error_msg = f"?????? (???{max_attempts}?)"
1688+
self._mark_generation_failed(generation_result, error_msg)
1689+
yield self._create_error_response(error_msg)
16051690

16061691
# ========== 响应格式化 ==========
16071692

0 commit comments

Comments
 (0)