feat: Implement parallel search functionality combining chunk and entity searches, enhancing retrieval efficiency and result accuracy

This commit is contained in:
wizardchen
2025-12-03 17:45:21 +08:00
parent eafa0cd80b
commit 6b4d17ec70
18 changed files with 2799 additions and 2273 deletions

3595
LICENSE

File diff suppressed because it is too large Load Diff

6
go.mod
View File

@@ -29,7 +29,6 @@ require (
github.com/spf13/viper v1.20.1 github.com/spf13/viper v1.20.1
github.com/stretchr/testify v1.11.1 github.com/stretchr/testify v1.11.1
github.com/tencentyun/cos-go-sdk-v5 v0.7.65 github.com/tencentyun/cos-go-sdk-v5 v0.7.65
github.com/xuri/excelize/v2 v2.10.0
github.com/yanyiwu/gojieba v1.4.5 github.com/yanyiwu/gojieba v1.4.5
go.opentelemetry.io/otel v1.37.0 go.opentelemetry.io/otel v1.37.0
go.opentelemetry.io/otel/exporters/otlp/otlptrace v1.37.0 go.opentelemetry.io/otel/exporters/otlp/otlptrace v1.37.0
@@ -106,8 +105,6 @@ require (
github.com/pelletier/go-toml/v2 v2.2.3 // indirect github.com/pelletier/go-toml/v2 v2.2.3 // indirect
github.com/pierrec/lz4/v4 v4.1.21 // indirect github.com/pierrec/lz4/v4 v4.1.21 // indirect
github.com/pmezard/go-difflib v1.0.0 // indirect github.com/pmezard/go-difflib v1.0.0 // indirect
github.com/richardlehane/mscfb v1.0.4 // indirect
github.com/richardlehane/msoleps v1.0.4 // indirect
github.com/rivo/uniseg v0.4.7 // indirect github.com/rivo/uniseg v0.4.7 // indirect
github.com/robfig/cron/v3 v3.0.1 // indirect github.com/robfig/cron/v3 v3.0.1 // indirect
github.com/rs/xid v1.6.0 // indirect github.com/rs/xid v1.6.0 // indirect
@@ -117,12 +114,9 @@ require (
github.com/spf13/cast v1.10.0 // indirect github.com/spf13/cast v1.10.0 // indirect
github.com/spf13/pflag v1.0.6 // indirect github.com/spf13/pflag v1.0.6 // indirect
github.com/subosito/gotenv v1.6.0 // indirect github.com/subosito/gotenv v1.6.0 // indirect
github.com/tiendc/go-deepcopy v1.7.1 // indirect
github.com/twitchyliquid64/golang-asm v0.15.1 // indirect github.com/twitchyliquid64/golang-asm v0.15.1 // indirect
github.com/ugorji/go/codec v1.2.12 // indirect github.com/ugorji/go/codec v1.2.12 // indirect
github.com/wk8/go-ordered-map/v2 v2.1.8 // indirect github.com/wk8/go-ordered-map/v2 v2.1.8 // indirect
github.com/xuri/efp v0.0.1 // indirect
github.com/xuri/nfp v0.0.2-0.20250530014748-2ddeb826f9a9 // indirect
github.com/yosida95/uritemplate/v3 v3.0.2 // indirect github.com/yosida95/uritemplate/v3 v3.0.2 // indirect
go.opentelemetry.io/auto/sdk v1.1.0 // indirect go.opentelemetry.io/auto/sdk v1.1.0 // indirect
go.opentelemetry.io/otel/metric v1.37.0 // indirect go.opentelemetry.io/otel/metric v1.37.0 // indirect

15
go.sum
View File

@@ -235,11 +235,6 @@ github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZb
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
github.com/redis/go-redis/v9 v9.14.0 h1:u4tNCjXOyzfgeLN+vAZaW1xUooqWDqVEsZN0U01jfAE= github.com/redis/go-redis/v9 v9.14.0 h1:u4tNCjXOyzfgeLN+vAZaW1xUooqWDqVEsZN0U01jfAE=
github.com/redis/go-redis/v9 v9.14.0/go.mod h1:huWgSWd8mW6+m0VPhJjSSQ+d6Nh1VICQ6Q5lHuCH/Iw= github.com/redis/go-redis/v9 v9.14.0/go.mod h1:huWgSWd8mW6+m0VPhJjSSQ+d6Nh1VICQ6Q5lHuCH/Iw=
github.com/richardlehane/mscfb v1.0.4 h1:WULscsljNPConisD5hR0+OyZjwK46Pfyr6mPu5ZawpM=
github.com/richardlehane/mscfb v1.0.4/go.mod h1:YzVpcZg9czvAuhk9T+a3avCpcFPMUWm7gK3DypaEsUk=
github.com/richardlehane/msoleps v1.0.1/go.mod h1:BWev5JBpU9Ko2WAgmZEuiz4/u3ZYTKbjLycmwiWUfWg=
github.com/richardlehane/msoleps v1.0.4 h1:WuESlvhX3gH2IHcd8UqyCuFY5yiq/GR/yqaSM/9/g00=
github.com/richardlehane/msoleps v1.0.4/go.mod h1:BWev5JBpU9Ko2WAgmZEuiz4/u3ZYTKbjLycmwiWUfWg=
github.com/rivo/uniseg v0.2.0/go.mod h1:J6wj4VEh+S6ZtnVlnTBMWIodfgj8LQOQFoIToxlJtxc= github.com/rivo/uniseg v0.2.0/go.mod h1:J6wj4VEh+S6ZtnVlnTBMWIodfgj8LQOQFoIToxlJtxc=
github.com/rivo/uniseg v0.4.7 h1:WUdvkW8uEhrYfLC4ZzdpI2ztxP1I582+49Oc5Mq64VQ= github.com/rivo/uniseg v0.4.7 h1:WUdvkW8uEhrYfLC4ZzdpI2ztxP1I582+49Oc5Mq64VQ=
github.com/rivo/uniseg v0.4.7/go.mod h1:FN3SvrM+Zdj16jyLfmOkMNblXMcoc8DfTHruCPUcx88= github.com/rivo/uniseg v0.4.7/go.mod h1:FN3SvrM+Zdj16jyLfmOkMNblXMcoc8DfTHruCPUcx88=
@@ -284,8 +279,6 @@ github.com/tencentcloud/tencentcloud-sdk-go/tencentcloud/common v1.0.563/go.mod
github.com/tencentcloud/tencentcloud-sdk-go/tencentcloud/kms v1.0.563/go.mod h1:uom4Nvi9W+Qkom0exYiJ9VWJjXwyxtPYTkKkaLMlfE0= github.com/tencentcloud/tencentcloud-sdk-go/tencentcloud/kms v1.0.563/go.mod h1:uom4Nvi9W+Qkom0exYiJ9VWJjXwyxtPYTkKkaLMlfE0=
github.com/tencentyun/cos-go-sdk-v5 v0.7.65 h1:+WBbfwThfZSbxpf1Dw6fyMwyzVtWBBExqfDJ5giiR2s= github.com/tencentyun/cos-go-sdk-v5 v0.7.65 h1:+WBbfwThfZSbxpf1Dw6fyMwyzVtWBBExqfDJ5giiR2s=
github.com/tencentyun/cos-go-sdk-v5 v0.7.65/go.mod h1:8+hG+mQMuRP/OIS9d83syAvXvrMj9HhkND6Q1fLghw0= github.com/tencentyun/cos-go-sdk-v5 v0.7.65/go.mod h1:8+hG+mQMuRP/OIS9d83syAvXvrMj9HhkND6Q1fLghw0=
github.com/tiendc/go-deepcopy v1.7.1 h1:LnubftI6nYaaMOcaz0LphzwraqN8jiWTwm416sitff4=
github.com/tiendc/go-deepcopy v1.7.1/go.mod h1:4bKjNC2r7boYOkD2IOuZpYjmlDdzjbpTRyCx+goBCJQ=
github.com/tmthrgd/go-hex v0.0.0-20190904060850-447a3041c3bc h1:9lRDQMhESg+zvGYmW5DyG0UqvY96Bu5QYsTLvCHdrgo= github.com/tmthrgd/go-hex v0.0.0-20190904060850-447a3041c3bc h1:9lRDQMhESg+zvGYmW5DyG0UqvY96Bu5QYsTLvCHdrgo=
github.com/tmthrgd/go-hex v0.0.0-20190904060850-447a3041c3bc/go.mod h1:bciPuU6GHm1iF1pBvUfxfsH0Wmnc2VbpgvbI9ZWuIRs= github.com/tmthrgd/go-hex v0.0.0-20190904060850-447a3041c3bc/go.mod h1:bciPuU6GHm1iF1pBvUfxfsH0Wmnc2VbpgvbI9ZWuIRs=
github.com/twitchyliquid64/golang-asm v0.15.1 h1:SU5vSMR7hnwNxj24w34ZyCi/FmDZTkS4MhqMhdFk5YI= github.com/twitchyliquid64/golang-asm v0.15.1 h1:SU5vSMR7hnwNxj24w34ZyCi/FmDZTkS4MhqMhdFk5YI=
@@ -310,12 +303,6 @@ github.com/wk8/go-ordered-map/v2 v2.1.8 h1:5h/BUHu93oj4gIdvHHHGsScSTMijfx5PeYkE/
github.com/wk8/go-ordered-map/v2 v2.1.8/go.mod h1:5nJHM5DyteebpVlHnWMV0rPz6Zp7+xBAnxjb1X5vnTw= github.com/wk8/go-ordered-map/v2 v2.1.8/go.mod h1:5nJHM5DyteebpVlHnWMV0rPz6Zp7+xBAnxjb1X5vnTw=
github.com/x448/float16 v0.8.4 h1:qLwI1I70+NjRFUR3zs1JPUCgaCXSh3SW62uAKT1mSBM= github.com/x448/float16 v0.8.4 h1:qLwI1I70+NjRFUR3zs1JPUCgaCXSh3SW62uAKT1mSBM=
github.com/x448/float16 v0.8.4/go.mod h1:14CWIYCyZA/cWjXOioeEpHeN/83MdbZDRQHoFcYsOfg= github.com/x448/float16 v0.8.4/go.mod h1:14CWIYCyZA/cWjXOioeEpHeN/83MdbZDRQHoFcYsOfg=
github.com/xuri/efp v0.0.1 h1:fws5Rv3myXyYni8uwj2qKjVaRP30PdjeYe2Y6FDsCL8=
github.com/xuri/efp v0.0.1/go.mod h1:ybY/Jr0T0GTCnYjKqmdwxyxn2BQf2RcQIIvex5QldPI=
github.com/xuri/excelize/v2 v2.10.0 h1:8aKsP7JD39iKLc6dH5Tw3dgV3sPRh8uRVXu/fMstfW4=
github.com/xuri/excelize/v2 v2.10.0/go.mod h1:SC5TzhQkaOsTWpANfm+7bJCldzcnU/jrhqkTi/iBHBU=
github.com/xuri/nfp v0.0.2-0.20250530014748-2ddeb826f9a9 h1:+C0TIdyyYmzadGaL/HBLbf3WdLgC29pgyhTjAT/0nuE=
github.com/xuri/nfp v0.0.2-0.20250530014748-2ddeb826f9a9/go.mod h1:WwHg+CVyzlv/TX9xqBFXEZAuxOPxn2k1GNHwG41IIUQ=
github.com/yanyiwu/gojieba v1.4.5 h1:VyZogGtdFSnJbACHvDRvDreXPPVPCg8axKFUdblU/JI= github.com/yanyiwu/gojieba v1.4.5 h1:VyZogGtdFSnJbACHvDRvDreXPPVPCg8axKFUdblU/JI=
github.com/yanyiwu/gojieba v1.4.5/go.mod h1:JUq4DddFVGdHXJHxxepxRmhrKlDpaBxR8O28v6fKYLY= github.com/yanyiwu/gojieba v1.4.5/go.mod h1:JUq4DddFVGdHXJHxxepxRmhrKlDpaBxR8O28v6fKYLY=
github.com/yosida95/uritemplate/v3 v3.0.2 h1:Ed3Oyj9yrmi9087+NczuL5BwkIc4wvTb5zIM+UJPGz4= github.com/yosida95/uritemplate/v3 v3.0.2 h1:Ed3Oyj9yrmi9087+NczuL5BwkIc4wvTb5zIM+UJPGz4=
@@ -359,8 +346,6 @@ golang.org/x/crypto v0.23.0/go.mod h1:CKFgDieR+mRhux2Lsu27y0fO304Db0wZe70UKqHu0v
golang.org/x/crypto v0.31.0/go.mod h1:kDsLvtWBEx7MV9tJOj9bnXsPbxwJQ6csT/x4KIN4Ssk= golang.org/x/crypto v0.31.0/go.mod h1:kDsLvtWBEx7MV9tJOj9bnXsPbxwJQ6csT/x4KIN4Ssk=
golang.org/x/crypto v0.43.0 h1:dduJYIi3A3KOfdGOHX8AVZ/jGiyPa3IbBozJ5kNuE04= golang.org/x/crypto v0.43.0 h1:dduJYIi3A3KOfdGOHX8AVZ/jGiyPa3IbBozJ5kNuE04=
golang.org/x/crypto v0.43.0/go.mod h1:BFbav4mRNlXJL4wNeejLpWxB7wMbc79PdRGhWKncxR0= golang.org/x/crypto v0.43.0/go.mod h1:BFbav4mRNlXJL4wNeejLpWxB7wMbc79PdRGhWKncxR0=
golang.org/x/image v0.25.0 h1:Y6uW6rH1y5y/LK1J8BPWZtr6yZ7hrsy6hFrXjgsc2fQ=
golang.org/x/image v0.25.0/go.mod h1:tCAmOEGthTtkalusGp1g3xa2gke8J6c2N565dTyl9Rs=
golang.org/x/mod v0.6.0-dev.0.20220419223038-86c51ed26bb4/go.mod h1:jJ57K6gSWd91VN4djpZkiMVwK6gcyfeH4XE8wZrZaV4= golang.org/x/mod v0.6.0-dev.0.20220419223038-86c51ed26bb4/go.mod h1:jJ57K6gSWd91VN4djpZkiMVwK6gcyfeH4XE8wZrZaV4=
golang.org/x/mod v0.8.0/go.mod h1:iBbtSCu2XBx23ZKBPSOrRkjjQPZFPuis4dIYUhu/chs= golang.org/x/mod v0.8.0/go.mod h1:iBbtSCu2XBx23ZKBPSOrRkjjQPZFPuis4dIYUhu/chs=
golang.org/x/mod v0.12.0/go.mod h1:iBbtSCu2XBx23ZKBPSOrRkjjQPZFPuis4dIYUhu/chs= golang.org/x/mod v0.12.0/go.mod h1:iBbtSCu2XBx23ZKBPSOrRkjjQPZFPuis4dIYUhu/chs=

View File

@@ -8,6 +8,7 @@ import (
"strings" "strings"
"github.com/Tencent/WeKnora/internal/logger" "github.com/Tencent/WeKnora/internal/logger"
"github.com/Tencent/WeKnora/internal/searchutil"
"github.com/Tencent/WeKnora/internal/types" "github.com/Tencent/WeKnora/internal/types"
"gorm.io/gorm" "gorm.io/gorm"
) )
@@ -546,17 +547,7 @@ func (t *GrepChunksTool) deduplicateChunks(ctx context.Context, results []chunkW
// buildContentSignature creates a normalized signature for content to detect near-duplicates // buildContentSignature creates a normalized signature for content to detect near-duplicates
func (t *GrepChunksTool) buildContentSignature(content string) string { func (t *GrepChunksTool) buildContentSignature(content string) string {
c := strings.ToLower(strings.TrimSpace(content)) return searchutil.BuildContentSignature(content)
if c == "" {
return ""
}
// Normalize whitespace
c = strings.Join(strings.Fields(c), " ")
// Use first 128 characters as signature
if len(c) > 128 {
c = c[:128]
}
return c
} }
// scoreChunks calculates match scores for chunks based on pattern matches // scoreChunks calculates match scores for chunks based on pattern matches
@@ -693,36 +684,10 @@ func (t *GrepChunksTool) applyMMR(
// tokenizeSimple tokenizes text into a set of words (simple whitespace-based) // tokenizeSimple tokenizes text into a set of words (simple whitespace-based)
func (t *GrepChunksTool) tokenizeSimple(text string) map[string]struct{} { func (t *GrepChunksTool) tokenizeSimple(text string) map[string]struct{} {
text = strings.ToLower(text) return searchutil.TokenizeSimple(text)
fields := strings.Fields(text)
set := make(map[string]struct{}, len(fields))
for _, f := range fields {
if len(f) > 1 {
set[f] = struct{}{}
}
}
return set
} }
// jaccard calculates Jaccard similarity between two token sets // jaccard calculates Jaccard similarity between two token sets
func (t *GrepChunksTool) jaccard(a, b map[string]struct{}) float64 { func (t *GrepChunksTool) jaccard(a, b map[string]struct{}) float64 {
if len(a) == 0 && len(b) == 0 { return searchutil.Jaccard(a, b)
return 0
}
// Calculate intersection
inter := 0
for k := range a {
if _, ok := b[k]; ok {
inter++
}
}
// Calculate union
union := len(a) + len(b) - inter
if union == 0 {
return 0
}
return float64(inter) / float64(union)
} }

View File

@@ -264,15 +264,12 @@ func (t *KnowledgeSearchTool) Execute(ctx context.Context, args map[string]inter
topK, vectorThreshold, keywordThreshold, kbTypeMap) topK, vectorThreshold, keywordThreshold, kbTypeMap)
logger.Infof(ctx, "[Tool][KnowledgeSearch] Concurrent search completed: %d raw results", len(allResults)) logger.Infof(ctx, "[Tool][KnowledgeSearch] Concurrent search completed: %d raw results", len(allResults))
// Normalize keyword search results to ensure fair comparison across knowledge bases // Note: HybridSearch now uses RRF (Reciprocal Rank Fusion) which produces normalized scores
logger.Debugf(ctx, "[Tool][KnowledgeSearch] Normalizing keyword search results...") // RRF scores are in range [0, ~0.033] (max when rank=1 on both sides: 2/(60+1))
t.normalizeKeywordSearchResults(ctx, allResults) // Threshold filtering is already done inside HybridSearch before RRF, so we skip it here
logger.Infof(ctx, "[Tool][KnowledgeSearch] After keyword normalization: %d results", len(allResults))
// Filter by threshold first
filteredResults := t.filterByThreshold(allResults, vectorThreshold, keywordThreshold)
// Deduplicate before reranking to reduce processing overhead // Deduplicate before reranking to reduce processing overhead
deduplicatedBeforeRerank := t.deduplicateResults(filteredResults) deduplicatedBeforeRerank := t.deduplicateResults(allResults)
// Apply ReRank if model is configured // Apply ReRank if model is configured
// Prefer chatModel (LLM-based reranking) over rerankModel if both are available // Prefer chatModel (LLM-based reranking) over rerankModel if both are available
@@ -286,6 +283,9 @@ func (t *KnowledgeSearchTool) Execute(ctx context.Context, args map[string]inter
} }
} }
// Variable to hold results through reranking and MMR stages
var filteredResults []*searchResultWithMeta
if t.chatModel != nil && len(deduplicatedBeforeRerank) > 0 && rerankQuery != "" { if t.chatModel != nil && len(deduplicatedBeforeRerank) > 0 && rerankQuery != "" {
logger.Infof( logger.Infof(
ctx, ctx,
@@ -347,10 +347,9 @@ func (t *KnowledgeSearchTool) Execute(ctx context.Context, args map[string]inter
} }
} }
// Apply absolute minimum score filter to remove very low quality chunks // Note: minScore filter is skipped because HybridSearch now uses RRF scores
logger.Debugf(ctx, "[Tool][KnowledgeSearch] Applying min_score filter (%.2f)...", minScore) // RRF scores are in range [0, ~0.033], not [0, 1], so old thresholds don't apply
filteredResults = t.filterByMinScore(filteredResults, minScore) // Threshold filtering is already done inside HybridSearch before RRF fusion
logger.Infof(ctx, "[Tool][KnowledgeSearch] After min_score filter: %d results", len(filteredResults))
// Final deduplication after rerank (in case rerank changed scores/order but duplicates remain) // Final deduplication after rerank (in case rerank changed scores/order but duplicates remain)
logger.Debugf(ctx, "[Tool][KnowledgeSearch] Final deduplication after rerank...") logger.Debugf(ctx, "[Tool][KnowledgeSearch] Final deduplication after rerank...")
@@ -465,45 +464,6 @@ func (t *KnowledgeSearchTool) concurrentSearch(
return allResults return allResults
} }
// filterByThreshold filters results based on match type and threshold
// Special handling for history matches: uses lower threshold (reduced by 0.1, minimum 0.5)
func (t *KnowledgeSearchTool) filterByThreshold(
results []*searchResultWithMeta,
vectorThreshold, keywordThreshold float64,
) []*searchResultWithMeta {
filtered := make([]*searchResultWithMeta, 0)
for _, r := range results {
var threshold float64
// Special handling for history matches: use lower threshold
switch r.MatchType {
case types.MatchTypeHistory:
// Use the lower of the two thresholds, then reduce by 0.1 (minimum 0.5)
th := vectorThreshold
if keywordThreshold < th {
th = keywordThreshold
}
threshold = math.Max(th-0.1, 0.5)
case types.MatchTypeEmbedding:
threshold = vectorThreshold
case types.MatchTypeKeywords:
threshold = keywordThreshold
default:
// For other match types (graph, nearby chunk, etc.), use the lower threshold
threshold = vectorThreshold
if keywordThreshold < threshold {
threshold = keywordThreshold
}
}
// Check if result meets threshold
if r.Score >= threshold {
filtered = append(filtered, r)
}
}
return filtered
}
// rerankResults applies reranking to search results using LLM prompt scoring or rerank model // rerankResults applies reranking to search results using LLM prompt scoring or rerank model
func (t *KnowledgeSearchTool) rerankResults( func (t *KnowledgeSearchTool) rerankResults(
ctx context.Context, ctx context.Context,
@@ -566,15 +526,13 @@ func (t *KnowledgeSearchTool) rerankResults(
} }
// Apply composite scoring to reranked results // Apply composite scoring to reranked results
// Get query intent from context if available (optional) logger.Debugf(ctx, "[Tool][KnowledgeSearch] Applying composite scoring")
queryIntent := t.getQueryIntentFromContext(ctx)
logger.Debugf(ctx, "[Tool][KnowledgeSearch] Applying composite scoring with query_intent=%s", queryIntent)
// Store base scores before composite scoring // Store base scores before composite scoring
for _, result := range rerankedNonFAQ { for _, result := range rerankedNonFAQ {
baseScore := result.Score baseScore := result.Score
// Apply composite score // Apply composite score
result.Score = t.compositeScore(result, result.Score, baseScore, queryIntent) result.Score = t.compositeScore(result, result.Score, baseScore)
} }
// Combine FAQ results (with original order) and reranked non-FAQ results // Combine FAQ results (with original order) and reranked non-FAQ results
@@ -899,20 +857,6 @@ func (t *KnowledgeSearchTool) rerankWithModel(
return reranked, nil return reranked, nil
} }
// filterByMinScore filters results by absolute minimum score
func (t *KnowledgeSearchTool) filterByMinScore(
results []*searchResultWithMeta,
minScore float64,
) []*searchResultWithMeta {
filtered := make([]*searchResultWithMeta, 0)
for _, r := range results {
if r.Score >= minScore {
filtered = append(filtered, r)
}
}
return filtered
}
// deduplicateResults removes duplicate chunks, keeping the highest score // deduplicateResults removes duplicate chunks, keeping the highest score
// Uses multiple keys (ID, parent chunk ID, knowledge+index) and content signature for deduplication // Uses multiple keys (ID, parent chunk ID, knowledge+index) and content signature for deduplication
func (t *KnowledgeSearchTool) deduplicateResults(results []*searchResultWithMeta) []*searchResultWithMeta { func (t *KnowledgeSearchTool) deduplicateResults(results []*searchResultWithMeta) []*searchResultWithMeta {
@@ -984,17 +928,7 @@ func (t *KnowledgeSearchTool) deduplicateResults(results []*searchResultWithMeta
// buildContentSignature creates a normalized signature for content to detect near-duplicates // buildContentSignature creates a normalized signature for content to detect near-duplicates
func (t *KnowledgeSearchTool) buildContentSignature(content string) string { func (t *KnowledgeSearchTool) buildContentSignature(content string) string {
c := strings.ToLower(strings.TrimSpace(content)) return searchutil.BuildContentSignature(content)
if c == "" {
return ""
}
// Normalize whitespace
c = strings.Join(strings.Fields(c), " ")
// Use first 128 characters as signature
if len(c) > 128 {
c = c[:128]
}
return c
} }
// formatOutput formats the search results for display // formatOutput formats the search results for display
@@ -1215,47 +1149,6 @@ type chunkRange struct {
end int end int
} }
// normalizeKeywordSearchResults normalizes keyword search result scores into [0,1] globally across all knowledge bases
// Improvements:
// 1. Uses robust normalization with percentile-based bounds to handle outliers
// 2. Handles edge cases: single result, no variance, negative scores
// 3. Global normalization ensures fair comparison across different knowledge bases
func (t *KnowledgeSearchTool) normalizeKeywordSearchResults(ctx context.Context, results []*searchResultWithMeta) {
searchutil.NormalizeKeywordScores[*searchResultWithMeta](
results,
func(r *searchResultWithMeta) bool {
return r.MatchType == types.MatchTypeKeywords
},
func(r *searchResultWithMeta) float64 {
return r.Score
},
func(r *searchResultWithMeta, score float64) {
r.Score = score
},
searchutil.KeywordScoreCallbacks{
OnNoVariance: func(count int, score float64) {
logger.Infof(
ctx,
"[Tool][KnowledgeSearch] Keyword scores have no variance, all set to 1.0: count=%d, score=%.3f",
count,
score,
)
},
OnNormalized: func(count int, rawMin, rawMax, normalizeMin, normalizeMax float64) {
logger.Infof(
ctx,
"[Tool][KnowledgeSearch] Normalized keyword scores: count=%d, raw_min=%.3f, raw_max=%.3f, normalize_min=%.3f, normalize_max=%.3f",
count,
rawMin,
rawMax,
normalizeMin,
normalizeMax,
)
},
},
)
}
// getEnrichedPassage 合并Content和ImageInfo的文本内容 // getEnrichedPassage 合并Content和ImageInfo的文本内容
func (t *KnowledgeSearchTool) getEnrichedPassage(ctx context.Context, result *types.SearchResult) string { func (t *KnowledgeSearchTool) getEnrichedPassage(ctx context.Context, result *types.SearchResult) string {
if result.ImageInfo == "" { if result.ImageInfo == "" {
@@ -1302,19 +1195,10 @@ func (t *KnowledgeSearchTool) getEnrichedPassage(ctx context.Context, result *ty
return combinedText return combinedText
} }
// getQueryIntentFromContext attempts to extract query intent from context (optional)
func (t *KnowledgeSearchTool) getQueryIntentFromContext(ctx context.Context) string {
// Try to get query intent from context if available
// This is optional and may not always be present in agent tool context
// Return empty string if not available
return ""
}
// compositeScore calculates a composite score considering multiple factors // compositeScore calculates a composite score considering multiple factors
func (t *KnowledgeSearchTool) compositeScore( func (t *KnowledgeSearchTool) compositeScore(
result *searchResultWithMeta, result *searchResultWithMeta,
modelScore, baseScore float64, modelScore, baseScore float64,
queryIntent string,
) float64 { ) float64 {
// Source weight: web_search results get slightly lower weight // Source weight: web_search results get slightly lower weight
sourceWeight := 1.0 sourceWeight := 1.0
@@ -1322,26 +1206,6 @@ func (t *KnowledgeSearchTool) compositeScore(
sourceWeight = 0.95 sourceWeight = 0.95
} }
// Intent boost: adjust score based on query intent and chunk characteristics
intentBoost := 1.0
if queryIntent != "" {
switch queryIntent {
case "definition":
// Boost summary chunks for definition queries
if result.ChunkType == string(types.ChunkTypeSummary) {
intentBoost = 1.05
}
case "howto":
// Boost longer chunks for howto queries
if result.EndAt-result.StartAt > 300 {
intentBoost = 1.03
}
case "compare":
// No boost for compare queries
intentBoost = 1.0
}
}
// Position prior: slightly favor chunks earlier in the document // Position prior: slightly favor chunks earlier in the document
positionPrior := 1.0 positionPrior := 1.0
if result.StartAt >= 0 && result.EndAt > result.StartAt { if result.StartAt >= 0 && result.EndAt > result.StartAt {
@@ -1352,7 +1216,6 @@ func (t *KnowledgeSearchTool) compositeScore(
// Composite formula: weighted combination of model score, base score, and source weight // Composite formula: weighted combination of model score, base score, and source weight
composite := 0.6*modelScore + 0.3*baseScore + 0.1*sourceWeight composite := 0.6*modelScore + 0.3*baseScore + 0.1*sourceWeight
composite *= intentBoost
composite *= positionPrior composite *= positionPrior
// Clamp to [0, 1] // Clamp to [0, 1]
@@ -1368,13 +1231,7 @@ func (t *KnowledgeSearchTool) compositeScore(
// clampFloat clamps a float value to the specified range // clampFloat clamps a float value to the specified range
func (t *KnowledgeSearchTool) clampFloat(v, minV, maxV float64) float64 { func (t *KnowledgeSearchTool) clampFloat(v, minV, maxV float64) float64 {
if v < minV { return searchutil.ClampFloat(v, minV, maxV)
return minV
}
if v > maxV {
return maxV
}
return v
} }
// applyMMR applies Maximal Marginal Relevance algorithm to reduce redundancy // applyMMR applies Maximal Marginal Relevance algorithm to reduce redundancy
@@ -1456,36 +1313,10 @@ func (t *KnowledgeSearchTool) applyMMR(
// tokenizeSimple tokenizes text into a set of words (simple whitespace-based) // tokenizeSimple tokenizes text into a set of words (simple whitespace-based)
func (t *KnowledgeSearchTool) tokenizeSimple(text string) map[string]struct{} { func (t *KnowledgeSearchTool) tokenizeSimple(text string) map[string]struct{} {
text = strings.ToLower(text) return searchutil.TokenizeSimple(text)
fields := strings.Fields(text)
set := make(map[string]struct{}, len(fields))
for _, f := range fields {
if len(f) > 1 {
set[f] = struct{}{}
}
}
return set
} }
// jaccard calculates Jaccard similarity between two token sets // jaccard calculates Jaccard similarity between two token sets
func (t *KnowledgeSearchTool) jaccard(a, b map[string]struct{}) float64 { func (t *KnowledgeSearchTool) jaccard(a, b map[string]struct{}) float64 {
if len(a) == 0 && len(b) == 0 { return searchutil.Jaccard(a, b)
return 0
}
// Calculate intersection
inter := 0
for k := range a {
if _, ok := b[k]; ok {
inter++
}
}
// Calculate union
union := len(a) + len(b) - inter
if union == 0 {
return 0
}
return float64(inter) / float64(union)
} }

View File

@@ -1,166 +0,0 @@
package chatpipline
import (
"context"
"encoding/json"
"regexp"
"strings"
"github.com/Tencent/WeKnora/internal/config"
"github.com/Tencent/WeKnora/internal/models/chat"
"github.com/Tencent/WeKnora/internal/types"
"github.com/Tencent/WeKnora/internal/types/interfaces"
)
// PluginPreprocess Query preprocessing plugin
type PluginPreprocess struct {
config *config.Config
modelService interfaces.ModelService
}
// Regular expressions for text cleaning
var (
multiSpaceRegex = regexp.MustCompile(`\s+`) // Multiple spaces
)
// NewPluginPreprocess Creates a new query preprocessing plugin
func NewPluginPreprocess(
eventManager *EventManager,
config *config.Config,
cleaner interfaces.ResourceCleaner,
modelService interfaces.ModelService,
) *PluginPreprocess {
res := &PluginPreprocess{
config: config,
modelService: modelService,
}
eventManager.Register(res)
return res
}
// ActivationEvents Register activation events
func (p *PluginPreprocess) ActivationEvents() []types.EventType {
return []types.EventType{types.PREPROCESS_QUERY}
}
// OnEvent Process events
func (p *PluginPreprocess) OnEvent(
ctx context.Context,
eventType types.EventType,
chatManage *types.ChatManage,
next func() *PluginError,
) *PluginError {
rawQuery := strings.TrimSpace(chatManage.RewriteQuery)
if rawQuery == "" {
return next()
}
pipelineInfo(ctx, "Preprocess", "input", map[string]interface{}{
"session_id": chatManage.SessionID,
"rewrite_query": rawQuery,
})
// Lightweight normalization: just collapse multiple spaces
processed := multiSpaceRegex.ReplaceAllString(rawQuery, " ")
processed = strings.TrimSpace(processed)
chatManage.ProcessedQuery = processed
chatManage.QueryIntent = p.detectIntentLLM(ctx, chatManage, processed)
pipelineInfo(ctx, "Preprocess", "output", map[string]interface{}{
"session_id": chatManage.SessionID,
"processed_query": processed,
"query_intent": chatManage.QueryIntent,
})
return next()
}
// intentResp is a response for intent detection
type intentResp struct {
Intent string `json:"intent"`
Confidence float64 `json:"confidence"`
}
// detectIntentLLM detects the intent of a query using an LLM
func (p *PluginPreprocess) detectIntentLLM(ctx context.Context, chatManage *types.ChatManage, text string) string {
if p.modelService == nil || chatManage.ChatModelID == "" {
pipelineWarn(
ctx,
"IntentDetect",
"skip",
map[string]interface{}{"reason": "no_model", "session_id": chatManage.SessionID},
)
return "general"
}
chatModel, err := p.modelService.GetChatModel(ctx, chatManage.ChatModelID)
if err != nil {
pipelineWarn(
ctx,
"IntentDetect",
"get_model_failed",
map[string]interface{}{"error": err.Error(), "model_id": chatManage.ChatModelID},
)
return "general"
}
pipelineInfo(
ctx,
"IntentDetect",
"start",
map[string]interface{}{"session_id": chatManage.SessionID, "model_id": chatManage.ChatModelID},
)
sys := "You are a query intent classifier. Classify the user's query into one of: definition, howto, compare, qa, general. Respond ONLY with a JSON object {\"intent\": \"...\", \"confidence\": 0.0 } inside a markdown fenced block."
usr := text
think := false
resp, err := chatModel.Chat(ctx, []chat.Message{
{Role: "system", Content: sys},
{Role: "user", Content: usr},
}, &chat.ChatOptions{Temperature: 0.0, MaxCompletionTokens: 64, Thinking: &think})
if err != nil || resp.Content == "" {
pipelineWarn(ctx, "IntentDetect", "model_call_failed", map[string]interface{}{"error": err})
return "general"
}
body := extractJSONBody(resp.Content)
var ir intentResp
if err := json.Unmarshal([]byte(body), &ir); err != nil {
pipelineWarn(ctx, "IntentDetect", "parse_failed", map[string]interface{}{"body": body, "error": err.Error()})
return "general"
}
pipelineInfo(
ctx,
"IntentDetect",
"result",
map[string]interface{}{"intent": ir.Intent, "confidence": ir.Confidence},
)
switch strings.ToLower(strings.TrimSpace(ir.Intent)) {
case "definition", "howto", "compare", "qa", "general":
return strings.ToLower(ir.Intent)
default:
return "general"
}
}
// extractJSONBody extracts a JSON body from a string
func extractJSONBody(text string) string {
t := strings.TrimSpace(text)
// Try fenced block first
if i := strings.Index(t, "{"); i >= 0 {
j := strings.LastIndex(t, "}")
if j > i {
return t[i : j+1]
}
}
return "{}"
}
// Close Releases resources
func (p *PluginPreprocess) Close() {
}
// ShutdownHandler Returns shutdown function
func (p *PluginPreprocess) ShutdownHandler() func() {
return func() {
p.Close()
}
}

View File

@@ -8,6 +8,7 @@ import (
"strings" "strings"
"github.com/Tencent/WeKnora/internal/models/rerank" "github.com/Tencent/WeKnora/internal/models/rerank"
"github.com/Tencent/WeKnora/internal/searchutil"
"github.com/Tencent/WeKnora/internal/types" "github.com/Tencent/WeKnora/internal/types"
"github.com/Tencent/WeKnora/internal/types/interfaces" "github.com/Tencent/WeKnora/internal/types/interfaces"
) )
@@ -41,7 +42,6 @@ func (p *PluginRerank) OnEvent(ctx context.Context,
"rerank_model": chatManage.RerankModelID, "rerank_model": chatManage.RerankModelID,
"rerank_thresh": chatManage.RerankThreshold, "rerank_thresh": chatManage.RerankThreshold,
"rewrite_query": chatManage.RewriteQuery, "rewrite_query": chatManage.RewriteQuery,
"processed_query": chatManage.ProcessedQuery,
}) })
if len(chatManage.SearchResult) == 0 { if len(chatManage.SearchResult) == 0 {
pipelineInfo(ctx, "Rerank", "skip", map[string]interface{}{ pipelineInfo(ctx, "Rerank", "skip", map[string]interface{}{
@@ -77,18 +77,40 @@ func (p *PluginRerank) OnEvent(ctx context.Context,
passages = append(passages, passage) passages = append(passages, passage)
} }
// Try reranking with different query variants in priority order // Single rerank call with RewriteQuery, use threshold degradation if no results
originalThreshold := chatManage.RerankThreshold
rerankResp := p.rerank(ctx, chatManage, rerankModel, chatManage.RewriteQuery, passages) rerankResp := p.rerank(ctx, chatManage, rerankModel, chatManage.RewriteQuery, passages)
if len(rerankResp) == 0 {
rerankResp = p.rerank(ctx, chatManage, rerankModel, chatManage.ProcessedQuery, passages) // If no results and threshold is high enough, try with lower threshold
if len(rerankResp) == 0 { if len(rerankResp) == 0 && originalThreshold > 0.3 {
rerankResp = p.rerank(ctx, chatManage, rerankModel, chatManage.Query, passages) degradedThreshold := originalThreshold * 0.7
if degradedThreshold < 0.3 {
degradedThreshold = 0.3
} }
pipelineInfo(ctx, "Rerank", "threshold_degrade", map[string]interface{}{
"original": originalThreshold,
"degraded": degradedThreshold,
})
chatManage.RerankThreshold = degradedThreshold
rerankResp = p.rerank(ctx, chatManage, rerankModel, chatManage.RewriteQuery, passages)
// Restore original threshold
chatManage.RerankThreshold = originalThreshold
} }
pipelineInfo(ctx, "Rerank", "model_response", map[string]interface{}{ pipelineInfo(ctx, "Rerank", "model_response", map[string]interface{}{
"result_cnt": len(rerankResp), "result_cnt": len(rerankResp),
}) })
// Log input scores before reranking for debugging
for i, sr := range chatManage.SearchResult {
pipelineInfo(ctx, "Rerank", "input_score", map[string]interface{}{
"index": i,
"chunk_id": sr.ID,
"score": fmt.Sprintf("%.4f", sr.Score),
"match_type": sr.MatchType,
})
}
for i := range chatManage.SearchResult { for i := range chatManage.SearchResult {
chatManage.SearchResult[i].Metadata = ensureMetadata(chatManage.SearchResult[i].Metadata) chatManage.SearchResult[i].Metadata = ensureMetadata(chatManage.SearchResult[i].Metadata)
} }
@@ -97,8 +119,15 @@ func (p *PluginRerank) OnEvent(ctx context.Context,
sr := chatManage.SearchResult[rr.Index] sr := chatManage.SearchResult[rr.Index]
base := sr.Score base := sr.Score
sr.Metadata["base_score"] = fmt.Sprintf("%.4f", base) sr.Metadata["base_score"] = fmt.Sprintf("%.4f", base)
sr.Score = rr.RelevanceScore modelScore := rr.RelevanceScore
sr.Score = compositeScore(sr, rr.RelevanceScore, base, chatManage) sr.Score = compositeScore(sr, modelScore, base)
pipelineInfo(ctx, "Rerank", "composite_calc", map[string]interface{}{
"chunk_id": sr.ID,
"base_score": fmt.Sprintf("%.4f", base),
"model_score": fmt.Sprintf("%.4f", modelScore),
"final_score": fmt.Sprintf("%.4f", sr.Score),
"match_type": sr.MatchType,
})
reranked = append(reranked, sr) reranked = append(reranked, sr)
} }
final := applyMMR(ctx, reranked, chatManage, min(len(reranked), max(1, chatManage.RerankTopK)), 0.7) final := applyMMR(ctx, reranked, chatManage, min(len(reranked), max(1, chatManage.RerankTopK)), 0.7)
@@ -112,7 +141,6 @@ func (p *PluginRerank) OnEvent(ctx context.Context,
"chunk_id": reranked[i].ID, "chunk_id": reranked[i].ID,
"base_score": reranked[i].Metadata["base_score"], "base_score": reranked[i].Metadata["base_score"],
"final_score": fmt.Sprintf("%.4f", reranked[i].Score), "final_score": fmt.Sprintf("%.4f", reranked[i].Score),
"intent": chatManage.QueryIntent,
}) })
} }
@@ -185,7 +213,7 @@ func ensureMetadata(m map[string]string) map[string]string {
} }
// compositeScore calculates the composite score for a search result // compositeScore calculates the composite score for a search result
func compositeScore(sr *types.SearchResult, modelScore, baseScore float64, chatManage *types.ChatManage) float64 { func compositeScore(sr *types.SearchResult, modelScore, baseScore float64) float64 {
sourceWeight := 1.0 sourceWeight := 1.0
switch strings.ToLower(sr.KnowledgeSource) { switch strings.ToLower(sr.KnowledgeSource) {
case "web_search": case "web_search":
@@ -193,27 +221,11 @@ func compositeScore(sr *types.SearchResult, modelScore, baseScore float64, chatM
default: default:
sourceWeight = 1.0 sourceWeight = 1.0
} }
intentBoost := 1.0
if chatManage.QueryIntent != "" {
switch chatManage.QueryIntent {
case "definition":
if sr.ChunkType == string(types.ChunkTypeSummary) {
intentBoost = 1.05
}
case "howto":
if sr.EndAt-sr.StartAt > 300 {
intentBoost = 1.03
}
case "compare":
intentBoost = 1.0
}
}
positionPrior := 1.0 positionPrior := 1.0
if sr.StartAt >= 0 { if sr.StartAt >= 0 {
positionPrior += clampFloat(1.0-float64(sr.StartAt)/float64(sr.EndAt+1), -0.05, 0.05) positionPrior += searchutil.ClampFloat(1.0-float64(sr.StartAt)/float64(sr.EndAt+1), -0.05, 0.05)
} }
composite := 0.6*modelScore + 0.3*baseScore + 0.1*sourceWeight composite := 0.6*modelScore + 0.3*baseScore + 0.1*sourceWeight
composite *= intentBoost
composite *= positionPrior composite *= positionPrior
if composite < 0 { if composite < 0 {
composite = 0 composite = 0
@@ -224,7 +236,7 @@ func compositeScore(sr *types.SearchResult, modelScore, baseScore float64, chatM
return composite return composite
} }
// applyMMR applies the MMR algorithm to the search results // applyMMR applies the MMR algorithm to the search results with pre-computed token sets
func applyMMR( func applyMMR(
ctx context.Context, ctx context.Context,
results []*types.SearchResult, results []*types.SearchResult,
@@ -240,40 +252,60 @@ func applyMMR(
"k": k, "k": k,
"candidates": len(results), "candidates": len(results),
}) })
selected := make([]*types.SearchResult, 0, k)
candidates := make([]*types.SearchResult, len(results)) // Pre-compute all token sets upfront (optimization)
copy(candidates, results) allTokenSets := make([]map[string]struct{}, len(results))
tokenSets := make([]map[string]struct{}, len(candidates)) for i, r := range results {
for i, r := range candidates { allTokenSets[i] = searchutil.TokenizeSimple(getEnrichedPassage(ctx, r))
tokenSets[i] = tokenizeSimple(getEnrichedPassage(ctx, r))
} }
for len(selected) < k && len(candidates) > 0 {
bestIdx := 0 selected := make([]*types.SearchResult, 0, k)
selectedTokenSets := make([]map[string]struct{}, 0, k)
selectedIndices := make(map[int]struct{})
for len(selected) < k && len(selectedIndices) < len(results) {
bestIdx := -1
bestScore := -1.0 bestScore := -1.0
for i, r := range candidates {
for i, r := range results {
if _, isSelected := selectedIndices[i]; isSelected {
continue
}
relevance := r.Score relevance := r.Score
redundancy := 0.0 redundancy := 0.0
for _, s := range selected {
redundancy = math.Max(redundancy, jaccard(tokenSets[i], tokenizeSimple(getEnrichedPassage(ctx, s)))) // Use pre-computed token sets for redundancy calculation
for _, selTokens := range selectedTokenSets {
sim := searchutil.Jaccard(allTokenSets[i], selTokens)
if sim > redundancy {
redundancy = sim
} }
}
mmr := lambda*relevance - (1.0-lambda)*redundancy mmr := lambda*relevance - (1.0-lambda)*redundancy
if mmr > bestScore { if mmr > bestScore {
bestScore = mmr bestScore = mmr
bestIdx = i bestIdx = i
} }
} }
selected = append(selected, candidates[bestIdx])
candidates = append(candidates[:bestIdx], candidates[bestIdx+1:]...) if bestIdx < 0 {
break
} }
// Compute average redundancy among selected
selected = append(selected, results[bestIdx])
selectedTokenSets = append(selectedTokenSets, allTokenSets[bestIdx])
selectedIndices[bestIdx] = struct{}{}
}
// Compute average redundancy among selected using pre-computed token sets
avgRed := 0.0 avgRed := 0.0
if len(selected) > 1 { if len(selected) > 1 {
pairs := 0 pairs := 0
for i := 0; i < len(selected); i++ { for i := 0; i < len(selectedTokenSets); i++ {
for j := i + 1; j < len(selected); j++ { for j := i + 1; j < len(selectedTokenSets); j++ {
si := tokenizeSimple(getEnrichedPassage(ctx, selected[i])) avgRed += searchutil.Jaccard(selectedTokenSets[i], selectedTokenSets[j])
sj := tokenizeSimple(getEnrichedPassage(ctx, selected[j]))
avgRed += jaccard(si, sj)
pairs++ pairs++
} }
} }
@@ -288,93 +320,59 @@ func applyMMR(
return selected return selected
} }
// tokenizeSimple tokenizes a text into a set of tokens // getEnrichedPassage 合并Content、ImageInfo和GeneratedQuestions的文本内容
func tokenizeSimple(text string) map[string]struct{} {
text = strings.ToLower(text)
fields := strings.Fields(text)
set := make(map[string]struct{}, len(fields))
for _, f := range fields {
if len(f) > 1 {
set[f] = struct{}{}
}
}
return set
}
// jaccard calculates the Jaccard similarity between two sets of tokens
func jaccard(a, b map[string]struct{}) float64 {
if len(a) == 0 && len(b) == 0 {
return 0
}
inter := 0
for k := range a {
if _, ok := b[k]; ok {
inter++
}
}
union := len(a) + len(b) - inter
if union == 0 {
return 0
}
return float64(inter) / float64(union)
}
// clampFloat clamps a float value between a minimum and maximum value
func clampFloat(v, minV, maxV float64) float64 {
if v < minV {
return minV
}
if v > maxV {
return maxV
}
return v
}
// getEnrichedPassage 合并Content和ImageInfo的文本内容
func getEnrichedPassage(ctx context.Context, result *types.SearchResult) string { func getEnrichedPassage(ctx context.Context, result *types.SearchResult) string {
if result.ImageInfo == "" { combinedText := result.Content
return result.Content var enrichments []string
}
// 解析ImageInfo // 解析ImageInfo
if result.ImageInfo != "" {
var imageInfos []types.ImageInfo var imageInfos []types.ImageInfo
err := json.Unmarshal([]byte(result.ImageInfo), &imageInfos) err := json.Unmarshal([]byte(result.ImageInfo), &imageInfos)
if err != nil { if err != nil {
pipelineWarn(ctx, "Rerank", "image_info_parse", map[string]interface{}{ pipelineWarn(ctx, "Rerank", "image_info_parse", map[string]interface{}{
"error": err.Error(), "error": err.Error(),
}) })
return result.Content } else {
}
if len(imageInfos) == 0 {
return result.Content
}
// 提取所有图片的描述和OCR文本 // 提取所有图片的描述和OCR文本
var imageTexts []string
for _, img := range imageInfos { for _, img := range imageInfos {
if img.Caption != "" { if img.Caption != "" {
imageTexts = append(imageTexts, fmt.Sprintf("图片描述: %s", img.Caption)) enrichments = append(enrichments, fmt.Sprintf("图片描述: %s", img.Caption))
} }
if img.OCRText != "" { if img.OCRText != "" {
imageTexts = append(imageTexts, fmt.Sprintf("图片文本: %s", img.OCRText)) enrichments = append(enrichments, fmt.Sprintf("图片文本: %s", img.OCRText))
}
}
} }
} }
if len(imageTexts) == 0 { // 解析ChunkMetadata中的GeneratedQuestions
return result.Content if len(result.ChunkMetadata) > 0 {
var docMeta types.DocumentChunkMetadata
err := json.Unmarshal(result.ChunkMetadata, &docMeta)
if err != nil {
pipelineWarn(ctx, "Rerank", "chunk_metadata_parse", map[string]interface{}{
"error": err.Error(),
})
} else if len(docMeta.GeneratedQuestions) > 0 {
enrichments = append(enrichments, fmt.Sprintf("相关问题: %s", strings.Join(docMeta.GeneratedQuestions, "; ")))
}
} }
// 组合内容和图片信息 if len(enrichments) == 0 {
combinedText := result.Content return combinedText
}
// 组合内容和增强信息
if combinedText != "" { if combinedText != "" {
combinedText += "\n\n" combinedText += "\n\n"
} }
combinedText += strings.Join(imageTexts, "\n") combinedText += strings.Join(enrichments, "\n")
pipelineInfo(ctx, "Rerank", "image_info_merge", map[string]interface{}{ pipelineInfo(ctx, "Rerank", "passage_enrich", map[string]interface{}{
"content_len": len(result.Content), "content_len": len(result.Content),
"image_len": len(strings.Join(imageTexts, "\n")), "enrichment": strings.Join(enrichments, "\n"),
"enrichment_len": len(strings.Join(enrichments, "\n")),
}) })
return combinedText return combinedText

View File

@@ -2,13 +2,14 @@ package chatpipline
import ( import (
"context" "context"
"encoding/json"
"fmt" "fmt"
"regexp"
"strings" "strings"
"sync" "sync"
"unicode"
"github.com/Tencent/WeKnora/internal/config" "github.com/Tencent/WeKnora/internal/config"
"github.com/Tencent/WeKnora/internal/models/chat" "github.com/Tencent/WeKnora/internal/logger"
"github.com/Tencent/WeKnora/internal/searchutil" "github.com/Tencent/WeKnora/internal/searchutil"
"github.com/Tencent/WeKnora/internal/types" "github.com/Tencent/WeKnora/internal/types"
"github.com/Tencent/WeKnora/internal/types/interfaces" "github.com/Tencent/WeKnora/internal/types/interfaces"
@@ -18,7 +19,6 @@ import (
type PluginSearch struct { type PluginSearch struct {
knowledgeBaseService interfaces.KnowledgeBaseService knowledgeBaseService interfaces.KnowledgeBaseService
knowledgeService interfaces.KnowledgeService knowledgeService interfaces.KnowledgeService
modelService interfaces.ModelService
config *config.Config config *config.Config
webSearchService interfaces.WebSearchService webSearchService interfaces.WebSearchService
tenantService interfaces.TenantService tenantService interfaces.TenantService
@@ -28,7 +28,6 @@ type PluginSearch struct {
func NewPluginSearch(eventManager *EventManager, func NewPluginSearch(eventManager *EventManager,
knowledgeBaseService interfaces.KnowledgeBaseService, knowledgeBaseService interfaces.KnowledgeBaseService,
knowledgeService interfaces.KnowledgeService, knowledgeService interfaces.KnowledgeService,
modelService interfaces.ModelService,
config *config.Config, config *config.Config,
webSearchService interfaces.WebSearchService, webSearchService interfaces.WebSearchService,
tenantService interfaces.TenantService, tenantService interfaces.TenantService,
@@ -37,7 +36,6 @@ func NewPluginSearch(eventManager *EventManager,
res := &PluginSearch{ res := &PluginSearch{
knowledgeBaseService: knowledgeBaseService, knowledgeBaseService: knowledgeBaseService,
knowledgeService: knowledgeService, knowledgeService: knowledgeService,
modelService: modelService,
config: config, config: config,
webSearchService: webSearchService, webSearchService: webSearchService,
tenantService: tenantService, tenantService: tenantService,
@@ -77,7 +75,6 @@ func (p *PluginSearch) OnEvent(ctx context.Context,
pipelineInfo(ctx, "Search", "input", map[string]interface{}{ pipelineInfo(ctx, "Search", "input", map[string]interface{}{
"session_id": chatManage.SessionID, "session_id": chatManage.SessionID,
"rewrite_query": chatManage.RewriteQuery, "rewrite_query": chatManage.RewriteQuery,
"processed_query": chatManage.ProcessedQuery,
"kb_ids": strings.Join(knowledgeBaseIDs, ","), "kb_ids": strings.Join(knowledgeBaseIDs, ","),
"tenant_id": chatManage.TenantID, "tenant_id": chatManage.TenantID,
"web_enabled": chatManage.WebSearchEnabled, "web_enabled": chatManage.WebSearchEnabled,
@@ -121,6 +118,16 @@ func (p *PluginSearch) OnEvent(ctx context.Context,
chatManage.SearchResult = allResults chatManage.SearchResult = allResults
// Log all search results with scores before any processing
for i, r := range chatManage.SearchResult {
pipelineInfo(ctx, "Search", "result_score_before_normalize", map[string]interface{}{
"index": i,
"chunk_id": r.ID,
"score": fmt.Sprintf("%.4f", r.Score),
"match_type": r.MatchType,
})
}
// If recall is low, attempt query expansion with keyword-focused search // If recall is low, attempt query expansion with keyword-focused search
if chatManage.EnableQueryExpansion && len(chatManage.SearchResult) < max(1, chatManage.EmbeddingTopK/2) { if chatManage.EnableQueryExpansion && len(chatManage.SearchResult) < max(1, chatManage.EmbeddingTopK/2) {
pipelineInfo(ctx, "Search", "recall_low", map[string]interface{}{ pipelineInfo(ctx, "Search", "recall_low", map[string]interface{}{
@@ -189,6 +196,7 @@ func (p *PluginSearch) OnEvent(ctx context.Context,
} }
wgExp.Wait() wgExp.Wait()
if len(expResults) > 0 { if len(expResults) > 0 {
// Scores already normalized in HybridSearch
pipelineInfo(ctx, "Search", "expansion_done", map[string]interface{}{ pipelineInfo(ctx, "Search", "expansion_done", map[string]interface{}{
"added": len(expResults), "added": len(expResults),
}) })
@@ -215,6 +223,16 @@ func (p *PluginSearch) OnEvent(ctx context.Context,
"after": len(chatManage.SearchResult), "after": len(chatManage.SearchResult),
}) })
// Log final scores after all processing
for i, r := range chatManage.SearchResult {
pipelineInfo(ctx, "Search", "final_score", map[string]interface{}{
"index": i,
"chunk_id": r.ID,
"score": fmt.Sprintf("%.4f", r.Score),
"match_type": r.MatchType,
})
}
// Return if we have results // Return if we have results
if len(chatManage.SearchResult) != 0 { if len(chatManage.SearchResult) != 0 {
pipelineInfo(ctx, "Search", "output", map[string]interface{}{ pipelineInfo(ctx, "Search", "output", map[string]interface{}{
@@ -250,32 +268,33 @@ func (p *PluginSearch) getSearchResultFromHistory(chatManage *types.ChatManage)
func removeDuplicateResults(results []*types.SearchResult) []*types.SearchResult { func removeDuplicateResults(results []*types.SearchResult) []*types.SearchResult {
seen := make(map[string]bool) seen := make(map[string]bool)
contentSig := make(map[string]bool) contentSig := make(map[string]string) // sig -> first chunk ID
var uniqueResults []*types.SearchResult var uniqueResults []*types.SearchResult
for _, r := range results { for _, r := range results {
keys := []string{r.ID} keys := []string{r.ID}
if r.ParentChunkID != "" { if r.ParentChunkID != "" {
keys = append(keys, "parent:"+r.ParentChunkID) keys = append(keys, "parent:"+r.ParentChunkID)
} }
if r.KnowledgeID != "" {
keys = append(keys, fmt.Sprintf("kb:%s#%d", r.KnowledgeID, r.ChunkIndex))
}
dup := false dup := false
dupKey := ""
for _, k := range keys { for _, k := range keys {
if seen[k] { if seen[k] {
dup = true dup = true
dupKey = k
break break
} }
} }
if dup { if dup {
logger.Debugf(context.Background(), "Dedup: chunk %s removed due to key: %s", r.ID, dupKey)
continue continue
} }
sig := buildContentSignature(r.Content) sig := buildContentSignature(r.Content)
if sig != "" { if sig != "" {
if contentSig[sig] { if firstChunk, exists := contentSig[sig]; exists {
logger.Debugf(context.Background(), "Dedup: chunk %s removed due to content signature (dup of %s, sig prefix: %.50s...)", r.ID, firstChunk, sig)
continue continue
} }
contentSig[sig] = true contentSig[sig] = r.ID
} }
for _, k := range keys { for _, k := range keys {
seen[k] = true seen[k] = true
@@ -286,24 +305,16 @@ func removeDuplicateResults(results []*types.SearchResult) []*types.SearchResult
} }
func buildContentSignature(content string) string { func buildContentSignature(content string) string {
c := strings.ToLower(strings.TrimSpace(content)) return searchutil.BuildContentSignature(content)
if c == "" {
return ""
}
c = strings.Join(strings.Fields(c), " ")
if len(c) > 128 {
c = c[:128]
}
return c
} }
// searchKnowledgeBases performs KB searches for rewrite and processed queries across KB IDs // searchKnowledgeBases performs KB searches across KB IDs using RewriteQuery only
func (p *PluginSearch) searchKnowledgeBases( func (p *PluginSearch) searchKnowledgeBases(
ctx context.Context, ctx context.Context,
knowledgeBaseIDs []string, knowledgeBaseIDs []string,
chatManage *types.ChatManage, chatManage *types.ChatManage,
) []*types.SearchResult { ) []*types.SearchResult {
// Build base params for rewrite query // Build params for rewrite query
baseParams := types.SearchParams{ baseParams := types.SearchParams{
QueryText: strings.TrimSpace(chatManage.RewriteQuery), QueryText: strings.TrimSpace(chatManage.RewriteQuery),
VectorThreshold: chatManage.VectorThreshold, VectorThreshold: chatManage.VectorThreshold,
@@ -315,7 +326,7 @@ func (p *PluginSearch) searchKnowledgeBases(
var mu sync.Mutex var mu sync.Mutex
var results []*types.SearchResult var results []*types.SearchResult
// Search with rewrite query // Search with rewrite query only (removed duplicate ProcessedQuery search)
for _, kbID := range knowledgeBaseIDs { for _, kbID := range knowledgeBaseIDs {
wg.Add(1) wg.Add(1)
go func(knowledgeBaseID string) { go func(knowledgeBaseID string) {
@@ -326,13 +337,11 @@ func (p *PluginSearch) searchKnowledgeBases(
"kb_id": knowledgeBaseID, "kb_id": knowledgeBaseID,
"query": baseParams.QueryText, "query": baseParams.QueryText,
"error": err.Error(), "error": err.Error(),
"query_ty": "rewrite",
}) })
return return
} }
pipelineInfo(ctx, "Search", "kb_result", map[string]interface{}{ pipelineInfo(ctx, "Search", "kb_result", map[string]interface{}{
"kb_id": knowledgeBaseID, "kb_id": knowledgeBaseID,
"query_ty": "rewrite",
"hit_count": len(res), "hit_count": len(res),
}) })
mu.Lock() mu.Lock()
@@ -343,45 +352,6 @@ func (p *PluginSearch) searchKnowledgeBases(
wg.Wait() wg.Wait()
// If processed query differs, search again
if chatManage.RewriteQuery != chatManage.ProcessedQuery {
paramsProcessed := baseParams
paramsProcessed.QueryText = strings.TrimSpace(chatManage.ProcessedQuery)
pipelineInfo(ctx, "Search", "processed_query_search", map[string]interface{}{
"query": paramsProcessed.QueryText,
})
wg = sync.WaitGroup{}
for _, kbID := range knowledgeBaseIDs {
wg.Add(1)
go func(knowledgeBaseID string) {
defer wg.Done()
res, err := p.knowledgeBaseService.HybridSearch(ctx, knowledgeBaseID, paramsProcessed)
if err != nil {
pipelineWarn(ctx, "Search", "kb_search_error", map[string]interface{}{
"kb_id": knowledgeBaseID,
"query": paramsProcessed.QueryText,
"error": err.Error(),
"query_ty": "processed",
})
return
}
pipelineInfo(ctx, "Search", "kb_result", map[string]interface{}{
"kb_id": knowledgeBaseID,
"query_ty": "processed",
"hit_count": len(res),
})
mu.Lock()
results = append(results, res...)
mu.Unlock()
}(kbID)
}
wg.Wait()
}
// Normalize keyword retriever scores after collecting all results from multiple knowledge bases
normalizeKeywordSearchResults(ctx, results)
pipelineInfo(ctx, "Search", "kb_result_summary", map[string]interface{}{ pipelineInfo(ctx, "Search", "kb_result_summary", map[string]interface{}{
"total_hits": len(results), "total_hits": len(results),
}) })
@@ -413,11 +383,8 @@ func (p *PluginSearch) searchWebIfEnabled(ctx context.Context, chatManage *types
}) })
return nil return nil
} }
// Build questions (rewrite + processed if different) // Build questions using RewriteQuery only
questions := []string{strings.TrimSpace(chatManage.RewriteQuery)} questions := []string{strings.TrimSpace(chatManage.RewriteQuery)}
if chatManage.ProcessedQuery != "" && chatManage.ProcessedQuery != chatManage.RewriteQuery {
questions = append(questions, strings.TrimSpace(chatManage.ProcessedQuery))
}
// Load session-scoped temp KB state from Redis using SessionService // Load session-scoped temp KB state from Redis using SessionService
tempKBID, seen, ids := p.sessionService.GetWebSearchTempKBState(ctx, chatManage.SessionID) tempKBID, seen, ids := p.sessionService.GetWebSearchTempKBState(ctx, chatManage.SessionID)
compressed, kbID, newSeen, newIDs, err := p.webSearchService.CompressWithRAG( compressed, kbID, newSeen, newIDs, err := p.webSearchService.CompressWithRAG(
@@ -440,119 +407,155 @@ func (p *PluginSearch) searchWebIfEnabled(ctx context.Context, chatManage *types
return res return res
} }
// expandQueries generates paraphrases and synonyms using chat model to improve keyword recall // expandQueries generates query variants locally without LLM to improve keyword recall
// Uses simple techniques: word reordering, stopword removal, key phrase extraction
func (p *PluginSearch) expandQueries(ctx context.Context, chatManage *types.ChatManage) []string { func (p *PluginSearch) expandQueries(ctx context.Context, chatManage *types.ChatManage) []string {
if p.modelService == nil || chatManage.ChatModelID == "" { query := strings.TrimSpace(chatManage.RewriteQuery)
pipelineWarn(ctx, "Search", "expansion_skip", map[string]interface{}{ if query == "" {
"reason": "no_model",
})
return nil return nil
} }
model, err := p.modelService.GetChatModel(ctx, chatManage.ChatModelID)
if err != nil { expansions := make([]string, 0, 5)
pipelineWarn(ctx, "Search", "expansion_get_model_failed", map[string]interface{}{ seen := make(map[string]struct{})
"error": err.Error(), seen[strings.ToLower(query)] = struct{}{}
}) if q := strings.ToLower(chatManage.Query); q != "" {
return nil seen[q] = struct{}{}
} }
sys := "Generate up to 5 diverse paraphrases or keyword variants for the user query to improve keyword-based search recall. Respond ONLY with a JSON array of strings inside a fenced code block."
usr := chatManage.RewriteQuery addIfNew := func(s string) {
think := false s = strings.TrimSpace(s)
resp, err := model.Chat(ctx, []chat.Message{ if s == "" || len(s) < 3 {
{Role: "system", Content: sys}, return
{Role: "user", Content: usr},
}, &chat.ChatOptions{Temperature: 0.2, MaxCompletionTokens: 200, Thinking: &think})
if err != nil || resp.Content == "" {
pipelineWarn(ctx, "Search", "expansion_model_call_failed", map[string]interface{}{
"error": err,
})
return nil
}
body := extractJSONBlock(resp.Content)
var arr []string
if err := json.Unmarshal([]byte(body), &arr); err != nil || len(arr) == 0 {
// Fallback: split lines
lines := strings.Split(resp.Content, "\n")
for _, l := range lines {
l = strings.TrimSpace(l)
if l != "" {
arr = append(arr, l)
}
}
}
uniq := make(map[string]struct{})
base := []string{chatManage.Query, chatManage.RewriteQuery, chatManage.ProcessedQuery}
for _, b := range base {
if s := strings.TrimSpace(b); s != "" {
uniq[strings.ToLower(s)] = struct{}{}
}
}
expansions := make([]string, 0, len(arr))
for _, a := range arr {
s := strings.TrimSpace(a)
if s == "" {
continue
} }
key := strings.ToLower(s) key := strings.ToLower(s)
if _, ok := uniq[key]; ok { if _, ok := seen[key]; ok {
continue return
} }
uniq[key] = struct{}{} seen[key] = struct{}{}
expansions = append(expansions, s) expansions = append(expansions, s)
if len(expansions) >= 5 { }
break
// 1. Remove common stopwords and create keyword-only variant
keywords := extractKeywords(query)
if len(keywords) >= 2 {
addIfNew(strings.Join(keywords, " "))
}
// 2. Extract quoted phrases or key segments
phrases := extractPhrases(query)
for _, phrase := range phrases {
addIfNew(phrase)
}
// 3. Split by common delimiters and use longest segment
segments := splitByDelimiters(query)
for _, seg := range segments {
if len(seg) > 5 {
addIfNew(seg)
} }
} }
pipelineInfo(ctx, "Search", "expansion_result", map[string]interface{}{
// 4. Remove question words (什么/如何/怎么/为什么/哪个 etc.)
cleaned := removeQuestionWords(query)
if cleaned != query {
addIfNew(cleaned)
}
// Limit to 5 expansions
if len(expansions) > 5 {
expansions = expansions[:5]
}
pipelineInfo(ctx, "Search", "local_expansion_result", map[string]interface{}{
"variants": len(expansions), "variants": len(expansions),
}) })
return expansions return expansions
} }
func extractJSONBlock(text string) string { // Common Chinese and English stopwords
t := strings.TrimSpace(text) var stopwords = map[string]struct{}{
if i := strings.Index(t, "["); i >= 0 { "的": {}, "是": {}, "在": {}, "了": {}, "": {}, "与": {}, "或": {},
j := strings.LastIndex(t, "]") "a": {}, "an": {}, "the": {}, "is": {}, "are": {}, "was": {}, "were": {},
if j > i { "be": {}, "been": {}, "being": {}, "have": {}, "has": {}, "had": {},
return t[i : j+1] "do": {}, "does": {}, "did": {}, "will": {}, "would": {}, "could": {},
} "should": {}, "may": {}, "might": {}, "must": {}, "can": {},
} "to": {}, "of": {}, "in": {}, "for": {}, "on": {}, "with": {}, "at": {},
return "[]" "by": {}, "from": {}, "as": {}, "into": {}, "through": {}, "about": {},
"what": {}, "how": {}, "why": {}, "when": {}, "where": {}, "which": {},
"who": {}, "whom": {}, "whose": {},
} }
// normalizeKeywordSearchResults normalizes keyword search result scores into [0,1] globally across all knowledge bases // Question words in Chinese
// Improvements: var questionWords = regexp.MustCompile(`^(什么是|什么|如何|怎么|怎样|为什么|为何|哪个|哪些|谁|何时|何地|请问|请告诉我|帮我|我想知道|我想了解)`)
// 1. Uses robust normalization with percentile-based bounds to handle outliers
// 2. Handles edge cases: single result, no variance, negative scores func extractKeywords(text string) []string {
// 3. Global normalization ensures fair comparison across different knowledge bases words := tokenize(text)
func normalizeKeywordSearchResults(ctx context.Context, results []*types.SearchResult) { keywords := make([]string, 0, len(words))
searchutil.NormalizeKeywordScores[*types.SearchResult]( for _, w := range words {
results, lower := strings.ToLower(w)
func(r *types.SearchResult) bool { if _, isStop := stopwords[lower]; !isStop && len(w) > 1 {
return r.MatchType == types.MatchTypeKeywords keywords = append(keywords, w)
}, }
func(r *types.SearchResult) float64 { }
return r.Score return keywords
}, }
func(r *types.SearchResult, score float64) {
r.Score = score func extractPhrases(text string) []string {
}, // Extract quoted content
searchutil.KeywordScoreCallbacks{ var phrases []string
OnNoVariance: func(count int, score float64) { re := regexp.MustCompile(`["'"'「」『』]([^"'"'「」『』]+)["'"'「」『』]`)
pipelineInfo(ctx, "Search", "keyword_scores_no_variance", map[string]interface{}{ matches := re.FindAllStringSubmatch(text, -1)
"count": count, for _, m := range matches {
"score": score, if len(m) > 1 && len(m[1]) > 2 {
}) phrases = append(phrases, m[1])
}, }
OnNormalized: func(count int, rawMin, rawMax, normalizeMin, normalizeMax float64) { }
pipelineInfo(ctx, "Search", "normalize_keyword_scores", map[string]interface{}{ return phrases
"count": count, }
"raw_min": rawMin,
"raw_max": rawMax, func splitByDelimiters(text string) []string {
"normalize_min": normalizeMin, // Split by common delimiters
"normalize_max": normalizeMax, re := regexp.MustCompile(`[,;;、。!?!?\s]+`)
}) parts := re.Split(text, -1)
}, var result []string
}, for _, p := range parts {
) p = strings.TrimSpace(p)
if p != "" {
result = append(result, p)
}
}
return result
}
func removeQuestionWords(text string) string {
return strings.TrimSpace(questionWords.ReplaceAllString(text, ""))
}
func tokenize(text string) []string {
var tokens []string
var current strings.Builder
for _, r := range text {
if unicode.IsLetter(r) || unicode.IsDigit(r) {
current.WriteRune(r)
} else if unicode.Is(unicode.Han, r) {
// Flush current token
if current.Len() > 0 {
tokens = append(tokens, current.String())
current.Reset()
}
// Chinese character as single token
tokens = append(tokens, string(r))
} else {
// Delimiter
if current.Len() > 0 {
tokens = append(tokens, current.String())
current.Reset()
}
}
}
if current.Len() > 0 {
tokens = append(tokens, current.String())
}
return tokens
} }

View File

@@ -190,5 +190,6 @@ func chunk2SearchResult(chunk *types.Chunk, knowledge *types.Knowledge) *types.S
ImageInfo: chunk.ImageInfo, ImageInfo: chunk.ImageInfo,
KnowledgeFilename: knowledge.FileName, KnowledgeFilename: knowledge.FileName,
KnowledgeSource: knowledge.Source, KnowledgeSource: knowledge.Source,
ChunkMetadata: chunk.Metadata,
} }
} }

View File

@@ -0,0 +1,181 @@
package chatpipline
import (
"context"
"sync"
"github.com/Tencent/WeKnora/internal/config"
"github.com/Tencent/WeKnora/internal/logger"
"github.com/Tencent/WeKnora/internal/types"
"github.com/Tencent/WeKnora/internal/types/interfaces"
)
// PluginSearchParallel implements parallel search functionality combining chunk search and entity search
type PluginSearchParallel struct {
// Chunk search dependencies
knowledgeBaseService interfaces.KnowledgeBaseService
knowledgeService interfaces.KnowledgeService
config *config.Config
webSearchService interfaces.WebSearchService
tenantService interfaces.TenantService
sessionService interfaces.SessionService
// Entity search dependencies
graphRepo interfaces.RetrieveGraphRepository
chunkRepo interfaces.ChunkRepository
knowledgeRepo interfaces.KnowledgeRepository
// Internal plugins
searchPlugin *PluginSearch
searchEntityPlugin *PluginSearchEntity
}
// NewPluginSearchParallel creates a new parallel search plugin
func NewPluginSearchParallel(
eventManager *EventManager,
knowledgeBaseService interfaces.KnowledgeBaseService,
knowledgeService interfaces.KnowledgeService,
config *config.Config,
webSearchService interfaces.WebSearchService,
tenantService interfaces.TenantService,
sessionService interfaces.SessionService,
graphRepository interfaces.RetrieveGraphRepository,
chunkRepository interfaces.ChunkRepository,
knowledgeRepository interfaces.KnowledgeRepository,
) *PluginSearchParallel {
// Create internal plugins without registering them
searchPlugin := &PluginSearch{
knowledgeBaseService: knowledgeBaseService,
knowledgeService: knowledgeService,
config: config,
webSearchService: webSearchService,
tenantService: tenantService,
sessionService: sessionService,
}
searchEntityPlugin := &PluginSearchEntity{
graphRepo: graphRepository,
chunkRepo: chunkRepository,
knowledgeRepo: knowledgeRepository,
}
res := &PluginSearchParallel{
knowledgeBaseService: knowledgeBaseService,
knowledgeService: knowledgeService,
config: config,
webSearchService: webSearchService,
tenantService: tenantService,
sessionService: sessionService,
graphRepo: graphRepository,
chunkRepo: chunkRepository,
knowledgeRepo: knowledgeRepository,
searchPlugin: searchPlugin,
searchEntityPlugin: searchEntityPlugin,
}
eventManager.Register(res)
return res
}
// ActivationEvents returns the event types this plugin handles
func (p *PluginSearchParallel) ActivationEvents() []types.EventType {
return []types.EventType{types.CHUNK_SEARCH_PARALLEL}
}
// OnEvent handles parallel search events - runs chunk search and entity search concurrently
func (p *PluginSearchParallel) OnEvent(ctx context.Context,
eventType types.EventType, chatManage *types.ChatManage, next func() *PluginError,
) *PluginError {
pipelineInfo(ctx, "SearchParallel", "start", map[string]interface{}{
"session_id": chatManage.SessionID,
"has_entities": len(chatManage.Entity) > 0,
"rewrite_query": chatManage.RewriteQuery,
})
var wg sync.WaitGroup
var mu sync.Mutex
var chunkSearchErr *PluginError
var entitySearchErr *PluginError
// Use separate ChatManage copies to avoid concurrent write conflicts
chunkChatManage := *chatManage
chunkChatManage.SearchResult = nil
entityChatManage := *chatManage
entityChatManage.SearchResult = nil
// Run chunk search and entity search in parallel
wg.Add(2)
// Goroutine 1: Chunk Search
go func() {
defer wg.Done()
err := p.searchPlugin.OnEvent(ctx, types.CHUNK_SEARCH, &chunkChatManage, func() *PluginError {
return nil
})
if err != nil && err != ErrSearchNothing {
mu.Lock()
chunkSearchErr = err
mu.Unlock()
}
pipelineInfo(ctx, "SearchParallel", "chunk_search_done", map[string]interface{}{
"result_count": len(chunkChatManage.SearchResult),
"has_error": err != nil && err != ErrSearchNothing,
})
}()
// Goroutine 2: Entity Search (only if entities are available)
go func() {
defer wg.Done()
if len(chatManage.Entity) == 0 {
pipelineInfo(ctx, "SearchParallel", "entity_search_skip", map[string]interface{}{
"reason": "no_entities",
})
return
}
err := p.searchEntityPlugin.OnEvent(ctx, types.ENTITY_SEARCH, &entityChatManage, func() *PluginError {
return nil
})
if err != nil && err != ErrSearchNothing {
mu.Lock()
entitySearchErr = err
mu.Unlock()
}
pipelineInfo(ctx, "SearchParallel", "entity_search_done", map[string]interface{}{
"result_count": len(entityChatManage.SearchResult),
"has_error": err != nil && err != ErrSearchNothing,
})
}()
wg.Wait()
// Merge results from both searches (no concurrent access now)
chatManage.SearchResult = append(chunkChatManage.SearchResult, entityChatManage.SearchResult...)
chatManage.SearchResult = removeDuplicateResults(chatManage.SearchResult)
// Log any errors but don't fail the pipeline if at least one search succeeded
if chunkSearchErr != nil {
logger.Warnf(ctx, "[SearchParallel] Chunk search error: %v", chunkSearchErr.Err)
}
if entitySearchErr != nil {
logger.Warnf(ctx, "[SearchParallel] Entity search error: %v", entitySearchErr.Err)
}
pipelineInfo(ctx, "SearchParallel", "complete", map[string]interface{}{
"session_id": chatManage.SessionID,
"chunk_results": len(chunkChatManage.SearchResult),
"entity_results": len(entityChatManage.SearchResult),
"total_results": len(chatManage.SearchResult),
"chunk_search_error": chunkSearchErr != nil,
"entity_search_error": entitySearchErr != nil,
})
// Return error only if both searches failed and we have no results
if len(chatManage.SearchResult) == 0 {
if chunkSearchErr != nil {
return chunkSearchErr
}
return ErrSearchNothing
}
return next()
}

View File

@@ -34,7 +34,7 @@ func (p *PluginTracing) ActivationEvents() []types.EventType {
types.CHAT_COMPLETION_STREAM, types.CHAT_COMPLETION_STREAM,
types.FILTER_TOP_K, types.FILTER_TOP_K,
types.REWRITE_QUERY, types.REWRITE_QUERY,
types.PREPROCESS_QUERY, types.CHUNK_SEARCH_PARALLEL,
} }
} }
@@ -69,8 +69,8 @@ func (p *PluginTracing) OnEvent(ctx context.Context,
return p.FilterTopK(ctx, eventType, chatManage, next) return p.FilterTopK(ctx, eventType, chatManage, next)
case types.REWRITE_QUERY: case types.REWRITE_QUERY:
return p.RewriteQuery(ctx, eventType, chatManage, next) return p.RewriteQuery(ctx, eventType, chatManage, next)
case types.PREPROCESS_QUERY: case types.CHUNK_SEARCH_PARALLEL:
return p.PreprocessQuery(ctx, eventType, chatManage, next) return p.SearchParallel(ctx, eventType, chatManage, next)
} }
return next() return next()
} }
@@ -95,7 +95,6 @@ func (p *PluginTracing) Search(ctx context.Context,
} }
span.SetAttributes( span.SetAttributes(
attribute.String("hybrid_search", string(searchResultJson)), attribute.String("hybrid_search", string(searchResultJson)),
attribute.String("processed_query", chatManage.ProcessedQuery),
attribute.Int("search_unique_count", len(unique)), attribute.Int("search_unique_count", len(unique)),
) )
return err return err
@@ -119,7 +118,6 @@ func (p *PluginTracing) Rerank(ctx context.Context,
span.SetAttributes( span.SetAttributes(
attribute.Int("rerank_resp_count", len(chatManage.RerankResult)), attribute.Int("rerank_resp_count", len(chatManage.RerankResult)),
attribute.String("rerank_resp_results", string(resultJson)), attribute.String("rerank_resp_results", string(resultJson)),
attribute.String("query_intent", chatManage.QueryIntent),
) )
return err return err
} }
@@ -266,22 +264,20 @@ func (p *PluginTracing) RewriteQuery(ctx context.Context,
return err return err
} }
// PreprocessQuery traces query preprocessing operations // SearchParallel traces parallel search operations (chunk + entity)
func (p *PluginTracing) PreprocessQuery(ctx context.Context, func (p *PluginTracing) SearchParallel(ctx context.Context,
eventType types.EventType, chatManage *types.ChatManage, next func() *PluginError, eventType types.EventType, chatManage *types.ChatManage, next func() *PluginError,
) *PluginError { ) *PluginError {
_, span := tracing.ContextWithSpan(ctx, "PluginTracing.PreprocessQuery") _, span := tracing.ContextWithSpan(ctx, "PluginTracing.SearchParallel")
defer span.End() defer span.End()
span.SetAttributes( span.SetAttributes(
attribute.String("query", chatManage.Query), attribute.String("query", chatManage.Query),
attribute.String("rewrite_query", chatManage.RewriteQuery),
attribute.Int("entity_count", len(chatManage.Entity)),
) )
err := next() err := next()
span.SetAttributes( span.SetAttributes(
attribute.String("processed_query", chatManage.ProcessedQuery), attribute.Int("search_result_count", len(chatManage.SearchResult)),
) )
return err return err
} }

View File

@@ -9,7 +9,6 @@ import (
"time" "time"
"github.com/Tencent/WeKnora/internal/application/service/retriever" "github.com/Tencent/WeKnora/internal/application/service/retriever"
"github.com/Tencent/WeKnora/internal/common"
"github.com/Tencent/WeKnora/internal/logger" "github.com/Tencent/WeKnora/internal/logger"
"github.com/Tencent/WeKnora/internal/models/embedding" "github.com/Tencent/WeKnora/internal/models/embedding"
"github.com/Tencent/WeKnora/internal/types" "github.com/Tencent/WeKnora/internal/types"
@@ -523,35 +522,113 @@ func (s *knowledgeBaseService) HybridSearch(ctx context.Context,
// Collect all results from different retrievers and deduplicate by chunk ID // Collect all results from different retrievers and deduplicate by chunk ID
logger.Infof(ctx, "Processing retrieval results") logger.Infof(ctx, "Processing retrieval results")
matchResults := []*types.IndexWithScore{}
// Separate results by retriever type for RRF fusion
var vectorResults []*types.IndexWithScore
var keywordResults []*types.IndexWithScore
for _, retrieveResult := range retrieveResults { for _, retrieveResult := range retrieveResults {
logger.Infof(ctx, "Retrieval results, engine: %v, retriever: %v, count: %v", logger.Infof(ctx, "Retrieval results, engine: %v, retriever: %v, count: %v",
retrieveResult.RetrieverEngineType, retrieveResult.RetrieverEngineType,
retrieveResult.RetrieverType, retrieveResult.RetrieverType,
len(retrieveResult.Results), len(retrieveResult.Results),
) )
matchResults = append(matchResults, retrieveResult.Results...) if retrieveResult.RetrieverType == types.VectorRetrieverType {
vectorResults = append(vectorResults, retrieveResult.Results...)
} else {
keywordResults = append(keywordResults, retrieveResult.Results...)
}
} }
// Early return if no results // Early return if no results
if len(matchResults) == 0 { if len(vectorResults) == 0 && len(keywordResults) == 0 {
logger.Info(ctx, "No search results found") logger.Info(ctx, "No search results found")
return nil, nil return nil, nil
} }
logger.Infof(ctx, "Result count before deduplication: %d", len(matchResults)) logger.Infof(ctx, "Result count before RRF fusion: vector=%d, keyword=%d", len(vectorResults), len(keywordResults))
// First, try standard deduplication // Use RRF (Reciprocal Rank Fusion) to merge results
deduplicatedChunks := common.DeduplicateWithScore( // RRF score = sum(1 / (k + rank)) for each retriever where the chunk appears
func(r *types.IndexWithScore) string { return r.ChunkID }, // k=60 is a common choice that works well in practice
matchResults...) const rrfK = 60
logger.Infof(ctx, "Result count after deduplication: %d", len(deduplicatedChunks))
// Build rank maps for each retriever (already sorted by score from retriever)
vectorRanks := make(map[string]int)
for i, r := range vectorResults {
if _, exists := vectorRanks[r.ChunkID]; !exists {
vectorRanks[r.ChunkID] = i + 1 // 1-indexed rank
}
}
keywordRanks := make(map[string]int)
for i, r := range keywordResults {
if _, exists := keywordRanks[r.ChunkID]; !exists {
keywordRanks[r.ChunkID] = i + 1 // 1-indexed rank
}
}
// Collect all unique chunks and compute RRF scores
chunkInfoMap := make(map[string]*types.IndexWithScore)
rrfScores := make(map[string]float64)
// Process vector results
for _, r := range vectorResults {
if _, exists := chunkInfoMap[r.ChunkID]; !exists {
chunkInfoMap[r.ChunkID] = r
}
}
// Process keyword results
for _, r := range keywordResults {
if _, exists := chunkInfoMap[r.ChunkID]; !exists {
chunkInfoMap[r.ChunkID] = r
}
}
// Compute RRF scores
for chunkID := range chunkInfoMap {
rrfScore := 0.0
if rank, ok := vectorRanks[chunkID]; ok {
rrfScore += 1.0 / float64(rrfK+rank)
}
if rank, ok := keywordRanks[chunkID]; ok {
rrfScore += 1.0 / float64(rrfK+rank)
}
rrfScores[chunkID] = rrfScore
}
// Convert to slice and sort by RRF score
deduplicatedChunks := make([]*types.IndexWithScore, 0, len(chunkInfoMap))
for chunkID, info := range chunkInfoMap {
// Store RRF score in the Score field for downstream processing
info.Score = rrfScores[chunkID]
deduplicatedChunks = append(deduplicatedChunks, info)
}
slices.SortFunc(deduplicatedChunks, func(a, b *types.IndexWithScore) int {
if a.Score > b.Score {
return -1
} else if a.Score < b.Score {
return 1
}
return 0
})
logger.Infof(ctx, "Result count after RRF fusion: %d", len(deduplicatedChunks))
// Log top results after RRF fusion for debugging
for i, chunk := range deduplicatedChunks {
if i < 15 {
vRank, vOk := vectorRanks[chunk.ChunkID]
kRank, kOk := keywordRanks[chunk.ChunkID]
logger.Debugf(ctx, "RRF rank %d: chunk_id=%s, rrf_score=%.6f, vector_rank=%v(%v), keyword_rank=%v(%v)",
i, chunk.ChunkID, chunk.Score, vRank, vOk, kRank, kOk)
}
}
kb.EnsureDefaults() kb.EnsureDefaults()
// Check if we need iterative retrieval for FAQ with separate indexing // Check if we need iterative retrieval for FAQ with separate indexing
// Only use iterative retrieval if we don't have enough unique chunks after first deduplication // Only use iterative retrieval if we don't have enough unique chunks after first deduplication
totalRetrieved := len(vectorResults) + len(keywordResults)
needsIterativeRetrieval := len(deduplicatedChunks) < params.MatchCount && needsIterativeRetrieval := len(deduplicatedChunks) < params.MatchCount &&
kb.Type == types.KnowledgeBaseTypeFAQ && len(matchResults) == matchCount kb.Type == types.KnowledgeBaseTypeFAQ && totalRetrieved == matchCount*2
if needsIterativeRetrieval { if needsIterativeRetrieval {
logger.Info(ctx, "Not enough unique chunks, using iterative retrieval for FAQ") logger.Info(ctx, "Not enough unique chunks, using iterative retrieval for FAQ")
// Use iterative retrieval to get more unique chunks (with negative question filtering inside) // Use iterative retrieval to get more unique chunks (with negative question filtering inside)
@@ -891,10 +968,38 @@ func (s *knowledgeBaseService) processSearchResults(ctx context.Context,
} }
} }
// Build final search results // Build final search results - preserve original order from input chunks
var searchResults []*types.SearchResult var searchResults []*types.SearchResult
for chunkID, chunk := range chunkMap { addedChunkIDs := make(map[string]bool)
// First pass: Add results in the original order from input chunks
for _, inputChunk := range chunks {
chunk, exists := chunkMap[inputChunk.ChunkID]
if !exists {
logger.Debugf(ctx, "Chunk not found in chunkMap: %s", inputChunk.ChunkID)
continue
}
if !s.isValidTextChunk(chunk) { if !s.isValidTextChunk(chunk) {
logger.Debugf(ctx, "Chunk is not valid text chunk: %s, type: %s", chunk.ID, chunk.ChunkType)
continue
}
if addedChunkIDs[chunk.ID] {
continue
}
score := chunkScores[chunk.ID]
if knowledge, ok := knowledgeMap[chunk.KnowledgeID]; ok {
matchType := chunkMatchTypes[chunk.ID]
searchResults = append(searchResults, s.buildSearchResult(chunk, knowledge, score, matchType))
addedChunkIDs[chunk.ID] = true
} else {
logger.Warnf(ctx, "Knowledge not found for chunk: %s, knowledge_id: %s", chunk.ID, chunk.KnowledgeID)
}
}
// Second pass: Add additional chunks (parent, nearby, relation) that weren't in original input
for chunkID, chunk := range chunkMap {
if addedChunkIDs[chunkID] || !s.isValidTextChunk(chunk) {
continue continue
} }
@@ -959,6 +1064,7 @@ func (s *knowledgeBaseService) buildSearchResult(chunk *types.Chunk,
ImageInfo: chunk.ImageInfo, ImageInfo: chunk.ImageInfo,
KnowledgeFilename: knowledge.FileName, KnowledgeFilename: knowledge.FileName,
KnowledgeSource: knowledge.Source, KnowledgeSource: knowledge.Source,
ChunkMetadata: chunk.Metadata,
} }
} }

View File

@@ -800,7 +800,6 @@ func (s *sessionService) SearchKnowledge(ctx context.Context,
// Use specific event list, only including retrieval-related events, not LLM summarization // Use specific event list, only including retrieval-related events, not LLM summarization
searchEvents := []types.EventType{ searchEvents := []types.EventType{
types.PREPROCESS_QUERY, // Preprocess query
types.CHUNK_SEARCH, // Vector search types.CHUNK_SEARCH, // Vector search
types.CHUNK_RERANK, // Rerank search results types.CHUNK_RERANK, // Rerank search results
types.CHUNK_MERGE, // Merge search results types.CHUNK_MERGE, // Merge search results

View File

@@ -140,10 +140,10 @@ func BuildContainer(container *dig.Container) *dig.Container {
must(container.Invoke(chatpipline.NewPluginChatCompletionStream)) must(container.Invoke(chatpipline.NewPluginChatCompletionStream))
must(container.Invoke(chatpipline.NewPluginStreamFilter)) must(container.Invoke(chatpipline.NewPluginStreamFilter))
must(container.Invoke(chatpipline.NewPluginFilterTopK)) must(container.Invoke(chatpipline.NewPluginFilterTopK))
must(container.Invoke(chatpipline.NewPluginPreprocess))
must(container.Invoke(chatpipline.NewPluginRewrite)) must(container.Invoke(chatpipline.NewPluginRewrite))
must(container.Invoke(chatpipline.NewPluginExtractEntity)) must(container.Invoke(chatpipline.NewPluginExtractEntity))
must(container.Invoke(chatpipline.NewPluginSearchEntity)) must(container.Invoke(chatpipline.NewPluginSearchEntity))
must(container.Invoke(chatpipline.NewPluginSearchParallel))
// HTTP handlers layer // HTTP handlers layer
must(container.Provide(handler.NewTenantHandler)) must(container.Provide(handler.NewTenantHandler))

View File

@@ -5,7 +5,6 @@ package event
// QueryData represents query-related event data // QueryData represents query-related event data
type QueryData struct { type QueryData struct {
OriginalQuery string `json:"original_query"` OriginalQuery string `json:"original_query"`
ProcessedQuery string `json:"processed_query,omitempty"`
RewrittenQuery string `json:"rewritten_query,omitempty"` RewrittenQuery string `json:"rewritten_query,omitempty"`
SessionID string `json:"session_id"` SessionID string `json:"session_id"`
UserID string `json:"user_id,omitempty"` UserID string `json:"user_id,omitempty"`

View File

@@ -0,0 +1,70 @@
package searchutil
import (
"crypto/md5"
"encoding/hex"
"strings"
)
// BuildContentSignature creates a normalized MD5 signature for content to detect duplicates.
// It normalizes the content by lowercasing, trimming whitespace, and collapsing multiple spaces.
func BuildContentSignature(content string) string {
c := strings.ToLower(strings.TrimSpace(content))
if c == "" {
return ""
}
// Normalize whitespace
c = strings.Join(strings.Fields(c), " ")
// Use MD5 hash of full content
hash := md5.Sum([]byte(c))
return hex.EncodeToString(hash[:])
}
// TokenizeSimple tokenizes text into a set of words (simple whitespace-based).
// Returns a map where keys are lowercase tokens with length > 1.
func TokenizeSimple(text string) map[string]struct{} {
text = strings.ToLower(text)
fields := strings.Fields(text)
set := make(map[string]struct{}, len(fields))
for _, f := range fields {
if len(f) > 1 {
set[f] = struct{}{}
}
}
return set
}
// Jaccard calculates Jaccard similarity between two token sets.
// Returns a value between 0 and 1, where 1 means identical sets.
func Jaccard(a, b map[string]struct{}) float64 {
if len(a) == 0 && len(b) == 0 {
return 0
}
// Calculate intersection
inter := 0
for k := range a {
if _, ok := b[k]; ok {
inter++
}
}
// Calculate union
union := len(a) + len(b) - inter
if union == 0 {
return 0
}
return float64(inter) / float64(union)
}
// ClampFloat clamps a float value to the specified range [minV, maxV].
func ClampFloat(v, minV, maxV float64) float64 {
if v < minV {
return minV
}
if v > maxV {
return maxV
}
return v
}

View File

@@ -5,9 +5,7 @@ package types
type ChatManage struct { type ChatManage struct {
SessionID string `json:"session_id"` // Unique identifier for the chat session SessionID string `json:"session_id"` // Unique identifier for the chat session
Query string `json:"query,omitempty"` // Original user query Query string `json:"query,omitempty"` // Original user query
ProcessedQuery string `json:"processed_query,omitempty"` // Query after preprocessing
RewriteQuery string `json:"rewrite_query,omitempty"` // Query after rewriting for better retrieval RewriteQuery string `json:"rewrite_query,omitempty"` // Query after rewriting for better retrieval
QueryIntent string `json:"query_intent,omitempty"` // Parsed intent: definition/howto/compare/qa/general
History []*History `json:"history,omitempty"` // Chat history for context History []*History `json:"history,omitempty"` // Chat history for context
KnowledgeBaseID string `json:"knowledge_base_id"` // ID of the knowledge base to search against (deprecated, use KnowledgeBaseIDs) KnowledgeBaseID string `json:"knowledge_base_id"` // ID of the knowledge base to search against (deprecated, use KnowledgeBaseIDs)
@@ -60,9 +58,7 @@ func (c *ChatManage) Clone() *ChatManage {
return &ChatManage{ return &ChatManage{
Query: c.Query, Query: c.Query,
ProcessedQuery: c.ProcessedQuery,
RewriteQuery: c.RewriteQuery, RewriteQuery: c.RewriteQuery,
QueryIntent: c.QueryIntent,
SessionID: c.SessionID, SessionID: c.SessionID,
KnowledgeBaseID: c.KnowledgeBaseID, KnowledgeBaseID: c.KnowledgeBaseID,
KnowledgeBaseIDs: knowledgeBaseIDs, KnowledgeBaseIDs: knowledgeBaseIDs,
@@ -103,9 +99,9 @@ func (c *ChatManage) Clone() *ChatManage {
type EventType string type EventType string
const ( const (
PREPROCESS_QUERY EventType = "preprocess_query" // Query preprocessing stage
REWRITE_QUERY EventType = "rewrite_query" // Query rewriting for better retrieval REWRITE_QUERY EventType = "rewrite_query" // Query rewriting for better retrieval
CHUNK_SEARCH EventType = "chunk_search" // Search for relevant chunks CHUNK_SEARCH EventType = "chunk_search" // Search for relevant chunks
CHUNK_SEARCH_PARALLEL EventType = "chunk_search_parallel" // Parallel search: chunks + entities
ENTITY_SEARCH EventType = "entity_search" // Search for relevant entities ENTITY_SEARCH EventType = "entity_search" // Search for relevant entities
CHUNK_RERANK EventType = "chunk_rerank" // Rerank search results CHUNK_RERANK EventType = "chunk_rerank" // Rerank search results
CHUNK_MERGE EventType = "chunk_merge" // Merge similar chunks CHUNK_MERGE EventType = "chunk_merge" // Merge similar chunks
@@ -134,9 +130,7 @@ var Pipline = map[string][]EventType{
}, },
"rag_stream": { // Streaming Retrieval Augmented Generation "rag_stream": { // Streaming Retrieval Augmented Generation
REWRITE_QUERY, REWRITE_QUERY,
PREPROCESS_QUERY, CHUNK_SEARCH_PARALLEL, // Parallel: CHUNK_SEARCH + ENTITY_SEARCH
CHUNK_SEARCH,
ENTITY_SEARCH,
CHUNK_RERANK, CHUNK_RERANK,
CHUNK_MERGE, CHUNK_MERGE,
FILTER_TOP_K, FILTER_TOP_K,

View File

@@ -46,6 +46,9 @@ type SearchResult struct {
// Knowledge source // Knowledge source
// Used to indicate the source of the knowledge, such as "url" // Used to indicate the source of the knowledge, such as "url"
KnowledgeSource string `json:"knowledge_source"` KnowledgeSource string `json:"knowledge_source"`
// ChunkMetadata stores chunk-level metadata (e.g., generated questions)
ChunkMetadata JSON `json:"chunk_metadata,omitempty"`
} }
// SearchParams represents the search parameters // SearchParams represents the search parameters