基于AutoModel重构jembedding的service.py

This commit is contained in:
jingrow 2025-10-08 15:07:52 +08:00
parent 8fc4a3c603
commit feb598d72a

View File

@ -1,56 +1,103 @@
import asyncio import asyncio
from typing import List, Iterable, AsyncGenerator, Optional from typing import List, Iterable, AsyncGenerator, Optional
from transformers import AutoTokenizer, AutoModelForCausalLM from transformers import AutoTokenizer, AutoModel
import torch import torch
import torch.nn.functional as F
from torch import Tensor
class JEmbeddingService: class JEmbeddingService:
def __init__(self, model_path: str = "Qwen/Qwen3-Embedding-0.6B"): def __init__(self, model_path: str = "Qwen/Qwen3-Embedding-0.6B", max_length: int = 8192):
self.model_path = model_path self.model_path = model_path
self.max_length = max_length
self.tokenizer = None self.tokenizer = None
self.model = None self.model = None
self._load_model() self._load_model()
def _load_model(self) -> None: def _load_model(self) -> None:
self.tokenizer = AutoTokenizer.from_pretrained(self.model_path, trust_remote_code=True) """加载模型和分词器"""
self.model = AutoModelForCausalLM.from_pretrained(self.model_path, trust_remote_code=True) self.tokenizer = AutoTokenizer.from_pretrained(
self.model_path,
trust_remote_code=True,
padding_side='left' # 左填充,适合指令模型
)
self.model = AutoModel.from_pretrained(
self.model_path,
trust_remote_code=True
)
self.model.eval() self.model.eval()
async def embed(self, texts: List[str]) -> List[List[float]]: def _last_token_pool(self, last_hidden_states: Tensor, attention_mask: Tensor) -> Tensor:
"""
提取最后一个有效token的隐藏状态
这是Qwen3-Embedding模型推荐的池化方式
"""
left_padding = (attention_mask[:, -1].sum() == attention_mask.shape[0])
if left_padding:
return last_hidden_states[:, -1]
else:
sequence_lengths = attention_mask.sum(dim=1) - 1
batch_size = last_hidden_states.shape[0]
return last_hidden_states[torch.arange(batch_size, device=last_hidden_states.device), sequence_lengths]
async def embed(self, texts: List[str], normalize: bool = True) -> List[List[float]]:
"""
将文本列表转换为向量表示
Args:
texts: 文本列表
normalize: 是否进行L2归一化默认True
Returns:
向量列表每个向量对应一个输入文本
"""
if not isinstance(texts, list) or any(not isinstance(t, str) for t in texts): if not isinstance(texts, list) or any(not isinstance(t, str) for t in texts):
raise ValueError("texts必须是字符串列表") raise ValueError("texts必须是字符串列表")
def encode_texts(): def encode_texts():
embeddings = [] # 批量分词
for text in texts: batch_dict = self.tokenizer(
# Tokenize texts,
inputs = self.tokenizer(text, return_tensors="pt", padding=True, truncation=True) padding=True,
truncation=True,
max_length=self.max_length,
return_tensors="pt"
)
# Get embeddings from last hidden state # 移动到模型设备
with torch.no_grad(): batch_dict = {k: v.to(self.model.device) for k, v in batch_dict.items()}
outputs = self.model(**inputs, output_hidden_states=True)
embedding = outputs.hidden_states[-1].mean(dim=1).squeeze() # 获取嵌入
embeddings.append(embedding.tolist()) with torch.no_grad():
return embeddings outputs = self.model(**batch_dict)
# 使用last_token_pool提取最后一个有效token的嵌入
embeddings = self._last_token_pool(outputs.last_hidden_state, batch_dict['attention_mask'])
# L2归一化推荐用于相似度计算
if normalize:
embeddings = F.normalize(embeddings, p=2, dim=1)
return embeddings.tolist()
loop = asyncio.get_running_loop() loop = asyncio.get_running_loop()
return await loop.run_in_executor(None, encode_texts) return await loop.run_in_executor(None, encode_texts)
async def similarity(self, embeddings_a: List[List[float]], embeddings_b: List[List[float]]):
def compute_similarity():
import torch.nn.functional as F
# Convert to tensors
emb_a = torch.tensor(embeddings_a)
emb_b = torch.tensor(embeddings_b)
# Compute cosine similarity
return F.cosine_similarity(emb_a.unsqueeze(1), emb_b.unsqueeze(0), dim=2).tolist()
loop = asyncio.get_running_loop() async def process_batch(self, items: Iterable[str], normalize: bool = True) -> AsyncGenerator[dict, None]:
return await loop.run_in_executor(None, compute_similarity) """
批量处理文本生成向量表示
async def process_batch(self, items: Iterable[str]) -> AsyncGenerator[dict, None]: Args:
items: 文本迭代器
normalize: 是否进行L2归一化
Yields:
处理结果字典包含indexstatusembedding或message
"""
texts: List[str] = [] texts: List[str] = []
indices: List[int] = [] indices: List[int] = []
# 收集有效文本
for idx, text in enumerate(items): for idx, text in enumerate(items):
try: try:
if not isinstance(text, str): if not isinstance(text, str):
@ -59,17 +106,19 @@ class JEmbeddingService:
indices.append(idx) indices.append(idx)
except Exception as e: except Exception as e:
yield {"index": idx, "status": "error", "message": str(e)} yield {"index": idx, "status": "error", "message": str(e)}
await asyncio.sleep(0) await asyncio.sleep(0) # 让出控制权
if not texts: if not texts:
return return
try: try:
vectors = await self.embed(texts) # 批量处理所有文本
vectors = await self.embed(texts, normalize=normalize)
for i, vec in zip(indices, vectors): for i, vec in zip(indices, vectors):
yield {"index": i + 1, "status": "success", "embedding": vec} yield {"index": i, "status": "success", "embedding": vec}
except Exception as e: except Exception as e:
# 如果批量处理失败,为每个文本返回错误
for i in indices: for i in indices:
yield {"index": i + 1, "status": "error", "message": str(e)} yield {"index": i, "status": "error", "message": str(e)}