jembedding更新为支持离线加载模型
This commit is contained in:
parent
fea1df9990
commit
8fc4a3c603
@ -1,7 +1,6 @@
|
|||||||
from fastapi import FastAPI
|
from fastapi import FastAPI
|
||||||
from api import router
|
|
||||||
from settings import settings
|
from settings import settings
|
||||||
|
from api import router
|
||||||
|
|
||||||
app = FastAPI(
|
app = FastAPI(
|
||||||
title="JEmbedding",
|
title="JEmbedding",
|
||||||
|
|||||||
@ -1,27 +1,52 @@
|
|||||||
import asyncio
|
import asyncio
|
||||||
from typing import List, Iterable, AsyncGenerator, Optional
|
from typing import List, Iterable, AsyncGenerator, Optional
|
||||||
from sentence_transformers import SentenceTransformer
|
from transformers import AutoTokenizer, AutoModelForCausalLM
|
||||||
|
import torch
|
||||||
|
|
||||||
|
|
||||||
class JEmbeddingService:
|
class JEmbeddingService:
|
||||||
def __init__(self, model_name: str = "Qwen/Qwen3-Embedding-0.6B"):
|
def __init__(self, model_path: str = "Qwen/Qwen3-Embedding-0.6B"):
|
||||||
self.model_name = model_name
|
self.model_path = model_path
|
||||||
self.model: Optional[SentenceTransformer] = None
|
self.tokenizer = None
|
||||||
|
self.model = None
|
||||||
self._load_model()
|
self._load_model()
|
||||||
|
|
||||||
def _load_model(self) -> None:
|
def _load_model(self) -> None:
|
||||||
self.model = SentenceTransformer(self.model_name)
|
self.tokenizer = AutoTokenizer.from_pretrained(self.model_path, trust_remote_code=True)
|
||||||
|
self.model = AutoModelForCausalLM.from_pretrained(self.model_path, trust_remote_code=True)
|
||||||
|
self.model.eval()
|
||||||
|
|
||||||
async def embed(self, texts: List[str]) -> List[List[float]]:
|
async def embed(self, texts: List[str]) -> List[List[float]]:
|
||||||
if not isinstance(texts, list) or any(not isinstance(t, str) for t in texts):
|
if not isinstance(texts, list) or any(not isinstance(t, str) for t in texts):
|
||||||
raise ValueError("texts必须是字符串列表")
|
raise ValueError("texts必须是字符串列表")
|
||||||
|
|
||||||
|
def encode_texts():
|
||||||
|
embeddings = []
|
||||||
|
for text in texts:
|
||||||
|
# Tokenize
|
||||||
|
inputs = self.tokenizer(text, return_tensors="pt", padding=True, truncation=True)
|
||||||
|
|
||||||
|
# Get embeddings from last hidden state
|
||||||
|
with torch.no_grad():
|
||||||
|
outputs = self.model(**inputs, output_hidden_states=True)
|
||||||
|
embedding = outputs.hidden_states[-1].mean(dim=1).squeeze()
|
||||||
|
embeddings.append(embedding.tolist())
|
||||||
|
return embeddings
|
||||||
|
|
||||||
loop = asyncio.get_running_loop()
|
loop = asyncio.get_running_loop()
|
||||||
embeddings = await loop.run_in_executor(None, self.model.encode, texts)
|
return await loop.run_in_executor(None, encode_texts)
|
||||||
return [vec.tolist() if hasattr(vec, 'tolist') else vec for vec in embeddings]
|
|
||||||
|
|
||||||
async def similarity(self, embeddings_a: List[List[float]], embeddings_b: List[List[float]]):
|
async def similarity(self, embeddings_a: List[List[float]], embeddings_b: List[List[float]]):
|
||||||
|
def compute_similarity():
|
||||||
|
import torch.nn.functional as F
|
||||||
|
# Convert to tensors
|
||||||
|
emb_a = torch.tensor(embeddings_a)
|
||||||
|
emb_b = torch.tensor(embeddings_b)
|
||||||
|
# Compute cosine similarity
|
||||||
|
return F.cosine_similarity(emb_a.unsqueeze(1), emb_b.unsqueeze(0), dim=2).tolist()
|
||||||
|
|
||||||
loop = asyncio.get_running_loop()
|
loop = asyncio.get_running_loop()
|
||||||
return await loop.run_in_executor(None, self.model.similarity, embeddings_a, embeddings_b)
|
return await loop.run_in_executor(None, compute_similarity)
|
||||||
|
|
||||||
async def process_batch(self, items: Iterable[str]) -> AsyncGenerator[dict, None]:
|
async def process_batch(self, items: Iterable[str]) -> AsyncGenerator[dict, None]:
|
||||||
texts: List[str] = []
|
texts: List[str] = []
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user