From feb598d72af60559bd502be3232eb2d425ad445d Mon Sep 17 00:00:00 2001 From: jingrow Date: Wed, 8 Oct 2025 15:07:52 +0800 Subject: [PATCH] =?UTF-8?q?=E5=9F=BA=E4=BA=8EAutoModel=E9=87=8D=E6=9E=84je?= =?UTF-8?q?mbedding=E7=9A=84service.py?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- apps/jembedding/service.py | 111 ++++++++++++++++++++++++++----------- 1 file changed, 80 insertions(+), 31 deletions(-) diff --git a/apps/jembedding/service.py b/apps/jembedding/service.py index 2b3b0df..ace5625 100644 --- a/apps/jembedding/service.py +++ b/apps/jembedding/service.py @@ -1,56 +1,103 @@ import asyncio from typing import List, Iterable, AsyncGenerator, Optional -from transformers import AutoTokenizer, AutoModelForCausalLM +from transformers import AutoTokenizer, AutoModel import torch +import torch.nn.functional as F +from torch import Tensor class JEmbeddingService: - def __init__(self, model_path: str = "Qwen/Qwen3-Embedding-0.6B"): + def __init__(self, model_path: str = "Qwen/Qwen3-Embedding-0.6B", max_length: int = 8192): self.model_path = model_path + self.max_length = max_length self.tokenizer = None self.model = None self._load_model() def _load_model(self) -> None: - self.tokenizer = AutoTokenizer.from_pretrained(self.model_path, trust_remote_code=True) - self.model = AutoModelForCausalLM.from_pretrained(self.model_path, trust_remote_code=True) + """加载模型和分词器""" + self.tokenizer = AutoTokenizer.from_pretrained( + self.model_path, + trust_remote_code=True, + padding_side='left' # 左填充,适合指令模型 + ) + self.model = AutoModel.from_pretrained( + self.model_path, + trust_remote_code=True + ) self.model.eval() - async def embed(self, texts: List[str]) -> List[List[float]]: + def _last_token_pool(self, last_hidden_states: Tensor, attention_mask: Tensor) -> Tensor: + """ + 提取最后一个有效token的隐藏状态 + 这是Qwen3-Embedding模型推荐的池化方式 + """ + left_padding = (attention_mask[:, -1].sum() == attention_mask.shape[0]) + if left_padding: + return last_hidden_states[:, -1] + else: + sequence_lengths = attention_mask.sum(dim=1) - 1 + batch_size = last_hidden_states.shape[0] + return last_hidden_states[torch.arange(batch_size, device=last_hidden_states.device), sequence_lengths] + + async def embed(self, texts: List[str], normalize: bool = True) -> List[List[float]]: + """ + 将文本列表转换为向量表示 + + Args: + texts: 文本列表 + normalize: 是否进行L2归一化,默认True + + Returns: + 向量列表,每个向量对应一个输入文本 + """ if not isinstance(texts, list) or any(not isinstance(t, str) for t in texts): raise ValueError("texts必须是字符串列表") def encode_texts(): - embeddings = [] - for text in texts: - # Tokenize - inputs = self.tokenizer(text, return_tensors="pt", padding=True, truncation=True) + # 批量分词 + batch_dict = self.tokenizer( + texts, + padding=True, + truncation=True, + max_length=self.max_length, + return_tensors="pt" + ) + + # 移动到模型设备 + batch_dict = {k: v.to(self.model.device) for k, v in batch_dict.items()} + + # 获取嵌入 + with torch.no_grad(): + outputs = self.model(**batch_dict) + # 使用last_token_pool提取最后一个有效token的嵌入 + embeddings = self._last_token_pool(outputs.last_hidden_state, batch_dict['attention_mask']) - # Get embeddings from last hidden state - with torch.no_grad(): - outputs = self.model(**inputs, output_hidden_states=True) - embedding = outputs.hidden_states[-1].mean(dim=1).squeeze() - embeddings.append(embedding.tolist()) - return embeddings + # L2归一化(推荐用于相似度计算) + if normalize: + embeddings = F.normalize(embeddings, p=2, dim=1) + + return embeddings.tolist() loop = asyncio.get_running_loop() return await loop.run_in_executor(None, encode_texts) - async def similarity(self, embeddings_a: List[List[float]], embeddings_b: List[List[float]]): - def compute_similarity(): - import torch.nn.functional as F - # Convert to tensors - emb_a = torch.tensor(embeddings_a) - emb_b = torch.tensor(embeddings_b) - # Compute cosine similarity - return F.cosine_similarity(emb_a.unsqueeze(1), emb_b.unsqueeze(0), dim=2).tolist() - - loop = asyncio.get_running_loop() - return await loop.run_in_executor(None, compute_similarity) - async def process_batch(self, items: Iterable[str]) -> AsyncGenerator[dict, None]: + async def process_batch(self, items: Iterable[str], normalize: bool = True) -> AsyncGenerator[dict, None]: + """ + 批量处理文本,生成向量表示 + + Args: + items: 文本迭代器 + normalize: 是否进行L2归一化 + + Yields: + 处理结果字典,包含index、status、embedding或message + """ texts: List[str] = [] indices: List[int] = [] + + # 收集有效文本 for idx, text in enumerate(items): try: if not isinstance(text, str): @@ -59,17 +106,19 @@ class JEmbeddingService: indices.append(idx) except Exception as e: yield {"index": idx, "status": "error", "message": str(e)} - await asyncio.sleep(0) + await asyncio.sleep(0) # 让出控制权 if not texts: return try: - vectors = await self.embed(texts) + # 批量处理所有文本 + vectors = await self.embed(texts, normalize=normalize) for i, vec in zip(indices, vectors): - yield {"index": i + 1, "status": "success", "embedding": vec} + yield {"index": i, "status": "success", "embedding": vec} except Exception as e: + # 如果批量处理失败,为每个文本返回错误 for i in indices: - yield {"index": i + 1, "status": "error", "message": str(e)} + yield {"index": i, "status": "error", "message": str(e)}