From 8fc4a3c603d1d3d37b27cc9b72be004f8ae36d67 Mon Sep 17 00:00:00 2001 From: jingrow Date: Wed, 8 Oct 2025 06:26:13 +0800 Subject: [PATCH] =?UTF-8?q?jembedding=E6=9B=B4=E6=96=B0=E4=B8=BA=E6=94=AF?= =?UTF-8?q?=E6=8C=81=E7=A6=BB=E7=BA=BF=E5=8A=A0=E8=BD=BD=E6=A8=A1=E5=9E=8B?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- apps/jembedding/app.py | 3 +-- apps/jembedding/service.py | 41 ++++++++++++++++++++++++++++++-------- 2 files changed, 34 insertions(+), 10 deletions(-) diff --git a/apps/jembedding/app.py b/apps/jembedding/app.py index e184a03..7d5b977 100644 --- a/apps/jembedding/app.py +++ b/apps/jembedding/app.py @@ -1,7 +1,6 @@ from fastapi import FastAPI -from api import router from settings import settings - +from api import router app = FastAPI( title="JEmbedding", diff --git a/apps/jembedding/service.py b/apps/jembedding/service.py index 3d2f631..2b3b0df 100644 --- a/apps/jembedding/service.py +++ b/apps/jembedding/service.py @@ -1,27 +1,52 @@ import asyncio from typing import List, Iterable, AsyncGenerator, Optional -from sentence_transformers import SentenceTransformer +from transformers import AutoTokenizer, AutoModelForCausalLM +import torch class JEmbeddingService: - def __init__(self, model_name: str = "Qwen/Qwen3-Embedding-0.6B"): - self.model_name = model_name - self.model: Optional[SentenceTransformer] = None + def __init__(self, model_path: str = "Qwen/Qwen3-Embedding-0.6B"): + self.model_path = model_path + self.tokenizer = None + self.model = None self._load_model() def _load_model(self) -> None: - self.model = SentenceTransformer(self.model_name) + self.tokenizer = AutoTokenizer.from_pretrained(self.model_path, trust_remote_code=True) + self.model = AutoModelForCausalLM.from_pretrained(self.model_path, trust_remote_code=True) + self.model.eval() async def embed(self, texts: List[str]) -> List[List[float]]: if not isinstance(texts, list) or any(not isinstance(t, str) for t in texts): raise ValueError("texts必须是字符串列表") + + def encode_texts(): + embeddings = [] + for text in texts: + # Tokenize + inputs = self.tokenizer(text, return_tensors="pt", padding=True, truncation=True) + + # Get embeddings from last hidden state + with torch.no_grad(): + outputs = self.model(**inputs, output_hidden_states=True) + embedding = outputs.hidden_states[-1].mean(dim=1).squeeze() + embeddings.append(embedding.tolist()) + return embeddings + loop = asyncio.get_running_loop() - embeddings = await loop.run_in_executor(None, self.model.encode, texts) - return [vec.tolist() if hasattr(vec, 'tolist') else vec for vec in embeddings] + return await loop.run_in_executor(None, encode_texts) async def similarity(self, embeddings_a: List[List[float]], embeddings_b: List[List[float]]): + def compute_similarity(): + import torch.nn.functional as F + # Convert to tensors + emb_a = torch.tensor(embeddings_a) + emb_b = torch.tensor(embeddings_b) + # Compute cosine similarity + return F.cosine_similarity(emb_a.unsqueeze(1), emb_b.unsqueeze(0), dim=2).tolist() + loop = asyncio.get_running_loop() - return await loop.run_in_executor(None, self.model.similarity, embeddings_a, embeddings_b) + return await loop.run_in_executor(None, compute_similarity) async def process_batch(self, items: Iterable[str]) -> AsyncGenerator[dict, None]: texts: List[str] = []