mirror of
https://github.com/Tencent/WeKnora.git
synced 2026-06-04 13:30:32 +08:00
feat: Implement parallel search functionality combining chunk and entity searches, enhancing retrieval efficiency and result accuracy
This commit is contained in:
6
go.mod
6
go.mod
@@ -29,7 +29,6 @@ require (
|
||||
github.com/spf13/viper v1.20.1
|
||||
github.com/stretchr/testify v1.11.1
|
||||
github.com/tencentyun/cos-go-sdk-v5 v0.7.65
|
||||
github.com/xuri/excelize/v2 v2.10.0
|
||||
github.com/yanyiwu/gojieba v1.4.5
|
||||
go.opentelemetry.io/otel v1.37.0
|
||||
go.opentelemetry.io/otel/exporters/otlp/otlptrace v1.37.0
|
||||
@@ -106,8 +105,6 @@ require (
|
||||
github.com/pelletier/go-toml/v2 v2.2.3 // indirect
|
||||
github.com/pierrec/lz4/v4 v4.1.21 // indirect
|
||||
github.com/pmezard/go-difflib v1.0.0 // indirect
|
||||
github.com/richardlehane/mscfb v1.0.4 // indirect
|
||||
github.com/richardlehane/msoleps v1.0.4 // indirect
|
||||
github.com/rivo/uniseg v0.4.7 // indirect
|
||||
github.com/robfig/cron/v3 v3.0.1 // indirect
|
||||
github.com/rs/xid v1.6.0 // indirect
|
||||
@@ -117,12 +114,9 @@ require (
|
||||
github.com/spf13/cast v1.10.0 // indirect
|
||||
github.com/spf13/pflag v1.0.6 // indirect
|
||||
github.com/subosito/gotenv v1.6.0 // indirect
|
||||
github.com/tiendc/go-deepcopy v1.7.1 // indirect
|
||||
github.com/twitchyliquid64/golang-asm v0.15.1 // indirect
|
||||
github.com/ugorji/go/codec v1.2.12 // indirect
|
||||
github.com/wk8/go-ordered-map/v2 v2.1.8 // indirect
|
||||
github.com/xuri/efp v0.0.1 // indirect
|
||||
github.com/xuri/nfp v0.0.2-0.20250530014748-2ddeb826f9a9 // indirect
|
||||
github.com/yosida95/uritemplate/v3 v3.0.2 // indirect
|
||||
go.opentelemetry.io/auto/sdk v1.1.0 // indirect
|
||||
go.opentelemetry.io/otel/metric v1.37.0 // indirect
|
||||
|
||||
15
go.sum
15
go.sum
@@ -235,11 +235,6 @@ github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZb
|
||||
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
|
||||
github.com/redis/go-redis/v9 v9.14.0 h1:u4tNCjXOyzfgeLN+vAZaW1xUooqWDqVEsZN0U01jfAE=
|
||||
github.com/redis/go-redis/v9 v9.14.0/go.mod h1:huWgSWd8mW6+m0VPhJjSSQ+d6Nh1VICQ6Q5lHuCH/Iw=
|
||||
github.com/richardlehane/mscfb v1.0.4 h1:WULscsljNPConisD5hR0+OyZjwK46Pfyr6mPu5ZawpM=
|
||||
github.com/richardlehane/mscfb v1.0.4/go.mod h1:YzVpcZg9czvAuhk9T+a3avCpcFPMUWm7gK3DypaEsUk=
|
||||
github.com/richardlehane/msoleps v1.0.1/go.mod h1:BWev5JBpU9Ko2WAgmZEuiz4/u3ZYTKbjLycmwiWUfWg=
|
||||
github.com/richardlehane/msoleps v1.0.4 h1:WuESlvhX3gH2IHcd8UqyCuFY5yiq/GR/yqaSM/9/g00=
|
||||
github.com/richardlehane/msoleps v1.0.4/go.mod h1:BWev5JBpU9Ko2WAgmZEuiz4/u3ZYTKbjLycmwiWUfWg=
|
||||
github.com/rivo/uniseg v0.2.0/go.mod h1:J6wj4VEh+S6ZtnVlnTBMWIodfgj8LQOQFoIToxlJtxc=
|
||||
github.com/rivo/uniseg v0.4.7 h1:WUdvkW8uEhrYfLC4ZzdpI2ztxP1I582+49Oc5Mq64VQ=
|
||||
github.com/rivo/uniseg v0.4.7/go.mod h1:FN3SvrM+Zdj16jyLfmOkMNblXMcoc8DfTHruCPUcx88=
|
||||
@@ -284,8 +279,6 @@ github.com/tencentcloud/tencentcloud-sdk-go/tencentcloud/common v1.0.563/go.mod
|
||||
github.com/tencentcloud/tencentcloud-sdk-go/tencentcloud/kms v1.0.563/go.mod h1:uom4Nvi9W+Qkom0exYiJ9VWJjXwyxtPYTkKkaLMlfE0=
|
||||
github.com/tencentyun/cos-go-sdk-v5 v0.7.65 h1:+WBbfwThfZSbxpf1Dw6fyMwyzVtWBBExqfDJ5giiR2s=
|
||||
github.com/tencentyun/cos-go-sdk-v5 v0.7.65/go.mod h1:8+hG+mQMuRP/OIS9d83syAvXvrMj9HhkND6Q1fLghw0=
|
||||
github.com/tiendc/go-deepcopy v1.7.1 h1:LnubftI6nYaaMOcaz0LphzwraqN8jiWTwm416sitff4=
|
||||
github.com/tiendc/go-deepcopy v1.7.1/go.mod h1:4bKjNC2r7boYOkD2IOuZpYjmlDdzjbpTRyCx+goBCJQ=
|
||||
github.com/tmthrgd/go-hex v0.0.0-20190904060850-447a3041c3bc h1:9lRDQMhESg+zvGYmW5DyG0UqvY96Bu5QYsTLvCHdrgo=
|
||||
github.com/tmthrgd/go-hex v0.0.0-20190904060850-447a3041c3bc/go.mod h1:bciPuU6GHm1iF1pBvUfxfsH0Wmnc2VbpgvbI9ZWuIRs=
|
||||
github.com/twitchyliquid64/golang-asm v0.15.1 h1:SU5vSMR7hnwNxj24w34ZyCi/FmDZTkS4MhqMhdFk5YI=
|
||||
@@ -310,12 +303,6 @@ github.com/wk8/go-ordered-map/v2 v2.1.8 h1:5h/BUHu93oj4gIdvHHHGsScSTMijfx5PeYkE/
|
||||
github.com/wk8/go-ordered-map/v2 v2.1.8/go.mod h1:5nJHM5DyteebpVlHnWMV0rPz6Zp7+xBAnxjb1X5vnTw=
|
||||
github.com/x448/float16 v0.8.4 h1:qLwI1I70+NjRFUR3zs1JPUCgaCXSh3SW62uAKT1mSBM=
|
||||
github.com/x448/float16 v0.8.4/go.mod h1:14CWIYCyZA/cWjXOioeEpHeN/83MdbZDRQHoFcYsOfg=
|
||||
github.com/xuri/efp v0.0.1 h1:fws5Rv3myXyYni8uwj2qKjVaRP30PdjeYe2Y6FDsCL8=
|
||||
github.com/xuri/efp v0.0.1/go.mod h1:ybY/Jr0T0GTCnYjKqmdwxyxn2BQf2RcQIIvex5QldPI=
|
||||
github.com/xuri/excelize/v2 v2.10.0 h1:8aKsP7JD39iKLc6dH5Tw3dgV3sPRh8uRVXu/fMstfW4=
|
||||
github.com/xuri/excelize/v2 v2.10.0/go.mod h1:SC5TzhQkaOsTWpANfm+7bJCldzcnU/jrhqkTi/iBHBU=
|
||||
github.com/xuri/nfp v0.0.2-0.20250530014748-2ddeb826f9a9 h1:+C0TIdyyYmzadGaL/HBLbf3WdLgC29pgyhTjAT/0nuE=
|
||||
github.com/xuri/nfp v0.0.2-0.20250530014748-2ddeb826f9a9/go.mod h1:WwHg+CVyzlv/TX9xqBFXEZAuxOPxn2k1GNHwG41IIUQ=
|
||||
github.com/yanyiwu/gojieba v1.4.5 h1:VyZogGtdFSnJbACHvDRvDreXPPVPCg8axKFUdblU/JI=
|
||||
github.com/yanyiwu/gojieba v1.4.5/go.mod h1:JUq4DddFVGdHXJHxxepxRmhrKlDpaBxR8O28v6fKYLY=
|
||||
github.com/yosida95/uritemplate/v3 v3.0.2 h1:Ed3Oyj9yrmi9087+NczuL5BwkIc4wvTb5zIM+UJPGz4=
|
||||
@@ -359,8 +346,6 @@ golang.org/x/crypto v0.23.0/go.mod h1:CKFgDieR+mRhux2Lsu27y0fO304Db0wZe70UKqHu0v
|
||||
golang.org/x/crypto v0.31.0/go.mod h1:kDsLvtWBEx7MV9tJOj9bnXsPbxwJQ6csT/x4KIN4Ssk=
|
||||
golang.org/x/crypto v0.43.0 h1:dduJYIi3A3KOfdGOHX8AVZ/jGiyPa3IbBozJ5kNuE04=
|
||||
golang.org/x/crypto v0.43.0/go.mod h1:BFbav4mRNlXJL4wNeejLpWxB7wMbc79PdRGhWKncxR0=
|
||||
golang.org/x/image v0.25.0 h1:Y6uW6rH1y5y/LK1J8BPWZtr6yZ7hrsy6hFrXjgsc2fQ=
|
||||
golang.org/x/image v0.25.0/go.mod h1:tCAmOEGthTtkalusGp1g3xa2gke8J6c2N565dTyl9Rs=
|
||||
golang.org/x/mod v0.6.0-dev.0.20220419223038-86c51ed26bb4/go.mod h1:jJ57K6gSWd91VN4djpZkiMVwK6gcyfeH4XE8wZrZaV4=
|
||||
golang.org/x/mod v0.8.0/go.mod h1:iBbtSCu2XBx23ZKBPSOrRkjjQPZFPuis4dIYUhu/chs=
|
||||
golang.org/x/mod v0.12.0/go.mod h1:iBbtSCu2XBx23ZKBPSOrRkjjQPZFPuis4dIYUhu/chs=
|
||||
|
||||
@@ -8,6 +8,7 @@ import (
|
||||
"strings"
|
||||
|
||||
"github.com/Tencent/WeKnora/internal/logger"
|
||||
"github.com/Tencent/WeKnora/internal/searchutil"
|
||||
"github.com/Tencent/WeKnora/internal/types"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
@@ -546,17 +547,7 @@ func (t *GrepChunksTool) deduplicateChunks(ctx context.Context, results []chunkW
|
||||
|
||||
// buildContentSignature creates a normalized signature for content to detect near-duplicates
|
||||
func (t *GrepChunksTool) buildContentSignature(content string) string {
|
||||
c := strings.ToLower(strings.TrimSpace(content))
|
||||
if c == "" {
|
||||
return ""
|
||||
}
|
||||
// Normalize whitespace
|
||||
c = strings.Join(strings.Fields(c), " ")
|
||||
// Use first 128 characters as signature
|
||||
if len(c) > 128 {
|
||||
c = c[:128]
|
||||
}
|
||||
return c
|
||||
return searchutil.BuildContentSignature(content)
|
||||
}
|
||||
|
||||
// scoreChunks calculates match scores for chunks based on pattern matches
|
||||
@@ -693,36 +684,10 @@ func (t *GrepChunksTool) applyMMR(
|
||||
|
||||
// tokenizeSimple tokenizes text into a set of words (simple whitespace-based)
|
||||
func (t *GrepChunksTool) tokenizeSimple(text string) map[string]struct{} {
|
||||
text = strings.ToLower(text)
|
||||
fields := strings.Fields(text)
|
||||
set := make(map[string]struct{}, len(fields))
|
||||
for _, f := range fields {
|
||||
if len(f) > 1 {
|
||||
set[f] = struct{}{}
|
||||
}
|
||||
}
|
||||
return set
|
||||
return searchutil.TokenizeSimple(text)
|
||||
}
|
||||
|
||||
// jaccard calculates Jaccard similarity between two token sets
|
||||
func (t *GrepChunksTool) jaccard(a, b map[string]struct{}) float64 {
|
||||
if len(a) == 0 && len(b) == 0 {
|
||||
return 0
|
||||
}
|
||||
|
||||
// Calculate intersection
|
||||
inter := 0
|
||||
for k := range a {
|
||||
if _, ok := b[k]; ok {
|
||||
inter++
|
||||
}
|
||||
}
|
||||
|
||||
// Calculate union
|
||||
union := len(a) + len(b) - inter
|
||||
if union == 0 {
|
||||
return 0
|
||||
}
|
||||
|
||||
return float64(inter) / float64(union)
|
||||
return searchutil.Jaccard(a, b)
|
||||
}
|
||||
|
||||
@@ -264,15 +264,12 @@ func (t *KnowledgeSearchTool) Execute(ctx context.Context, args map[string]inter
|
||||
topK, vectorThreshold, keywordThreshold, kbTypeMap)
|
||||
logger.Infof(ctx, "[Tool][KnowledgeSearch] Concurrent search completed: %d raw results", len(allResults))
|
||||
|
||||
// Normalize keyword search results to ensure fair comparison across knowledge bases
|
||||
logger.Debugf(ctx, "[Tool][KnowledgeSearch] Normalizing keyword search results...")
|
||||
t.normalizeKeywordSearchResults(ctx, allResults)
|
||||
logger.Infof(ctx, "[Tool][KnowledgeSearch] After keyword normalization: %d results", len(allResults))
|
||||
// Note: HybridSearch now uses RRF (Reciprocal Rank Fusion) which produces normalized scores
|
||||
// RRF scores are in range [0, ~0.033] (max when rank=1 on both sides: 2/(60+1))
|
||||
// Threshold filtering is already done inside HybridSearch before RRF, so we skip it here
|
||||
|
||||
// Filter by threshold first
|
||||
filteredResults := t.filterByThreshold(allResults, vectorThreshold, keywordThreshold)
|
||||
// Deduplicate before reranking to reduce processing overhead
|
||||
deduplicatedBeforeRerank := t.deduplicateResults(filteredResults)
|
||||
deduplicatedBeforeRerank := t.deduplicateResults(allResults)
|
||||
|
||||
// Apply ReRank if model is configured
|
||||
// Prefer chatModel (LLM-based reranking) over rerankModel if both are available
|
||||
@@ -286,6 +283,9 @@ func (t *KnowledgeSearchTool) Execute(ctx context.Context, args map[string]inter
|
||||
}
|
||||
}
|
||||
|
||||
// Variable to hold results through reranking and MMR stages
|
||||
var filteredResults []*searchResultWithMeta
|
||||
|
||||
if t.chatModel != nil && len(deduplicatedBeforeRerank) > 0 && rerankQuery != "" {
|
||||
logger.Infof(
|
||||
ctx,
|
||||
@@ -347,10 +347,9 @@ func (t *KnowledgeSearchTool) Execute(ctx context.Context, args map[string]inter
|
||||
}
|
||||
}
|
||||
|
||||
// Apply absolute minimum score filter to remove very low quality chunks
|
||||
logger.Debugf(ctx, "[Tool][KnowledgeSearch] Applying min_score filter (%.2f)...", minScore)
|
||||
filteredResults = t.filterByMinScore(filteredResults, minScore)
|
||||
logger.Infof(ctx, "[Tool][KnowledgeSearch] After min_score filter: %d results", len(filteredResults))
|
||||
// Note: minScore filter is skipped because HybridSearch now uses RRF scores
|
||||
// RRF scores are in range [0, ~0.033], not [0, 1], so old thresholds don't apply
|
||||
// Threshold filtering is already done inside HybridSearch before RRF fusion
|
||||
|
||||
// Final deduplication after rerank (in case rerank changed scores/order but duplicates remain)
|
||||
logger.Debugf(ctx, "[Tool][KnowledgeSearch] Final deduplication after rerank...")
|
||||
@@ -465,45 +464,6 @@ func (t *KnowledgeSearchTool) concurrentSearch(
|
||||
return allResults
|
||||
}
|
||||
|
||||
// filterByThreshold filters results based on match type and threshold
|
||||
// Special handling for history matches: uses lower threshold (reduced by 0.1, minimum 0.5)
|
||||
func (t *KnowledgeSearchTool) filterByThreshold(
|
||||
results []*searchResultWithMeta,
|
||||
vectorThreshold, keywordThreshold float64,
|
||||
) []*searchResultWithMeta {
|
||||
filtered := make([]*searchResultWithMeta, 0)
|
||||
for _, r := range results {
|
||||
var threshold float64
|
||||
|
||||
// Special handling for history matches: use lower threshold
|
||||
switch r.MatchType {
|
||||
case types.MatchTypeHistory:
|
||||
// Use the lower of the two thresholds, then reduce by 0.1 (minimum 0.5)
|
||||
th := vectorThreshold
|
||||
if keywordThreshold < th {
|
||||
th = keywordThreshold
|
||||
}
|
||||
threshold = math.Max(th-0.1, 0.5)
|
||||
case types.MatchTypeEmbedding:
|
||||
threshold = vectorThreshold
|
||||
case types.MatchTypeKeywords:
|
||||
threshold = keywordThreshold
|
||||
default:
|
||||
// For other match types (graph, nearby chunk, etc.), use the lower threshold
|
||||
threshold = vectorThreshold
|
||||
if keywordThreshold < threshold {
|
||||
threshold = keywordThreshold
|
||||
}
|
||||
}
|
||||
|
||||
// Check if result meets threshold
|
||||
if r.Score >= threshold {
|
||||
filtered = append(filtered, r)
|
||||
}
|
||||
}
|
||||
return filtered
|
||||
}
|
||||
|
||||
// rerankResults applies reranking to search results using LLM prompt scoring or rerank model
|
||||
func (t *KnowledgeSearchTool) rerankResults(
|
||||
ctx context.Context,
|
||||
@@ -566,15 +526,13 @@ func (t *KnowledgeSearchTool) rerankResults(
|
||||
}
|
||||
|
||||
// Apply composite scoring to reranked results
|
||||
// Get query intent from context if available (optional)
|
||||
queryIntent := t.getQueryIntentFromContext(ctx)
|
||||
logger.Debugf(ctx, "[Tool][KnowledgeSearch] Applying composite scoring with query_intent=%s", queryIntent)
|
||||
logger.Debugf(ctx, "[Tool][KnowledgeSearch] Applying composite scoring")
|
||||
|
||||
// Store base scores before composite scoring
|
||||
for _, result := range rerankedNonFAQ {
|
||||
baseScore := result.Score
|
||||
// Apply composite score
|
||||
result.Score = t.compositeScore(result, result.Score, baseScore, queryIntent)
|
||||
result.Score = t.compositeScore(result, result.Score, baseScore)
|
||||
}
|
||||
|
||||
// Combine FAQ results (with original order) and reranked non-FAQ results
|
||||
@@ -899,20 +857,6 @@ func (t *KnowledgeSearchTool) rerankWithModel(
|
||||
return reranked, nil
|
||||
}
|
||||
|
||||
// filterByMinScore filters results by absolute minimum score
|
||||
func (t *KnowledgeSearchTool) filterByMinScore(
|
||||
results []*searchResultWithMeta,
|
||||
minScore float64,
|
||||
) []*searchResultWithMeta {
|
||||
filtered := make([]*searchResultWithMeta, 0)
|
||||
for _, r := range results {
|
||||
if r.Score >= minScore {
|
||||
filtered = append(filtered, r)
|
||||
}
|
||||
}
|
||||
return filtered
|
||||
}
|
||||
|
||||
// deduplicateResults removes duplicate chunks, keeping the highest score
|
||||
// Uses multiple keys (ID, parent chunk ID, knowledge+index) and content signature for deduplication
|
||||
func (t *KnowledgeSearchTool) deduplicateResults(results []*searchResultWithMeta) []*searchResultWithMeta {
|
||||
@@ -984,17 +928,7 @@ func (t *KnowledgeSearchTool) deduplicateResults(results []*searchResultWithMeta
|
||||
|
||||
// buildContentSignature creates a normalized signature for content to detect near-duplicates
|
||||
func (t *KnowledgeSearchTool) buildContentSignature(content string) string {
|
||||
c := strings.ToLower(strings.TrimSpace(content))
|
||||
if c == "" {
|
||||
return ""
|
||||
}
|
||||
// Normalize whitespace
|
||||
c = strings.Join(strings.Fields(c), " ")
|
||||
// Use first 128 characters as signature
|
||||
if len(c) > 128 {
|
||||
c = c[:128]
|
||||
}
|
||||
return c
|
||||
return searchutil.BuildContentSignature(content)
|
||||
}
|
||||
|
||||
// formatOutput formats the search results for display
|
||||
@@ -1215,47 +1149,6 @@ type chunkRange struct {
|
||||
end int
|
||||
}
|
||||
|
||||
// normalizeKeywordSearchResults normalizes keyword search result scores into [0,1] globally across all knowledge bases
|
||||
// Improvements:
|
||||
// 1. Uses robust normalization with percentile-based bounds to handle outliers
|
||||
// 2. Handles edge cases: single result, no variance, negative scores
|
||||
// 3. Global normalization ensures fair comparison across different knowledge bases
|
||||
func (t *KnowledgeSearchTool) normalizeKeywordSearchResults(ctx context.Context, results []*searchResultWithMeta) {
|
||||
searchutil.NormalizeKeywordScores[*searchResultWithMeta](
|
||||
results,
|
||||
func(r *searchResultWithMeta) bool {
|
||||
return r.MatchType == types.MatchTypeKeywords
|
||||
},
|
||||
func(r *searchResultWithMeta) float64 {
|
||||
return r.Score
|
||||
},
|
||||
func(r *searchResultWithMeta, score float64) {
|
||||
r.Score = score
|
||||
},
|
||||
searchutil.KeywordScoreCallbacks{
|
||||
OnNoVariance: func(count int, score float64) {
|
||||
logger.Infof(
|
||||
ctx,
|
||||
"[Tool][KnowledgeSearch] Keyword scores have no variance, all set to 1.0: count=%d, score=%.3f",
|
||||
count,
|
||||
score,
|
||||
)
|
||||
},
|
||||
OnNormalized: func(count int, rawMin, rawMax, normalizeMin, normalizeMax float64) {
|
||||
logger.Infof(
|
||||
ctx,
|
||||
"[Tool][KnowledgeSearch] Normalized keyword scores: count=%d, raw_min=%.3f, raw_max=%.3f, normalize_min=%.3f, normalize_max=%.3f",
|
||||
count,
|
||||
rawMin,
|
||||
rawMax,
|
||||
normalizeMin,
|
||||
normalizeMax,
|
||||
)
|
||||
},
|
||||
},
|
||||
)
|
||||
}
|
||||
|
||||
// getEnrichedPassage 合并Content和ImageInfo的文本内容
|
||||
func (t *KnowledgeSearchTool) getEnrichedPassage(ctx context.Context, result *types.SearchResult) string {
|
||||
if result.ImageInfo == "" {
|
||||
@@ -1302,19 +1195,10 @@ func (t *KnowledgeSearchTool) getEnrichedPassage(ctx context.Context, result *ty
|
||||
return combinedText
|
||||
}
|
||||
|
||||
// getQueryIntentFromContext attempts to extract query intent from context (optional)
|
||||
func (t *KnowledgeSearchTool) getQueryIntentFromContext(ctx context.Context) string {
|
||||
// Try to get query intent from context if available
|
||||
// This is optional and may not always be present in agent tool context
|
||||
// Return empty string if not available
|
||||
return ""
|
||||
}
|
||||
|
||||
// compositeScore calculates a composite score considering multiple factors
|
||||
func (t *KnowledgeSearchTool) compositeScore(
|
||||
result *searchResultWithMeta,
|
||||
modelScore, baseScore float64,
|
||||
queryIntent string,
|
||||
) float64 {
|
||||
// Source weight: web_search results get slightly lower weight
|
||||
sourceWeight := 1.0
|
||||
@@ -1322,26 +1206,6 @@ func (t *KnowledgeSearchTool) compositeScore(
|
||||
sourceWeight = 0.95
|
||||
}
|
||||
|
||||
// Intent boost: adjust score based on query intent and chunk characteristics
|
||||
intentBoost := 1.0
|
||||
if queryIntent != "" {
|
||||
switch queryIntent {
|
||||
case "definition":
|
||||
// Boost summary chunks for definition queries
|
||||
if result.ChunkType == string(types.ChunkTypeSummary) {
|
||||
intentBoost = 1.05
|
||||
}
|
||||
case "howto":
|
||||
// Boost longer chunks for howto queries
|
||||
if result.EndAt-result.StartAt > 300 {
|
||||
intentBoost = 1.03
|
||||
}
|
||||
case "compare":
|
||||
// No boost for compare queries
|
||||
intentBoost = 1.0
|
||||
}
|
||||
}
|
||||
|
||||
// Position prior: slightly favor chunks earlier in the document
|
||||
positionPrior := 1.0
|
||||
if result.StartAt >= 0 && result.EndAt > result.StartAt {
|
||||
@@ -1352,7 +1216,6 @@ func (t *KnowledgeSearchTool) compositeScore(
|
||||
|
||||
// Composite formula: weighted combination of model score, base score, and source weight
|
||||
composite := 0.6*modelScore + 0.3*baseScore + 0.1*sourceWeight
|
||||
composite *= intentBoost
|
||||
composite *= positionPrior
|
||||
|
||||
// Clamp to [0, 1]
|
||||
@@ -1368,13 +1231,7 @@ func (t *KnowledgeSearchTool) compositeScore(
|
||||
|
||||
// clampFloat clamps a float value to the specified range
|
||||
func (t *KnowledgeSearchTool) clampFloat(v, minV, maxV float64) float64 {
|
||||
if v < minV {
|
||||
return minV
|
||||
}
|
||||
if v > maxV {
|
||||
return maxV
|
||||
}
|
||||
return v
|
||||
return searchutil.ClampFloat(v, minV, maxV)
|
||||
}
|
||||
|
||||
// applyMMR applies Maximal Marginal Relevance algorithm to reduce redundancy
|
||||
@@ -1456,36 +1313,10 @@ func (t *KnowledgeSearchTool) applyMMR(
|
||||
|
||||
// tokenizeSimple tokenizes text into a set of words (simple whitespace-based)
|
||||
func (t *KnowledgeSearchTool) tokenizeSimple(text string) map[string]struct{} {
|
||||
text = strings.ToLower(text)
|
||||
fields := strings.Fields(text)
|
||||
set := make(map[string]struct{}, len(fields))
|
||||
for _, f := range fields {
|
||||
if len(f) > 1 {
|
||||
set[f] = struct{}{}
|
||||
}
|
||||
}
|
||||
return set
|
||||
return searchutil.TokenizeSimple(text)
|
||||
}
|
||||
|
||||
// jaccard calculates Jaccard similarity between two token sets
|
||||
func (t *KnowledgeSearchTool) jaccard(a, b map[string]struct{}) float64 {
|
||||
if len(a) == 0 && len(b) == 0 {
|
||||
return 0
|
||||
}
|
||||
|
||||
// Calculate intersection
|
||||
inter := 0
|
||||
for k := range a {
|
||||
if _, ok := b[k]; ok {
|
||||
inter++
|
||||
}
|
||||
}
|
||||
|
||||
// Calculate union
|
||||
union := len(a) + len(b) - inter
|
||||
if union == 0 {
|
||||
return 0
|
||||
}
|
||||
|
||||
return float64(inter) / float64(union)
|
||||
return searchutil.Jaccard(a, b)
|
||||
}
|
||||
|
||||
@@ -1,166 +0,0 @@
|
||||
package chatpipline
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"regexp"
|
||||
"strings"
|
||||
|
||||
"github.com/Tencent/WeKnora/internal/config"
|
||||
"github.com/Tencent/WeKnora/internal/models/chat"
|
||||
"github.com/Tencent/WeKnora/internal/types"
|
||||
"github.com/Tencent/WeKnora/internal/types/interfaces"
|
||||
)
|
||||
|
||||
// PluginPreprocess Query preprocessing plugin
|
||||
type PluginPreprocess struct {
|
||||
config *config.Config
|
||||
modelService interfaces.ModelService
|
||||
}
|
||||
|
||||
// Regular expressions for text cleaning
|
||||
var (
|
||||
multiSpaceRegex = regexp.MustCompile(`\s+`) // Multiple spaces
|
||||
)
|
||||
|
||||
// NewPluginPreprocess Creates a new query preprocessing plugin
|
||||
func NewPluginPreprocess(
|
||||
eventManager *EventManager,
|
||||
config *config.Config,
|
||||
cleaner interfaces.ResourceCleaner,
|
||||
modelService interfaces.ModelService,
|
||||
) *PluginPreprocess {
|
||||
res := &PluginPreprocess{
|
||||
config: config,
|
||||
modelService: modelService,
|
||||
}
|
||||
|
||||
eventManager.Register(res)
|
||||
return res
|
||||
}
|
||||
|
||||
// ActivationEvents Register activation events
|
||||
func (p *PluginPreprocess) ActivationEvents() []types.EventType {
|
||||
return []types.EventType{types.PREPROCESS_QUERY}
|
||||
}
|
||||
|
||||
// OnEvent Process events
|
||||
func (p *PluginPreprocess) OnEvent(
|
||||
ctx context.Context,
|
||||
eventType types.EventType,
|
||||
chatManage *types.ChatManage,
|
||||
next func() *PluginError,
|
||||
) *PluginError {
|
||||
rawQuery := strings.TrimSpace(chatManage.RewriteQuery)
|
||||
if rawQuery == "" {
|
||||
return next()
|
||||
}
|
||||
|
||||
pipelineInfo(ctx, "Preprocess", "input", map[string]interface{}{
|
||||
"session_id": chatManage.SessionID,
|
||||
"rewrite_query": rawQuery,
|
||||
})
|
||||
|
||||
// Lightweight normalization: just collapse multiple spaces
|
||||
processed := multiSpaceRegex.ReplaceAllString(rawQuery, " ")
|
||||
processed = strings.TrimSpace(processed)
|
||||
|
||||
chatManage.ProcessedQuery = processed
|
||||
chatManage.QueryIntent = p.detectIntentLLM(ctx, chatManage, processed)
|
||||
|
||||
pipelineInfo(ctx, "Preprocess", "output", map[string]interface{}{
|
||||
"session_id": chatManage.SessionID,
|
||||
"processed_query": processed,
|
||||
"query_intent": chatManage.QueryIntent,
|
||||
})
|
||||
|
||||
return next()
|
||||
}
|
||||
|
||||
// intentResp is a response for intent detection
|
||||
type intentResp struct {
|
||||
Intent string `json:"intent"`
|
||||
Confidence float64 `json:"confidence"`
|
||||
}
|
||||
|
||||
// detectIntentLLM detects the intent of a query using an LLM
|
||||
func (p *PluginPreprocess) detectIntentLLM(ctx context.Context, chatManage *types.ChatManage, text string) string {
|
||||
if p.modelService == nil || chatManage.ChatModelID == "" {
|
||||
pipelineWarn(
|
||||
ctx,
|
||||
"IntentDetect",
|
||||
"skip",
|
||||
map[string]interface{}{"reason": "no_model", "session_id": chatManage.SessionID},
|
||||
)
|
||||
return "general"
|
||||
}
|
||||
chatModel, err := p.modelService.GetChatModel(ctx, chatManage.ChatModelID)
|
||||
if err != nil {
|
||||
pipelineWarn(
|
||||
ctx,
|
||||
"IntentDetect",
|
||||
"get_model_failed",
|
||||
map[string]interface{}{"error": err.Error(), "model_id": chatManage.ChatModelID},
|
||||
)
|
||||
return "general"
|
||||
}
|
||||
pipelineInfo(
|
||||
ctx,
|
||||
"IntentDetect",
|
||||
"start",
|
||||
map[string]interface{}{"session_id": chatManage.SessionID, "model_id": chatManage.ChatModelID},
|
||||
)
|
||||
sys := "You are a query intent classifier. Classify the user's query into one of: definition, howto, compare, qa, general. Respond ONLY with a JSON object {\"intent\": \"...\", \"confidence\": 0.0 } inside a markdown fenced block."
|
||||
usr := text
|
||||
think := false
|
||||
resp, err := chatModel.Chat(ctx, []chat.Message{
|
||||
{Role: "system", Content: sys},
|
||||
{Role: "user", Content: usr},
|
||||
}, &chat.ChatOptions{Temperature: 0.0, MaxCompletionTokens: 64, Thinking: &think})
|
||||
if err != nil || resp.Content == "" {
|
||||
pipelineWarn(ctx, "IntentDetect", "model_call_failed", map[string]interface{}{"error": err})
|
||||
return "general"
|
||||
}
|
||||
body := extractJSONBody(resp.Content)
|
||||
var ir intentResp
|
||||
if err := json.Unmarshal([]byte(body), &ir); err != nil {
|
||||
pipelineWarn(ctx, "IntentDetect", "parse_failed", map[string]interface{}{"body": body, "error": err.Error()})
|
||||
return "general"
|
||||
}
|
||||
pipelineInfo(
|
||||
ctx,
|
||||
"IntentDetect",
|
||||
"result",
|
||||
map[string]interface{}{"intent": ir.Intent, "confidence": ir.Confidence},
|
||||
)
|
||||
switch strings.ToLower(strings.TrimSpace(ir.Intent)) {
|
||||
case "definition", "howto", "compare", "qa", "general":
|
||||
return strings.ToLower(ir.Intent)
|
||||
default:
|
||||
return "general"
|
||||
}
|
||||
}
|
||||
|
||||
// extractJSONBody extracts a JSON body from a string
|
||||
func extractJSONBody(text string) string {
|
||||
t := strings.TrimSpace(text)
|
||||
// Try fenced block first
|
||||
if i := strings.Index(t, "{"); i >= 0 {
|
||||
j := strings.LastIndex(t, "}")
|
||||
if j > i {
|
||||
return t[i : j+1]
|
||||
}
|
||||
}
|
||||
return "{}"
|
||||
}
|
||||
|
||||
// Close Releases resources
|
||||
func (p *PluginPreprocess) Close() {
|
||||
}
|
||||
|
||||
// ShutdownHandler Returns shutdown function
|
||||
func (p *PluginPreprocess) ShutdownHandler() func() {
|
||||
return func() {
|
||||
p.Close()
|
||||
}
|
||||
}
|
||||
@@ -8,6 +8,7 @@ import (
|
||||
"strings"
|
||||
|
||||
"github.com/Tencent/WeKnora/internal/models/rerank"
|
||||
"github.com/Tencent/WeKnora/internal/searchutil"
|
||||
"github.com/Tencent/WeKnora/internal/types"
|
||||
"github.com/Tencent/WeKnora/internal/types/interfaces"
|
||||
)
|
||||
@@ -36,12 +37,11 @@ func (p *PluginRerank) OnEvent(ctx context.Context,
|
||||
eventType types.EventType, chatManage *types.ChatManage, next func() *PluginError,
|
||||
) *PluginError {
|
||||
pipelineInfo(ctx, "Rerank", "input", map[string]interface{}{
|
||||
"session_id": chatManage.SessionID,
|
||||
"candidate_cnt": len(chatManage.SearchResult),
|
||||
"rerank_model": chatManage.RerankModelID,
|
||||
"rerank_thresh": chatManage.RerankThreshold,
|
||||
"rewrite_query": chatManage.RewriteQuery,
|
||||
"processed_query": chatManage.ProcessedQuery,
|
||||
"session_id": chatManage.SessionID,
|
||||
"candidate_cnt": len(chatManage.SearchResult),
|
||||
"rerank_model": chatManage.RerankModelID,
|
||||
"rerank_thresh": chatManage.RerankThreshold,
|
||||
"rewrite_query": chatManage.RewriteQuery,
|
||||
})
|
||||
if len(chatManage.SearchResult) == 0 {
|
||||
pipelineInfo(ctx, "Rerank", "skip", map[string]interface{}{
|
||||
@@ -77,18 +77,40 @@ func (p *PluginRerank) OnEvent(ctx context.Context,
|
||||
passages = append(passages, passage)
|
||||
}
|
||||
|
||||
// Try reranking with different query variants in priority order
|
||||
// Single rerank call with RewriteQuery, use threshold degradation if no results
|
||||
originalThreshold := chatManage.RerankThreshold
|
||||
rerankResp := p.rerank(ctx, chatManage, rerankModel, chatManage.RewriteQuery, passages)
|
||||
if len(rerankResp) == 0 {
|
||||
rerankResp = p.rerank(ctx, chatManage, rerankModel, chatManage.ProcessedQuery, passages)
|
||||
if len(rerankResp) == 0 {
|
||||
rerankResp = p.rerank(ctx, chatManage, rerankModel, chatManage.Query, passages)
|
||||
|
||||
// If no results and threshold is high enough, try with lower threshold
|
||||
if len(rerankResp) == 0 && originalThreshold > 0.3 {
|
||||
degradedThreshold := originalThreshold * 0.7
|
||||
if degradedThreshold < 0.3 {
|
||||
degradedThreshold = 0.3
|
||||
}
|
||||
pipelineInfo(ctx, "Rerank", "threshold_degrade", map[string]interface{}{
|
||||
"original": originalThreshold,
|
||||
"degraded": degradedThreshold,
|
||||
})
|
||||
chatManage.RerankThreshold = degradedThreshold
|
||||
rerankResp = p.rerank(ctx, chatManage, rerankModel, chatManage.RewriteQuery, passages)
|
||||
// Restore original threshold
|
||||
chatManage.RerankThreshold = originalThreshold
|
||||
}
|
||||
|
||||
pipelineInfo(ctx, "Rerank", "model_response", map[string]interface{}{
|
||||
"result_cnt": len(rerankResp),
|
||||
})
|
||||
|
||||
// Log input scores before reranking for debugging
|
||||
for i, sr := range chatManage.SearchResult {
|
||||
pipelineInfo(ctx, "Rerank", "input_score", map[string]interface{}{
|
||||
"index": i,
|
||||
"chunk_id": sr.ID,
|
||||
"score": fmt.Sprintf("%.4f", sr.Score),
|
||||
"match_type": sr.MatchType,
|
||||
})
|
||||
}
|
||||
|
||||
for i := range chatManage.SearchResult {
|
||||
chatManage.SearchResult[i].Metadata = ensureMetadata(chatManage.SearchResult[i].Metadata)
|
||||
}
|
||||
@@ -97,8 +119,15 @@ func (p *PluginRerank) OnEvent(ctx context.Context,
|
||||
sr := chatManage.SearchResult[rr.Index]
|
||||
base := sr.Score
|
||||
sr.Metadata["base_score"] = fmt.Sprintf("%.4f", base)
|
||||
sr.Score = rr.RelevanceScore
|
||||
sr.Score = compositeScore(sr, rr.RelevanceScore, base, chatManage)
|
||||
modelScore := rr.RelevanceScore
|
||||
sr.Score = compositeScore(sr, modelScore, base)
|
||||
pipelineInfo(ctx, "Rerank", "composite_calc", map[string]interface{}{
|
||||
"chunk_id": sr.ID,
|
||||
"base_score": fmt.Sprintf("%.4f", base),
|
||||
"model_score": fmt.Sprintf("%.4f", modelScore),
|
||||
"final_score": fmt.Sprintf("%.4f", sr.Score),
|
||||
"match_type": sr.MatchType,
|
||||
})
|
||||
reranked = append(reranked, sr)
|
||||
}
|
||||
final := applyMMR(ctx, reranked, chatManage, min(len(reranked), max(1, chatManage.RerankTopK)), 0.7)
|
||||
@@ -112,7 +141,6 @@ func (p *PluginRerank) OnEvent(ctx context.Context,
|
||||
"chunk_id": reranked[i].ID,
|
||||
"base_score": reranked[i].Metadata["base_score"],
|
||||
"final_score": fmt.Sprintf("%.4f", reranked[i].Score),
|
||||
"intent": chatManage.QueryIntent,
|
||||
})
|
||||
}
|
||||
|
||||
@@ -185,7 +213,7 @@ func ensureMetadata(m map[string]string) map[string]string {
|
||||
}
|
||||
|
||||
// compositeScore calculates the composite score for a search result
|
||||
func compositeScore(sr *types.SearchResult, modelScore, baseScore float64, chatManage *types.ChatManage) float64 {
|
||||
func compositeScore(sr *types.SearchResult, modelScore, baseScore float64) float64 {
|
||||
sourceWeight := 1.0
|
||||
switch strings.ToLower(sr.KnowledgeSource) {
|
||||
case "web_search":
|
||||
@@ -193,27 +221,11 @@ func compositeScore(sr *types.SearchResult, modelScore, baseScore float64, chatM
|
||||
default:
|
||||
sourceWeight = 1.0
|
||||
}
|
||||
intentBoost := 1.0
|
||||
if chatManage.QueryIntent != "" {
|
||||
switch chatManage.QueryIntent {
|
||||
case "definition":
|
||||
if sr.ChunkType == string(types.ChunkTypeSummary) {
|
||||
intentBoost = 1.05
|
||||
}
|
||||
case "howto":
|
||||
if sr.EndAt-sr.StartAt > 300 {
|
||||
intentBoost = 1.03
|
||||
}
|
||||
case "compare":
|
||||
intentBoost = 1.0
|
||||
}
|
||||
}
|
||||
positionPrior := 1.0
|
||||
if sr.StartAt >= 0 {
|
||||
positionPrior += clampFloat(1.0-float64(sr.StartAt)/float64(sr.EndAt+1), -0.05, 0.05)
|
||||
positionPrior += searchutil.ClampFloat(1.0-float64(sr.StartAt)/float64(sr.EndAt+1), -0.05, 0.05)
|
||||
}
|
||||
composite := 0.6*modelScore + 0.3*baseScore + 0.1*sourceWeight
|
||||
composite *= intentBoost
|
||||
composite *= positionPrior
|
||||
if composite < 0 {
|
||||
composite = 0
|
||||
@@ -224,7 +236,7 @@ func compositeScore(sr *types.SearchResult, modelScore, baseScore float64, chatM
|
||||
return composite
|
||||
}
|
||||
|
||||
// applyMMR applies the MMR algorithm to the search results
|
||||
// applyMMR applies the MMR algorithm to the search results with pre-computed token sets
|
||||
func applyMMR(
|
||||
ctx context.Context,
|
||||
results []*types.SearchResult,
|
||||
@@ -240,40 +252,60 @@ func applyMMR(
|
||||
"k": k,
|
||||
"candidates": len(results),
|
||||
})
|
||||
selected := make([]*types.SearchResult, 0, k)
|
||||
candidates := make([]*types.SearchResult, len(results))
|
||||
copy(candidates, results)
|
||||
tokenSets := make([]map[string]struct{}, len(candidates))
|
||||
for i, r := range candidates {
|
||||
tokenSets[i] = tokenizeSimple(getEnrichedPassage(ctx, r))
|
||||
|
||||
// Pre-compute all token sets upfront (optimization)
|
||||
allTokenSets := make([]map[string]struct{}, len(results))
|
||||
for i, r := range results {
|
||||
allTokenSets[i] = searchutil.TokenizeSimple(getEnrichedPassage(ctx, r))
|
||||
}
|
||||
for len(selected) < k && len(candidates) > 0 {
|
||||
bestIdx := 0
|
||||
|
||||
selected := make([]*types.SearchResult, 0, k)
|
||||
selectedTokenSets := make([]map[string]struct{}, 0, k)
|
||||
selectedIndices := make(map[int]struct{})
|
||||
|
||||
for len(selected) < k && len(selectedIndices) < len(results) {
|
||||
bestIdx := -1
|
||||
bestScore := -1.0
|
||||
for i, r := range candidates {
|
||||
|
||||
for i, r := range results {
|
||||
if _, isSelected := selectedIndices[i]; isSelected {
|
||||
continue
|
||||
}
|
||||
|
||||
relevance := r.Score
|
||||
redundancy := 0.0
|
||||
for _, s := range selected {
|
||||
redundancy = math.Max(redundancy, jaccard(tokenSets[i], tokenizeSimple(getEnrichedPassage(ctx, s))))
|
||||
|
||||
// Use pre-computed token sets for redundancy calculation
|
||||
for _, selTokens := range selectedTokenSets {
|
||||
sim := searchutil.Jaccard(allTokenSets[i], selTokens)
|
||||
if sim > redundancy {
|
||||
redundancy = sim
|
||||
}
|
||||
}
|
||||
|
||||
mmr := lambda*relevance - (1.0-lambda)*redundancy
|
||||
if mmr > bestScore {
|
||||
bestScore = mmr
|
||||
bestIdx = i
|
||||
}
|
||||
}
|
||||
selected = append(selected, candidates[bestIdx])
|
||||
candidates = append(candidates[:bestIdx], candidates[bestIdx+1:]...)
|
||||
|
||||
if bestIdx < 0 {
|
||||
break
|
||||
}
|
||||
|
||||
selected = append(selected, results[bestIdx])
|
||||
selectedTokenSets = append(selectedTokenSets, allTokenSets[bestIdx])
|
||||
selectedIndices[bestIdx] = struct{}{}
|
||||
}
|
||||
// Compute average redundancy among selected
|
||||
|
||||
// Compute average redundancy among selected using pre-computed token sets
|
||||
avgRed := 0.0
|
||||
if len(selected) > 1 {
|
||||
pairs := 0
|
||||
for i := 0; i < len(selected); i++ {
|
||||
for j := i + 1; j < len(selected); j++ {
|
||||
si := tokenizeSimple(getEnrichedPassage(ctx, selected[i]))
|
||||
sj := tokenizeSimple(getEnrichedPassage(ctx, selected[j]))
|
||||
avgRed += jaccard(si, sj)
|
||||
for i := 0; i < len(selectedTokenSets); i++ {
|
||||
for j := i + 1; j < len(selectedTokenSets); j++ {
|
||||
avgRed += searchutil.Jaccard(selectedTokenSets[i], selectedTokenSets[j])
|
||||
pairs++
|
||||
}
|
||||
}
|
||||
@@ -288,93 +320,59 @@ func applyMMR(
|
||||
return selected
|
||||
}
|
||||
|
||||
// tokenizeSimple tokenizes a text into a set of tokens
|
||||
func tokenizeSimple(text string) map[string]struct{} {
|
||||
text = strings.ToLower(text)
|
||||
fields := strings.Fields(text)
|
||||
set := make(map[string]struct{}, len(fields))
|
||||
for _, f := range fields {
|
||||
if len(f) > 1 {
|
||||
set[f] = struct{}{}
|
||||
}
|
||||
}
|
||||
return set
|
||||
}
|
||||
|
||||
// jaccard calculates the Jaccard similarity between two sets of tokens
|
||||
func jaccard(a, b map[string]struct{}) float64 {
|
||||
if len(a) == 0 && len(b) == 0 {
|
||||
return 0
|
||||
}
|
||||
inter := 0
|
||||
for k := range a {
|
||||
if _, ok := b[k]; ok {
|
||||
inter++
|
||||
}
|
||||
}
|
||||
union := len(a) + len(b) - inter
|
||||
if union == 0 {
|
||||
return 0
|
||||
}
|
||||
return float64(inter) / float64(union)
|
||||
}
|
||||
|
||||
// clampFloat clamps a float value between a minimum and maximum value
|
||||
func clampFloat(v, minV, maxV float64) float64 {
|
||||
if v < minV {
|
||||
return minV
|
||||
}
|
||||
if v > maxV {
|
||||
return maxV
|
||||
}
|
||||
return v
|
||||
}
|
||||
|
||||
// getEnrichedPassage 合并Content和ImageInfo的文本内容
|
||||
// getEnrichedPassage 合并Content、ImageInfo和GeneratedQuestions的文本内容
|
||||
func getEnrichedPassage(ctx context.Context, result *types.SearchResult) string {
|
||||
if result.ImageInfo == "" {
|
||||
return result.Content
|
||||
}
|
||||
combinedText := result.Content
|
||||
var enrichments []string
|
||||
|
||||
// 解析ImageInfo
|
||||
var imageInfos []types.ImageInfo
|
||||
err := json.Unmarshal([]byte(result.ImageInfo), &imageInfos)
|
||||
if err != nil {
|
||||
pipelineWarn(ctx, "Rerank", "image_info_parse", map[string]interface{}{
|
||||
"error": err.Error(),
|
||||
})
|
||||
return result.Content
|
||||
}
|
||||
|
||||
if len(imageInfos) == 0 {
|
||||
return result.Content
|
||||
}
|
||||
|
||||
// 提取所有图片的描述和OCR文本
|
||||
var imageTexts []string
|
||||
for _, img := range imageInfos {
|
||||
if img.Caption != "" {
|
||||
imageTexts = append(imageTexts, fmt.Sprintf("图片描述: %s", img.Caption))
|
||||
}
|
||||
if img.OCRText != "" {
|
||||
imageTexts = append(imageTexts, fmt.Sprintf("图片文本: %s", img.OCRText))
|
||||
if result.ImageInfo != "" {
|
||||
var imageInfos []types.ImageInfo
|
||||
err := json.Unmarshal([]byte(result.ImageInfo), &imageInfos)
|
||||
if err != nil {
|
||||
pipelineWarn(ctx, "Rerank", "image_info_parse", map[string]interface{}{
|
||||
"error": err.Error(),
|
||||
})
|
||||
} else {
|
||||
// 提取所有图片的描述和OCR文本
|
||||
for _, img := range imageInfos {
|
||||
if img.Caption != "" {
|
||||
enrichments = append(enrichments, fmt.Sprintf("图片描述: %s", img.Caption))
|
||||
}
|
||||
if img.OCRText != "" {
|
||||
enrichments = append(enrichments, fmt.Sprintf("图片文本: %s", img.OCRText))
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if len(imageTexts) == 0 {
|
||||
return result.Content
|
||||
// 解析ChunkMetadata中的GeneratedQuestions
|
||||
if len(result.ChunkMetadata) > 0 {
|
||||
var docMeta types.DocumentChunkMetadata
|
||||
err := json.Unmarshal(result.ChunkMetadata, &docMeta)
|
||||
if err != nil {
|
||||
pipelineWarn(ctx, "Rerank", "chunk_metadata_parse", map[string]interface{}{
|
||||
"error": err.Error(),
|
||||
})
|
||||
} else if len(docMeta.GeneratedQuestions) > 0 {
|
||||
enrichments = append(enrichments, fmt.Sprintf("相关问题: %s", strings.Join(docMeta.GeneratedQuestions, "; ")))
|
||||
}
|
||||
}
|
||||
|
||||
// 组合内容和图片信息
|
||||
combinedText := result.Content
|
||||
if len(enrichments) == 0 {
|
||||
return combinedText
|
||||
}
|
||||
|
||||
// 组合内容和增强信息
|
||||
if combinedText != "" {
|
||||
combinedText += "\n\n"
|
||||
}
|
||||
combinedText += strings.Join(imageTexts, "\n")
|
||||
combinedText += strings.Join(enrichments, "\n")
|
||||
|
||||
pipelineInfo(ctx, "Rerank", "image_info_merge", map[string]interface{}{
|
||||
"content_len": len(result.Content),
|
||||
"image_len": len(strings.Join(imageTexts, "\n")),
|
||||
pipelineInfo(ctx, "Rerank", "passage_enrich", map[string]interface{}{
|
||||
"content_len": len(result.Content),
|
||||
"enrichment": strings.Join(enrichments, "\n"),
|
||||
"enrichment_len": len(strings.Join(enrichments, "\n")),
|
||||
})
|
||||
|
||||
return combinedText
|
||||
|
||||
@@ -2,13 +2,14 @@ package chatpipline
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"regexp"
|
||||
"strings"
|
||||
"sync"
|
||||
"unicode"
|
||||
|
||||
"github.com/Tencent/WeKnora/internal/config"
|
||||
"github.com/Tencent/WeKnora/internal/models/chat"
|
||||
"github.com/Tencent/WeKnora/internal/logger"
|
||||
"github.com/Tencent/WeKnora/internal/searchutil"
|
||||
"github.com/Tencent/WeKnora/internal/types"
|
||||
"github.com/Tencent/WeKnora/internal/types/interfaces"
|
||||
@@ -18,7 +19,6 @@ import (
|
||||
type PluginSearch struct {
|
||||
knowledgeBaseService interfaces.KnowledgeBaseService
|
||||
knowledgeService interfaces.KnowledgeService
|
||||
modelService interfaces.ModelService
|
||||
config *config.Config
|
||||
webSearchService interfaces.WebSearchService
|
||||
tenantService interfaces.TenantService
|
||||
@@ -28,7 +28,6 @@ type PluginSearch struct {
|
||||
func NewPluginSearch(eventManager *EventManager,
|
||||
knowledgeBaseService interfaces.KnowledgeBaseService,
|
||||
knowledgeService interfaces.KnowledgeService,
|
||||
modelService interfaces.ModelService,
|
||||
config *config.Config,
|
||||
webSearchService interfaces.WebSearchService,
|
||||
tenantService interfaces.TenantService,
|
||||
@@ -37,7 +36,6 @@ func NewPluginSearch(eventManager *EventManager,
|
||||
res := &PluginSearch{
|
||||
knowledgeBaseService: knowledgeBaseService,
|
||||
knowledgeService: knowledgeService,
|
||||
modelService: modelService,
|
||||
config: config,
|
||||
webSearchService: webSearchService,
|
||||
tenantService: tenantService,
|
||||
@@ -75,12 +73,11 @@ func (p *PluginSearch) OnEvent(ctx context.Context,
|
||||
}
|
||||
|
||||
pipelineInfo(ctx, "Search", "input", map[string]interface{}{
|
||||
"session_id": chatManage.SessionID,
|
||||
"rewrite_query": chatManage.RewriteQuery,
|
||||
"processed_query": chatManage.ProcessedQuery,
|
||||
"kb_ids": strings.Join(knowledgeBaseIDs, ","),
|
||||
"tenant_id": chatManage.TenantID,
|
||||
"web_enabled": chatManage.WebSearchEnabled,
|
||||
"session_id": chatManage.SessionID,
|
||||
"rewrite_query": chatManage.RewriteQuery,
|
||||
"kb_ids": strings.Join(knowledgeBaseIDs, ","),
|
||||
"tenant_id": chatManage.TenantID,
|
||||
"web_enabled": chatManage.WebSearchEnabled,
|
||||
})
|
||||
|
||||
// Run KB search and web search concurrently
|
||||
@@ -121,6 +118,16 @@ func (p *PluginSearch) OnEvent(ctx context.Context,
|
||||
|
||||
chatManage.SearchResult = allResults
|
||||
|
||||
// Log all search results with scores before any processing
|
||||
for i, r := range chatManage.SearchResult {
|
||||
pipelineInfo(ctx, "Search", "result_score_before_normalize", map[string]interface{}{
|
||||
"index": i,
|
||||
"chunk_id": r.ID,
|
||||
"score": fmt.Sprintf("%.4f", r.Score),
|
||||
"match_type": r.MatchType,
|
||||
})
|
||||
}
|
||||
|
||||
// If recall is low, attempt query expansion with keyword-focused search
|
||||
if chatManage.EnableQueryExpansion && len(chatManage.SearchResult) < max(1, chatManage.EmbeddingTopK/2) {
|
||||
pipelineInfo(ctx, "Search", "recall_low", map[string]interface{}{
|
||||
@@ -189,6 +196,7 @@ func (p *PluginSearch) OnEvent(ctx context.Context,
|
||||
}
|
||||
wgExp.Wait()
|
||||
if len(expResults) > 0 {
|
||||
// Scores already normalized in HybridSearch
|
||||
pipelineInfo(ctx, "Search", "expansion_done", map[string]interface{}{
|
||||
"added": len(expResults),
|
||||
})
|
||||
@@ -215,6 +223,16 @@ func (p *PluginSearch) OnEvent(ctx context.Context,
|
||||
"after": len(chatManage.SearchResult),
|
||||
})
|
||||
|
||||
// Log final scores after all processing
|
||||
for i, r := range chatManage.SearchResult {
|
||||
pipelineInfo(ctx, "Search", "final_score", map[string]interface{}{
|
||||
"index": i,
|
||||
"chunk_id": r.ID,
|
||||
"score": fmt.Sprintf("%.4f", r.Score),
|
||||
"match_type": r.MatchType,
|
||||
})
|
||||
}
|
||||
|
||||
// Return if we have results
|
||||
if len(chatManage.SearchResult) != 0 {
|
||||
pipelineInfo(ctx, "Search", "output", map[string]interface{}{
|
||||
@@ -250,32 +268,33 @@ func (p *PluginSearch) getSearchResultFromHistory(chatManage *types.ChatManage)
|
||||
|
||||
func removeDuplicateResults(results []*types.SearchResult) []*types.SearchResult {
|
||||
seen := make(map[string]bool)
|
||||
contentSig := make(map[string]bool)
|
||||
contentSig := make(map[string]string) // sig -> first chunk ID
|
||||
var uniqueResults []*types.SearchResult
|
||||
for _, r := range results {
|
||||
keys := []string{r.ID}
|
||||
if r.ParentChunkID != "" {
|
||||
keys = append(keys, "parent:"+r.ParentChunkID)
|
||||
}
|
||||
if r.KnowledgeID != "" {
|
||||
keys = append(keys, fmt.Sprintf("kb:%s#%d", r.KnowledgeID, r.ChunkIndex))
|
||||
}
|
||||
dup := false
|
||||
dupKey := ""
|
||||
for _, k := range keys {
|
||||
if seen[k] {
|
||||
dup = true
|
||||
dupKey = k
|
||||
break
|
||||
}
|
||||
}
|
||||
if dup {
|
||||
logger.Debugf(context.Background(), "Dedup: chunk %s removed due to key: %s", r.ID, dupKey)
|
||||
continue
|
||||
}
|
||||
sig := buildContentSignature(r.Content)
|
||||
if sig != "" {
|
||||
if contentSig[sig] {
|
||||
if firstChunk, exists := contentSig[sig]; exists {
|
||||
logger.Debugf(context.Background(), "Dedup: chunk %s removed due to content signature (dup of %s, sig prefix: %.50s...)", r.ID, firstChunk, sig)
|
||||
continue
|
||||
}
|
||||
contentSig[sig] = true
|
||||
contentSig[sig] = r.ID
|
||||
}
|
||||
for _, k := range keys {
|
||||
seen[k] = true
|
||||
@@ -286,24 +305,16 @@ func removeDuplicateResults(results []*types.SearchResult) []*types.SearchResult
|
||||
}
|
||||
|
||||
func buildContentSignature(content string) string {
|
||||
c := strings.ToLower(strings.TrimSpace(content))
|
||||
if c == "" {
|
||||
return ""
|
||||
}
|
||||
c = strings.Join(strings.Fields(c), " ")
|
||||
if len(c) > 128 {
|
||||
c = c[:128]
|
||||
}
|
||||
return c
|
||||
return searchutil.BuildContentSignature(content)
|
||||
}
|
||||
|
||||
// searchKnowledgeBases performs KB searches for rewrite and processed queries across KB IDs
|
||||
// searchKnowledgeBases performs KB searches across KB IDs using RewriteQuery only
|
||||
func (p *PluginSearch) searchKnowledgeBases(
|
||||
ctx context.Context,
|
||||
knowledgeBaseIDs []string,
|
||||
chatManage *types.ChatManage,
|
||||
) []*types.SearchResult {
|
||||
// Build base params for rewrite query
|
||||
// Build params for rewrite query
|
||||
baseParams := types.SearchParams{
|
||||
QueryText: strings.TrimSpace(chatManage.RewriteQuery),
|
||||
VectorThreshold: chatManage.VectorThreshold,
|
||||
@@ -315,7 +326,7 @@ func (p *PluginSearch) searchKnowledgeBases(
|
||||
var mu sync.Mutex
|
||||
var results []*types.SearchResult
|
||||
|
||||
// Search with rewrite query
|
||||
// Search with rewrite query only (removed duplicate ProcessedQuery search)
|
||||
for _, kbID := range knowledgeBaseIDs {
|
||||
wg.Add(1)
|
||||
go func(knowledgeBaseID string) {
|
||||
@@ -323,16 +334,14 @@ func (p *PluginSearch) searchKnowledgeBases(
|
||||
res, err := p.knowledgeBaseService.HybridSearch(ctx, knowledgeBaseID, baseParams)
|
||||
if err != nil {
|
||||
pipelineWarn(ctx, "Search", "kb_search_error", map[string]interface{}{
|
||||
"kb_id": knowledgeBaseID,
|
||||
"query": baseParams.QueryText,
|
||||
"error": err.Error(),
|
||||
"query_ty": "rewrite",
|
||||
"kb_id": knowledgeBaseID,
|
||||
"query": baseParams.QueryText,
|
||||
"error": err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
pipelineInfo(ctx, "Search", "kb_result", map[string]interface{}{
|
||||
"kb_id": knowledgeBaseID,
|
||||
"query_ty": "rewrite",
|
||||
"hit_count": len(res),
|
||||
})
|
||||
mu.Lock()
|
||||
@@ -343,45 +352,6 @@ func (p *PluginSearch) searchKnowledgeBases(
|
||||
|
||||
wg.Wait()
|
||||
|
||||
// If processed query differs, search again
|
||||
if chatManage.RewriteQuery != chatManage.ProcessedQuery {
|
||||
paramsProcessed := baseParams
|
||||
paramsProcessed.QueryText = strings.TrimSpace(chatManage.ProcessedQuery)
|
||||
pipelineInfo(ctx, "Search", "processed_query_search", map[string]interface{}{
|
||||
"query": paramsProcessed.QueryText,
|
||||
})
|
||||
|
||||
wg = sync.WaitGroup{}
|
||||
for _, kbID := range knowledgeBaseIDs {
|
||||
wg.Add(1)
|
||||
go func(knowledgeBaseID string) {
|
||||
defer wg.Done()
|
||||
res, err := p.knowledgeBaseService.HybridSearch(ctx, knowledgeBaseID, paramsProcessed)
|
||||
if err != nil {
|
||||
pipelineWarn(ctx, "Search", "kb_search_error", map[string]interface{}{
|
||||
"kb_id": knowledgeBaseID,
|
||||
"query": paramsProcessed.QueryText,
|
||||
"error": err.Error(),
|
||||
"query_ty": "processed",
|
||||
})
|
||||
return
|
||||
}
|
||||
pipelineInfo(ctx, "Search", "kb_result", map[string]interface{}{
|
||||
"kb_id": knowledgeBaseID,
|
||||
"query_ty": "processed",
|
||||
"hit_count": len(res),
|
||||
})
|
||||
mu.Lock()
|
||||
results = append(results, res...)
|
||||
mu.Unlock()
|
||||
}(kbID)
|
||||
}
|
||||
wg.Wait()
|
||||
}
|
||||
|
||||
// Normalize keyword retriever scores after collecting all results from multiple knowledge bases
|
||||
normalizeKeywordSearchResults(ctx, results)
|
||||
|
||||
pipelineInfo(ctx, "Search", "kb_result_summary", map[string]interface{}{
|
||||
"total_hits": len(results),
|
||||
})
|
||||
@@ -413,11 +383,8 @@ func (p *PluginSearch) searchWebIfEnabled(ctx context.Context, chatManage *types
|
||||
})
|
||||
return nil
|
||||
}
|
||||
// Build questions (rewrite + processed if different)
|
||||
// Build questions using RewriteQuery only
|
||||
questions := []string{strings.TrimSpace(chatManage.RewriteQuery)}
|
||||
if chatManage.ProcessedQuery != "" && chatManage.ProcessedQuery != chatManage.RewriteQuery {
|
||||
questions = append(questions, strings.TrimSpace(chatManage.ProcessedQuery))
|
||||
}
|
||||
// Load session-scoped temp KB state from Redis using SessionService
|
||||
tempKBID, seen, ids := p.sessionService.GetWebSearchTempKBState(ctx, chatManage.SessionID)
|
||||
compressed, kbID, newSeen, newIDs, err := p.webSearchService.CompressWithRAG(
|
||||
@@ -440,119 +407,155 @@ func (p *PluginSearch) searchWebIfEnabled(ctx context.Context, chatManage *types
|
||||
return res
|
||||
}
|
||||
|
||||
// expandQueries generates paraphrases and synonyms using chat model to improve keyword recall
|
||||
// expandQueries generates query variants locally without LLM to improve keyword recall
|
||||
// Uses simple techniques: word reordering, stopword removal, key phrase extraction
|
||||
func (p *PluginSearch) expandQueries(ctx context.Context, chatManage *types.ChatManage) []string {
|
||||
if p.modelService == nil || chatManage.ChatModelID == "" {
|
||||
pipelineWarn(ctx, "Search", "expansion_skip", map[string]interface{}{
|
||||
"reason": "no_model",
|
||||
})
|
||||
query := strings.TrimSpace(chatManage.RewriteQuery)
|
||||
if query == "" {
|
||||
return nil
|
||||
}
|
||||
model, err := p.modelService.GetChatModel(ctx, chatManage.ChatModelID)
|
||||
if err != nil {
|
||||
pipelineWarn(ctx, "Search", "expansion_get_model_failed", map[string]interface{}{
|
||||
"error": err.Error(),
|
||||
})
|
||||
return nil
|
||||
|
||||
expansions := make([]string, 0, 5)
|
||||
seen := make(map[string]struct{})
|
||||
seen[strings.ToLower(query)] = struct{}{}
|
||||
if q := strings.ToLower(chatManage.Query); q != "" {
|
||||
seen[q] = struct{}{}
|
||||
}
|
||||
sys := "Generate up to 5 diverse paraphrases or keyword variants for the user query to improve keyword-based search recall. Respond ONLY with a JSON array of strings inside a fenced code block."
|
||||
usr := chatManage.RewriteQuery
|
||||
think := false
|
||||
resp, err := model.Chat(ctx, []chat.Message{
|
||||
{Role: "system", Content: sys},
|
||||
{Role: "user", Content: usr},
|
||||
}, &chat.ChatOptions{Temperature: 0.2, MaxCompletionTokens: 200, Thinking: &think})
|
||||
if err != nil || resp.Content == "" {
|
||||
pipelineWarn(ctx, "Search", "expansion_model_call_failed", map[string]interface{}{
|
||||
"error": err,
|
||||
})
|
||||
return nil
|
||||
}
|
||||
body := extractJSONBlock(resp.Content)
|
||||
var arr []string
|
||||
if err := json.Unmarshal([]byte(body), &arr); err != nil || len(arr) == 0 {
|
||||
// Fallback: split lines
|
||||
lines := strings.Split(resp.Content, "\n")
|
||||
for _, l := range lines {
|
||||
l = strings.TrimSpace(l)
|
||||
if l != "" {
|
||||
arr = append(arr, l)
|
||||
}
|
||||
}
|
||||
}
|
||||
uniq := make(map[string]struct{})
|
||||
base := []string{chatManage.Query, chatManage.RewriteQuery, chatManage.ProcessedQuery}
|
||||
for _, b := range base {
|
||||
if s := strings.TrimSpace(b); s != "" {
|
||||
uniq[strings.ToLower(s)] = struct{}{}
|
||||
}
|
||||
}
|
||||
expansions := make([]string, 0, len(arr))
|
||||
for _, a := range arr {
|
||||
s := strings.TrimSpace(a)
|
||||
if s == "" {
|
||||
continue
|
||||
|
||||
addIfNew := func(s string) {
|
||||
s = strings.TrimSpace(s)
|
||||
if s == "" || len(s) < 3 {
|
||||
return
|
||||
}
|
||||
key := strings.ToLower(s)
|
||||
if _, ok := uniq[key]; ok {
|
||||
continue
|
||||
if _, ok := seen[key]; ok {
|
||||
return
|
||||
}
|
||||
uniq[key] = struct{}{}
|
||||
seen[key] = struct{}{}
|
||||
expansions = append(expansions, s)
|
||||
if len(expansions) >= 5 {
|
||||
break
|
||||
}
|
||||
|
||||
// 1. Remove common stopwords and create keyword-only variant
|
||||
keywords := extractKeywords(query)
|
||||
if len(keywords) >= 2 {
|
||||
addIfNew(strings.Join(keywords, " "))
|
||||
}
|
||||
|
||||
// 2. Extract quoted phrases or key segments
|
||||
phrases := extractPhrases(query)
|
||||
for _, phrase := range phrases {
|
||||
addIfNew(phrase)
|
||||
}
|
||||
|
||||
// 3. Split by common delimiters and use longest segment
|
||||
segments := splitByDelimiters(query)
|
||||
for _, seg := range segments {
|
||||
if len(seg) > 5 {
|
||||
addIfNew(seg)
|
||||
}
|
||||
}
|
||||
pipelineInfo(ctx, "Search", "expansion_result", map[string]interface{}{
|
||||
|
||||
// 4. Remove question words (什么/如何/怎么/为什么/哪个 etc.)
|
||||
cleaned := removeQuestionWords(query)
|
||||
if cleaned != query {
|
||||
addIfNew(cleaned)
|
||||
}
|
||||
|
||||
// Limit to 5 expansions
|
||||
if len(expansions) > 5 {
|
||||
expansions = expansions[:5]
|
||||
}
|
||||
|
||||
pipelineInfo(ctx, "Search", "local_expansion_result", map[string]interface{}{
|
||||
"variants": len(expansions),
|
||||
})
|
||||
return expansions
|
||||
}
|
||||
|
||||
func extractJSONBlock(text string) string {
|
||||
t := strings.TrimSpace(text)
|
||||
if i := strings.Index(t, "["); i >= 0 {
|
||||
j := strings.LastIndex(t, "]")
|
||||
if j > i {
|
||||
return t[i : j+1]
|
||||
}
|
||||
}
|
||||
return "[]"
|
||||
// Common Chinese and English stopwords
|
||||
var stopwords = map[string]struct{}{
|
||||
"的": {}, "是": {}, "在": {}, "了": {}, "和": {}, "与": {}, "或": {},
|
||||
"a": {}, "an": {}, "the": {}, "is": {}, "are": {}, "was": {}, "were": {},
|
||||
"be": {}, "been": {}, "being": {}, "have": {}, "has": {}, "had": {},
|
||||
"do": {}, "does": {}, "did": {}, "will": {}, "would": {}, "could": {},
|
||||
"should": {}, "may": {}, "might": {}, "must": {}, "can": {},
|
||||
"to": {}, "of": {}, "in": {}, "for": {}, "on": {}, "with": {}, "at": {},
|
||||
"by": {}, "from": {}, "as": {}, "into": {}, "through": {}, "about": {},
|
||||
"what": {}, "how": {}, "why": {}, "when": {}, "where": {}, "which": {},
|
||||
"who": {}, "whom": {}, "whose": {},
|
||||
}
|
||||
|
||||
// normalizeKeywordSearchResults normalizes keyword search result scores into [0,1] globally across all knowledge bases
|
||||
// Improvements:
|
||||
// 1. Uses robust normalization with percentile-based bounds to handle outliers
|
||||
// 2. Handles edge cases: single result, no variance, negative scores
|
||||
// 3. Global normalization ensures fair comparison across different knowledge bases
|
||||
func normalizeKeywordSearchResults(ctx context.Context, results []*types.SearchResult) {
|
||||
searchutil.NormalizeKeywordScores[*types.SearchResult](
|
||||
results,
|
||||
func(r *types.SearchResult) bool {
|
||||
return r.MatchType == types.MatchTypeKeywords
|
||||
},
|
||||
func(r *types.SearchResult) float64 {
|
||||
return r.Score
|
||||
},
|
||||
func(r *types.SearchResult, score float64) {
|
||||
r.Score = score
|
||||
},
|
||||
searchutil.KeywordScoreCallbacks{
|
||||
OnNoVariance: func(count int, score float64) {
|
||||
pipelineInfo(ctx, "Search", "keyword_scores_no_variance", map[string]interface{}{
|
||||
"count": count,
|
||||
"score": score,
|
||||
})
|
||||
},
|
||||
OnNormalized: func(count int, rawMin, rawMax, normalizeMin, normalizeMax float64) {
|
||||
pipelineInfo(ctx, "Search", "normalize_keyword_scores", map[string]interface{}{
|
||||
"count": count,
|
||||
"raw_min": rawMin,
|
||||
"raw_max": rawMax,
|
||||
"normalize_min": normalizeMin,
|
||||
"normalize_max": normalizeMax,
|
||||
})
|
||||
},
|
||||
},
|
||||
)
|
||||
// Question words in Chinese
|
||||
var questionWords = regexp.MustCompile(`^(什么是|什么|如何|怎么|怎样|为什么|为何|哪个|哪些|谁|何时|何地|请问|请告诉我|帮我|我想知道|我想了解)`)
|
||||
|
||||
func extractKeywords(text string) []string {
|
||||
words := tokenize(text)
|
||||
keywords := make([]string, 0, len(words))
|
||||
for _, w := range words {
|
||||
lower := strings.ToLower(w)
|
||||
if _, isStop := stopwords[lower]; !isStop && len(w) > 1 {
|
||||
keywords = append(keywords, w)
|
||||
}
|
||||
}
|
||||
return keywords
|
||||
}
|
||||
|
||||
func extractPhrases(text string) []string {
|
||||
// Extract quoted content
|
||||
var phrases []string
|
||||
re := regexp.MustCompile(`["'"'「」『』]([^"'"'「」『』]+)["'"'「」『』]`)
|
||||
matches := re.FindAllStringSubmatch(text, -1)
|
||||
for _, m := range matches {
|
||||
if len(m) > 1 && len(m[1]) > 2 {
|
||||
phrases = append(phrases, m[1])
|
||||
}
|
||||
}
|
||||
return phrases
|
||||
}
|
||||
|
||||
func splitByDelimiters(text string) []string {
|
||||
// Split by common delimiters
|
||||
re := regexp.MustCompile(`[,,;;、。!?!?\s]+`)
|
||||
parts := re.Split(text, -1)
|
||||
var result []string
|
||||
for _, p := range parts {
|
||||
p = strings.TrimSpace(p)
|
||||
if p != "" {
|
||||
result = append(result, p)
|
||||
}
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
func removeQuestionWords(text string) string {
|
||||
return strings.TrimSpace(questionWords.ReplaceAllString(text, ""))
|
||||
}
|
||||
|
||||
func tokenize(text string) []string {
|
||||
var tokens []string
|
||||
var current strings.Builder
|
||||
|
||||
for _, r := range text {
|
||||
if unicode.IsLetter(r) || unicode.IsDigit(r) {
|
||||
current.WriteRune(r)
|
||||
} else if unicode.Is(unicode.Han, r) {
|
||||
// Flush current token
|
||||
if current.Len() > 0 {
|
||||
tokens = append(tokens, current.String())
|
||||
current.Reset()
|
||||
}
|
||||
// Chinese character as single token
|
||||
tokens = append(tokens, string(r))
|
||||
} else {
|
||||
// Delimiter
|
||||
if current.Len() > 0 {
|
||||
tokens = append(tokens, current.String())
|
||||
current.Reset()
|
||||
}
|
||||
}
|
||||
}
|
||||
if current.Len() > 0 {
|
||||
tokens = append(tokens, current.String())
|
||||
}
|
||||
return tokens
|
||||
}
|
||||
|
||||
@@ -190,5 +190,6 @@ func chunk2SearchResult(chunk *types.Chunk, knowledge *types.Knowledge) *types.S
|
||||
ImageInfo: chunk.ImageInfo,
|
||||
KnowledgeFilename: knowledge.FileName,
|
||||
KnowledgeSource: knowledge.Source,
|
||||
ChunkMetadata: chunk.Metadata,
|
||||
}
|
||||
}
|
||||
|
||||
181
internal/application/service/chat_pipline/search_parallel.go
Normal file
181
internal/application/service/chat_pipline/search_parallel.go
Normal file
@@ -0,0 +1,181 @@
|
||||
package chatpipline
|
||||
|
||||
import (
|
||||
"context"
|
||||
"sync"
|
||||
|
||||
"github.com/Tencent/WeKnora/internal/config"
|
||||
"github.com/Tencent/WeKnora/internal/logger"
|
||||
"github.com/Tencent/WeKnora/internal/types"
|
||||
"github.com/Tencent/WeKnora/internal/types/interfaces"
|
||||
)
|
||||
|
||||
// PluginSearchParallel implements parallel search functionality combining chunk search and entity search
|
||||
type PluginSearchParallel struct {
|
||||
// Chunk search dependencies
|
||||
knowledgeBaseService interfaces.KnowledgeBaseService
|
||||
knowledgeService interfaces.KnowledgeService
|
||||
config *config.Config
|
||||
webSearchService interfaces.WebSearchService
|
||||
tenantService interfaces.TenantService
|
||||
sessionService interfaces.SessionService
|
||||
|
||||
// Entity search dependencies
|
||||
graphRepo interfaces.RetrieveGraphRepository
|
||||
chunkRepo interfaces.ChunkRepository
|
||||
knowledgeRepo interfaces.KnowledgeRepository
|
||||
|
||||
// Internal plugins
|
||||
searchPlugin *PluginSearch
|
||||
searchEntityPlugin *PluginSearchEntity
|
||||
}
|
||||
|
||||
// NewPluginSearchParallel creates a new parallel search plugin
|
||||
func NewPluginSearchParallel(
|
||||
eventManager *EventManager,
|
||||
knowledgeBaseService interfaces.KnowledgeBaseService,
|
||||
knowledgeService interfaces.KnowledgeService,
|
||||
config *config.Config,
|
||||
webSearchService interfaces.WebSearchService,
|
||||
tenantService interfaces.TenantService,
|
||||
sessionService interfaces.SessionService,
|
||||
graphRepository interfaces.RetrieveGraphRepository,
|
||||
chunkRepository interfaces.ChunkRepository,
|
||||
knowledgeRepository interfaces.KnowledgeRepository,
|
||||
) *PluginSearchParallel {
|
||||
// Create internal plugins without registering them
|
||||
searchPlugin := &PluginSearch{
|
||||
knowledgeBaseService: knowledgeBaseService,
|
||||
knowledgeService: knowledgeService,
|
||||
config: config,
|
||||
webSearchService: webSearchService,
|
||||
tenantService: tenantService,
|
||||
sessionService: sessionService,
|
||||
}
|
||||
|
||||
searchEntityPlugin := &PluginSearchEntity{
|
||||
graphRepo: graphRepository,
|
||||
chunkRepo: chunkRepository,
|
||||
knowledgeRepo: knowledgeRepository,
|
||||
}
|
||||
|
||||
res := &PluginSearchParallel{
|
||||
knowledgeBaseService: knowledgeBaseService,
|
||||
knowledgeService: knowledgeService,
|
||||
config: config,
|
||||
webSearchService: webSearchService,
|
||||
tenantService: tenantService,
|
||||
sessionService: sessionService,
|
||||
graphRepo: graphRepository,
|
||||
chunkRepo: chunkRepository,
|
||||
knowledgeRepo: knowledgeRepository,
|
||||
searchPlugin: searchPlugin,
|
||||
searchEntityPlugin: searchEntityPlugin,
|
||||
}
|
||||
eventManager.Register(res)
|
||||
return res
|
||||
}
|
||||
|
||||
// ActivationEvents returns the event types this plugin handles
|
||||
func (p *PluginSearchParallel) ActivationEvents() []types.EventType {
|
||||
return []types.EventType{types.CHUNK_SEARCH_PARALLEL}
|
||||
}
|
||||
|
||||
// OnEvent handles parallel search events - runs chunk search and entity search concurrently
|
||||
func (p *PluginSearchParallel) OnEvent(ctx context.Context,
|
||||
eventType types.EventType, chatManage *types.ChatManage, next func() *PluginError,
|
||||
) *PluginError {
|
||||
pipelineInfo(ctx, "SearchParallel", "start", map[string]interface{}{
|
||||
"session_id": chatManage.SessionID,
|
||||
"has_entities": len(chatManage.Entity) > 0,
|
||||
"rewrite_query": chatManage.RewriteQuery,
|
||||
})
|
||||
|
||||
var wg sync.WaitGroup
|
||||
var mu sync.Mutex
|
||||
var chunkSearchErr *PluginError
|
||||
var entitySearchErr *PluginError
|
||||
|
||||
// Use separate ChatManage copies to avoid concurrent write conflicts
|
||||
chunkChatManage := *chatManage
|
||||
chunkChatManage.SearchResult = nil
|
||||
|
||||
entityChatManage := *chatManage
|
||||
entityChatManage.SearchResult = nil
|
||||
|
||||
// Run chunk search and entity search in parallel
|
||||
wg.Add(2)
|
||||
|
||||
// Goroutine 1: Chunk Search
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
err := p.searchPlugin.OnEvent(ctx, types.CHUNK_SEARCH, &chunkChatManage, func() *PluginError {
|
||||
return nil
|
||||
})
|
||||
if err != nil && err != ErrSearchNothing {
|
||||
mu.Lock()
|
||||
chunkSearchErr = err
|
||||
mu.Unlock()
|
||||
}
|
||||
pipelineInfo(ctx, "SearchParallel", "chunk_search_done", map[string]interface{}{
|
||||
"result_count": len(chunkChatManage.SearchResult),
|
||||
"has_error": err != nil && err != ErrSearchNothing,
|
||||
})
|
||||
}()
|
||||
|
||||
// Goroutine 2: Entity Search (only if entities are available)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
if len(chatManage.Entity) == 0 {
|
||||
pipelineInfo(ctx, "SearchParallel", "entity_search_skip", map[string]interface{}{
|
||||
"reason": "no_entities",
|
||||
})
|
||||
return
|
||||
}
|
||||
err := p.searchEntityPlugin.OnEvent(ctx, types.ENTITY_SEARCH, &entityChatManage, func() *PluginError {
|
||||
return nil
|
||||
})
|
||||
if err != nil && err != ErrSearchNothing {
|
||||
mu.Lock()
|
||||
entitySearchErr = err
|
||||
mu.Unlock()
|
||||
}
|
||||
pipelineInfo(ctx, "SearchParallel", "entity_search_done", map[string]interface{}{
|
||||
"result_count": len(entityChatManage.SearchResult),
|
||||
"has_error": err != nil && err != ErrSearchNothing,
|
||||
})
|
||||
}()
|
||||
|
||||
wg.Wait()
|
||||
|
||||
// Merge results from both searches (no concurrent access now)
|
||||
chatManage.SearchResult = append(chunkChatManage.SearchResult, entityChatManage.SearchResult...)
|
||||
chatManage.SearchResult = removeDuplicateResults(chatManage.SearchResult)
|
||||
|
||||
// Log any errors but don't fail the pipeline if at least one search succeeded
|
||||
if chunkSearchErr != nil {
|
||||
logger.Warnf(ctx, "[SearchParallel] Chunk search error: %v", chunkSearchErr.Err)
|
||||
}
|
||||
if entitySearchErr != nil {
|
||||
logger.Warnf(ctx, "[SearchParallel] Entity search error: %v", entitySearchErr.Err)
|
||||
}
|
||||
|
||||
pipelineInfo(ctx, "SearchParallel", "complete", map[string]interface{}{
|
||||
"session_id": chatManage.SessionID,
|
||||
"chunk_results": len(chunkChatManage.SearchResult),
|
||||
"entity_results": len(entityChatManage.SearchResult),
|
||||
"total_results": len(chatManage.SearchResult),
|
||||
"chunk_search_error": chunkSearchErr != nil,
|
||||
"entity_search_error": entitySearchErr != nil,
|
||||
})
|
||||
|
||||
// Return error only if both searches failed and we have no results
|
||||
if len(chatManage.SearchResult) == 0 {
|
||||
if chunkSearchErr != nil {
|
||||
return chunkSearchErr
|
||||
}
|
||||
return ErrSearchNothing
|
||||
}
|
||||
|
||||
return next()
|
||||
}
|
||||
@@ -34,7 +34,7 @@ func (p *PluginTracing) ActivationEvents() []types.EventType {
|
||||
types.CHAT_COMPLETION_STREAM,
|
||||
types.FILTER_TOP_K,
|
||||
types.REWRITE_QUERY,
|
||||
types.PREPROCESS_QUERY,
|
||||
types.CHUNK_SEARCH_PARALLEL,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -69,8 +69,8 @@ func (p *PluginTracing) OnEvent(ctx context.Context,
|
||||
return p.FilterTopK(ctx, eventType, chatManage, next)
|
||||
case types.REWRITE_QUERY:
|
||||
return p.RewriteQuery(ctx, eventType, chatManage, next)
|
||||
case types.PREPROCESS_QUERY:
|
||||
return p.PreprocessQuery(ctx, eventType, chatManage, next)
|
||||
case types.CHUNK_SEARCH_PARALLEL:
|
||||
return p.SearchParallel(ctx, eventType, chatManage, next)
|
||||
}
|
||||
return next()
|
||||
}
|
||||
@@ -95,7 +95,6 @@ func (p *PluginTracing) Search(ctx context.Context,
|
||||
}
|
||||
span.SetAttributes(
|
||||
attribute.String("hybrid_search", string(searchResultJson)),
|
||||
attribute.String("processed_query", chatManage.ProcessedQuery),
|
||||
attribute.Int("search_unique_count", len(unique)),
|
||||
)
|
||||
return err
|
||||
@@ -119,7 +118,6 @@ func (p *PluginTracing) Rerank(ctx context.Context,
|
||||
span.SetAttributes(
|
||||
attribute.Int("rerank_resp_count", len(chatManage.RerankResult)),
|
||||
attribute.String("rerank_resp_results", string(resultJson)),
|
||||
attribute.String("query_intent", chatManage.QueryIntent),
|
||||
)
|
||||
return err
|
||||
}
|
||||
@@ -266,22 +264,20 @@ func (p *PluginTracing) RewriteQuery(ctx context.Context,
|
||||
return err
|
||||
}
|
||||
|
||||
// PreprocessQuery traces query preprocessing operations
|
||||
func (p *PluginTracing) PreprocessQuery(ctx context.Context,
|
||||
// SearchParallel traces parallel search operations (chunk + entity)
|
||||
func (p *PluginTracing) SearchParallel(ctx context.Context,
|
||||
eventType types.EventType, chatManage *types.ChatManage, next func() *PluginError,
|
||||
) *PluginError {
|
||||
_, span := tracing.ContextWithSpan(ctx, "PluginTracing.PreprocessQuery")
|
||||
_, span := tracing.ContextWithSpan(ctx, "PluginTracing.SearchParallel")
|
||||
defer span.End()
|
||||
|
||||
span.SetAttributes(
|
||||
attribute.String("query", chatManage.Query),
|
||||
attribute.String("rewrite_query", chatManage.RewriteQuery),
|
||||
attribute.Int("entity_count", len(chatManage.Entity)),
|
||||
)
|
||||
|
||||
err := next()
|
||||
|
||||
span.SetAttributes(
|
||||
attribute.String("processed_query", chatManage.ProcessedQuery),
|
||||
attribute.Int("search_result_count", len(chatManage.SearchResult)),
|
||||
)
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -9,7 +9,6 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/Tencent/WeKnora/internal/application/service/retriever"
|
||||
"github.com/Tencent/WeKnora/internal/common"
|
||||
"github.com/Tencent/WeKnora/internal/logger"
|
||||
"github.com/Tencent/WeKnora/internal/models/embedding"
|
||||
"github.com/Tencent/WeKnora/internal/types"
|
||||
@@ -523,35 +522,113 @@ func (s *knowledgeBaseService) HybridSearch(ctx context.Context,
|
||||
|
||||
// Collect all results from different retrievers and deduplicate by chunk ID
|
||||
logger.Infof(ctx, "Processing retrieval results")
|
||||
matchResults := []*types.IndexWithScore{}
|
||||
|
||||
// Separate results by retriever type for RRF fusion
|
||||
var vectorResults []*types.IndexWithScore
|
||||
var keywordResults []*types.IndexWithScore
|
||||
for _, retrieveResult := range retrieveResults {
|
||||
logger.Infof(ctx, "Retrieval results, engine: %v, retriever: %v, count: %v",
|
||||
retrieveResult.RetrieverEngineType,
|
||||
retrieveResult.RetrieverType,
|
||||
len(retrieveResult.Results),
|
||||
)
|
||||
matchResults = append(matchResults, retrieveResult.Results...)
|
||||
if retrieveResult.RetrieverType == types.VectorRetrieverType {
|
||||
vectorResults = append(vectorResults, retrieveResult.Results...)
|
||||
} else {
|
||||
keywordResults = append(keywordResults, retrieveResult.Results...)
|
||||
}
|
||||
}
|
||||
|
||||
// Early return if no results
|
||||
if len(matchResults) == 0 {
|
||||
if len(vectorResults) == 0 && len(keywordResults) == 0 {
|
||||
logger.Info(ctx, "No search results found")
|
||||
return nil, nil
|
||||
}
|
||||
logger.Infof(ctx, "Result count before deduplication: %d", len(matchResults))
|
||||
logger.Infof(ctx, "Result count before RRF fusion: vector=%d, keyword=%d", len(vectorResults), len(keywordResults))
|
||||
|
||||
// First, try standard deduplication
|
||||
deduplicatedChunks := common.DeduplicateWithScore(
|
||||
func(r *types.IndexWithScore) string { return r.ChunkID },
|
||||
matchResults...)
|
||||
logger.Infof(ctx, "Result count after deduplication: %d", len(deduplicatedChunks))
|
||||
// Use RRF (Reciprocal Rank Fusion) to merge results
|
||||
// RRF score = sum(1 / (k + rank)) for each retriever where the chunk appears
|
||||
// k=60 is a common choice that works well in practice
|
||||
const rrfK = 60
|
||||
|
||||
// Build rank maps for each retriever (already sorted by score from retriever)
|
||||
vectorRanks := make(map[string]int)
|
||||
for i, r := range vectorResults {
|
||||
if _, exists := vectorRanks[r.ChunkID]; !exists {
|
||||
vectorRanks[r.ChunkID] = i + 1 // 1-indexed rank
|
||||
}
|
||||
}
|
||||
keywordRanks := make(map[string]int)
|
||||
for i, r := range keywordResults {
|
||||
if _, exists := keywordRanks[r.ChunkID]; !exists {
|
||||
keywordRanks[r.ChunkID] = i + 1 // 1-indexed rank
|
||||
}
|
||||
}
|
||||
|
||||
// Collect all unique chunks and compute RRF scores
|
||||
chunkInfoMap := make(map[string]*types.IndexWithScore)
|
||||
rrfScores := make(map[string]float64)
|
||||
|
||||
// Process vector results
|
||||
for _, r := range vectorResults {
|
||||
if _, exists := chunkInfoMap[r.ChunkID]; !exists {
|
||||
chunkInfoMap[r.ChunkID] = r
|
||||
}
|
||||
}
|
||||
// Process keyword results
|
||||
for _, r := range keywordResults {
|
||||
if _, exists := chunkInfoMap[r.ChunkID]; !exists {
|
||||
chunkInfoMap[r.ChunkID] = r
|
||||
}
|
||||
}
|
||||
|
||||
// Compute RRF scores
|
||||
for chunkID := range chunkInfoMap {
|
||||
rrfScore := 0.0
|
||||
if rank, ok := vectorRanks[chunkID]; ok {
|
||||
rrfScore += 1.0 / float64(rrfK+rank)
|
||||
}
|
||||
if rank, ok := keywordRanks[chunkID]; ok {
|
||||
rrfScore += 1.0 / float64(rrfK+rank)
|
||||
}
|
||||
rrfScores[chunkID] = rrfScore
|
||||
}
|
||||
|
||||
// Convert to slice and sort by RRF score
|
||||
deduplicatedChunks := make([]*types.IndexWithScore, 0, len(chunkInfoMap))
|
||||
for chunkID, info := range chunkInfoMap {
|
||||
// Store RRF score in the Score field for downstream processing
|
||||
info.Score = rrfScores[chunkID]
|
||||
deduplicatedChunks = append(deduplicatedChunks, info)
|
||||
}
|
||||
slices.SortFunc(deduplicatedChunks, func(a, b *types.IndexWithScore) int {
|
||||
if a.Score > b.Score {
|
||||
return -1
|
||||
} else if a.Score < b.Score {
|
||||
return 1
|
||||
}
|
||||
return 0
|
||||
})
|
||||
|
||||
logger.Infof(ctx, "Result count after RRF fusion: %d", len(deduplicatedChunks))
|
||||
|
||||
// Log top results after RRF fusion for debugging
|
||||
for i, chunk := range deduplicatedChunks {
|
||||
if i < 15 {
|
||||
vRank, vOk := vectorRanks[chunk.ChunkID]
|
||||
kRank, kOk := keywordRanks[chunk.ChunkID]
|
||||
logger.Debugf(ctx, "RRF rank %d: chunk_id=%s, rrf_score=%.6f, vector_rank=%v(%v), keyword_rank=%v(%v)",
|
||||
i, chunk.ChunkID, chunk.Score, vRank, vOk, kRank, kOk)
|
||||
}
|
||||
}
|
||||
|
||||
kb.EnsureDefaults()
|
||||
|
||||
// Check if we need iterative retrieval for FAQ with separate indexing
|
||||
// Only use iterative retrieval if we don't have enough unique chunks after first deduplication
|
||||
totalRetrieved := len(vectorResults) + len(keywordResults)
|
||||
needsIterativeRetrieval := len(deduplicatedChunks) < params.MatchCount &&
|
||||
kb.Type == types.KnowledgeBaseTypeFAQ && len(matchResults) == matchCount
|
||||
kb.Type == types.KnowledgeBaseTypeFAQ && totalRetrieved == matchCount*2
|
||||
if needsIterativeRetrieval {
|
||||
logger.Info(ctx, "Not enough unique chunks, using iterative retrieval for FAQ")
|
||||
// Use iterative retrieval to get more unique chunks (with negative question filtering inside)
|
||||
@@ -891,10 +968,38 @@ func (s *knowledgeBaseService) processSearchResults(ctx context.Context,
|
||||
}
|
||||
}
|
||||
|
||||
// Build final search results
|
||||
// Build final search results - preserve original order from input chunks
|
||||
var searchResults []*types.SearchResult
|
||||
for chunkID, chunk := range chunkMap {
|
||||
addedChunkIDs := make(map[string]bool)
|
||||
|
||||
// First pass: Add results in the original order from input chunks
|
||||
for _, inputChunk := range chunks {
|
||||
chunk, exists := chunkMap[inputChunk.ChunkID]
|
||||
if !exists {
|
||||
logger.Debugf(ctx, "Chunk not found in chunkMap: %s", inputChunk.ChunkID)
|
||||
continue
|
||||
}
|
||||
if !s.isValidTextChunk(chunk) {
|
||||
logger.Debugf(ctx, "Chunk is not valid text chunk: %s, type: %s", chunk.ID, chunk.ChunkType)
|
||||
continue
|
||||
}
|
||||
if addedChunkIDs[chunk.ID] {
|
||||
continue
|
||||
}
|
||||
|
||||
score := chunkScores[chunk.ID]
|
||||
if knowledge, ok := knowledgeMap[chunk.KnowledgeID]; ok {
|
||||
matchType := chunkMatchTypes[chunk.ID]
|
||||
searchResults = append(searchResults, s.buildSearchResult(chunk, knowledge, score, matchType))
|
||||
addedChunkIDs[chunk.ID] = true
|
||||
} else {
|
||||
logger.Warnf(ctx, "Knowledge not found for chunk: %s, knowledge_id: %s", chunk.ID, chunk.KnowledgeID)
|
||||
}
|
||||
}
|
||||
|
||||
// Second pass: Add additional chunks (parent, nearby, relation) that weren't in original input
|
||||
for chunkID, chunk := range chunkMap {
|
||||
if addedChunkIDs[chunkID] || !s.isValidTextChunk(chunk) {
|
||||
continue
|
||||
}
|
||||
|
||||
@@ -959,6 +1064,7 @@ func (s *knowledgeBaseService) buildSearchResult(chunk *types.Chunk,
|
||||
ImageInfo: chunk.ImageInfo,
|
||||
KnowledgeFilename: knowledge.FileName,
|
||||
KnowledgeSource: knowledge.Source,
|
||||
ChunkMetadata: chunk.Metadata,
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -800,11 +800,10 @@ func (s *sessionService) SearchKnowledge(ctx context.Context,
|
||||
|
||||
// Use specific event list, only including retrieval-related events, not LLM summarization
|
||||
searchEvents := []types.EventType{
|
||||
types.PREPROCESS_QUERY, // Preprocess query
|
||||
types.CHUNK_SEARCH, // Vector search
|
||||
types.CHUNK_RERANK, // Rerank search results
|
||||
types.CHUNK_MERGE, // Merge search results
|
||||
types.FILTER_TOP_K, // Filter top K results
|
||||
types.CHUNK_SEARCH, // Vector search
|
||||
types.CHUNK_RERANK, // Rerank search results
|
||||
types.CHUNK_MERGE, // Merge search results
|
||||
types.FILTER_TOP_K, // Filter top K results
|
||||
}
|
||||
|
||||
ctx, span := tracing.ContextWithSpan(ctx, "SessionService.SearchKnowledge")
|
||||
|
||||
@@ -140,10 +140,10 @@ func BuildContainer(container *dig.Container) *dig.Container {
|
||||
must(container.Invoke(chatpipline.NewPluginChatCompletionStream))
|
||||
must(container.Invoke(chatpipline.NewPluginStreamFilter))
|
||||
must(container.Invoke(chatpipline.NewPluginFilterTopK))
|
||||
must(container.Invoke(chatpipline.NewPluginPreprocess))
|
||||
must(container.Invoke(chatpipline.NewPluginRewrite))
|
||||
must(container.Invoke(chatpipline.NewPluginExtractEntity))
|
||||
must(container.Invoke(chatpipline.NewPluginSearchEntity))
|
||||
must(container.Invoke(chatpipline.NewPluginSearchParallel))
|
||||
|
||||
// HTTP handlers layer
|
||||
must(container.Provide(handler.NewTenantHandler))
|
||||
|
||||
@@ -5,7 +5,6 @@ package event
|
||||
// QueryData represents query-related event data
|
||||
type QueryData struct {
|
||||
OriginalQuery string `json:"original_query"`
|
||||
ProcessedQuery string `json:"processed_query,omitempty"`
|
||||
RewrittenQuery string `json:"rewritten_query,omitempty"`
|
||||
SessionID string `json:"session_id"`
|
||||
UserID string `json:"user_id,omitempty"`
|
||||
|
||||
70
internal/searchutil/textutil.go
Normal file
70
internal/searchutil/textutil.go
Normal file
@@ -0,0 +1,70 @@
|
||||
package searchutil
|
||||
|
||||
import (
|
||||
"crypto/md5"
|
||||
"encoding/hex"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// BuildContentSignature creates a normalized MD5 signature for content to detect duplicates.
|
||||
// It normalizes the content by lowercasing, trimming whitespace, and collapsing multiple spaces.
|
||||
func BuildContentSignature(content string) string {
|
||||
c := strings.ToLower(strings.TrimSpace(content))
|
||||
if c == "" {
|
||||
return ""
|
||||
}
|
||||
// Normalize whitespace
|
||||
c = strings.Join(strings.Fields(c), " ")
|
||||
// Use MD5 hash of full content
|
||||
hash := md5.Sum([]byte(c))
|
||||
return hex.EncodeToString(hash[:])
|
||||
}
|
||||
|
||||
// TokenizeSimple tokenizes text into a set of words (simple whitespace-based).
|
||||
// Returns a map where keys are lowercase tokens with length > 1.
|
||||
func TokenizeSimple(text string) map[string]struct{} {
|
||||
text = strings.ToLower(text)
|
||||
fields := strings.Fields(text)
|
||||
set := make(map[string]struct{}, len(fields))
|
||||
for _, f := range fields {
|
||||
if len(f) > 1 {
|
||||
set[f] = struct{}{}
|
||||
}
|
||||
}
|
||||
return set
|
||||
}
|
||||
|
||||
// Jaccard calculates Jaccard similarity between two token sets.
|
||||
// Returns a value between 0 and 1, where 1 means identical sets.
|
||||
func Jaccard(a, b map[string]struct{}) float64 {
|
||||
if len(a) == 0 && len(b) == 0 {
|
||||
return 0
|
||||
}
|
||||
|
||||
// Calculate intersection
|
||||
inter := 0
|
||||
for k := range a {
|
||||
if _, ok := b[k]; ok {
|
||||
inter++
|
||||
}
|
||||
}
|
||||
|
||||
// Calculate union
|
||||
union := len(a) + len(b) - inter
|
||||
if union == 0 {
|
||||
return 0
|
||||
}
|
||||
|
||||
return float64(inter) / float64(union)
|
||||
}
|
||||
|
||||
// ClampFloat clamps a float value to the specified range [minV, maxV].
|
||||
func ClampFloat(v, minV, maxV float64) float64 {
|
||||
if v < minV {
|
||||
return minV
|
||||
}
|
||||
if v > maxV {
|
||||
return maxV
|
||||
}
|
||||
return v
|
||||
}
|
||||
@@ -3,12 +3,10 @@ package types
|
||||
// ChatManage represents the configuration and state for a chat session
|
||||
// including query processing, search parameters, and model configurations
|
||||
type ChatManage struct {
|
||||
SessionID string `json:"session_id"` // Unique identifier for the chat session
|
||||
Query string `json:"query,omitempty"` // Original user query
|
||||
ProcessedQuery string `json:"processed_query,omitempty"` // Query after preprocessing
|
||||
RewriteQuery string `json:"rewrite_query,omitempty"` // Query after rewriting for better retrieval
|
||||
QueryIntent string `json:"query_intent,omitempty"` // Parsed intent: definition/howto/compare/qa/general
|
||||
History []*History `json:"history,omitempty"` // Chat history for context
|
||||
SessionID string `json:"session_id"` // Unique identifier for the chat session
|
||||
Query string `json:"query,omitempty"` // Original user query
|
||||
RewriteQuery string `json:"rewrite_query,omitempty"` // Query after rewriting for better retrieval
|
||||
History []*History `json:"history,omitempty"` // Chat history for context
|
||||
|
||||
KnowledgeBaseID string `json:"knowledge_base_id"` // ID of the knowledge base to search against (deprecated, use KnowledgeBaseIDs)
|
||||
KnowledgeBaseIDs []string `json:"knowledge_base_ids"` // IDs of knowledge bases to search (multi-KB support)
|
||||
@@ -60,9 +58,7 @@ func (c *ChatManage) Clone() *ChatManage {
|
||||
|
||||
return &ChatManage{
|
||||
Query: c.Query,
|
||||
ProcessedQuery: c.ProcessedQuery,
|
||||
RewriteQuery: c.RewriteQuery,
|
||||
QueryIntent: c.QueryIntent,
|
||||
SessionID: c.SessionID,
|
||||
KnowledgeBaseID: c.KnowledgeBaseID,
|
||||
KnowledgeBaseIDs: knowledgeBaseIDs,
|
||||
@@ -103,9 +99,9 @@ func (c *ChatManage) Clone() *ChatManage {
|
||||
type EventType string
|
||||
|
||||
const (
|
||||
PREPROCESS_QUERY EventType = "preprocess_query" // Query preprocessing stage
|
||||
REWRITE_QUERY EventType = "rewrite_query" // Query rewriting for better retrieval
|
||||
CHUNK_SEARCH EventType = "chunk_search" // Search for relevant chunks
|
||||
CHUNK_SEARCH_PARALLEL EventType = "chunk_search_parallel" // Parallel search: chunks + entities
|
||||
ENTITY_SEARCH EventType = "entity_search" // Search for relevant entities
|
||||
CHUNK_RERANK EventType = "chunk_rerank" // Rerank search results
|
||||
CHUNK_MERGE EventType = "chunk_merge" // Merge similar chunks
|
||||
@@ -134,9 +130,7 @@ var Pipline = map[string][]EventType{
|
||||
},
|
||||
"rag_stream": { // Streaming Retrieval Augmented Generation
|
||||
REWRITE_QUERY,
|
||||
PREPROCESS_QUERY,
|
||||
CHUNK_SEARCH,
|
||||
ENTITY_SEARCH,
|
||||
CHUNK_SEARCH_PARALLEL, // Parallel: CHUNK_SEARCH + ENTITY_SEARCH
|
||||
CHUNK_RERANK,
|
||||
CHUNK_MERGE,
|
||||
FILTER_TOP_K,
|
||||
|
||||
@@ -46,6 +46,9 @@ type SearchResult struct {
|
||||
// Knowledge source
|
||||
// Used to indicate the source of the knowledge, such as "url"
|
||||
KnowledgeSource string `json:"knowledge_source"`
|
||||
|
||||
// ChunkMetadata stores chunk-level metadata (e.g., generated questions)
|
||||
ChunkMetadata JSON `json:"chunk_metadata,omitempty"`
|
||||
}
|
||||
|
||||
// SearchParams represents the search parameters
|
||||
|
||||
Reference in New Issue
Block a user