Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
162 changes: 160 additions & 2 deletions app/api/v1/admin.py
Original file line number Diff line number Diff line change
Expand Up @@ -462,6 +462,121 @@ def _normalize_admin_token_item(pool_name: str, item: Any) -> dict | None:
}


TOKEN_PAGE_DEFAULT = 30
TOKEN_PAGE_ALLOWED = {30, 50, 200}
TOKEN_PAGE_ALL_LIMIT = 10000


def _parse_token_page(page: Any) -> int:
try:
n = int(page)
except Exception:
n = 1
return max(1, n)


def _parse_token_per_page(per_page: Any) -> tuple[int, bool]:
v = str(per_page if per_page is not None else "").strip().lower()
if v in ("all", "全部"):
return TOKEN_PAGE_ALL_LIMIT, True
try:
n = int(v or TOKEN_PAGE_DEFAULT)
except Exception:
return TOKEN_PAGE_DEFAULT, False
if n not in TOKEN_PAGE_ALLOWED:
return TOKEN_PAGE_DEFAULT, False
return n, False


def _is_token_invalid(item: dict) -> bool:
return str(item.get("status") or "").strip().lower() in ("invalid", "expired", "disabled")


def _is_token_exhausted(item: dict) -> bool:
status = str(item.get("status") or "").strip().lower()
if status == "cooling":
return True
try:
quota_known = bool(item.get("quota_known"))
quota = int(item.get("quota"))
except Exception:
quota_known = False
quota = -1
if quota_known and quota <= 0:
return True

token_type = str(item.get("token_type") or "sso")
try:
heavy_known = bool(item.get("heavy_quota_known"))
heavy_quota = int(item.get("heavy_quota"))
except Exception:
heavy_known = False
heavy_quota = -1
if token_type == "ssoSuper" and heavy_known and heavy_quota <= 0:
return True
return False


def _is_token_active(item: dict) -> bool:
return (not _is_token_invalid(item)) and (not _is_token_exhausted(item))


def _match_token_status(item: dict, status: str) -> bool:
s = str(status or "").strip().lower()
if not s:
return True
if s in ("invalid", "失效"):
return _is_token_invalid(item)
if s in ("active", "正常"):
return _is_token_active(item)
if s in ("exhausted", "额度耗尽", "limited", "限流中"):
return _is_token_exhausted(item)
if s in ("cooling", "冷却中"):
return str(item.get("status") or "").strip().lower() == "cooling"
if s in ("unused", "未使用"):
try:
quota = int(item.get("quota"))
except Exception:
quota = -2
return quota == -1
return True


def _match_token_nsfw(item: dict, nsfw: str) -> bool:
v = str(nsfw or "").strip().lower()
if not v:
return True
note = str(item.get("note") or "").lower()
has_nsfw = "nsfw" in note
if v in ("1", "true", "yes", "on", "enabled"):
return has_nsfw
if v in ("0", "false", "no", "off", "disabled"):
return not has_nsfw
return True


def _filter_admin_tokens(items: list[dict], *, token_type: str, status: str, nsfw: str, search: str) -> list[dict]:
token_type_norm = str(token_type or "all").strip()
search_norm = str(search or "").strip().lower()

out: list[dict] = []
for item in items:
cur_type = str(item.get("token_type") or "sso")
if token_type_norm in ("sso", "ssoSuper") and cur_type != token_type_norm:
continue
if not _match_token_status(item, status):
continue
if not _match_token_nsfw(item, nsfw):
continue
if search_norm:
token = str(item.get("token") or "").lower()
note = str(item.get("note") or "").lower()
if search_norm not in token and search_norm not in note:
continue
out.append(item)
return out


def _collect_tokens_from_pool_payload(payload: Any) -> list[str]:
if not isinstance(payload, dict):
return []
Expand Down Expand Up @@ -675,21 +790,64 @@ async def get_storage_info():
return {"type": storage_type or "local"}

