Files
WeKnora/internal/application/service/chat_pipline/rerank.go

380 lines
11 KiB
Go

package chatpipline
import (
"context"
"encoding/json"
"fmt"
"math"
"strings"
"github.com/Tencent/WeKnora/internal/models/rerank"
"github.com/Tencent/WeKnora/internal/searchutil"
"github.com/Tencent/WeKnora/internal/types"
"github.com/Tencent/WeKnora/internal/types/interfaces"
)
// PluginRerank implements reranking functionality for chat pipeline
type PluginRerank struct {
modelService interfaces.ModelService // Service to access rerank models
}
// NewPluginRerank creates a new rerank plugin instance
func NewPluginRerank(eventManager *EventManager, modelService interfaces.ModelService) *PluginRerank {
res := &PluginRerank{
modelService: modelService,
}
eventManager.Register(res)
return res
}
// ActivationEvents returns the event types this plugin handles
func (p *PluginRerank) ActivationEvents() []types.EventType {
return []types.EventType{types.CHUNK_RERANK}
}
// OnEvent handles reranking events in the chat pipeline
func (p *PluginRerank) OnEvent(ctx context.Context,
eventType types.EventType, chatManage *types.ChatManage, next func() *PluginError,
) *PluginError {
pipelineInfo(ctx, "Rerank", "input", map[string]interface{}{
"session_id": chatManage.SessionID,
"candidate_cnt": len(chatManage.SearchResult),
"rerank_model": chatManage.RerankModelID,
"rerank_thresh": chatManage.RerankThreshold,
"rewrite_query": chatManage.RewriteQuery,
})
if len(chatManage.SearchResult) == 0 {
pipelineInfo(ctx, "Rerank", "skip", map[string]interface{}{
"reason": "empty_search_result",
})
return next()
}
if chatManage.RerankModelID == "" {
pipelineWarn(ctx, "Rerank", "skip", map[string]interface{}{
"reason": "empty_model_id",
})
return next()
}
// Get rerank model from service
rerankModel, err := p.modelService.GetRerankModel(ctx, chatManage.RerankModelID)
if err != nil {
pipelineError(ctx, "Rerank", "get_model", map[string]interface{}{
"model_id": chatManage.RerankModelID,
"error": err.Error(),
})
return ErrGetRerankModel.WithError(err)
}
// Prepare passages for reranking
pipelineInfo(ctx, "Rerank", "build_passages", map[string]interface{}{
"candidate_cnt": len(chatManage.SearchResult),
})
var passages []string
for _, result := range chatManage.SearchResult {
// 合并Content和ImageInfo的文本内容
passage := getEnrichedPassage(ctx, result)
passages = append(passages, passage)
}
// Single rerank call with RewriteQuery, use threshold degradation if no results
originalThreshold := chatManage.RerankThreshold
rerankResp := p.rerank(ctx, chatManage, rerankModel, chatManage.RewriteQuery, passages)
// If no results and threshold is high enough, try with lower threshold
if len(rerankResp) == 0 && originalThreshold > 0.3 {
degradedThreshold := originalThreshold * 0.7
if degradedThreshold < 0.3 {
degradedThreshold = 0.3
}
pipelineInfo(ctx, "Rerank", "threshold_degrade", map[string]interface{}{
"original": originalThreshold,
"degraded": degradedThreshold,
})
chatManage.RerankThreshold = degradedThreshold
rerankResp = p.rerank(ctx, chatManage, rerankModel, chatManage.RewriteQuery, passages)
// Restore original threshold
chatManage.RerankThreshold = originalThreshold
}
pipelineInfo(ctx, "Rerank", "model_response", map[string]interface{}{
"result_cnt": len(rerankResp),
})
// Log input scores before reranking for debugging
for i, sr := range chatManage.SearchResult {
pipelineInfo(ctx, "Rerank", "input_score", map[string]interface{}{
"index": i,
"chunk_id": sr.ID,
"score": fmt.Sprintf("%.4f", sr.Score),
"match_type": sr.MatchType,
})
}
for i := range chatManage.SearchResult {
chatManage.SearchResult[i].Metadata = ensureMetadata(chatManage.SearchResult[i].Metadata)
}
reranked := make([]*types.SearchResult, 0, len(rerankResp))
for _, rr := range rerankResp {
sr := chatManage.SearchResult[rr.Index]
base := sr.Score
sr.Metadata["base_score"] = fmt.Sprintf("%.4f", base)
modelScore := rr.RelevanceScore
sr.Score = compositeScore(sr, modelScore, base)
pipelineInfo(ctx, "Rerank", "composite_calc", map[string]interface{}{
"chunk_id": sr.ID,
"base_score": fmt.Sprintf("%.4f", base),
"model_score": fmt.Sprintf("%.4f", modelScore),
"final_score": fmt.Sprintf("%.4f", sr.Score),
"match_type": sr.MatchType,
})
reranked = append(reranked, sr)
}
final := applyMMR(ctx, reranked, chatManage, min(len(reranked), max(1, chatManage.RerankTopK)), 0.7)
chatManage.RerankResult = final
// Log composite top scores and MMR selection summary
topN := min(3, len(reranked))
for i := 0; i < topN; i++ {
pipelineInfo(ctx, "Rerank", "composite_top", map[string]interface{}{
"rank": i + 1,
"chunk_id": reranked[i].ID,
"base_score": reranked[i].Metadata["base_score"],
"final_score": fmt.Sprintf("%.4f", reranked[i].Score),
})
}
if len(chatManage.RerankResult) == 0 {
pipelineWarn(ctx, "Rerank", "output", map[string]interface{}{
"filtered_cnt": 0,
})
return ErrSearchNothing
}
pipelineInfo(ctx, "Rerank", "output", map[string]interface{}{
"filtered_cnt": len(chatManage.RerankResult),
})
return next()
}
// rerank performs the actual reranking operation with given query and passages
func (p *PluginRerank) rerank(ctx context.Context,
chatManage *types.ChatManage, rerankModel rerank.Reranker, query string, passages []string,
) []rerank.RankResult {
pipelineInfo(ctx, "Rerank", "model_call", map[string]interface{}{
"query_variant": query,
"passages": len(passages),
})
rerankResp, err := rerankModel.Rerank(ctx, query, passages)
if err != nil {
pipelineError(ctx, "Rerank", "model_call", map[string]interface{}{
"query_variant": query,
"error": err.Error(),
})
return nil
}
// Log top scores for debugging
pipelineInfo(ctx, "Rerank", "threshold", map[string]interface{}{
"threshold": chatManage.RerankThreshold,
})
for i := range min(5, len(rerankResp)) {
pipelineInfo(ctx, "Rerank", "top_score", map[string]interface{}{
"rank": i + 1,
"score": rerankResp[i].RelevanceScore,
"chunk_id": chatManage.SearchResult[rerankResp[i].Index].ID,
"match_type": chatManage.SearchResult[rerankResp[i].Index].MatchType,
"chunk_type": chatManage.SearchResult[rerankResp[i].Index].ChunkType,
"content": chatManage.SearchResult[rerankResp[i].Index].Content,
})
}
// Filter results based on threshold with special handling for history matches
rankFilter := []rerank.RankResult{}
for _, result := range rerankResp {
th := chatManage.RerankThreshold
matchType := chatManage.SearchResult[result.Index].MatchType
if matchType == types.MatchTypeHistory {
th = math.Max(th-0.1, 0.5) // Lower threshold for history matches
}
if result.RelevanceScore > th {
rankFilter = append(rankFilter, result)
}
}
return rankFilter
}
// ensureMetadata ensures the metadata is not nil
func ensureMetadata(m map[string]string) map[string]string {
if m == nil {
return make(map[string]string)
}
return m
}
// compositeScore calculates the composite score for a search result
func compositeScore(sr *types.SearchResult, modelScore, baseScore float64) float64 {
sourceWeight := 1.0
switch strings.ToLower(sr.KnowledgeSource) {
case "web_search":
sourceWeight = 0.95
default:
sourceWeight = 1.0
}
positionPrior := 1.0
if sr.StartAt >= 0 {
positionPrior += searchutil.ClampFloat(1.0-float64(sr.StartAt)/float64(sr.EndAt+1), -0.05, 0.05)
}
composite := 0.6*modelScore + 0.3*baseScore + 0.1*sourceWeight
composite *= positionPrior
if composite < 0 {
composite = 0
}
if composite > 1 {
composite = 1
}
return composite
}
// applyMMR applies the MMR algorithm to the search results with pre-computed token sets
func applyMMR(
ctx context.Context,
results []*types.SearchResult,
chatManage *types.ChatManage,
k int,
lambda float64,
) []*types.SearchResult {
if k <= 0 || len(results) == 0 {
return nil
}
pipelineInfo(ctx, "Rerank", "mmr_start", map[string]interface{}{
"lambda": lambda,
"k": k,
"candidates": len(results),
})
// Pre-compute all token sets upfront (optimization)
allTokenSets := make([]map[string]struct{}, len(results))
for i, r := range results {
allTokenSets[i] = searchutil.TokenizeSimple(getEnrichedPassage(ctx, r))
}
selected := make([]*types.SearchResult, 0, k)
selectedTokenSets := make([]map[string]struct{}, 0, k)
selectedIndices := make(map[int]struct{})
for len(selected) < k && len(selectedIndices) < len(results) {
bestIdx := -1
bestScore := -1.0
for i, r := range results {
if _, isSelected := selectedIndices[i]; isSelected {
continue
}
relevance := r.Score
redundancy := 0.0
// Use pre-computed token sets for redundancy calculation
for _, selTokens := range selectedTokenSets {
sim := searchutil.Jaccard(allTokenSets[i], selTokens)
if sim > redundancy {
redundancy = sim
}
}
mmr := lambda*relevance - (1.0-lambda)*redundancy
if mmr > bestScore {
bestScore = mmr
bestIdx = i
}
}
if bestIdx < 0 {
break
}
selected = append(selected, results[bestIdx])
selectedTokenSets = append(selectedTokenSets, allTokenSets[bestIdx])
selectedIndices[bestIdx] = struct{}{}
}
// Compute average redundancy among selected using pre-computed token sets
avgRed := 0.0
if len(selected) > 1 {
pairs := 0
for i := 0; i < len(selectedTokenSets); i++ {
for j := i + 1; j < len(selectedTokenSets); j++ {
avgRed += searchutil.Jaccard(selectedTokenSets[i], selectedTokenSets[j])
pairs++
}
}
if pairs > 0 {
avgRed /= float64(pairs)
}
}
pipelineInfo(ctx, "Rerank", "mmr_done", map[string]interface{}{
"selected": len(selected),
"avg_redundancy": fmt.Sprintf("%.4f", avgRed),
})
return selected
}
// getEnrichedPassage 合并Content、ImageInfo和GeneratedQuestions的文本内容
func getEnrichedPassage(ctx context.Context, result *types.SearchResult) string {
combinedText := result.Content
var enrichments []string
// 解析ImageInfo
if result.ImageInfo != "" {
var imageInfos []types.ImageInfo
err := json.Unmarshal([]byte(result.ImageInfo), &imageInfos)
if err != nil {
pipelineWarn(ctx, "Rerank", "image_info_parse", map[string]interface{}{
"error": err.Error(),
})
} else {
// 提取所有图片的描述和OCR文本
for _, img := range imageInfos {
if img.Caption != "" {
enrichments = append(enrichments, fmt.Sprintf("图片描述: %s", img.Caption))
}
if img.OCRText != "" {
enrichments = append(enrichments, fmt.Sprintf("图片文本: %s", img.OCRText))
}
}
}
}
// 解析ChunkMetadata中的GeneratedQuestions
if len(result.ChunkMetadata) > 0 {
var docMeta types.DocumentChunkMetadata
err := json.Unmarshal(result.ChunkMetadata, &docMeta)
if err != nil {
pipelineWarn(ctx, "Rerank", "chunk_metadata_parse", map[string]interface{}{
"error": err.Error(),
})
} else if len(docMeta.GeneratedQuestions) > 0 {
enrichments = append(enrichments, fmt.Sprintf("相关问题: %s", strings.Join(docMeta.GeneratedQuestions, "; ")))
}
}
if len(enrichments) == 0 {
return combinedText
}
// 组合内容和增强信息
if combinedText != "" {
combinedText += "\n\n"
}
combinedText += strings.Join(enrichments, "\n")
pipelineInfo(ctx, "Rerank", "passage_enrich", map[string]interface{}{
"content_len": len(result.Content),
"enrichment": strings.Join(enrichments, "\n"),
"enrichment_len": len(strings.Join(enrichments, "\n")),
})
return combinedText
}