japi/apps/jembedding/service.py

76 lines
2.9 KiB
Python

import asyncio
from typing import List, Iterable, AsyncGenerator, Optional
from transformers import AutoTokenizer, AutoModelForCausalLM
import torch
class JEmbeddingService:
def __init__(self, model_path: str = "Qwen/Qwen3-Embedding-0.6B"):
self.model_path = model_path
self.tokenizer = None
self.model = None
self._load_model()
def _load_model(self) -> None:
self.tokenizer = AutoTokenizer.from_pretrained(self.model_path, trust_remote_code=True)
self.model = AutoModelForCausalLM.from_pretrained(self.model_path, trust_remote_code=True)
self.model.eval()
async def embed(self, texts: List[str]) -> List[List[float]]:
if not isinstance(texts, list) or any(not isinstance(t, str) for t in texts):
raise ValueError("texts必须是字符串列表")
def encode_texts():
embeddings = []
for text in texts:
# Tokenize
inputs = self.tokenizer(text, return_tensors="pt", padding=True, truncation=True)
# Get embeddings from last hidden state
with torch.no_grad():
outputs = self.model(**inputs, output_hidden_states=True)
embedding = outputs.hidden_states[-1].mean(dim=1).squeeze()
embeddings.append(embedding.tolist())
return embeddings
loop = asyncio.get_running_loop()
return await loop.run_in_executor(None, encode_texts)
async def similarity(self, embeddings_a: List[List[float]], embeddings_b: List[List[float]]):
def compute_similarity():
import torch.nn.functional as F
# Convert to tensors
emb_a = torch.tensor(embeddings_a)
emb_b = torch.tensor(embeddings_b)
# Compute cosine similarity
return F.cosine_similarity(emb_a.unsqueeze(1), emb_b.unsqueeze(0), dim=2).tolist()
loop = asyncio.get_running_loop()
return await loop.run_in_executor(None, compute_similarity)
async def process_batch(self, items: Iterable[str]) -> AsyncGenerator[dict, None]:
texts: List[str] = []
indices: List[int] = []
for idx, text in enumerate(items):
try:
if not isinstance(text, str):
raise ValueError("每个元素必须是字符串")
texts.append(text)
indices.append(idx)
except Exception as e:
yield {"index": idx, "status": "error", "message": str(e)}
await asyncio.sleep(0)
if not texts:
return
try:
vectors = await self.embed(texts)
for i, vec in zip(indices, vectors):
yield {"index": i + 1, "status": "success", "embedding": vec}
except Exception as e:
for i in indices:
yield {"index": i + 1, "status": "error", "message": str(e)}