Update rerank_server_demo.py

1. Added CUDA debugging support
2. Improved memory management during model inference.
3. Fixed logits output

Test Device: Windows WSL Ubuntu 20.04 (python3.10 CUDA 5090 32G)
This commit is contained in:
Suroy
2026-03-18 15:37:44 +08:00
committed by lyingbug
parent 5506956e87
commit 5c6b710bae

View File

@@ -1,3 +1,4 @@
import gc
import torch
import uvicorn
from fastapi import FastAPI
@@ -5,6 +6,10 @@ from pydantic import BaseModel, Field
from transformers import AutoModelForSequenceClassification, AutoTokenizer
from typing import List
# 使能 CUDA 调试
# import os
# os.environ['CUDA_LAUNCH_BLOCKING']='1'
# --- 1. 定义API的请求和响应数据结构 ---
# 请求体结构保持不变
@@ -53,7 +58,7 @@ except Exception as e:
app = FastAPI(
title="Reranker API (Test Version)",
description="一个返回 'score' 字段以测试Go客户端兼容性的API服务",
version="1.0.1"
version="1.0.2"
)
# --- 4. 定义API端点 ---
@@ -65,8 +70,23 @@ def rerank_endpoint(request: RerankRequest):
pairs = [[request.query, doc] for doc in request.documents]
with torch.no_grad():
inputs = tokenizer(pairs, padding=True, truncation=True, return_tensors='pt', max_length=1024).to(device)
scores = model(**inputs, return_dict=True).logits.view(-1, ).float()
inputs = outputs = logits = None
try:
inputs = tokenizer(pairs, padding=True, truncation=True, return_tensors='pt', max_length=1024).to(device)
outputs = model(**inputs, return_dict=True)
logits = outputs.logits.view(-1, ).float()
scores = torch.sigmoid(logits)
finally:
# 释放 GPU 资源占用
del inputs, outputs, logits
gc.collect()
if torch.cuda.is_available():
torch.cuda.empty_cache()
elif hasattr(torch, "mps") and torch.mps.is_available():
torch.mps.empty_cache()
# --- 修改开始:按照测试用的结构来构建结果 ---
results = []
@@ -99,4 +119,3 @@ def read_root():
# --- 5. 启动服务 ---
if __name__ == "__main__":
uvicorn.run(app, host="0.0.0.0", port=8000)