import asyncio from typing import List, Iterable, AsyncGenerator, Optional from transformers import AutoTokenizer, AutoModel import torch import torch.nn.functional as F from torch import Tensor class JEmbeddingService: def __init__(self, model_path: str = "Qwen/Qwen3-Embedding-0.6B", max_length: int = 8192): self.model_path = model_path self.max_length = max_length self.tokenizer = None self.model = None self._load_model() def _load_model(self) -> None: """加载模型和分词器""" self.tokenizer = AutoTokenizer.from_pretrained( self.model_path, trust_remote_code=True, padding_side='left' # 左填充,适合指令模型 ) self.model = AutoModel.from_pretrained( self.model_path, trust_remote_code=True ) self.model.eval() def _last_token_pool(self, last_hidden_states: Tensor, attention_mask: Tensor) -> Tensor: """ 提取最后一个有效token的隐藏状态 这是Qwen3-Embedding模型推荐的池化方式 """ left_padding = (attention_mask[:, -1].sum() == attention_mask.shape[0]) if left_padding: return last_hidden_states[:, -1] else: sequence_lengths = attention_mask.sum(dim=1) - 1 batch_size = last_hidden_states.shape[0] return last_hidden_states[torch.arange(batch_size, device=last_hidden_states.device), sequence_lengths] async def embed(self, texts: List[str], normalize: bool = True) -> List[List[float]]: """ 将文本列表转换为向量表示 Args: texts: 文本列表 normalize: 是否进行L2归一化,默认True Returns: 向量列表,每个向量对应一个输入文本 """ if not isinstance(texts, list) or any(not isinstance(t, str) for t in texts): raise ValueError("texts必须是字符串列表") def encode_texts(): # 批量分词 batch_dict = self.tokenizer( texts, padding=True, truncation=True, max_length=self.max_length, return_tensors="pt" ) # 移动到模型设备 batch_dict = {k: v.to(self.model.device) for k, v in batch_dict.items()} # 获取嵌入 with torch.no_grad(): outputs = self.model(**batch_dict) # 使用last_token_pool提取最后一个有效token的嵌入 embeddings = self._last_token_pool(outputs.last_hidden_state, batch_dict['attention_mask']) # L2归一化(推荐用于相似度计算) if normalize: embeddings = F.normalize(embeddings, p=2, dim=1) return embeddings.tolist() loop = asyncio.get_running_loop() return await loop.run_in_executor(None, encode_texts) async def process_batch(self, items: Iterable[str], normalize: bool = True) -> AsyncGenerator[dict, None]: """ 批量处理文本,生成向量表示 Args: items: 文本迭代器 normalize: 是否进行L2归一化 Yields: 处理结果字典,包含index、status、embedding或message """ texts: List[str] = [] indices: List[int] = [] # 收集有效文本 for idx, text in enumerate(items): try: if not isinstance(text, str): raise ValueError("每个元素必须是字符串") texts.append(text) indices.append(idx) except Exception as e: yield {"index": idx, "status": "error", "message": str(e)} await asyncio.sleep(0) # 让出控制权 if not texts: return try: # 批量处理所有文本 vectors = await self.embed(texts, normalize=normalize) for i, vec in zip(indices, vectors): yield {"index": i, "status": "success", "embedding": vec} except Exception as e: # 如果批量处理失败,为每个文本返回错误 for i in indices: yield {"index": i, "status": "error", "message": str(e)}