mirror of
https://github.com/Tencent/WeKnora.git
synced 2026-06-04 13:30:32 +08:00
Update rerank_server_demo.py
1. Added CUDA debugging support 2. Improved memory management during model inference. 3. Fixed logits output Test Device: Windows WSL Ubuntu 20.04 (python3.10 CUDA 5090 32G)
This commit is contained in:
@@ -1,3 +1,4 @@
|
||||
import gc
|
||||
import torch
|
||||
import uvicorn
|
||||
from fastapi import FastAPI
|
||||
@@ -5,6 +6,10 @@ from pydantic import BaseModel, Field
|
||||
from transformers import AutoModelForSequenceClassification, AutoTokenizer
|
||||
from typing import List
|
||||
|
||||
# 使能 CUDA 调试
|
||||
# import os
|
||||
# os.environ['CUDA_LAUNCH_BLOCKING']='1'
|
||||
|
||||
# --- 1. 定义API的请求和响应数据结构 ---
|
||||
|
||||
# 请求体结构保持不变
|
||||
@@ -53,7 +58,7 @@ except Exception as e:
|
||||
app = FastAPI(
|
||||
title="Reranker API (Test Version)",
|
||||
description="一个返回 'score' 字段以测试Go客户端兼容性的API服务",
|
||||
version="1.0.1"
|
||||
version="1.0.2"
|
||||
)
|
||||
|
||||
# --- 4. 定义API端点 ---
|
||||
@@ -65,8 +70,23 @@ def rerank_endpoint(request: RerankRequest):
|
||||
pairs = [[request.query, doc] for doc in request.documents]
|
||||
|
||||
with torch.no_grad():
|
||||
inputs = tokenizer(pairs, padding=True, truncation=True, return_tensors='pt', max_length=1024).to(device)
|
||||
scores = model(**inputs, return_dict=True).logits.view(-1, ).float()
|
||||
inputs = outputs = logits = None
|
||||
|
||||
try:
|
||||
inputs = tokenizer(pairs, padding=True, truncation=True, return_tensors='pt', max_length=1024).to(device)
|
||||
outputs = model(**inputs, return_dict=True)
|
||||
logits = outputs.logits.view(-1, ).float()
|
||||
scores = torch.sigmoid(logits)
|
||||
finally:
|
||||
# 释放 GPU 资源占用
|
||||
del inputs, outputs, logits
|
||||
gc.collect()
|
||||
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.empty_cache()
|
||||
elif hasattr(torch, "mps") and torch.mps.is_available():
|
||||
torch.mps.empty_cache()
|
||||
|
||||
|
||||
# --- 修改开始:按照测试用的结构来构建结果 ---
|
||||
results = []
|
||||
@@ -99,4 +119,3 @@ def read_root():
|
||||
# --- 5. 启动服务 ---
|
||||
if __name__ == "__main__":
|
||||
uvicorn.run(app, host="0.0.0.0", port=8000)
|
||||
|
||||
Reference in New Issue
Block a user