@router.get("/api/v1/admin/tokens", dependencies=[Depends(verify_api_key)])
async def get_tokens_api():
async def get_tokens_api(
page: int = Query(default=1),
per_page: str = Query(default="30"),
token_type: str = Query(default="all"),
status: str = Query(default=""),
nsfw: str = Query(default=""),
search: str = Query(default=""),
):
"""获取所有 Token"""
storage = get_storage()
tokens = await storage.load_tokens()
data = tokens if isinstance(tokens, dict) else {}
out: dict[str, list[dict]] = {}
normalized_items: list[dict] = []
for pool_name, raw_items in data.items():
arr = raw_items if isinstance(raw_items, list) else []
normalized: list[dict] = []
for item in arr:
obj = _normalize_admin_token_item(pool_name, item)
if obj:
normalized.append(obj)
normalized_items.append({**obj, "pool": str(pool_name)})
out[str(pool_name)] = normalized
return out

current_page = _parse_token_page(page)
page_size, is_all = _parse_token_per_page(per_page)
filtered = _filter_admin_tokens(
normalized_items,
token_type=token_type,
status=status,
nsfw=nsfw,
search=search,
)

total = len(filtered)
pages = max(1, (total + page_size - 1) // page_size)
if current_page > pages:
current_page = pages
start = (current_page - 1) * page_size
end = start + page_size
page_items = filtered[start:end]

page_pools: dict[str, list[dict]] = {"ssoBasic": [], "ssoSuper": []}
for item in page_items:
pool = str(item.get("pool") or "ssoBasic")
obj = dict(item)
obj.pop("pool", None)
page_pools.setdefault(pool, []).append(obj)

return {
"items": page_items,
"total": total,
"page": current_page,
"per_page": "all" if is_all else page_size,
"pages": pages,
"ssoBasic": page_pools.get("ssoBasic", []),
"ssoSuper": page_pools.get("ssoSuper", []),
}

@router.post("/api/v1/admin/tokens", dependencies=[Depends(verify_api_key)])
async def update_tokens_api(data: dict):
Expand Down
66 changes: 60 additions & 6 deletions app/api/v1/image.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
from app.services.grok.processor import ImageCollectProcessor, ImageStreamProcessor
from app.services.quota import enforce_daily_quota
from app.services.request_stats import request_stats
from app.services.token_usage import build_image_usage
from app.services.token import get_token_manager


Expand Down Expand Up @@ -510,6 +511,23 @@ async def _record_request(model_id: str, success: bool):
pass


async def _record_request_with_usage(model_id: str, success: bool, prompt: str, success_count: int = 1):
try:
usage = build_image_usage(prompt, success_count=success_count)
raw = usage.get("_raw") or {}
await request_stats.record_request(
model_id,
success=success,
total_tokens=int(usage.get("total_tokens", 0) or 0),
input_tokens=int(usage.get("input_tokens", 0) or 0),
output_tokens=int(usage.get("output_tokens", 0) or 0),
reasoning_tokens=int(raw.get("reasoning_tokens", 0) or 0),
cached_tokens=int(raw.get("cached_tokens", 0) or 0),
)
except Exception:
pass


async def _get_token_for_model(model_id: str):
"""获取指定模型可用 token,失败时抛出统一异常"""
try:
Expand Down Expand Up @@ -659,7 +677,12 @@ async def _wrapped_experimental_stream():
consume_on_fail=True,
is_usage=True,
)
await _record_request(model_info.model_id, True)
await _record_request_with_usage(
model_info.model_id,
True,
f"Image Generation: {request.prompt}",
success_count=n,
)
else:
await _record_request(model_info.model_id, False)
except Exception:
Expand Down Expand Up @@ -707,7 +730,12 @@ async def _wrapped_stream():
consume_on_fail=True,
is_usage=True,
)
await _record_request(model_info.model_id, True)
await _record_request_with_usage(
model_info.model_id,
True,
f"Image Generation: {request.prompt}",
success_count=n,
)
else:
await _record_request(model_info.model_id, False)
except Exception:
Expand Down Expand Up @@ -766,7 +794,15 @@ async def _wrapped_stream():
consume_on_fail=True,
is_usage=True,
)
await _record_request(model_info.model_id, bool(success))
if success:
await _record_request_with_usage(
model_info.model_id,
True,
f"Image Generation: {request.prompt}",
success_count=n,
)
else:
await _record_request(model_info.model_id, False)
except Exception:
pass

