package chatpipline import ( "context" "encoding/json" "fmt" "math" "strings" "github.com/Tencent/WeKnora/internal/models/rerank" "github.com/Tencent/WeKnora/internal/searchutil" "github.com/Tencent/WeKnora/internal/types" "github.com/Tencent/WeKnora/internal/types/interfaces" ) // PluginRerank implements reranking functionality for chat pipeline type PluginRerank struct { modelService interfaces.ModelService // Service to access rerank models } // NewPluginRerank creates a new rerank plugin instance func NewPluginRerank(eventManager *EventManager, modelService interfaces.ModelService) *PluginRerank { res := &PluginRerank{ modelService: modelService, } eventManager.Register(res) return res } // ActivationEvents returns the event types this plugin handles func (p *PluginRerank) ActivationEvents() []types.EventType { return []types.EventType{types.CHUNK_RERANK} } // OnEvent handles reranking events in the chat pipeline func (p *PluginRerank) OnEvent(ctx context.Context, eventType types.EventType, chatManage *types.ChatManage, next func() *PluginError, ) *PluginError { pipelineInfo(ctx, "Rerank", "input", map[string]interface{}{ "session_id": chatManage.SessionID, "candidate_cnt": len(chatManage.SearchResult), "rerank_model": chatManage.RerankModelID, "rerank_thresh": chatManage.RerankThreshold, "rewrite_query": chatManage.RewriteQuery, }) if len(chatManage.SearchResult) == 0 { pipelineInfo(ctx, "Rerank", "skip", map[string]interface{}{ "reason": "empty_search_result", }) return next() } if chatManage.RerankModelID == "" { pipelineWarn(ctx, "Rerank", "skip", map[string]interface{}{ "reason": "empty_model_id", }) return next() } // Get rerank model from service rerankModel, err := p.modelService.GetRerankModel(ctx, chatManage.RerankModelID) if err != nil { pipelineError(ctx, "Rerank", "get_model", map[string]interface{}{ "model_id": chatManage.RerankModelID, "error": err.Error(), }) return ErrGetRerankModel.WithError(err) } // Prepare passages for reranking pipelineInfo(ctx, "Rerank", "build_passages", map[string]interface{}{ "candidate_cnt": len(chatManage.SearchResult), }) var passages []string for _, result := range chatManage.SearchResult { // 合并Content和ImageInfo的文本内容 passage := getEnrichedPassage(ctx, result) passages = append(passages, passage) } // Single rerank call with RewriteQuery, use threshold degradation if no results originalThreshold := chatManage.RerankThreshold rerankResp := p.rerank(ctx, chatManage, rerankModel, chatManage.RewriteQuery, passages) // If no results and threshold is high enough, try with lower threshold if len(rerankResp) == 0 && originalThreshold > 0.3 { degradedThreshold := originalThreshold * 0.7 if degradedThreshold < 0.3 { degradedThreshold = 0.3 } pipelineInfo(ctx, "Rerank", "threshold_degrade", map[string]interface{}{ "original": originalThreshold, "degraded": degradedThreshold, }) chatManage.RerankThreshold = degradedThreshold rerankResp = p.rerank(ctx, chatManage, rerankModel, chatManage.RewriteQuery, passages) // Restore original threshold chatManage.RerankThreshold = originalThreshold } pipelineInfo(ctx, "Rerank", "model_response", map[string]interface{}{ "result_cnt": len(rerankResp), }) // Log input scores before reranking for debugging for i, sr := range chatManage.SearchResult { pipelineInfo(ctx, "Rerank", "input_score", map[string]interface{}{ "index": i, "chunk_id": sr.ID, "score": fmt.Sprintf("%.4f", sr.Score), "match_type": sr.MatchType, }) } for i := range chatManage.SearchResult { chatManage.SearchResult[i].Metadata = ensureMetadata(chatManage.SearchResult[i].Metadata) } reranked := make([]*types.SearchResult, 0, len(rerankResp)) for _, rr := range rerankResp { sr := chatManage.SearchResult[rr.Index] base := sr.Score sr.Metadata["base_score"] = fmt.Sprintf("%.4f", base) modelScore := rr.RelevanceScore sr.Score = compositeScore(sr, modelScore, base) pipelineInfo(ctx, "Rerank", "composite_calc", map[string]interface{}{ "chunk_id": sr.ID, "base_score": fmt.Sprintf("%.4f", base), "model_score": fmt.Sprintf("%.4f", modelScore), "final_score": fmt.Sprintf("%.4f", sr.Score), "match_type": sr.MatchType, }) reranked = append(reranked, sr) } final := applyMMR(ctx, reranked, chatManage, min(len(reranked), max(1, chatManage.RerankTopK)), 0.7) chatManage.RerankResult = final // Log composite top scores and MMR selection summary topN := min(3, len(reranked)) for i := 0; i < topN; i++ { pipelineInfo(ctx, "Rerank", "composite_top", map[string]interface{}{ "rank": i + 1, "chunk_id": reranked[i].ID, "base_score": reranked[i].Metadata["base_score"], "final_score": fmt.Sprintf("%.4f", reranked[i].Score), }) } if len(chatManage.RerankResult) == 0 { pipelineWarn(ctx, "Rerank", "output", map[string]interface{}{ "filtered_cnt": 0, }) return ErrSearchNothing } pipelineInfo(ctx, "Rerank", "output", map[string]interface{}{ "filtered_cnt": len(chatManage.RerankResult), }) return next() } // rerank performs the actual reranking operation with given query and passages func (p *PluginRerank) rerank(ctx context.Context, chatManage *types.ChatManage, rerankModel rerank.Reranker, query string, passages []string, ) []rerank.RankResult { pipelineInfo(ctx, "Rerank", "model_call", map[string]interface{}{ "query_variant": query, "passages": len(passages), }) rerankResp, err := rerankModel.Rerank(ctx, query, passages) if err != nil { pipelineError(ctx, "Rerank", "model_call", map[string]interface{}{ "query_variant": query, "error": err.Error(), }) return nil } // Log top scores for debugging pipelineInfo(ctx, "Rerank", "threshold", map[string]interface{}{ "threshold": chatManage.RerankThreshold, }) for i := range min(5, len(rerankResp)) { pipelineInfo(ctx, "Rerank", "top_score", map[string]interface{}{ "rank": i + 1, "score": rerankResp[i].RelevanceScore, "chunk_id": chatManage.SearchResult[rerankResp[i].Index].ID, "match_type": chatManage.SearchResult[rerankResp[i].Index].MatchType, "chunk_type": chatManage.SearchResult[rerankResp[i].Index].ChunkType, "content": chatManage.SearchResult[rerankResp[i].Index].Content, }) } // Filter results based on threshold with special handling for history matches rankFilter := []rerank.RankResult{} for _, result := range rerankResp { th := chatManage.RerankThreshold matchType := chatManage.SearchResult[result.Index].MatchType if matchType == types.MatchTypeHistory { th = math.Max(th-0.1, 0.5) // Lower threshold for history matches } if result.RelevanceScore > th { rankFilter = append(rankFilter, result) } } return rankFilter } // ensureMetadata ensures the metadata is not nil func ensureMetadata(m map[string]string) map[string]string { if m == nil { return make(map[string]string) } return m } // compositeScore calculates the composite score for a search result func compositeScore(sr *types.SearchResult, modelScore, baseScore float64) float64 { sourceWeight := 1.0 switch strings.ToLower(sr.KnowledgeSource) { case "web_search": sourceWeight = 0.95 default: sourceWeight = 1.0 } positionPrior := 1.0 if sr.StartAt >= 0 { positionPrior += searchutil.ClampFloat(1.0-float64(sr.StartAt)/float64(sr.EndAt+1), -0.05, 0.05) } composite := 0.6*modelScore + 0.3*baseScore + 0.1*sourceWeight composite *= positionPrior if composite < 0 { composite = 0 } if composite > 1 { composite = 1 } return composite } // applyMMR applies the MMR algorithm to the search results with pre-computed token sets func applyMMR( ctx context.Context, results []*types.SearchResult, chatManage *types.ChatManage, k int, lambda float64, ) []*types.SearchResult { if k <= 0 || len(results) == 0 { return nil } pipelineInfo(ctx, "Rerank", "mmr_start", map[string]interface{}{ "lambda": lambda, "k": k, "candidates": len(results), }) // Pre-compute all token sets upfront (optimization) allTokenSets := make([]map[string]struct{}, len(results)) for i, r := range results { allTokenSets[i] = searchutil.TokenizeSimple(getEnrichedPassage(ctx, r)) } selected := make([]*types.SearchResult, 0, k) selectedTokenSets := make([]map[string]struct{}, 0, k) selectedIndices := make(map[int]struct{}) for len(selected) < k && len(selectedIndices) < len(results) { bestIdx := -1 bestScore := -1.0 for i, r := range results { if _, isSelected := selectedIndices[i]; isSelected { continue } relevance := r.Score redundancy := 0.0 // Use pre-computed token sets for redundancy calculation for _, selTokens := range selectedTokenSets { sim := searchutil.Jaccard(allTokenSets[i], selTokens) if sim > redundancy { redundancy = sim } } mmr := lambda*relevance - (1.0-lambda)*redundancy if mmr > bestScore { bestScore = mmr bestIdx = i } } if bestIdx < 0 { break } selected = append(selected, results[bestIdx]) selectedTokenSets = append(selectedTokenSets, allTokenSets[bestIdx]) selectedIndices[bestIdx] = struct{}{} } // Compute average redundancy among selected using pre-computed token sets avgRed := 0.0 if len(selected) > 1 { pairs := 0 for i := 0; i < len(selectedTokenSets); i++ { for j := i + 1; j < len(selectedTokenSets); j++ { avgRed += searchutil.Jaccard(selectedTokenSets[i], selectedTokenSets[j]) pairs++ } } if pairs > 0 { avgRed /= float64(pairs) } } pipelineInfo(ctx, "Rerank", "mmr_done", map[string]interface{}{ "selected": len(selected), "avg_redundancy": fmt.Sprintf("%.4f", avgRed), }) return selected } // getEnrichedPassage 合并Content、ImageInfo和GeneratedQuestions的文本内容 func getEnrichedPassage(ctx context.Context, result *types.SearchResult) string { combinedText := result.Content var enrichments []string // 解析ImageInfo if result.ImageInfo != "" { var imageInfos []types.ImageInfo err := json.Unmarshal([]byte(result.ImageInfo), &imageInfos) if err != nil { pipelineWarn(ctx, "Rerank", "image_info_parse", map[string]interface{}{ "error": err.Error(), }) } else { // 提取所有图片的描述和OCR文本 for _, img := range imageInfos { if img.Caption != "" { enrichments = append(enrichments, fmt.Sprintf("图片描述: %s", img.Caption)) } if img.OCRText != "" { enrichments = append(enrichments, fmt.Sprintf("图片文本: %s", img.OCRText)) } } } } // 解析ChunkMetadata中的GeneratedQuestions if len(result.ChunkMetadata) > 0 { var docMeta types.DocumentChunkMetadata err := json.Unmarshal(result.ChunkMetadata, &docMeta) if err != nil { pipelineWarn(ctx, "Rerank", "chunk_metadata_parse", map[string]interface{}{ "error": err.Error(), }) } else if len(docMeta.GeneratedQuestions) > 0 { enrichments = append(enrichments, fmt.Sprintf("相关问题: %s", strings.Join(docMeta.GeneratedQuestions, "; "))) } } if len(enrichments) == 0 { return combinedText } // 组合内容和增强信息 if combinedText != "" { combinedText += "\n\n" } combinedText += strings.Join(enrichments, "\n") pipelineInfo(ctx, "Rerank", "passage_enrich", map[string]interface{}{ "content_len": len(result.Content), "enrichment": strings.Join(enrichments, "\n"), "enrichment_len": len(strings.Join(enrichments, "\n")), }) return combinedText }