基于AutoModel重构jembedding的service.py

This commit is contained in:
jingrow 2025-10-08 15:07:52 +08:00
parent 8fc4a3c603
commit feb598d72a

View File

@ -1,56 +1,103 @@
import asyncio
from typing import List, Iterable, AsyncGenerator, Optional
from transformers import AutoTokenizer, AutoModelForCausalLM
from transformers import AutoTokenizer, AutoModel
import torch
import torch.nn.functional as F
from torch import Tensor
class JEmbeddingService:
def __init__(self, model_path: str = "Qwen/Qwen3-Embedding-0.6B"):
def __init__(self, model_path: str = "Qwen/Qwen3-Embedding-0.6B", max_length: int = 8192):
self.model_path = model_path
self.max_length = max_length
self.tokenizer = None
self.model = None
self._load_model()
def _load_model(self) -> None:
self.tokenizer = AutoTokenizer.from_pretrained(self.model_path, trust_remote_code=True)
self.model = AutoModelForCausalLM.from_pretrained(self.model_path, trust_remote_code=True)
"""加载模型和分词器"""
self.tokenizer = AutoTokenizer.from_pretrained(
self.model_path,
trust_remote_code=True,
padding_side='left' # 左填充,适合指令模型
)
self.model = AutoModel.from_pretrained(
self.model_path,
trust_remote_code=True
)
self.model.eval()
async def embed(self, texts: List[str]) -> List[List[float]]:
def _last_token_pool(self, last_hidden_states: Tensor, attention_mask: Tensor) -> Tensor:
"""
提取最后一个有效token的隐藏状态
这是Qwen3-Embedding模型推荐的池化方式
"""
left_padding = (attention_mask[:, -1].sum() == attention_mask.shape[0])
if left_padding:
return last_hidden_states[:, -1]
else:
sequence_lengths = attention_mask.sum(dim=1) - 1
batch_size = last_hidden_states.shape[0]
return last_hidden_states[torch.arange(batch_size, device=last_hidden_states.device), sequence_lengths]
async def embed(self, texts: List[str], normalize: bool = True) -> List[List[float]]:
"""
将文本列表转换为向量表示
Args:
texts: 文本列表
normalize: 是否进行L2归一化默认True
Returns:
向量列表每个向量对应一个输入文本
"""
if not isinstance(texts, list) or any(not isinstance(t, str) for t in texts):
raise ValueError("texts必须是字符串列表")
def encode_texts():
embeddings = []
for text in texts:
# Tokenize
inputs = self.tokenizer(text, return_tensors="pt", padding=True, truncation=True)
# 批量分词
batch_dict = self.tokenizer(
texts,
padding=True,
truncation=True,
max_length=self.max_length,
return_tensors="pt"
)
# 移动到模型设备
batch_dict = {k: v.to(self.model.device) for k, v in batch_dict.items()}
# 获取嵌入
with torch.no_grad():
outputs = self.model(**batch_dict)
# 使用last_token_pool提取最后一个有效token的嵌入
embeddings = self._last_token_pool(outputs.last_hidden_state, batch_dict['attention_mask'])
# Get embeddings from last hidden state
with torch.no_grad():
outputs = self.model(**inputs, output_hidden_states=True)
embedding = outputs.hidden_states[-1].mean(dim=1).squeeze()
embeddings.append(embedding.tolist())
return embeddings
# L2归一化推荐用于相似度计算
if normalize:
embeddings = F.normalize(embeddings, p=2, dim=1)
return embeddings.tolist()
loop = asyncio.get_running_loop()
return await loop.run_in_executor(None, encode_texts)
async def similarity(self, embeddings_a: List[List[float]], embeddings_b: List[List[float]]):
def compute_similarity():
import torch.nn.functional as F
# Convert to tensors
emb_a = torch.tensor(embeddings_a)
emb_b = torch.tensor(embeddings_b)
# Compute cosine similarity
return F.cosine_similarity(emb_a.unsqueeze(1), emb_b.unsqueeze(0), dim=2).tolist()
loop = asyncio.get_running_loop()
return await loop.run_in_executor(None, compute_similarity)
async def process_batch(self, items: Iterable[str]) -> AsyncGenerator[dict, None]:
async def process_batch(self, items: Iterable[str], normalize: bool = True) -> AsyncGenerator[dict, None]:
"""
批量处理文本生成向量表示
Args:
items: 文本迭代器
normalize: 是否进行L2归一化
Yields:
处理结果字典包含indexstatusembedding或message
"""
texts: List[str] = []
indices: List[int] = []
# 收集有效文本
for idx, text in enumerate(items):
try:
if not isinstance(text, str):
@ -59,17 +106,19 @@ class JEmbeddingService:
indices.append(idx)
except Exception as e:
yield {"index": idx, "status": "error", "message": str(e)}
await asyncio.sleep(0)
await asyncio.sleep(0) # 让出控制权
if not texts:
return
try:
vectors = await self.embed(texts)
# 批量处理所有文本
vectors = await self.embed(texts, normalize=normalize)
for i, vec in zip(indices, vectors):
yield {"index": i + 1, "status": "success", "embedding": vec}
yield {"index": i, "status": "success", "embedding": vec}
except Exception as e:
# 如果批量处理失败,为每个文本返回错误
for i in indices:
yield {"index": i + 1, "status": "error", "message": str(e)}
yield {"index": i, "status": "error", "message": str(e)}