japi/apps/jembedding/service.py

125 lines
4.4 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import asyncio
from typing import List, Iterable, AsyncGenerator, Optional
from transformers import AutoTokenizer, AutoModel
import torch
import torch.nn.functional as F
from torch import Tensor
class JEmbeddingService:
def __init__(self, model_path: str = "Qwen/Qwen3-Embedding-0.6B", max_length: int = 8192):
self.model_path = model_path
self.max_length = max_length
self.tokenizer = None
self.model = None
self._load_model()
def _load_model(self) -> None:
"""加载模型和分词器"""
self.tokenizer = AutoTokenizer.from_pretrained(
self.model_path,
trust_remote_code=True,
padding_side='left' # 左填充,适合指令模型
)
self.model = AutoModel.from_pretrained(
self.model_path,
trust_remote_code=True
)
self.model.eval()
def _last_token_pool(self, last_hidden_states: Tensor, attention_mask: Tensor) -> Tensor:
"""
提取最后一个有效token的隐藏状态
这是Qwen3-Embedding模型推荐的池化方式
"""
left_padding = (attention_mask[:, -1].sum() == attention_mask.shape[0])
if left_padding:
return last_hidden_states[:, -1]
else:
sequence_lengths = attention_mask.sum(dim=1) - 1
batch_size = last_hidden_states.shape[0]
return last_hidden_states[torch.arange(batch_size, device=last_hidden_states.device), sequence_lengths]
async def embed(self, texts: List[str], normalize: bool = True) -> List[List[float]]:
"""
将文本列表转换为向量表示
Args:
texts: 文本列表
normalize: 是否进行L2归一化默认True
Returns:
向量列表,每个向量对应一个输入文本
"""
if not isinstance(texts, list) or any(not isinstance(t, str) for t in texts):
raise ValueError("texts必须是字符串列表")
def encode_texts():
# 批量分词
batch_dict = self.tokenizer(
texts,
padding=True,
truncation=True,
max_length=self.max_length,
return_tensors="pt"
)
# 移动到模型设备
batch_dict = {k: v.to(self.model.device) for k, v in batch_dict.items()}
# 获取嵌入
with torch.no_grad():
outputs = self.model(**batch_dict)
# 使用last_token_pool提取最后一个有效token的嵌入
embeddings = self._last_token_pool(outputs.last_hidden_state, batch_dict['attention_mask'])
# L2归一化推荐用于相似度计算
if normalize:
embeddings = F.normalize(embeddings, p=2, dim=1)
return embeddings.tolist()
loop = asyncio.get_running_loop()
return await loop.run_in_executor(None, encode_texts)
async def process_batch(self, items: Iterable[str], normalize: bool = True) -> AsyncGenerator[dict, None]:
"""
批量处理文本,生成向量表示
Args:
items: 文本迭代器
normalize: 是否进行L2归一化
Yields:
处理结果字典包含index、status、embedding或message
"""
texts: List[str] = []
indices: List[int] = []
# 收集有效文本
for idx, text in enumerate(items):
try:
if not isinstance(text, str):
raise ValueError("每个元素必须是字符串")
texts.append(text)
indices.append(idx)
except Exception as e:
yield {"index": idx, "status": "error", "message": str(e)}
await asyncio.sleep(0) # 让出控制权
if not texts:
return
try:
# 批量处理所有文本
vectors = await self.embed(texts, normalize=normalize)
for i, vec in zip(indices, vectors):
yield {"index": i, "status": "success", "embedding": vec}
except Exception as e:
# 如果批量处理失败,为每个文本返回错误
for i in indices:
yield {"index": i, "status": "error", "message": str(e)}