mirror of
https://github.com/Tencent/WeKnora.git
synced 2026-06-04 13:30:32 +08:00
380 lines
11 KiB
Go
380 lines
11 KiB
Go
package chatpipline
|
|
|
|
import (
|
|
"context"
|
|
"encoding/json"
|
|
"fmt"
|
|
"math"
|
|
"strings"
|
|
|
|
"github.com/Tencent/WeKnora/internal/models/rerank"
|
|
"github.com/Tencent/WeKnora/internal/searchutil"
|
|
"github.com/Tencent/WeKnora/internal/types"
|
|
"github.com/Tencent/WeKnora/internal/types/interfaces"
|
|
)
|
|
|
|
// PluginRerank implements reranking functionality for chat pipeline
|
|
type PluginRerank struct {
|
|
modelService interfaces.ModelService // Service to access rerank models
|
|
}
|
|
|
|
// NewPluginRerank creates a new rerank plugin instance
|
|
func NewPluginRerank(eventManager *EventManager, modelService interfaces.ModelService) *PluginRerank {
|
|
res := &PluginRerank{
|
|
modelService: modelService,
|
|
}
|
|
eventManager.Register(res)
|
|
return res
|
|
}
|
|
|
|
// ActivationEvents returns the event types this plugin handles
|
|
func (p *PluginRerank) ActivationEvents() []types.EventType {
|
|
return []types.EventType{types.CHUNK_RERANK}
|
|
}
|
|
|
|
// OnEvent handles reranking events in the chat pipeline
|
|
func (p *PluginRerank) OnEvent(ctx context.Context,
|
|
eventType types.EventType, chatManage *types.ChatManage, next func() *PluginError,
|
|
) *PluginError {
|
|
pipelineInfo(ctx, "Rerank", "input", map[string]interface{}{
|
|
"session_id": chatManage.SessionID,
|
|
"candidate_cnt": len(chatManage.SearchResult),
|
|
"rerank_model": chatManage.RerankModelID,
|
|
"rerank_thresh": chatManage.RerankThreshold,
|
|
"rewrite_query": chatManage.RewriteQuery,
|
|
})
|
|
if len(chatManage.SearchResult) == 0 {
|
|
pipelineInfo(ctx, "Rerank", "skip", map[string]interface{}{
|
|
"reason": "empty_search_result",
|
|
})
|
|
return next()
|
|
}
|
|
if chatManage.RerankModelID == "" {
|
|
pipelineWarn(ctx, "Rerank", "skip", map[string]interface{}{
|
|
"reason": "empty_model_id",
|
|
})
|
|
return next()
|
|
}
|
|
|
|
// Get rerank model from service
|
|
rerankModel, err := p.modelService.GetRerankModel(ctx, chatManage.RerankModelID)
|
|
if err != nil {
|
|
pipelineError(ctx, "Rerank", "get_model", map[string]interface{}{
|
|
"model_id": chatManage.RerankModelID,
|
|
"error": err.Error(),
|
|
})
|
|
return ErrGetRerankModel.WithError(err)
|
|
}
|
|
|
|
// Prepare passages for reranking
|
|
pipelineInfo(ctx, "Rerank", "build_passages", map[string]interface{}{
|
|
"candidate_cnt": len(chatManage.SearchResult),
|
|
})
|
|
var passages []string
|
|
for _, result := range chatManage.SearchResult {
|
|
// 合并Content和ImageInfo的文本内容
|
|
passage := getEnrichedPassage(ctx, result)
|
|
passages = append(passages, passage)
|
|
}
|
|
|
|
// Single rerank call with RewriteQuery, use threshold degradation if no results
|
|
originalThreshold := chatManage.RerankThreshold
|
|
rerankResp := p.rerank(ctx, chatManage, rerankModel, chatManage.RewriteQuery, passages)
|
|
|
|
// If no results and threshold is high enough, try with lower threshold
|
|
if len(rerankResp) == 0 && originalThreshold > 0.3 {
|
|
degradedThreshold := originalThreshold * 0.7
|
|
if degradedThreshold < 0.3 {
|
|
degradedThreshold = 0.3
|
|
}
|
|
pipelineInfo(ctx, "Rerank", "threshold_degrade", map[string]interface{}{
|
|
"original": originalThreshold,
|
|
"degraded": degradedThreshold,
|
|
})
|
|
chatManage.RerankThreshold = degradedThreshold
|
|
rerankResp = p.rerank(ctx, chatManage, rerankModel, chatManage.RewriteQuery, passages)
|
|
// Restore original threshold
|
|
chatManage.RerankThreshold = originalThreshold
|
|
}
|
|
|
|
pipelineInfo(ctx, "Rerank", "model_response", map[string]interface{}{
|
|
"result_cnt": len(rerankResp),
|
|
})
|
|
|
|
// Log input scores before reranking for debugging
|
|
for i, sr := range chatManage.SearchResult {
|
|
pipelineInfo(ctx, "Rerank", "input_score", map[string]interface{}{
|
|
"index": i,
|
|
"chunk_id": sr.ID,
|
|
"score": fmt.Sprintf("%.4f", sr.Score),
|
|
"match_type": sr.MatchType,
|
|
})
|
|
}
|
|
|
|
for i := range chatManage.SearchResult {
|
|
chatManage.SearchResult[i].Metadata = ensureMetadata(chatManage.SearchResult[i].Metadata)
|
|
}
|
|
reranked := make([]*types.SearchResult, 0, len(rerankResp))
|
|
for _, rr := range rerankResp {
|
|
sr := chatManage.SearchResult[rr.Index]
|
|
base := sr.Score
|
|
sr.Metadata["base_score"] = fmt.Sprintf("%.4f", base)
|
|
modelScore := rr.RelevanceScore
|
|
sr.Score = compositeScore(sr, modelScore, base)
|
|
pipelineInfo(ctx, "Rerank", "composite_calc", map[string]interface{}{
|
|
"chunk_id": sr.ID,
|
|
"base_score": fmt.Sprintf("%.4f", base),
|
|
"model_score": fmt.Sprintf("%.4f", modelScore),
|
|
"final_score": fmt.Sprintf("%.4f", sr.Score),
|
|
"match_type": sr.MatchType,
|
|
})
|
|
reranked = append(reranked, sr)
|
|
}
|
|
final := applyMMR(ctx, reranked, chatManage, min(len(reranked), max(1, chatManage.RerankTopK)), 0.7)
|
|
chatManage.RerankResult = final
|
|
|
|
// Log composite top scores and MMR selection summary
|
|
topN := min(3, len(reranked))
|
|
for i := 0; i < topN; i++ {
|
|
pipelineInfo(ctx, "Rerank", "composite_top", map[string]interface{}{
|
|
"rank": i + 1,
|
|
"chunk_id": reranked[i].ID,
|
|
"base_score": reranked[i].Metadata["base_score"],
|
|
"final_score": fmt.Sprintf("%.4f", reranked[i].Score),
|
|
})
|
|
}
|
|
|
|
if len(chatManage.RerankResult) == 0 {
|
|
pipelineWarn(ctx, "Rerank", "output", map[string]interface{}{
|
|
"filtered_cnt": 0,
|
|
})
|
|
return ErrSearchNothing
|
|
}
|
|
|
|
pipelineInfo(ctx, "Rerank", "output", map[string]interface{}{
|
|
"filtered_cnt": len(chatManage.RerankResult),
|
|
})
|
|
return next()
|
|
}
|
|
|
|
// rerank performs the actual reranking operation with given query and passages
|
|
func (p *PluginRerank) rerank(ctx context.Context,
|
|
chatManage *types.ChatManage, rerankModel rerank.Reranker, query string, passages []string,
|
|
) []rerank.RankResult {
|
|
pipelineInfo(ctx, "Rerank", "model_call", map[string]interface{}{
|
|
"query_variant": query,
|
|
"passages": len(passages),
|
|
})
|
|
rerankResp, err := rerankModel.Rerank(ctx, query, passages)
|
|
if err != nil {
|
|
pipelineError(ctx, "Rerank", "model_call", map[string]interface{}{
|
|
"query_variant": query,
|
|
"error": err.Error(),
|
|
})
|
|
return nil
|
|
}
|
|
|
|
// Log top scores for debugging
|
|
pipelineInfo(ctx, "Rerank", "threshold", map[string]interface{}{
|
|
"threshold": chatManage.RerankThreshold,
|
|
})
|
|
for i := range min(5, len(rerankResp)) {
|
|
pipelineInfo(ctx, "Rerank", "top_score", map[string]interface{}{
|
|
"rank": i + 1,
|
|
"score": rerankResp[i].RelevanceScore,
|
|
"chunk_id": chatManage.SearchResult[rerankResp[i].Index].ID,
|
|
"match_type": chatManage.SearchResult[rerankResp[i].Index].MatchType,
|
|
"chunk_type": chatManage.SearchResult[rerankResp[i].Index].ChunkType,
|
|
"content": chatManage.SearchResult[rerankResp[i].Index].Content,
|
|
})
|
|
}
|
|
|
|
// Filter results based on threshold with special handling for history matches
|
|
rankFilter := []rerank.RankResult{}
|
|
for _, result := range rerankResp {
|
|
th := chatManage.RerankThreshold
|
|
matchType := chatManage.SearchResult[result.Index].MatchType
|
|
if matchType == types.MatchTypeHistory {
|
|
th = math.Max(th-0.1, 0.5) // Lower threshold for history matches
|
|
}
|
|
if result.RelevanceScore > th {
|
|
rankFilter = append(rankFilter, result)
|
|
}
|
|
}
|
|
return rankFilter
|
|
}
|
|
|
|
// ensureMetadata ensures the metadata is not nil
|
|
func ensureMetadata(m map[string]string) map[string]string {
|
|
if m == nil {
|
|
return make(map[string]string)
|
|
}
|
|
return m
|
|
}
|
|
|
|
// compositeScore calculates the composite score for a search result
|
|
func compositeScore(sr *types.SearchResult, modelScore, baseScore float64) float64 {
|
|
sourceWeight := 1.0
|
|
switch strings.ToLower(sr.KnowledgeSource) {
|
|
case "web_search":
|
|
sourceWeight = 0.95
|
|
default:
|
|
sourceWeight = 1.0
|
|
}
|
|
positionPrior := 1.0
|
|
if sr.StartAt >= 0 {
|
|
positionPrior += searchutil.ClampFloat(1.0-float64(sr.StartAt)/float64(sr.EndAt+1), -0.05, 0.05)
|
|
}
|
|
composite := 0.6*modelScore + 0.3*baseScore + 0.1*sourceWeight
|
|
composite *= positionPrior
|
|
if composite < 0 {
|
|
composite = 0
|
|
}
|
|
if composite > 1 {
|
|
composite = 1
|
|
}
|
|
return composite
|
|
}
|
|
|
|
// applyMMR applies the MMR algorithm to the search results with pre-computed token sets
|
|
func applyMMR(
|
|
ctx context.Context,
|
|
results []*types.SearchResult,
|
|
chatManage *types.ChatManage,
|
|
k int,
|
|
lambda float64,
|
|
) []*types.SearchResult {
|
|
if k <= 0 || len(results) == 0 {
|
|
return nil
|
|
}
|
|
pipelineInfo(ctx, "Rerank", "mmr_start", map[string]interface{}{
|
|
"lambda": lambda,
|
|
"k": k,
|
|
"candidates": len(results),
|
|
})
|
|
|
|
// Pre-compute all token sets upfront (optimization)
|
|
allTokenSets := make([]map[string]struct{}, len(results))
|
|
for i, r := range results {
|
|
allTokenSets[i] = searchutil.TokenizeSimple(getEnrichedPassage(ctx, r))
|
|
}
|
|
|
|
selected := make([]*types.SearchResult, 0, k)
|
|
selectedTokenSets := make([]map[string]struct{}, 0, k)
|
|
selectedIndices := make(map[int]struct{})
|
|
|
|
for len(selected) < k && len(selectedIndices) < len(results) {
|
|
bestIdx := -1
|
|
bestScore := -1.0
|
|
|
|
for i, r := range results {
|
|
if _, isSelected := selectedIndices[i]; isSelected {
|
|
continue
|
|
}
|
|
|
|
relevance := r.Score
|
|
redundancy := 0.0
|
|
|
|
// Use pre-computed token sets for redundancy calculation
|
|
for _, selTokens := range selectedTokenSets {
|
|
sim := searchutil.Jaccard(allTokenSets[i], selTokens)
|
|
if sim > redundancy {
|
|
redundancy = sim
|
|
}
|
|
}
|
|
|
|
mmr := lambda*relevance - (1.0-lambda)*redundancy
|
|
if mmr > bestScore {
|
|
bestScore = mmr
|
|
bestIdx = i
|
|
}
|
|
}
|
|
|
|
if bestIdx < 0 {
|
|
break
|
|
}
|
|
|
|
selected = append(selected, results[bestIdx])
|
|
selectedTokenSets = append(selectedTokenSets, allTokenSets[bestIdx])
|
|
selectedIndices[bestIdx] = struct{}{}
|
|
}
|
|
|
|
// Compute average redundancy among selected using pre-computed token sets
|
|
avgRed := 0.0
|
|
if len(selected) > 1 {
|
|
pairs := 0
|
|
for i := 0; i < len(selectedTokenSets); i++ {
|
|
for j := i + 1; j < len(selectedTokenSets); j++ {
|
|
avgRed += searchutil.Jaccard(selectedTokenSets[i], selectedTokenSets[j])
|
|
pairs++
|
|
}
|
|
}
|
|
if pairs > 0 {
|
|
avgRed /= float64(pairs)
|
|
}
|
|
}
|
|
pipelineInfo(ctx, "Rerank", "mmr_done", map[string]interface{}{
|
|
"selected": len(selected),
|
|
"avg_redundancy": fmt.Sprintf("%.4f", avgRed),
|
|
})
|
|
return selected
|
|
}
|
|
|
|
// getEnrichedPassage 合并Content、ImageInfo和GeneratedQuestions的文本内容
|
|
func getEnrichedPassage(ctx context.Context, result *types.SearchResult) string {
|
|
combinedText := result.Content
|
|
var enrichments []string
|
|
|
|
// 解析ImageInfo
|
|
if result.ImageInfo != "" {
|
|
var imageInfos []types.ImageInfo
|
|
err := json.Unmarshal([]byte(result.ImageInfo), &imageInfos)
|
|
if err != nil {
|
|
pipelineWarn(ctx, "Rerank", "image_info_parse", map[string]interface{}{
|
|
"error": err.Error(),
|
|
})
|
|
} else {
|
|
// 提取所有图片的描述和OCR文本
|
|
for _, img := range imageInfos {
|
|
if img.Caption != "" {
|
|
enrichments = append(enrichments, fmt.Sprintf("图片描述: %s", img.Caption))
|
|
}
|
|
if img.OCRText != "" {
|
|
enrichments = append(enrichments, fmt.Sprintf("图片文本: %s", img.OCRText))
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
// 解析ChunkMetadata中的GeneratedQuestions
|
|
if len(result.ChunkMetadata) > 0 {
|
|
var docMeta types.DocumentChunkMetadata
|
|
err := json.Unmarshal(result.ChunkMetadata, &docMeta)
|
|
if err != nil {
|
|
pipelineWarn(ctx, "Rerank", "chunk_metadata_parse", map[string]interface{}{
|
|
"error": err.Error(),
|
|
})
|
|
} else if len(docMeta.GeneratedQuestions) > 0 {
|
|
enrichments = append(enrichments, fmt.Sprintf("相关问题: %s", strings.Join(docMeta.GeneratedQuestions, "; ")))
|
|
}
|
|
}
|
|
|
|
if len(enrichments) == 0 {
|
|
return combinedText
|
|
}
|
|
|
|
// 组合内容和增强信息
|
|
if combinedText != "" {
|
|
combinedText += "\n\n"
|
|
}
|
|
combinedText += strings.Join(enrichments, "\n")
|
|
|
|
pipelineInfo(ctx, "Rerank", "passage_enrich", map[string]interface{}{
|
|
"content_len": len(result.Content),
|
|
"enrichment": strings.Join(enrichments, "\n"),
|
|
"enrichment_len": len(strings.Join(enrichments, "\n")),
|
|
})
|
|
|
|
return combinedText
|
|
}
|