Expand Down Expand Up @@ -919,7 +955,12 @@ async def _wrapped_experimental_stream():
consume_on_fail=True,
is_usage=True,
)
await _record_request(model_info.model_id, True)
await _record_request_with_usage(
model_info.model_id,
True,
f"Image Edit: {edit_request.prompt}",
success_count=n,
)
else:
await _record_request(model_info.model_id, False)
except Exception:
Expand Down Expand Up @@ -970,7 +1011,12 @@ async def _wrapped_stream():
consume_on_fail=True,
is_usage=True,
)
await _record_request(model_info.model_id, True)
await _record_request_with_usage(
model_info.model_id,
True,
f"Image Edit: {edit_request.prompt}",
success_count=n,
)
else:
await _record_request(model_info.model_id, False)
except Exception:
Expand Down Expand Up @@ -1055,7 +1101,15 @@ async def _wrapped_stream():
consume_on_fail=True,
is_usage=True,
)
await _record_request(model_info.model_id, bool(success))
if success:
await _record_request_with_usage(
model_info.model_id,
True,
f"Image Edit: {edit_request.prompt}",
success_count=n,
)
else:
await _record_request(model_info.model_id, False)
except Exception:
pass

Expand Down
33 changes: 29 additions & 4 deletions app/services/grok/chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -532,7 +532,9 @@ async def completions(

# 处理响应
if is_stream:
processor = StreamProcessor(model_name, token, think).process(response)
stream_processor = StreamProcessor(model_name, token, think)
processor = stream_processor.process(response)
prompt_messages = [msg.model_dump() for msg in messages]

async def _wrapped_stream():
completed = False
Expand All @@ -544,19 +546,42 @@ async def _wrapped_stream():
# Only count as "success" when the stream ends naturally.
try:
if completed:
usage = stream_processor.build_usage(prompt_messages)
raw = usage.get("_raw") or {}
await token_mgr.sync_usage(token, model_name, consume_on_fail=True, is_usage=True)
await request_stats.record_request(model_name, success=True)
await request_stats.record_request(
model_name,
success=True,
total_tokens=int(usage.get("total_tokens", 0) or 0),
input_tokens=int(usage.get("prompt_tokens", 0) or 0),
output_tokens=int(usage.get("completion_tokens", 0) or 0),
reasoning_tokens=int(raw.get("reasoning_tokens", 0) or 0),
cached_tokens=int(raw.get("cached_tokens", 0) or 0),
)
else:
await request_stats.record_request(model_name, success=False)
except Exception:
pass

return _wrapped_stream()

result = await CollectProcessor(model_name, token).process(response)
result = await CollectProcessor(model_name, token).process(
response,
prompt_messages=[msg.model_dump() for msg in messages],
)
try:
usage = result.get("usage") or {}
raw = usage.get("_raw") or {}
await token_mgr.sync_usage(token, model_name, consume_on_fail=True, is_usage=True)
await request_stats.record_request(model_name, success=True)
await request_stats.record_request(
model_name,
success=True,
total_tokens=int(usage.get("total_tokens", 0) or 0),
input_tokens=int(usage.get("prompt_tokens", 0) or 0),
output_tokens=int(usage.get("completion_tokens", 0) or 0),
reasoning_tokens=int(raw.get("reasoning_tokens", 0) or 0),
cached_tokens=int(raw.get("cached_tokens", 0) or 0),
)
except Exception:
pass
return result
Expand Down
Loading