diff --git a/apps/jchat/api.py b/apps/jchat/api.py index 3c4b671..e0d31c7 100644 --- a/apps/jchat/api.py +++ b/apps/jchat/api.py @@ -1,4 +1,5 @@ from fastapi import APIRouter, HTTPException, Request +from fastapi.responses import StreamingResponse from service import ChatService from utils import jingrow_api_verify_and_billing from settings import settings @@ -57,3 +58,40 @@ async def chat_api(data: dict, request: Request): return result except Exception as e: raise HTTPException(status_code=500, detail=str(e)) + +@router.post(settings.chat_route + "/stream") +@dynamic_billing_wrapper +async def chat_stream_api(data: dict, request: Request): + """流式文本聊天API + + Args: + data: 包含以下字段的字典: + - messages: 消息列表,每个消息包含 role 和 content(必需) + - model: 选择使用的模型(可选,默认为配置的默认模型) + - temperature: 温度参数(可选,默认为0.7) + - top_p: top_p参数(可选,默认为0.9) + - max_tokens: 最大生成token数(可选,默认为2048) + request: FastAPI 请求对象 + + Returns: + SSE格式的流式响应 + """ + if "messages" not in data: + raise HTTPException(status_code=400, detail="缺少messages参数") + + try: + if "model" in data: + service.model = data["model"] + if "temperature" in data: + service.temperature = data["temperature"] + if "top_p" in data: + service.top_p = data["top_p"] + if "max_tokens" in data: + service.max_tokens = data["max_tokens"] + + return StreamingResponse( + service.chat_stream(data["messages"]), + media_type="text/event-stream" + ) + except Exception as e: + raise HTTPException(status_code=500, detail=str(e)) diff --git a/apps/jchat/service.py b/apps/jchat/service.py index 3bccd30..ecba35b 100644 --- a/apps/jchat/service.py +++ b/apps/jchat/service.py @@ -1,7 +1,7 @@ import json -import requests +import httpx import asyncio -from typing import Dict, Optional, List, Union +from typing import Dict, Optional, List, AsyncIterator from settings import settings # 默认模型配置 @@ -221,3 +221,37 @@ class ChatService: "status": "error", "message": f"聊天请求失败: {str(e)}" } + + async def chat_stream(self, messages: List[Dict]) -> AsyncIterator[bytes]: + """流式处理聊天请求,直接yield SSE格式数据 + + Args: + messages: 消息列表,每个消息包含 role 和 content + + Yields: + SSE格式的字节数据 + """ + model_config = self._get_model_config(self.model or default_model) + model_type = model_config["type"] + model_name = model_config["model"] + + payload = self._prepare_payload(messages, model_type, model_name) + payload["stream"] = True + + api_config = self._get_api_config(model_type) + + headers = { + "Content-Type": "application/json", + "Authorization": f"Bearer {api_config['key']}" + } + + async with httpx.AsyncClient(timeout=httpx.Timeout(300.0, connect=10.0)) as client: + async with client.stream("POST", api_config["url"], headers=headers, json=payload) as response: + async for line in response.aiter_lines(): + if not line or not line.startswith("data: "): + continue + data = line[6:].strip() + if data == "[DONE]": + yield b"data: [DONE]\n\n" + break + yield f"data: {data}\n\n".encode()