51 lines
2.0 KiB
Python
51 lines
2.0 KiB
Python
import asyncio
|
|
from typing import List, Iterable, AsyncGenerator, Optional
|
|
from sentence_transformers import SentenceTransformer
|
|
|
|
|
|
class JEmbeddingService:
|
|
def __init__(self, model_name: str = "Qwen/Qwen3-Embedding-0.6B"):
|
|
self.model_name = model_name
|
|
self.model: Optional[SentenceTransformer] = None
|
|
self._load_model()
|
|
|
|
def _load_model(self) -> None:
|
|
self.model = SentenceTransformer(self.model_name)
|
|
|
|
async def embed(self, texts: List[str]) -> List[List[float]]:
|
|
if not isinstance(texts, list) or any(not isinstance(t, str) for t in texts):
|
|
raise ValueError("texts必须是字符串列表")
|
|
loop = asyncio.get_running_loop()
|
|
embeddings = await loop.run_in_executor(None, self.model.encode, texts)
|
|
return [vec.tolist() if hasattr(vec, 'tolist') else vec for vec in embeddings]
|
|
|
|
async def similarity(self, embeddings_a: List[List[float]], embeddings_b: List[List[float]]):
|
|
loop = asyncio.get_running_loop()
|
|
return await loop.run_in_executor(None, self.model.similarity, embeddings_a, embeddings_b)
|
|
|
|
async def process_batch(self, items: Iterable[str]) -> AsyncGenerator[dict, None]:
|
|
texts: List[str] = []
|
|
indices: List[int] = []
|
|
for idx, text in enumerate(items):
|
|
try:
|
|
if not isinstance(text, str):
|
|
raise ValueError("每个元素必须是字符串")
|
|
texts.append(text)
|
|
indices.append(idx)
|
|
except Exception as e:
|
|
yield {"index": idx, "status": "error", "message": str(e)}
|
|
await asyncio.sleep(0)
|
|
|
|
if not texts:
|
|
return
|
|
|
|
try:
|
|
vectors = await self.embed(texts)
|
|
for i, vec in zip(indices, vectors):
|
|
yield {"index": i + 1, "status": "success", "embedding": vec}
|
|
except Exception as e:
|
|
for i in indices:
|
|
yield {"index": i + 1, "status": "error", "message": str(e)}
|
|
|
|
|