Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
61 changes: 55 additions & 6 deletions src/proxy_app/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -724,7 +724,8 @@ async def streaming_response_wrapper(
Wraps a streaming response to log the full response after completion
and ensures any errors during the stream are sent to the client.
"""
response_chunks = []
should_log_stream = logger is not None
response_chunks = [] if should_log_stream else None
full_response = {}

try:
Expand All @@ -733,14 +734,16 @@ async def streaming_response_wrapper(
logging.warning("Client disconnected, stopping stream.")
break
yield chunk_str
if not should_log_stream:
continue
if chunk_str.strip() and chunk_str.startswith("data:"):
content = chunk_str[len("data:") :].strip()
if content != "[DONE]":
try:
chunk_data = json.loads(content)
response_chunks.append(chunk_data)
if logger:
logger.log_stream_chunk(chunk_data)
if response_chunks is not None:
response_chunks.append(chunk_data)
logger.log_stream_chunk(chunk_data)
except json.JSONDecodeError:
pass
except Exception as e:
Expand All @@ -762,7 +765,7 @@ async def streaming_response_wrapper(
)
return # Stop further processing
finally:
if response_chunks:
if should_log_stream and response_chunks:
# --- Aggregation Logic ---
final_message = {"role": "assistant"}
aggregated_tool_calls = {}
Expand Down Expand Up @@ -878,14 +881,55 @@ async def streaming_response_wrapper(
"usage": usage_data,
}

if logger:
if should_log_stream:
logger.log_final_response(
status_code=200,
headers=None, # Headers are not available at this stage
body=full_response,
)


def _inject_iflow_metadata_from_incoming_headers(
request_data: dict[str, Any],
request_headers: dict[str, str],
) -> None:
"""Propagate iFlow-specific routing headers into request metadata."""
model_name = str(request_data.get("model", ""))
provider_name = model_name.split("/", 1)[0].lower() if "/" in model_name else ""
if provider_name != "iflow":
return

metadata = request_data.get("metadata")
if not isinstance(metadata, dict):
metadata = {}

def _pick(*header_names: str) -> str:
for key in header_names:
value = request_headers.get(key)
if value:
return value
return ""

mappings = [
("session_id", ("session-id", "x-litellm-session-id")),
("conversation_id", ("conversation-id", "x-litellm-conversation-id")),
("traceparent", ("traceparent",)),
("iflow_x_biz_info", ("x-biz-info",)),
("iflow_eagleeye_userdata", ("eagleeye-userdata",)),
("iflow_priority", ("priority",)),
]

for metadata_key, header_keys in mappings:
if metadata.get(metadata_key):
continue
picked = _pick(*header_keys)
if picked:
metadata[metadata_key] = picked

if metadata:
request_data["metadata"] = metadata


@app.post("/v1/chat/completions")
async def chat_completions(
request: Request,
Expand Down Expand Up @@ -933,6 +977,11 @@ async def chat_completions(
if raw_logger:
raw_logger.log_request(headers=request.headers, body=request_data)

_inject_iflow_metadata_from_incoming_headers(
request_data=request_data,
request_headers=dict(request.headers),
)

# Extract and log specific reasoning parameters for monitoring.
model = request_data.get("model")
generation_cfg = (
Expand Down
3 changes: 2 additions & 1 deletion src/proxy_app/provider_urls.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
"cohere": "https://api.cohere.ai/v1",
"bedrock": "https://bedrock-runtime.us-east-1.amazonaws.com",
"openrouter": "https://openrouter.ai/api/v1",
"iflow": "https://apis.iflow.cn/v1",
}

def get_provider_endpoint(provider: str, model_name: str, incoming_path: str) -> Optional[str]:
Expand Down Expand Up @@ -73,4 +74,4 @@ def get_provider_endpoint(provider: str, model_name: str, incoming_path: str) ->
return f"{base_url}/{action}"

# Fallback for other cases
return f"{base_url}/v1/{action}"
return f"{base_url}/v1/{action}"
8 changes: 7 additions & 1 deletion src/rotator_library/error_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -935,7 +935,13 @@ def classify_error(e: Exception, provider: Optional[str] = None) -> ClassifiedEr
)

if isinstance(
e, (httpx.TimeoutException, httpx.ConnectError, httpx.NetworkError)
e,
(
httpx.TimeoutException,
httpx.ConnectError,
httpx.NetworkError,
httpx.RemoteProtocolError, # peer closed connection without complete message
),
): # [NEW]
return ClassifiedError(
error_type="api_connection", original_exception=e, status_code=status_code
Expand Down
100 changes: 88 additions & 12 deletions src/rotator_library/providers/iflow_auth_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,8 @@

# Cookie-based authentication endpoint
IFLOW_API_KEY_ENDPOINT = "https://platform.iflow.cn/api/openapi/apikey"
IFLOW_DEFAULT_API_BASE = "https://apis.iflow.cn/v1"
IFLOW_CLI_USER_AGENT = "iFlow-Cli"

# Client credentials provided by iFlow
IFLOW_CLIENT_ID = "10009311001"
Expand Down Expand Up @@ -331,6 +333,32 @@ def __init__(self):
self._refresh_interval_seconds: int = 30 # Delay between queue items
self._refresh_max_retries: int = 3 # Attempts before kicked out

def get_api_base_candidates(self) -> List[str]:
"""Return ordered iFlow API base candidates from environment variables."""
candidates: List[str] = []

single_base = os.getenv("IFLOW_API_BASE", "").strip()
if single_base:
normalized = single_base.rstrip("/")
if normalized:
candidates.append(normalized)

base_list = os.getenv("IFLOW_API_BASES", "").strip()
if base_list:
for item in base_list.split(","):
normalized = item.strip().rstrip("/")
if normalized and normalized not in candidates:
candidates.append(normalized)

if IFLOW_DEFAULT_API_BASE not in candidates:
candidates.append(IFLOW_DEFAULT_API_BASE)

return candidates

def get_api_base(self) -> str:
"""Return the primary iFlow API base URL."""
return self.get_api_base_candidates()[0]

def _parse_env_credential_path(self, path: str) -> Optional[str]:
"""
Parse a virtual env:// path and return the credential index.
Expand Down Expand Up @@ -611,11 +639,21 @@ async def _fetch_user_info(self, access_token: str) -> Dict[str, Any]:
if not access_token or not access_token.strip():
raise ValueError("Access token is empty")

url = f"{IFLOW_USER_INFO_ENDPOINT}?accessToken={access_token}"
headers = {"Accept": "application/json"}
headers = {
"Accept": "*/*",
"accessToken": access_token,
"User-Agent": "node",
"Accept-Language": "*",
"Sec-Fetch-Mode": "cors",
"Accept-Encoding": "br, gzip, deflate",
}

async with httpx.AsyncClient(timeout=30.0) as client:
response = await client.get(url, headers=headers)
response = await client.get(
IFLOW_USER_INFO_ENDPOINT,
headers=headers,
params={"accessToken": access_token},
)
response.raise_for_status()
result = response.json()

Expand Down Expand Up @@ -965,12 +1003,32 @@ async def _refresh_token(self, path: str, force: bool = False) -> Dict[str, Any]
if not force and cached_creds and not self._is_token_expired(cached_creds):
return cached_creds

# [ROTATING TOKEN FIX] Always read fresh from disk before refresh.
# iFlow may use rotating refresh tokens - each refresh could invalidate the previous token.
# If we use a stale cached token, refresh will fail.
# Reading fresh from disk ensures we have the latest token.
await self._read_creds_from_file(path)
creds_from_file = self._credentials_cache[path]
# For file-based credentials, read fresh from disk before refresh.
# For env-loaded credentials, refresh using cached/env values (no file IO).
creds_from_file: Optional[Dict[str, Any]] = None
if cached_creds and cached_creds.get("_proxy_metadata", {}).get(
"loaded_from_env"
):
creds_from_file = cached_creds
elif self._parse_env_credential_path(path) is not None:
credential_index = self._parse_env_credential_path(path)
env_creds = self._load_from_env(credential_index)
if env_creds:
self._credentials_cache[path] = env_creds
creds_from_file = env_creds
else:
raise ValueError(
f"No environment credentials found for iFlow path: {path}"
)
else:
# [ROTATING TOKEN FIX] Always read fresh from disk before refresh.
# iFlow may use rotating refresh tokens - each refresh could invalidate
# the previous token. Fresh disk read keeps us in sync.
await self._read_creds_from_file(path)
creds_from_file = self._credentials_cache[path]

if creds_from_file is None:
raise ValueError(f"No credentials available for iFlow refresh: {path}")

lib_logger.debug(f"Refreshing iFlow OAuth token for '{Path(path).name}'...")
refresh_token = creds_from_file.get("refresh_token")
Expand Down Expand Up @@ -1215,7 +1273,8 @@ async def get_api_details(self, credential_identifier: str) -> Tuple[str, str]:
- API Key: credential_identifier is the API key string itself
"""
# Detect credential type
if os.path.isfile(credential_identifier):
credential_index = self._parse_env_credential_path(credential_identifier)
if credential_index is not None or os.path.isfile(credential_identifier):
creds = await self._load_credentials(credential_identifier)

# Check if this is a cookie-based credential
Expand Down Expand Up @@ -1243,13 +1302,30 @@ async def get_api_details(self, credential_identifier: str) -> Tuple[str, str]:

api_key = creds.get("api_key")
if not api_key:
raise ValueError("Missing api_key in iFlow OAuth credentials")
access_token = creds.get("access_token", "")
if access_token:
try:
user_info = await self._fetch_user_info(access_token)
api_key = user_info.get("api_key", "")
if api_key:
creds["api_key"] = api_key
if credential_index is None:
await self._save_credentials(
credential_identifier,
creds,
)
except Exception as e:
lib_logger.warning(
f"Failed to recover iFlow api_key from OAuth user info: {e}"
)
if not api_key:
raise ValueError("Missing api_key in iFlow OAuth credentials")
else:
# Direct API key: use as-is
lib_logger.debug("Using direct API key for iFlow")
api_key = credential_identifier

base_url = "https://apis.iflow.cn/v1"
base_url = self.get_api_base()
return base_url, api_key

async def proactively_refresh(self, credential_identifier: str):
Expand Down
Loading