jembedding启用GPU加速
This commit is contained in:
parent
feb598d72a
commit
c27fbb19c2
@ -21,11 +21,25 @@ class JEmbeddingService:
|
||||
trust_remote_code=True,
|
||||
padding_side='left' # 左填充,适合指令模型
|
||||
)
|
||||
|
||||
# 检查GPU可用性并自动选择设备
|
||||
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
print(f"🚀 使用设备: {self.device}")
|
||||
|
||||
self.model = AutoModel.from_pretrained(
|
||||
self.model_path,
|
||||
trust_remote_code=True
|
||||
)
|
||||
|
||||
# 将模型移动到指定设备
|
||||
self.model = self.model.to(self.device)
|
||||
self.model.eval()
|
||||
|
||||
# 打印模型信息
|
||||
if torch.cuda.is_available():
|
||||
gpu_memory = torch.cuda.get_device_properties(0).total_memory / 1024**3
|
||||
print(f"🎯 GPU内存: {gpu_memory:.1f} GB")
|
||||
print(f"📊 模型已加载到GPU: {next(self.model.parameters()).device}")
|
||||
|
||||
def _last_token_pool(self, last_hidden_states: Tensor, attention_mask: Tensor) -> Tensor:
|
||||
"""
|
||||
@ -65,7 +79,7 @@ class JEmbeddingService:
|
||||
)
|
||||
|
||||
# 移动到模型设备
|
||||
batch_dict = {k: v.to(self.model.device) for k, v in batch_dict.items()}
|
||||
batch_dict = {k: v.to(self.device) for k, v in batch_dict.items()}
|
||||
|
||||
# 获取嵌入
|
||||
with torch.no_grad():
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user