feat: Implement parallel search functionality combining chunk and entity searches, enhancing retrieval efficiency and result accuracy

This commit is contained in:
wizardchen
2025-12-03 17:45:21 +08:00
parent eafa0cd80b
commit 6b4d17ec70
18 changed files with 2799 additions and 2273 deletions

3595
LICENSE

File diff suppressed because it is too large Load Diff

6
go.mod
View File

@@ -29,7 +29,6 @@ require (
github.com/spf13/viper v1.20.1
github.com/stretchr/testify v1.11.1
github.com/tencentyun/cos-go-sdk-v5 v0.7.65
github.com/xuri/excelize/v2 v2.10.0
github.com/yanyiwu/gojieba v1.4.5
go.opentelemetry.io/otel v1.37.0
go.opentelemetry.io/otel/exporters/otlp/otlptrace v1.37.0
@@ -106,8 +105,6 @@ require (
github.com/pelletier/go-toml/v2 v2.2.3 // indirect
github.com/pierrec/lz4/v4 v4.1.21 // indirect
github.com/pmezard/go-difflib v1.0.0 // indirect
github.com/richardlehane/mscfb v1.0.4 // indirect
github.com/richardlehane/msoleps v1.0.4 // indirect
github.com/rivo/uniseg v0.4.7 // indirect
github.com/robfig/cron/v3 v3.0.1 // indirect
github.com/rs/xid v1.6.0 // indirect
@@ -117,12 +114,9 @@ require (
github.com/spf13/cast v1.10.0 // indirect
github.com/spf13/pflag v1.0.6 // indirect
github.com/subosito/gotenv v1.6.0 // indirect
github.com/tiendc/go-deepcopy v1.7.1 // indirect
github.com/twitchyliquid64/golang-asm v0.15.1 // indirect
github.com/ugorji/go/codec v1.2.12 // indirect
github.com/wk8/go-ordered-map/v2 v2.1.8 // indirect
github.com/xuri/efp v0.0.1 // indirect
github.com/xuri/nfp v0.0.2-0.20250530014748-2ddeb826f9a9 // indirect
github.com/yosida95/uritemplate/v3 v3.0.2 // indirect
go.opentelemetry.io/auto/sdk v1.1.0 // indirect
go.opentelemetry.io/otel/metric v1.37.0 // indirect

15
go.sum
View File

@@ -235,11 +235,6 @@ github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZb
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
github.com/redis/go-redis/v9 v9.14.0 h1:u4tNCjXOyzfgeLN+vAZaW1xUooqWDqVEsZN0U01jfAE=
github.com/redis/go-redis/v9 v9.14.0/go.mod h1:huWgSWd8mW6+m0VPhJjSSQ+d6Nh1VICQ6Q5lHuCH/Iw=
github.com/richardlehane/mscfb v1.0.4 h1:WULscsljNPConisD5hR0+OyZjwK46Pfyr6mPu5ZawpM=
github.com/richardlehane/mscfb v1.0.4/go.mod h1:YzVpcZg9czvAuhk9T+a3avCpcFPMUWm7gK3DypaEsUk=
github.com/richardlehane/msoleps v1.0.1/go.mod h1:BWev5JBpU9Ko2WAgmZEuiz4/u3ZYTKbjLycmwiWUfWg=
github.com/richardlehane/msoleps v1.0.4 h1:WuESlvhX3gH2IHcd8UqyCuFY5yiq/GR/yqaSM/9/g00=
github.com/richardlehane/msoleps v1.0.4/go.mod h1:BWev5JBpU9Ko2WAgmZEuiz4/u3ZYTKbjLycmwiWUfWg=
github.com/rivo/uniseg v0.2.0/go.mod h1:J6wj4VEh+S6ZtnVlnTBMWIodfgj8LQOQFoIToxlJtxc=
github.com/rivo/uniseg v0.4.7 h1:WUdvkW8uEhrYfLC4ZzdpI2ztxP1I582+49Oc5Mq64VQ=
github.com/rivo/uniseg v0.4.7/go.mod h1:FN3SvrM+Zdj16jyLfmOkMNblXMcoc8DfTHruCPUcx88=
@@ -284,8 +279,6 @@ github.com/tencentcloud/tencentcloud-sdk-go/tencentcloud/common v1.0.563/go.mod
github.com/tencentcloud/tencentcloud-sdk-go/tencentcloud/kms v1.0.563/go.mod h1:uom4Nvi9W+Qkom0exYiJ9VWJjXwyxtPYTkKkaLMlfE0=
github.com/tencentyun/cos-go-sdk-v5 v0.7.65 h1:+WBbfwThfZSbxpf1Dw6fyMwyzVtWBBExqfDJ5giiR2s=
github.com/tencentyun/cos-go-sdk-v5 v0.7.65/go.mod h1:8+hG+mQMuRP/OIS9d83syAvXvrMj9HhkND6Q1fLghw0=
github.com/tiendc/go-deepcopy v1.7.1 h1:LnubftI6nYaaMOcaz0LphzwraqN8jiWTwm416sitff4=
github.com/tiendc/go-deepcopy v1.7.1/go.mod h1:4bKjNC2r7boYOkD2IOuZpYjmlDdzjbpTRyCx+goBCJQ=
github.com/tmthrgd/go-hex v0.0.0-20190904060850-447a3041c3bc h1:9lRDQMhESg+zvGYmW5DyG0UqvY96Bu5QYsTLvCHdrgo=
github.com/tmthrgd/go-hex v0.0.0-20190904060850-447a3041c3bc/go.mod h1:bciPuU6GHm1iF1pBvUfxfsH0Wmnc2VbpgvbI9ZWuIRs=
github.com/twitchyliquid64/golang-asm v0.15.1 h1:SU5vSMR7hnwNxj24w34ZyCi/FmDZTkS4MhqMhdFk5YI=
@@ -310,12 +303,6 @@ github.com/wk8/go-ordered-map/v2 v2.1.8 h1:5h/BUHu93oj4gIdvHHHGsScSTMijfx5PeYkE/
github.com/wk8/go-ordered-map/v2 v2.1.8/go.mod h1:5nJHM5DyteebpVlHnWMV0rPz6Zp7+xBAnxjb1X5vnTw=
github.com/x448/float16 v0.8.4 h1:qLwI1I70+NjRFUR3zs1JPUCgaCXSh3SW62uAKT1mSBM=
github.com/x448/float16 v0.8.4/go.mod h1:14CWIYCyZA/cWjXOioeEpHeN/83MdbZDRQHoFcYsOfg=
github.com/xuri/efp v0.0.1 h1:fws5Rv3myXyYni8uwj2qKjVaRP30PdjeYe2Y6FDsCL8=
github.com/xuri/efp v0.0.1/go.mod h1:ybY/Jr0T0GTCnYjKqmdwxyxn2BQf2RcQIIvex5QldPI=
github.com/xuri/excelize/v2 v2.10.0 h1:8aKsP7JD39iKLc6dH5Tw3dgV3sPRh8uRVXu/fMstfW4=
github.com/xuri/excelize/v2 v2.10.0/go.mod h1:SC5TzhQkaOsTWpANfm+7bJCldzcnU/jrhqkTi/iBHBU=
github.com/xuri/nfp v0.0.2-0.20250530014748-2ddeb826f9a9 h1:+C0TIdyyYmzadGaL/HBLbf3WdLgC29pgyhTjAT/0nuE=
github.com/xuri/nfp v0.0.2-0.20250530014748-2ddeb826f9a9/go.mod h1:WwHg+CVyzlv/TX9xqBFXEZAuxOPxn2k1GNHwG41IIUQ=
github.com/yanyiwu/gojieba v1.4.5 h1:VyZogGtdFSnJbACHvDRvDreXPPVPCg8axKFUdblU/JI=
github.com/yanyiwu/gojieba v1.4.5/go.mod h1:JUq4DddFVGdHXJHxxepxRmhrKlDpaBxR8O28v6fKYLY=
github.com/yosida95/uritemplate/v3 v3.0.2 h1:Ed3Oyj9yrmi9087+NczuL5BwkIc4wvTb5zIM+UJPGz4=
@@ -359,8 +346,6 @@ golang.org/x/crypto v0.23.0/go.mod h1:CKFgDieR+mRhux2Lsu27y0fO304Db0wZe70UKqHu0v
golang.org/x/crypto v0.31.0/go.mod h1:kDsLvtWBEx7MV9tJOj9bnXsPbxwJQ6csT/x4KIN4Ssk=
golang.org/x/crypto v0.43.0 h1:dduJYIi3A3KOfdGOHX8AVZ/jGiyPa3IbBozJ5kNuE04=
golang.org/x/crypto v0.43.0/go.mod h1:BFbav4mRNlXJL4wNeejLpWxB7wMbc79PdRGhWKncxR0=
golang.org/x/image v0.25.0 h1:Y6uW6rH1y5y/LK1J8BPWZtr6yZ7hrsy6hFrXjgsc2fQ=
golang.org/x/image v0.25.0/go.mod h1:tCAmOEGthTtkalusGp1g3xa2gke8J6c2N565dTyl9Rs=
golang.org/x/mod v0.6.0-dev.0.20220419223038-86c51ed26bb4/go.mod h1:jJ57K6gSWd91VN4djpZkiMVwK6gcyfeH4XE8wZrZaV4=
golang.org/x/mod v0.8.0/go.mod h1:iBbtSCu2XBx23ZKBPSOrRkjjQPZFPuis4dIYUhu/chs=
golang.org/x/mod v0.12.0/go.mod h1:iBbtSCu2XBx23ZKBPSOrRkjjQPZFPuis4dIYUhu/chs=

View File

@@ -8,6 +8,7 @@ import (
"strings"
"github.com/Tencent/WeKnora/internal/logger"
"github.com/Tencent/WeKnora/internal/searchutil"
"github.com/Tencent/WeKnora/internal/types"
"gorm.io/gorm"
)
@@ -546,17 +547,7 @@ func (t *GrepChunksTool) deduplicateChunks(ctx context.Context, results []chunkW
// buildContentSignature creates a normalized signature for content to detect near-duplicates
func (t *GrepChunksTool) buildContentSignature(content string) string {
c := strings.ToLower(strings.TrimSpace(content))
if c == "" {
return ""
}
// Normalize whitespace
c = strings.Join(strings.Fields(c), " ")
// Use first 128 characters as signature
if len(c) > 128 {
c = c[:128]
}
return c
return searchutil.BuildContentSignature(content)
}
// scoreChunks calculates match scores for chunks based on pattern matches
@@ -693,36 +684,10 @@ func (t *GrepChunksTool) applyMMR(
// tokenizeSimple tokenizes text into a set of words (simple whitespace-based)
func (t *GrepChunksTool) tokenizeSimple(text string) map[string]struct{} {
text = strings.ToLower(text)
fields := strings.Fields(text)
set := make(map[string]struct{}, len(fields))
for _, f := range fields {
if len(f) > 1 {
set[f] = struct{}{}
}
}
return set
return searchutil.TokenizeSimple(text)
}
// jaccard calculates Jaccard similarity between two token sets
func (t *GrepChunksTool) jaccard(a, b map[string]struct{}) float64 {
if len(a) == 0 && len(b) == 0 {
return 0
}
// Calculate intersection
inter := 0
for k := range a {
if _, ok := b[k]; ok {
inter++
}
}
// Calculate union
union := len(a) + len(b) - inter
if union == 0 {
return 0
}
return float64(inter) / float64(union)
return searchutil.Jaccard(a, b)
}

View File

@@ -264,15 +264,12 @@ func (t *KnowledgeSearchTool) Execute(ctx context.Context, args map[string]inter
topK, vectorThreshold, keywordThreshold, kbTypeMap)
logger.Infof(ctx, "[Tool][KnowledgeSearch] Concurrent search completed: %d raw results", len(allResults))
// Normalize keyword search results to ensure fair comparison across knowledge bases
logger.Debugf(ctx, "[Tool][KnowledgeSearch] Normalizing keyword search results...")
t.normalizeKeywordSearchResults(ctx, allResults)
logger.Infof(ctx, "[Tool][KnowledgeSearch] After keyword normalization: %d results", len(allResults))
// Note: HybridSearch now uses RRF (Reciprocal Rank Fusion) which produces normalized scores
// RRF scores are in range [0, ~0.033] (max when rank=1 on both sides: 2/(60+1))
// Threshold filtering is already done inside HybridSearch before RRF, so we skip it here
// Filter by threshold first
filteredResults := t.filterByThreshold(allResults, vectorThreshold, keywordThreshold)
// Deduplicate before reranking to reduce processing overhead
deduplicatedBeforeRerank := t.deduplicateResults(filteredResults)
deduplicatedBeforeRerank := t.deduplicateResults(allResults)
// Apply ReRank if model is configured
// Prefer chatModel (LLM-based reranking) over rerankModel if both are available
@@ -286,6 +283,9 @@ func (t *KnowledgeSearchTool) Execute(ctx context.Context, args map[string]inter
}
}
// Variable to hold results through reranking and MMR stages
var filteredResults []*searchResultWithMeta
if t.chatModel != nil && len(deduplicatedBeforeRerank) > 0 && rerankQuery != "" {
logger.Infof(
ctx,
@@ -347,10 +347,9 @@ func (t *KnowledgeSearchTool) Execute(ctx context.Context, args map[string]inter
}
}
// Apply absolute minimum score filter to remove very low quality chunks
logger.Debugf(ctx, "[Tool][KnowledgeSearch] Applying min_score filter (%.2f)...", minScore)
filteredResults = t.filterByMinScore(filteredResults, minScore)
logger.Infof(ctx, "[Tool][KnowledgeSearch] After min_score filter: %d results", len(filteredResults))
// Note: minScore filter is skipped because HybridSearch now uses RRF scores
// RRF scores are in range [0, ~0.033], not [0, 1], so old thresholds don't apply
// Threshold filtering is already done inside HybridSearch before RRF fusion
// Final deduplication after rerank (in case rerank changed scores/order but duplicates remain)
logger.Debugf(ctx, "[Tool][KnowledgeSearch] Final deduplication after rerank...")
@@ -465,45 +464,6 @@ func (t *KnowledgeSearchTool) concurrentSearch(
return allResults
}
// filterByThreshold filters results based on match type and threshold
// Special handling for history matches: uses lower threshold (reduced by 0.1, minimum 0.5)
func (t *KnowledgeSearchTool) filterByThreshold(
results []*searchResultWithMeta,
vectorThreshold, keywordThreshold float64,
) []*searchResultWithMeta {
filtered := make([]*searchResultWithMeta, 0)
for _, r := range results {
var threshold float64
// Special handling for history matches: use lower threshold
switch r.MatchType {
case types.MatchTypeHistory:
// Use the lower of the two thresholds, then reduce by 0.1 (minimum 0.5)
th := vectorThreshold
if keywordThreshold < th {
th = keywordThreshold
}
threshold = math.Max(th-0.1, 0.5)
case types.MatchTypeEmbedding:
threshold = vectorThreshold
case types.MatchTypeKeywords:
threshold = keywordThreshold
default:
// For other match types (graph, nearby chunk, etc.), use the lower threshold
threshold = vectorThreshold
if keywordThreshold < threshold {
threshold = keywordThreshold
}
}
// Check if result meets threshold
if r.Score >= threshold {
filtered = append(filtered, r)
}
}
return filtered
}
// rerankResults applies reranking to search results using LLM prompt scoring or rerank model
func (t *KnowledgeSearchTool) rerankResults(
ctx context.Context,
@@ -566,15 +526,13 @@ func (t *KnowledgeSearchTool) rerankResults(
}
// Apply composite scoring to reranked results
// Get query intent from context if available (optional)
queryIntent := t.getQueryIntentFromContext(ctx)
logger.Debugf(ctx, "[Tool][KnowledgeSearch] Applying composite scoring with query_intent=%s", queryIntent)
logger.Debugf(ctx, "[Tool][KnowledgeSearch] Applying composite scoring")
// Store base scores before composite scoring
for _, result := range rerankedNonFAQ {
baseScore := result.Score
// Apply composite score
result.Score = t.compositeScore(result, result.Score, baseScore, queryIntent)
result.Score = t.compositeScore(result, result.Score, baseScore)
}
// Combine FAQ results (with original order) and reranked non-FAQ results
@@ -899,20 +857,6 @@ func (t *KnowledgeSearchTool) rerankWithModel(
return reranked, nil
}
// filterByMinScore filters results by absolute minimum score
func (t *KnowledgeSearchTool) filterByMinScore(
results []*searchResultWithMeta,
minScore float64,
) []*searchResultWithMeta {
filtered := make([]*searchResultWithMeta, 0)
for _, r := range results {
if r.Score >= minScore {
filtered = append(filtered, r)
}
}
return filtered
}
// deduplicateResults removes duplicate chunks, keeping the highest score
// Uses multiple keys (ID, parent chunk ID, knowledge+index) and content signature for deduplication
func (t *KnowledgeSearchTool) deduplicateResults(results []*searchResultWithMeta) []*searchResultWithMeta {
@@ -984,17 +928,7 @@ func (t *KnowledgeSearchTool) deduplicateResults(results []*searchResultWithMeta
// buildContentSignature creates a normalized signature for content to detect near-duplicates
func (t *KnowledgeSearchTool) buildContentSignature(content string) string {
c := strings.ToLower(strings.TrimSpace(content))
if c == "" {
return ""
}
// Normalize whitespace
c = strings.Join(strings.Fields(c), " ")
// Use first 128 characters as signature
if len(c) > 128 {
c = c[:128]
}
return c
return searchutil.BuildContentSignature(content)
}
// formatOutput formats the search results for display
@@ -1215,47 +1149,6 @@ type chunkRange struct {
end int
}
// normalizeKeywordSearchResults normalizes keyword search result scores into [0,1] globally across all knowledge bases
// Improvements:
// 1. Uses robust normalization with percentile-based bounds to handle outliers
// 2. Handles edge cases: single result, no variance, negative scores
// 3. Global normalization ensures fair comparison across different knowledge bases
func (t *KnowledgeSearchTool) normalizeKeywordSearchResults(ctx context.Context, results []*searchResultWithMeta) {
searchutil.NormalizeKeywordScores[*searchResultWithMeta](
results,
func(r *searchResultWithMeta) bool {
return r.MatchType == types.MatchTypeKeywords
},
func(r *searchResultWithMeta) float64 {
return r.Score
},
func(r *searchResultWithMeta, score float64) {
r.Score = score
},
searchutil.KeywordScoreCallbacks{
OnNoVariance: func(count int, score float64) {
logger.Infof(
ctx,
"[Tool][KnowledgeSearch] Keyword scores have no variance, all set to 1.0: count=%d, score=%.3f",
count,
score,
)
},
OnNormalized: func(count int, rawMin, rawMax, normalizeMin, normalizeMax float64) {
logger.Infof(
ctx,
"[Tool][KnowledgeSearch] Normalized keyword scores: count=%d, raw_min=%.3f, raw_max=%.3f, normalize_min=%.3f, normalize_max=%.3f",
count,
rawMin,
rawMax,
normalizeMin,
normalizeMax,
)
},
},
)
}
// getEnrichedPassage 合并Content和ImageInfo的文本内容
func (t *KnowledgeSearchTool) getEnrichedPassage(ctx context.Context, result *types.SearchResult) string {
if result.ImageInfo == "" {
@@ -1302,19 +1195,10 @@ func (t *KnowledgeSearchTool) getEnrichedPassage(ctx context.Context, result *ty
return combinedText
}
// getQueryIntentFromContext attempts to extract query intent from context (optional)
func (t *KnowledgeSearchTool) getQueryIntentFromContext(ctx context.Context) string {
// Try to get query intent from context if available
// This is optional and may not always be present in agent tool context
// Return empty string if not available
return ""
}
// compositeScore calculates a composite score considering multiple factors
func (t *KnowledgeSearchTool) compositeScore(
result *searchResultWithMeta,
modelScore, baseScore float64,
queryIntent string,
) float64 {
// Source weight: web_search results get slightly lower weight
sourceWeight := 1.0
@@ -1322,26 +1206,6 @@ func (t *KnowledgeSearchTool) compositeScore(
sourceWeight = 0.95
}
// Intent boost: adjust score based on query intent and chunk characteristics
intentBoost := 1.0
if queryIntent != "" {
switch queryIntent {
case "definition":
// Boost summary chunks for definition queries
if result.ChunkType == string(types.ChunkTypeSummary) {
intentBoost = 1.05
}
case "howto":
// Boost longer chunks for howto queries
if result.EndAt-result.StartAt > 300 {
intentBoost = 1.03
}
case "compare":
// No boost for compare queries
intentBoost = 1.0
}
}
// Position prior: slightly favor chunks earlier in the document
positionPrior := 1.0
if result.StartAt >= 0 && result.EndAt > result.StartAt {
@@ -1352,7 +1216,6 @@ func (t *KnowledgeSearchTool) compositeScore(
// Composite formula: weighted combination of model score, base score, and source weight
composite := 0.6*modelScore + 0.3*baseScore + 0.1*sourceWeight
composite *= intentBoost
composite *= positionPrior
// Clamp to [0, 1]
@@ -1368,13 +1231,7 @@ func (t *KnowledgeSearchTool) compositeScore(
// clampFloat clamps a float value to the specified range
func (t *KnowledgeSearchTool) clampFloat(v, minV, maxV float64) float64 {
if v < minV {
return minV
}
if v > maxV {
return maxV
}
return v
return searchutil.ClampFloat(v, minV, maxV)
}
// applyMMR applies Maximal Marginal Relevance algorithm to reduce redundancy
@@ -1456,36 +1313,10 @@ func (t *KnowledgeSearchTool) applyMMR(
// tokenizeSimple tokenizes text into a set of words (simple whitespace-based)
func (t *KnowledgeSearchTool) tokenizeSimple(text string) map[string]struct{} {
text = strings.ToLower(text)
fields := strings.Fields(text)
set := make(map[string]struct{}, len(fields))
for _, f := range fields {
if len(f) > 1 {
set[f] = struct{}{}
}
}
return set
return searchutil.TokenizeSimple(text)
}
// jaccard calculates Jaccard similarity between two token sets
func (t *KnowledgeSearchTool) jaccard(a, b map[string]struct{}) float64 {
if len(a) == 0 && len(b) == 0 {
return 0
}
// Calculate intersection
inter := 0
for k := range a {
if _, ok := b[k]; ok {
inter++
}
}
// Calculate union
union := len(a) + len(b) - inter
if union == 0 {
return 0
}
return float64(inter) / float64(union)
return searchutil.Jaccard(a, b)
}

View File

@@ -1,166 +0,0 @@
package chatpipline
import (
"context"
"encoding/json"
"regexp"
"strings"
"github.com/Tencent/WeKnora/internal/config"
"github.com/Tencent/WeKnora/internal/models/chat"
"github.com/Tencent/WeKnora/internal/types"
"github.com/Tencent/WeKnora/internal/types/interfaces"
)
// PluginPreprocess Query preprocessing plugin
type PluginPreprocess struct {
config *config.Config
modelService interfaces.ModelService
}
// Regular expressions for text cleaning
var (
multiSpaceRegex = regexp.MustCompile(`\s+`) // Multiple spaces
)
// NewPluginPreprocess Creates a new query preprocessing plugin
func NewPluginPreprocess(
eventManager *EventManager,
config *config.Config,
cleaner interfaces.ResourceCleaner,
modelService interfaces.ModelService,
) *PluginPreprocess {
res := &PluginPreprocess{
config: config,
modelService: modelService,
}
eventManager.Register(res)
return res
}
// ActivationEvents Register activation events
func (p *PluginPreprocess) ActivationEvents() []types.EventType {
return []types.EventType{types.PREPROCESS_QUERY}
}
// OnEvent Process events
func (p *PluginPreprocess) OnEvent(
ctx context.Context,
eventType types.EventType,
chatManage *types.ChatManage,
next func() *PluginError,
) *PluginError {
rawQuery := strings.TrimSpace(chatManage.RewriteQuery)
if rawQuery == "" {
return next()
}
pipelineInfo(ctx, "Preprocess", "input", map[string]interface{}{
"session_id": chatManage.SessionID,
"rewrite_query": rawQuery,
})
// Lightweight normalization: just collapse multiple spaces
processed := multiSpaceRegex.ReplaceAllString(rawQuery, " ")
processed = strings.TrimSpace(processed)
chatManage.ProcessedQuery = processed
chatManage.QueryIntent = p.detectIntentLLM(ctx, chatManage, processed)
pipelineInfo(ctx, "Preprocess", "output", map[string]interface{}{
"session_id": chatManage.SessionID,
"processed_query": processed,
"query_intent": chatManage.QueryIntent,
})
return next()
}
// intentResp is a response for intent detection
type intentResp struct {
Intent string `json:"intent"`
Confidence float64 `json:"confidence"`
}
// detectIntentLLM detects the intent of a query using an LLM
func (p *PluginPreprocess) detectIntentLLM(ctx context.Context, chatManage *types.ChatManage, text string) string {
if p.modelService == nil || chatManage.ChatModelID == "" {
pipelineWarn(
ctx,
"IntentDetect",
"skip",
map[string]interface{}{"reason": "no_model", "session_id": chatManage.SessionID},
)
return "general"
}
chatModel, err := p.modelService.GetChatModel(ctx, chatManage.ChatModelID)
if err != nil {
pipelineWarn(
ctx,
"IntentDetect",
"get_model_failed",
map[string]interface{}{"error": err.Error(), "model_id": chatManage.ChatModelID},
)
return "general"
}
pipelineInfo(
ctx,
"IntentDetect",
"start",
map[string]interface{}{"session_id": chatManage.SessionID, "model_id": chatManage.ChatModelID},
)
sys := "You are a query intent classifier. Classify the user's query into one of: definition, howto, compare, qa, general. Respond ONLY with a JSON object {\"intent\": \"...\", \"confidence\": 0.0 } inside a markdown fenced block."
usr := text
think := false
resp, err := chatModel.Chat(ctx, []chat.Message{
{Role: "system", Content: sys},
{Role: "user", Content: usr},
}, &chat.ChatOptions{Temperature: 0.0, MaxCompletionTokens: 64, Thinking: &think})
if err != nil || resp.Content == "" {
pipelineWarn(ctx, "IntentDetect", "model_call_failed", map[string]interface{}{"error": err})
return "general"
}
body := extractJSONBody(resp.Content)
var ir intentResp
if err := json.Unmarshal([]byte(body), &ir); err != nil {
pipelineWarn(ctx, "IntentDetect", "parse_failed", map[string]interface{}{"body": body, "error": err.Error()})
return "general"
}
pipelineInfo(
ctx,
"IntentDetect",
"result",
map[string]interface{}{"intent": ir.Intent, "confidence": ir.Confidence},
)
switch strings.ToLower(strings.TrimSpace(ir.Intent)) {
case "definition", "howto", "compare", "qa", "general":
return strings.ToLower(ir.Intent)
default:
return "general"
}
}
// extractJSONBody extracts a JSON body from a string
func extractJSONBody(text string) string {
t := strings.TrimSpace(text)
// Try fenced block first
if i := strings.Index(t, "{"); i >= 0 {
j := strings.LastIndex(t, "}")
if j > i {
return t[i : j+1]
}
}
return "{}"
}
// Close Releases resources
func (p *PluginPreprocess) Close() {
}
// ShutdownHandler Returns shutdown function
func (p *PluginPreprocess) ShutdownHandler() func() {
return func() {
p.Close()
}
}

View File

@@ -8,6 +8,7 @@ import (
"strings"
"github.com/Tencent/WeKnora/internal/models/rerank"
"github.com/Tencent/WeKnora/internal/searchutil"
"github.com/Tencent/WeKnora/internal/types"
"github.com/Tencent/WeKnora/internal/types/interfaces"
)
@@ -41,7 +42,6 @@ func (p *PluginRerank) OnEvent(ctx context.Context,
"rerank_model": chatManage.RerankModelID,
"rerank_thresh": chatManage.RerankThreshold,
"rewrite_query": chatManage.RewriteQuery,
"processed_query": chatManage.ProcessedQuery,
})
if len(chatManage.SearchResult) == 0 {
pipelineInfo(ctx, "Rerank", "skip", map[string]interface{}{
@@ -77,18 +77,40 @@ func (p *PluginRerank) OnEvent(ctx context.Context,
passages = append(passages, passage)
}
// Try reranking with different query variants in priority order
// Single rerank call with RewriteQuery, use threshold degradation if no results
originalThreshold := chatManage.RerankThreshold
rerankResp := p.rerank(ctx, chatManage, rerankModel, chatManage.RewriteQuery, passages)
if len(rerankResp) == 0 {
rerankResp = p.rerank(ctx, chatManage, rerankModel, chatManage.ProcessedQuery, passages)
if len(rerankResp) == 0 {
rerankResp = p.rerank(ctx, chatManage, rerankModel, chatManage.Query, passages)
// If no results and threshold is high enough, try with lower threshold
if len(rerankResp) == 0 && originalThreshold > 0.3 {
degradedThreshold := originalThreshold * 0.7
if degradedThreshold < 0.3 {
degradedThreshold = 0.3
}
pipelineInfo(ctx, "Rerank", "threshold_degrade", map[string]interface{}{
"original": originalThreshold,
"degraded": degradedThreshold,
})
chatManage.RerankThreshold = degradedThreshold
rerankResp = p.rerank(ctx, chatManage, rerankModel, chatManage.RewriteQuery, passages)
// Restore original threshold
chatManage.RerankThreshold = originalThreshold
}
pipelineInfo(ctx, "Rerank", "model_response", map[string]interface{}{
"result_cnt": len(rerankResp),
})
// Log input scores before reranking for debugging
for i, sr := range chatManage.SearchResult {
pipelineInfo(ctx, "Rerank", "input_score", map[string]interface{}{
"index": i,
"chunk_id": sr.ID,
"score": fmt.Sprintf("%.4f", sr.Score),
"match_type": sr.MatchType,
})
}
for i := range chatManage.SearchResult {
chatManage.SearchResult[i].Metadata = ensureMetadata(chatManage.SearchResult[i].Metadata)
}
@@ -97,8 +119,15 @@ func (p *PluginRerank) OnEvent(ctx context.Context,
sr := chatManage.SearchResult[rr.Index]
base := sr.Score
sr.Metadata["base_score"] = fmt.Sprintf("%.4f", base)
sr.Score = rr.RelevanceScore
sr.Score = compositeScore(sr, rr.RelevanceScore, base, chatManage)
modelScore := rr.RelevanceScore
sr.Score = compositeScore(sr, modelScore, base)
pipelineInfo(ctx, "Rerank", "composite_calc", map[string]interface{}{
"chunk_id": sr.ID,
"base_score": fmt.Sprintf("%.4f", base),
"model_score": fmt.Sprintf("%.4f", modelScore),
"final_score": fmt.Sprintf("%.4f", sr.Score),
"match_type": sr.MatchType,
})
reranked = append(reranked, sr)
}
final := applyMMR(ctx, reranked, chatManage, min(len(reranked), max(1, chatManage.RerankTopK)), 0.7)
@@ -112,7 +141,6 @@ func (p *PluginRerank) OnEvent(ctx context.Context,
"chunk_id": reranked[i].ID,
"base_score": reranked[i].Metadata["base_score"],
"final_score": fmt.Sprintf("%.4f", reranked[i].Score),
"intent": chatManage.QueryIntent,
})
}
@@ -185,7 +213,7 @@ func ensureMetadata(m map[string]string) map[string]string {
}
// compositeScore calculates the composite score for a search result
func compositeScore(sr *types.SearchResult, modelScore, baseScore float64, chatManage *types.ChatManage) float64 {
func compositeScore(sr *types.SearchResult, modelScore, baseScore float64) float64 {
sourceWeight := 1.0
switch strings.ToLower(sr.KnowledgeSource) {
case "web_search":
@@ -193,27 +221,11 @@ func compositeScore(sr *types.SearchResult, modelScore, baseScore float64, chatM
default:
sourceWeight = 1.0
}
intentBoost := 1.0
if chatManage.QueryIntent != "" {
switch chatManage.QueryIntent {
case "definition":
if sr.ChunkType == string(types.ChunkTypeSummary) {
intentBoost = 1.05
}
case "howto":
if sr.EndAt-sr.StartAt > 300 {
intentBoost = 1.03
}
case "compare":
intentBoost = 1.0
}
}
positionPrior := 1.0
if sr.StartAt >= 0 {
positionPrior += clampFloat(1.0-float64(sr.StartAt)/float64(sr.EndAt+1), -0.05, 0.05)
positionPrior += searchutil.ClampFloat(1.0-float64(sr.StartAt)/float64(sr.EndAt+1), -0.05, 0.05)
}
composite := 0.6*modelScore + 0.3*baseScore + 0.1*sourceWeight
composite *= intentBoost
composite *= positionPrior
if composite < 0 {
composite = 0
@@ -224,7 +236,7 @@ func compositeScore(sr *types.SearchResult, modelScore, baseScore float64, chatM
return composite
}
// applyMMR applies the MMR algorithm to the search results
// applyMMR applies the MMR algorithm to the search results with pre-computed token sets
func applyMMR(
ctx context.Context,
results []*types.SearchResult,
@@ -240,40 +252,60 @@ func applyMMR(
"k": k,
"candidates": len(results),
})
selected := make([]*types.SearchResult, 0, k)
candidates := make([]*types.SearchResult, len(results))
copy(candidates, results)
tokenSets := make([]map[string]struct{}, len(candidates))
for i, r := range candidates {
tokenSets[i] = tokenizeSimple(getEnrichedPassage(ctx, r))
// Pre-compute all token sets upfront (optimization)
allTokenSets := make([]map[string]struct{}, len(results))
for i, r := range results {
allTokenSets[i] = searchutil.TokenizeSimple(getEnrichedPassage(ctx, r))
}
for len(selected) < k && len(candidates) > 0 {
bestIdx := 0
selected := make([]*types.SearchResult, 0, k)
selectedTokenSets := make([]map[string]struct{}, 0, k)
selectedIndices := make(map[int]struct{})
for len(selected) < k && len(selectedIndices) < len(results) {
bestIdx := -1
bestScore := -1.0
for i, r := range candidates {
for i, r := range results {
if _, isSelected := selectedIndices[i]; isSelected {
continue
}
relevance := r.Score
redundancy := 0.0
for _, s := range selected {
redundancy = math.Max(redundancy, jaccard(tokenSets[i], tokenizeSimple(getEnrichedPassage(ctx, s))))
// Use pre-computed token sets for redundancy calculation
for _, selTokens := range selectedTokenSets {
sim := searchutil.Jaccard(allTokenSets[i], selTokens)
if sim > redundancy {
redundancy = sim
}
}
mmr := lambda*relevance - (1.0-lambda)*redundancy
if mmr > bestScore {
bestScore = mmr
bestIdx = i
}
}
selected = append(selected, candidates[bestIdx])
candidates = append(candidates[:bestIdx], candidates[bestIdx+1:]...)
if bestIdx < 0 {
break
}
// Compute average redundancy among selected
selected = append(selected, results[bestIdx])
selectedTokenSets = append(selectedTokenSets, allTokenSets[bestIdx])
selectedIndices[bestIdx] = struct{}{}
}
// Compute average redundancy among selected using pre-computed token sets
avgRed := 0.0
if len(selected) > 1 {
pairs := 0
for i := 0; i < len(selected); i++ {
for j := i + 1; j < len(selected); j++ {
si := tokenizeSimple(getEnrichedPassage(ctx, selected[i]))
sj := tokenizeSimple(getEnrichedPassage(ctx, selected[j]))
avgRed += jaccard(si, sj)
for i := 0; i < len(selectedTokenSets); i++ {
for j := i + 1; j < len(selectedTokenSets); j++ {
avgRed += searchutil.Jaccard(selectedTokenSets[i], selectedTokenSets[j])
pairs++
}
}
@@ -288,93 +320,59 @@ func applyMMR(
return selected
}
// tokenizeSimple tokenizes a text into a set of tokens
func tokenizeSimple(text string) map[string]struct{} {
text = strings.ToLower(text)
fields := strings.Fields(text)
set := make(map[string]struct{}, len(fields))
for _, f := range fields {
if len(f) > 1 {
set[f] = struct{}{}
}
}
return set
}
// jaccard calculates the Jaccard similarity between two sets of tokens
func jaccard(a, b map[string]struct{}) float64 {
if len(a) == 0 && len(b) == 0 {
return 0
}
inter := 0
for k := range a {
if _, ok := b[k]; ok {
inter++
}
}
union := len(a) + len(b) - inter
if union == 0 {
return 0
}
return float64(inter) / float64(union)
}
// clampFloat clamps a float value between a minimum and maximum value
func clampFloat(v, minV, maxV float64) float64 {
if v < minV {
return minV
}
if v > maxV {
return maxV
}
return v
}
// getEnrichedPassage 合并Content和ImageInfo的文本内容
// getEnrichedPassage 合并Content、ImageInfo和GeneratedQuestions的文本内容
func getEnrichedPassage(ctx context.Context, result *types.SearchResult) string {
if result.ImageInfo == "" {
return result.Content
}
combinedText := result.Content
var enrichments []string
// 解析ImageInfo
if result.ImageInfo != "" {
var imageInfos []types.ImageInfo
err := json.Unmarshal([]byte(result.ImageInfo), &imageInfos)
if err != nil {
pipelineWarn(ctx, "Rerank", "image_info_parse", map[string]interface{}{
"error": err.Error(),
})
return result.Content
}
if len(imageInfos) == 0 {
return result.Content
}
} else {
// 提取所有图片的描述和OCR文本
var imageTexts []string
for _, img := range imageInfos {
if img.Caption != "" {
imageTexts = append(imageTexts, fmt.Sprintf("图片描述: %s", img.Caption))
enrichments = append(enrichments, fmt.Sprintf("图片描述: %s", img.Caption))
}
if img.OCRText != "" {
imageTexts = append(imageTexts, fmt.Sprintf("图片文本: %s", img.OCRText))
enrichments = append(enrichments, fmt.Sprintf("图片文本: %s", img.OCRText))
}
}
}
}
if len(imageTexts) == 0 {
return result.Content
// 解析ChunkMetadata中的GeneratedQuestions
if len(result.ChunkMetadata) > 0 {
var docMeta types.DocumentChunkMetadata
err := json.Unmarshal(result.ChunkMetadata, &docMeta)
if err != nil {
pipelineWarn(ctx, "Rerank", "chunk_metadata_parse", map[string]interface{}{
"error": err.Error(),
})
} else if len(docMeta.GeneratedQuestions) > 0 {
enrichments = append(enrichments, fmt.Sprintf("相关问题: %s", strings.Join(docMeta.GeneratedQuestions, "; ")))
}
}
// 组合内容和图片信息
combinedText := result.Content
if len(enrichments) == 0 {
return combinedText
}
// 组合内容和增强信息
if combinedText != "" {
combinedText += "\n\n"
}
combinedText += strings.Join(imageTexts, "\n")
combinedText += strings.Join(enrichments, "\n")
pipelineInfo(ctx, "Rerank", "image_info_merge", map[string]interface{}{
pipelineInfo(ctx, "Rerank", "passage_enrich", map[string]interface{}{
"content_len": len(result.Content),
"image_len": len(strings.Join(imageTexts, "\n")),
"enrichment": strings.Join(enrichments, "\n"),
"enrichment_len": len(strings.Join(enrichments, "\n")),
})
return combinedText

View File

@@ -2,13 +2,14 @@ package chatpipline
import (
"context"
"encoding/json"
"fmt"
"regexp"
"strings"
"sync"
"unicode"
"github.com/Tencent/WeKnora/internal/config"
"github.com/Tencent/WeKnora/internal/models/chat"
"github.com/Tencent/WeKnora/internal/logger"
"github.com/Tencent/WeKnora/internal/searchutil"
"github.com/Tencent/WeKnora/internal/types"
"github.com/Tencent/WeKnora/internal/types/interfaces"
@@ -18,7 +19,6 @@ import (
type PluginSearch struct {
knowledgeBaseService interfaces.KnowledgeBaseService
knowledgeService interfaces.KnowledgeService
modelService interfaces.ModelService
config *config.Config
webSearchService interfaces.WebSearchService
tenantService interfaces.TenantService
@@ -28,7 +28,6 @@ type PluginSearch struct {
func NewPluginSearch(eventManager *EventManager,
knowledgeBaseService interfaces.KnowledgeBaseService,
knowledgeService interfaces.KnowledgeService,
modelService interfaces.ModelService,
config *config.Config,
webSearchService interfaces.WebSearchService,
tenantService interfaces.TenantService,
@@ -37,7 +36,6 @@ func NewPluginSearch(eventManager *EventManager,
res := &PluginSearch{
knowledgeBaseService: knowledgeBaseService,
knowledgeService: knowledgeService,
modelService: modelService,
config: config,
webSearchService: webSearchService,
tenantService: tenantService,
@@ -77,7 +75,6 @@ func (p *PluginSearch) OnEvent(ctx context.Context,
pipelineInfo(ctx, "Search", "input", map[string]interface{}{
"session_id": chatManage.SessionID,
"rewrite_query": chatManage.RewriteQuery,
"processed_query": chatManage.ProcessedQuery,
"kb_ids": strings.Join(knowledgeBaseIDs, ","),
"tenant_id": chatManage.TenantID,
"web_enabled": chatManage.WebSearchEnabled,
@@ -121,6 +118,16 @@ func (p *PluginSearch) OnEvent(ctx context.Context,
chatManage.SearchResult = allResults
// Log all search results with scores before any processing
for i, r := range chatManage.SearchResult {
pipelineInfo(ctx, "Search", "result_score_before_normalize", map[string]interface{}{
"index": i,
"chunk_id": r.ID,
"score": fmt.Sprintf("%.4f", r.Score),
"match_type": r.MatchType,
})
}
// If recall is low, attempt query expansion with keyword-focused search
if chatManage.EnableQueryExpansion && len(chatManage.SearchResult) < max(1, chatManage.EmbeddingTopK/2) {
pipelineInfo(ctx, "Search", "recall_low", map[string]interface{}{
@@ -189,6 +196,7 @@ func (p *PluginSearch) OnEvent(ctx context.Context,
}
wgExp.Wait()
if len(expResults) > 0 {
// Scores already normalized in HybridSearch
pipelineInfo(ctx, "Search", "expansion_done", map[string]interface{}{
"added": len(expResults),
})
@@ -215,6 +223,16 @@ func (p *PluginSearch) OnEvent(ctx context.Context,
"after": len(chatManage.SearchResult),
})
// Log final scores after all processing
for i, r := range chatManage.SearchResult {
pipelineInfo(ctx, "Search", "final_score", map[string]interface{}{
"index": i,
"chunk_id": r.ID,
"score": fmt.Sprintf("%.4f", r.Score),
"match_type": r.MatchType,
})
}
// Return if we have results
if len(chatManage.SearchResult) != 0 {
pipelineInfo(ctx, "Search", "output", map[string]interface{}{
@@ -250,32 +268,33 @@ func (p *PluginSearch) getSearchResultFromHistory(chatManage *types.ChatManage)
func removeDuplicateResults(results []*types.SearchResult) []*types.SearchResult {
seen := make(map[string]bool)
contentSig := make(map[string]bool)
contentSig := make(map[string]string) // sig -> first chunk ID
var uniqueResults []*types.SearchResult
for _, r := range results {
keys := []string{r.ID}
if r.ParentChunkID != "" {
keys = append(keys, "parent:"+r.ParentChunkID)
}
if r.KnowledgeID != "" {
keys = append(keys, fmt.Sprintf("kb:%s#%d", r.KnowledgeID, r.ChunkIndex))
}
dup := false
dupKey := ""
for _, k := range keys {
if seen[k] {
dup = true
dupKey = k
break
}
}
if dup {
logger.Debugf(context.Background(), "Dedup: chunk %s removed due to key: %s", r.ID, dupKey)
continue
}
sig := buildContentSignature(r.Content)
if sig != "" {
if contentSig[sig] {
if firstChunk, exists := contentSig[sig]; exists {
logger.Debugf(context.Background(), "Dedup: chunk %s removed due to content signature (dup of %s, sig prefix: %.50s...)", r.ID, firstChunk, sig)
continue
}
contentSig[sig] = true
contentSig[sig] = r.ID
}
for _, k := range keys {
seen[k] = true
@@ -286,24 +305,16 @@ func removeDuplicateResults(results []*types.SearchResult) []*types.SearchResult
}
func buildContentSignature(content string) string {
c := strings.ToLower(strings.TrimSpace(content))
if c == "" {
return ""
}
c = strings.Join(strings.Fields(c), " ")
if len(c) > 128 {
c = c[:128]
}
return c
return searchutil.BuildContentSignature(content)
}
// searchKnowledgeBases performs KB searches for rewrite and processed queries across KB IDs
// searchKnowledgeBases performs KB searches across KB IDs using RewriteQuery only
func (p *PluginSearch) searchKnowledgeBases(
ctx context.Context,
knowledgeBaseIDs []string,
chatManage *types.ChatManage,
) []*types.SearchResult {
// Build base params for rewrite query
// Build params for rewrite query
baseParams := types.SearchParams{
QueryText: strings.TrimSpace(chatManage.RewriteQuery),
VectorThreshold: chatManage.VectorThreshold,
@@ -315,7 +326,7 @@ func (p *PluginSearch) searchKnowledgeBases(
var mu sync.Mutex
var results []*types.SearchResult
// Search with rewrite query
// Search with rewrite query only (removed duplicate ProcessedQuery search)
for _, kbID := range knowledgeBaseIDs {
wg.Add(1)
go func(knowledgeBaseID string) {
@@ -326,13 +337,11 @@ func (p *PluginSearch) searchKnowledgeBases(
"kb_id": knowledgeBaseID,
"query": baseParams.QueryText,
"error": err.Error(),
"query_ty": "rewrite",
})
return
}
pipelineInfo(ctx, "Search", "kb_result", map[string]interface{}{
"kb_id": knowledgeBaseID,
"query_ty": "rewrite",
"hit_count": len(res),
})
mu.Lock()
@@ -343,45 +352,6 @@ func (p *PluginSearch) searchKnowledgeBases(
wg.Wait()
// If processed query differs, search again
if chatManage.RewriteQuery != chatManage.ProcessedQuery {
paramsProcessed := baseParams
paramsProcessed.QueryText = strings.TrimSpace(chatManage.ProcessedQuery)
pipelineInfo(ctx, "Search", "processed_query_search", map[string]interface{}{
"query": paramsProcessed.QueryText,
})
wg = sync.WaitGroup{}
for _, kbID := range knowledgeBaseIDs {
wg.Add(1)
go func(knowledgeBaseID string) {
defer wg.Done()
res, err := p.knowledgeBaseService.HybridSearch(ctx, knowledgeBaseID, paramsProcessed)
if err != nil {
pipelineWarn(ctx, "Search", "kb_search_error", map[string]interface{}{
"kb_id": knowledgeBaseID,
"query": paramsProcessed.QueryText,
"error": err.Error(),
"query_ty": "processed",
})
return
}
pipelineInfo(ctx, "Search", "kb_result", map[string]interface{}{
"kb_id": knowledgeBaseID,
"query_ty": "processed",
"hit_count": len(res),
})
mu.Lock()
results = append(results, res...)
mu.Unlock()
}(kbID)
}
wg.Wait()
}
// Normalize keyword retriever scores after collecting all results from multiple knowledge bases
normalizeKeywordSearchResults(ctx, results)
pipelineInfo(ctx, "Search", "kb_result_summary", map[string]interface{}{
"total_hits": len(results),
})
@@ -413,11 +383,8 @@ func (p *PluginSearch) searchWebIfEnabled(ctx context.Context, chatManage *types
})
return nil
}
// Build questions (rewrite + processed if different)
// Build questions using RewriteQuery only
questions := []string{strings.TrimSpace(chatManage.RewriteQuery)}
if chatManage.ProcessedQuery != "" && chatManage.ProcessedQuery != chatManage.RewriteQuery {
questions = append(questions, strings.TrimSpace(chatManage.ProcessedQuery))
}
// Load session-scoped temp KB state from Redis using SessionService
tempKBID, seen, ids := p.sessionService.GetWebSearchTempKBState(ctx, chatManage.SessionID)
compressed, kbID, newSeen, newIDs, err := p.webSearchService.CompressWithRAG(
@@ -440,119 +407,155 @@ func (p *PluginSearch) searchWebIfEnabled(ctx context.Context, chatManage *types
return res
}
// expandQueries generates paraphrases and synonyms using chat model to improve keyword recall
// expandQueries generates query variants locally without LLM to improve keyword recall
// Uses simple techniques: word reordering, stopword removal, key phrase extraction
func (p *PluginSearch) expandQueries(ctx context.Context, chatManage *types.ChatManage) []string {
if p.modelService == nil || chatManage.ChatModelID == "" {
pipelineWarn(ctx, "Search", "expansion_skip", map[string]interface{}{
"reason": "no_model",
})
query := strings.TrimSpace(chatManage.RewriteQuery)
if query == "" {
return nil
}
model, err := p.modelService.GetChatModel(ctx, chatManage.ChatModelID)
if err != nil {
pipelineWarn(ctx, "Search", "expansion_get_model_failed", map[string]interface{}{
"error": err.Error(),
})
return nil
expansions := make([]string, 0, 5)
seen := make(map[string]struct{})
seen[strings.ToLower(query)] = struct{}{}
if q := strings.ToLower(chatManage.Query); q != "" {
seen[q] = struct{}{}
}
sys := "Generate up to 5 diverse paraphrases or keyword variants for the user query to improve keyword-based search recall. Respond ONLY with a JSON array of strings inside a fenced code block."
usr := chatManage.RewriteQuery
think := false
resp, err := model.Chat(ctx, []chat.Message{
{Role: "system", Content: sys},
{Role: "user", Content: usr},
}, &chat.ChatOptions{Temperature: 0.2, MaxCompletionTokens: 200, Thinking: &think})
if err != nil || resp.Content == "" {
pipelineWarn(ctx, "Search", "expansion_model_call_failed", map[string]interface{}{
"error": err,
})
return nil
}
body := extractJSONBlock(resp.Content)
var arr []string
if err := json.Unmarshal([]byte(body), &arr); err != nil || len(arr) == 0 {
// Fallback: split lines
lines := strings.Split(resp.Content, "\n")
for _, l := range lines {
l = strings.TrimSpace(l)
if l != "" {
arr = append(arr, l)
}
}
}
uniq := make(map[string]struct{})
base := []string{chatManage.Query, chatManage.RewriteQuery, chatManage.ProcessedQuery}
for _, b := range base {
if s := strings.TrimSpace(b); s != "" {
uniq[strings.ToLower(s)] = struct{}{}
}
}
expansions := make([]string, 0, len(arr))
for _, a := range arr {
s := strings.TrimSpace(a)
if s == "" {
continue
addIfNew := func(s string) {
s = strings.TrimSpace(s)
if s == "" || len(s) < 3 {
return
}
key := strings.ToLower(s)
if _, ok := uniq[key]; ok {
continue
if _, ok := seen[key]; ok {
return
}
uniq[key] = struct{}{}
seen[key] = struct{}{}
expansions = append(expansions, s)
if len(expansions) >= 5 {
break
}
// 1. Remove common stopwords and create keyword-only variant
keywords := extractKeywords(query)
if len(keywords) >= 2 {
addIfNew(strings.Join(keywords, " "))
}
// 2. Extract quoted phrases or key segments
phrases := extractPhrases(query)
for _, phrase := range phrases {
addIfNew(phrase)
}
// 3. Split by common delimiters and use longest segment
segments := splitByDelimiters(query)
for _, seg := range segments {
if len(seg) > 5 {
addIfNew(seg)
}
}
pipelineInfo(ctx, "Search", "expansion_result", map[string]interface{}{
// 4. Remove question words (什么/如何/怎么/为什么/哪个 etc.)
cleaned := removeQuestionWords(query)
if cleaned != query {
addIfNew(cleaned)
}
// Limit to 5 expansions
if len(expansions) > 5 {
expansions = expansions[:5]
}
pipelineInfo(ctx, "Search", "local_expansion_result", map[string]interface{}{
"variants": len(expansions),
})
return expansions
}
func extractJSONBlock(text string) string {
t := strings.TrimSpace(text)
if i := strings.Index(t, "["); i >= 0 {
j := strings.LastIndex(t, "]")
if j > i {
return t[i : j+1]
}
}
return "[]"
// Common Chinese and English stopwords
var stopwords = map[string]struct{}{
"的": {}, "是": {}, "在": {}, "了": {}, "": {}, "与": {}, "或": {},
"a": {}, "an": {}, "the": {}, "is": {}, "are": {}, "was": {}, "were": {},
"be": {}, "been": {}, "being": {}, "have": {}, "has": {}, "had": {},
"do": {}, "does": {}, "did": {}, "will": {}, "would": {}, "could": {},
"should": {}, "may": {}, "might": {}, "must": {}, "can": {},
"to": {}, "of": {}, "in": {}, "for": {}, "on": {}, "with": {}, "at": {},
"by": {}, "from": {}, "as": {}, "into": {}, "through": {}, "about": {},
"what": {}, "how": {}, "why": {}, "when": {}, "where": {}, "which": {},
"who": {}, "whom": {}, "whose": {},
}
// normalizeKeywordSearchResults normalizes keyword search result scores into [0,1] globally across all knowledge bases
// Improvements:
// 1. Uses robust normalization with percentile-based bounds to handle outliers
// 2. Handles edge cases: single result, no variance, negative scores
// 3. Global normalization ensures fair comparison across different knowledge bases
func normalizeKeywordSearchResults(ctx context.Context, results []*types.SearchResult) {
searchutil.NormalizeKeywordScores[*types.SearchResult](
results,
func(r *types.SearchResult) bool {
return r.MatchType == types.MatchTypeKeywords
},
func(r *types.SearchResult) float64 {
return r.Score
},
func(r *types.SearchResult, score float64) {
r.Score = score
},
searchutil.KeywordScoreCallbacks{
OnNoVariance: func(count int, score float64) {
pipelineInfo(ctx, "Search", "keyword_scores_no_variance", map[string]interface{}{
"count": count,
"score": score,
})
},
OnNormalized: func(count int, rawMin, rawMax, normalizeMin, normalizeMax float64) {
pipelineInfo(ctx, "Search", "normalize_keyword_scores", map[string]interface{}{
"count": count,
"raw_min": rawMin,
"raw_max": rawMax,
"normalize_min": normalizeMin,
"normalize_max": normalizeMax,
})
},
},
)
// Question words in Chinese
var questionWords = regexp.MustCompile(`^(什么是|什么|如何|怎么|怎样|为什么|为何|哪个|哪些|谁|何时|何地|请问|请告诉我|帮我|我想知道|我想了解)`)
func extractKeywords(text string) []string {
words := tokenize(text)
keywords := make([]string, 0, len(words))
for _, w := range words {
lower := strings.ToLower(w)
if _, isStop := stopwords[lower]; !isStop && len(w) > 1 {
keywords = append(keywords, w)
}
}
return keywords
}
func extractPhrases(text string) []string {
// Extract quoted content
var phrases []string
re := regexp.MustCompile(`["'"'「」『』]([^"'"'「」『』]+)["'"'「」『』]`)
matches := re.FindAllStringSubmatch(text, -1)
for _, m := range matches {
if len(m) > 1 && len(m[1]) > 2 {
phrases = append(phrases, m[1])
}
}
return phrases
}
func splitByDelimiters(text string) []string {
// Split by common delimiters
re := regexp.MustCompile(`[,;;、。!?!?\s]+`)
parts := re.Split(text, -1)
var result []string
for _, p := range parts {
p = strings.TrimSpace(p)
if p != "" {
result = append(result, p)
}
}
return result
}
func removeQuestionWords(text string) string {
return strings.TrimSpace(questionWords.ReplaceAllString(text, ""))
}
func tokenize(text string) []string {
var tokens []string
var current strings.Builder
for _, r := range text {
if unicode.IsLetter(r) || unicode.IsDigit(r) {
current.WriteRune(r)
} else if unicode.Is(unicode.Han, r) {
// Flush current token
if current.Len() > 0 {
tokens = append(tokens, current.String())
current.Reset()
}
// Chinese character as single token
tokens = append(tokens, string(r))
} else {
// Delimiter
if current.Len() > 0 {
tokens = append(tokens, current.String())
current.Reset()
}
}
}
if current.Len() > 0 {
tokens = append(tokens, current.String())
}
return tokens
}

View File

@@ -190,5 +190,6 @@ func chunk2SearchResult(chunk *types.Chunk, knowledge *types.Knowledge) *types.S
ImageInfo: chunk.ImageInfo,
KnowledgeFilename: knowledge.FileName,
KnowledgeSource: knowledge.Source,
ChunkMetadata: chunk.Metadata,
}
}

View File

@@ -0,0 +1,181 @@
package chatpipline
import (
"context"
"sync"
"github.com/Tencent/WeKnora/internal/config"
"github.com/Tencent/WeKnora/internal/logger"
"github.com/Tencent/WeKnora/internal/types"
"github.com/Tencent/WeKnora/internal/types/interfaces"
)
// PluginSearchParallel implements parallel search functionality combining chunk search and entity search
type PluginSearchParallel struct {
// Chunk search dependencies
knowledgeBaseService interfaces.KnowledgeBaseService
knowledgeService interfaces.KnowledgeService
config *config.Config
webSearchService interfaces.WebSearchService
tenantService interfaces.TenantService
sessionService interfaces.SessionService
// Entity search dependencies
graphRepo interfaces.RetrieveGraphRepository
chunkRepo interfaces.ChunkRepository
knowledgeRepo interfaces.KnowledgeRepository
// Internal plugins
searchPlugin *PluginSearch
searchEntityPlugin *PluginSearchEntity
}
// NewPluginSearchParallel creates a new parallel search plugin
func NewPluginSearchParallel(
eventManager *EventManager,
knowledgeBaseService interfaces.KnowledgeBaseService,
knowledgeService interfaces.KnowledgeService,
config *config.Config,
webSearchService interfaces.WebSearchService,
tenantService interfaces.TenantService,
sessionService interfaces.SessionService,
graphRepository interfaces.RetrieveGraphRepository,
chunkRepository interfaces.ChunkRepository,
knowledgeRepository interfaces.KnowledgeRepository,
) *PluginSearchParallel {
// Create internal plugins without registering them
searchPlugin := &PluginSearch{
knowledgeBaseService: knowledgeBaseService,
knowledgeService: knowledgeService,
config: config,
webSearchService: webSearchService,
tenantService: tenantService,
sessionService: sessionService,
}
searchEntityPlugin := &PluginSearchEntity{
graphRepo: graphRepository,
chunkRepo: chunkRepository,
knowledgeRepo: knowledgeRepository,
}
res := &PluginSearchParallel{
knowledgeBaseService: knowledgeBaseService,
knowledgeService: knowledgeService,
config: config,
webSearchService: webSearchService,
tenantService: tenantService,
sessionService: sessionService,
graphRepo: graphRepository,
chunkRepo: chunkRepository,
knowledgeRepo: knowledgeRepository,
searchPlugin: searchPlugin,
searchEntityPlugin: searchEntityPlugin,
}
eventManager.Register(res)
return res
}
// ActivationEvents returns the event types this plugin handles
func (p *PluginSearchParallel) ActivationEvents() []types.EventType {
return []types.EventType{types.CHUNK_SEARCH_PARALLEL}
}
// OnEvent handles parallel search events - runs chunk search and entity search concurrently
func (p *PluginSearchParallel) OnEvent(ctx context.Context,
eventType types.EventType, chatManage *types.ChatManage, next func() *PluginError,
) *PluginError {
pipelineInfo(ctx, "SearchParallel", "start", map[string]interface{}{
"session_id": chatManage.SessionID,
"has_entities": len(chatManage.Entity) > 0,
"rewrite_query": chatManage.RewriteQuery,
})
var wg sync.WaitGroup
var mu sync.Mutex
var chunkSearchErr *PluginError
var entitySearchErr *PluginError
// Use separate ChatManage copies to avoid concurrent write conflicts
chunkChatManage := *chatManage
chunkChatManage.SearchResult = nil
entityChatManage := *chatManage
entityChatManage.SearchResult = nil
// Run chunk search and entity search in parallel
wg.Add(2)
// Goroutine 1: Chunk Search
go func() {
defer wg.Done()
err := p.searchPlugin.OnEvent(ctx, types.CHUNK_SEARCH, &chunkChatManage, func() *PluginError {
return nil
})
if err != nil && err != ErrSearchNothing {
mu.Lock()
chunkSearchErr = err
mu.Unlock()
}
pipelineInfo(ctx, "SearchParallel", "chunk_search_done", map[string]interface{}{
"result_count": len(chunkChatManage.SearchResult),
"has_error": err != nil && err != ErrSearchNothing,
})
}()
// Goroutine 2: Entity Search (only if entities are available)
go func() {
defer wg.Done()
if len(chatManage.Entity) == 0 {
pipelineInfo(ctx, "SearchParallel", "entity_search_skip", map[string]interface{}{
"reason": "no_entities",
})
return
}
err := p.searchEntityPlugin.OnEvent(ctx, types.ENTITY_SEARCH, &entityChatManage, func() *PluginError {
return nil
})
if err != nil && err != ErrSearchNothing {
mu.Lock()
entitySearchErr = err
mu.Unlock()
}
pipelineInfo(ctx, "SearchParallel", "entity_search_done", map[string]interface{}{
"result_count": len(entityChatManage.SearchResult),
"has_error": err != nil && err != ErrSearchNothing,
})
}()
wg.Wait()
// Merge results from both searches (no concurrent access now)
chatManage.SearchResult = append(chunkChatManage.SearchResult, entityChatManage.SearchResult...)
chatManage.SearchResult = removeDuplicateResults(chatManage.SearchResult)
// Log any errors but don't fail the pipeline if at least one search succeeded
if chunkSearchErr != nil {
logger.Warnf(ctx, "[SearchParallel] Chunk search error: %v", chunkSearchErr.Err)
}
if entitySearchErr != nil {
logger.Warnf(ctx, "[SearchParallel] Entity search error: %v", entitySearchErr.Err)
}
pipelineInfo(ctx, "SearchParallel", "complete", map[string]interface{}{
"session_id": chatManage.SessionID,
"chunk_results": len(chunkChatManage.SearchResult),
"entity_results": len(entityChatManage.SearchResult),
"total_results": len(chatManage.SearchResult),
"chunk_search_error": chunkSearchErr != nil,
"entity_search_error": entitySearchErr != nil,
})
// Return error only if both searches failed and we have no results
if len(chatManage.SearchResult) == 0 {
if chunkSearchErr != nil {
return chunkSearchErr
}
return ErrSearchNothing
}
return next()
}

View File

@@ -34,7 +34,7 @@ func (p *PluginTracing) ActivationEvents() []types.EventType {
types.CHAT_COMPLETION_STREAM,
types.FILTER_TOP_K,
types.REWRITE_QUERY,
types.PREPROCESS_QUERY,
types.CHUNK_SEARCH_PARALLEL,
}
}
@@ -69,8 +69,8 @@ func (p *PluginTracing) OnEvent(ctx context.Context,
return p.FilterTopK(ctx, eventType, chatManage, next)
case types.REWRITE_QUERY:
return p.RewriteQuery(ctx, eventType, chatManage, next)
case types.PREPROCESS_QUERY:
return p.PreprocessQuery(ctx, eventType, chatManage, next)
case types.CHUNK_SEARCH_PARALLEL:
return p.SearchParallel(ctx, eventType, chatManage, next)
}
return next()
}
@@ -95,7 +95,6 @@ func (p *PluginTracing) Search(ctx context.Context,
}
span.SetAttributes(
attribute.String("hybrid_search", string(searchResultJson)),
attribute.String("processed_query", chatManage.ProcessedQuery),
attribute.Int("search_unique_count", len(unique)),
)
return err
@@ -119,7 +118,6 @@ func (p *PluginTracing) Rerank(ctx context.Context,
span.SetAttributes(
attribute.Int("rerank_resp_count", len(chatManage.RerankResult)),
attribute.String("rerank_resp_results", string(resultJson)),
attribute.String("query_intent", chatManage.QueryIntent),
)
return err
}
@@ -266,22 +264,20 @@ func (p *PluginTracing) RewriteQuery(ctx context.Context,
return err
}
// PreprocessQuery traces query preprocessing operations
func (p *PluginTracing) PreprocessQuery(ctx context.Context,
// SearchParallel traces parallel search operations (chunk + entity)
func (p *PluginTracing) SearchParallel(ctx context.Context,
eventType types.EventType, chatManage *types.ChatManage, next func() *PluginError,
) *PluginError {
_, span := tracing.ContextWithSpan(ctx, "PluginTracing.PreprocessQuery")
_, span := tracing.ContextWithSpan(ctx, "PluginTracing.SearchParallel")
defer span.End()
span.SetAttributes(
attribute.String("query", chatManage.Query),
attribute.String("rewrite_query", chatManage.RewriteQuery),
attribute.Int("entity_count", len(chatManage.Entity)),
)
err := next()
span.SetAttributes(
attribute.String("processed_query", chatManage.ProcessedQuery),
attribute.Int("search_result_count", len(chatManage.SearchResult)),
)
return err
}

View File

@@ -9,7 +9,6 @@ import (
"time"
"github.com/Tencent/WeKnora/internal/application/service/retriever"
"github.com/Tencent/WeKnora/internal/common"
"github.com/Tencent/WeKnora/internal/logger"
"github.com/Tencent/WeKnora/internal/models/embedding"
"github.com/Tencent/WeKnora/internal/types"
@@ -523,35 +522,113 @@ func (s *knowledgeBaseService) HybridSearch(ctx context.Context,
// Collect all results from different retrievers and deduplicate by chunk ID
logger.Infof(ctx, "Processing retrieval results")
matchResults := []*types.IndexWithScore{}
// Separate results by retriever type for RRF fusion
var vectorResults []*types.IndexWithScore
var keywordResults []*types.IndexWithScore
for _, retrieveResult := range retrieveResults {
logger.Infof(ctx, "Retrieval results, engine: %v, retriever: %v, count: %v",
retrieveResult.RetrieverEngineType,
retrieveResult.RetrieverType,
len(retrieveResult.Results),
)
matchResults = append(matchResults, retrieveResult.Results...)
if retrieveResult.RetrieverType == types.VectorRetrieverType {
vectorResults = append(vectorResults, retrieveResult.Results...)
} else {
keywordResults = append(keywordResults, retrieveResult.Results...)
}
}
// Early return if no results
if len(matchResults) == 0 {
if len(vectorResults) == 0 && len(keywordResults) == 0 {
logger.Info(ctx, "No search results found")
return nil, nil
}
logger.Infof(ctx, "Result count before deduplication: %d", len(matchResults))
logger.Infof(ctx, "Result count before RRF fusion: vector=%d, keyword=%d", len(vectorResults), len(keywordResults))
// First, try standard deduplication
deduplicatedChunks := common.DeduplicateWithScore(
func(r *types.IndexWithScore) string { return r.ChunkID },
matchResults...)
logger.Infof(ctx, "Result count after deduplication: %d", len(deduplicatedChunks))
// Use RRF (Reciprocal Rank Fusion) to merge results
// RRF score = sum(1 / (k + rank)) for each retriever where the chunk appears
// k=60 is a common choice that works well in practice
const rrfK = 60
// Build rank maps for each retriever (already sorted by score from retriever)
vectorRanks := make(map[string]int)
for i, r := range vectorResults {
if _, exists := vectorRanks[r.ChunkID]; !exists {
vectorRanks[r.ChunkID] = i + 1 // 1-indexed rank
}
}
keywordRanks := make(map[string]int)
for i, r := range keywordResults {
if _, exists := keywordRanks[r.ChunkID]; !exists {
keywordRanks[r.ChunkID] = i + 1 // 1-indexed rank
}
}
// Collect all unique chunks and compute RRF scores
chunkInfoMap := make(map[string]*types.IndexWithScore)
rrfScores := make(map[string]float64)
// Process vector results
for _, r := range vectorResults {
if _, exists := chunkInfoMap[r.ChunkID]; !exists {
chunkInfoMap[r.ChunkID] = r
}
}
// Process keyword results
for _, r := range keywordResults {
if _, exists := chunkInfoMap[r.ChunkID]; !exists {
chunkInfoMap[r.ChunkID] = r
}
}
// Compute RRF scores
for chunkID := range chunkInfoMap {
rrfScore := 0.0
if rank, ok := vectorRanks[chunkID]; ok {
rrfScore += 1.0 / float64(rrfK+rank)
}
if rank, ok := keywordRanks[chunkID]; ok {
rrfScore += 1.0 / float64(rrfK+rank)
}
rrfScores[chunkID] = rrfScore
}
// Convert to slice and sort by RRF score
deduplicatedChunks := make([]*types.IndexWithScore, 0, len(chunkInfoMap))
for chunkID, info := range chunkInfoMap {
// Store RRF score in the Score field for downstream processing
info.Score = rrfScores[chunkID]
deduplicatedChunks = append(deduplicatedChunks, info)
}
slices.SortFunc(deduplicatedChunks, func(a, b *types.IndexWithScore) int {
if a.Score > b.Score {
return -1
} else if a.Score < b.Score {
return 1
}
return 0
})
logger.Infof(ctx, "Result count after RRF fusion: %d", len(deduplicatedChunks))
// Log top results after RRF fusion for debugging
for i, chunk := range deduplicatedChunks {
if i < 15 {
vRank, vOk := vectorRanks[chunk.ChunkID]
kRank, kOk := keywordRanks[chunk.ChunkID]
logger.Debugf(ctx, "RRF rank %d: chunk_id=%s, rrf_score=%.6f, vector_rank=%v(%v), keyword_rank=%v(%v)",
i, chunk.ChunkID, chunk.Score, vRank, vOk, kRank, kOk)
}
}
kb.EnsureDefaults()
// Check if we need iterative retrieval for FAQ with separate indexing
// Only use iterative retrieval if we don't have enough unique chunks after first deduplication
totalRetrieved := len(vectorResults) + len(keywordResults)
needsIterativeRetrieval := len(deduplicatedChunks) < params.MatchCount &&
kb.Type == types.KnowledgeBaseTypeFAQ && len(matchResults) == matchCount
kb.Type == types.KnowledgeBaseTypeFAQ && totalRetrieved == matchCount*2
if needsIterativeRetrieval {
logger.Info(ctx, "Not enough unique chunks, using iterative retrieval for FAQ")
// Use iterative retrieval to get more unique chunks (with negative question filtering inside)
@@ -891,10 +968,38 @@ func (s *knowledgeBaseService) processSearchResults(ctx context.Context,
}
}
// Build final search results
// Build final search results - preserve original order from input chunks
var searchResults []*types.SearchResult
for chunkID, chunk := range chunkMap {
addedChunkIDs := make(map[string]bool)
// First pass: Add results in the original order from input chunks
for _, inputChunk := range chunks {
chunk, exists := chunkMap[inputChunk.ChunkID]
if !exists {
logger.Debugf(ctx, "Chunk not found in chunkMap: %s", inputChunk.ChunkID)
continue
}
if !s.isValidTextChunk(chunk) {
logger.Debugf(ctx, "Chunk is not valid text chunk: %s, type: %s", chunk.ID, chunk.ChunkType)
continue
}
if addedChunkIDs[chunk.ID] {
continue
}
score := chunkScores[chunk.ID]
if knowledge, ok := knowledgeMap[chunk.KnowledgeID]; ok {
matchType := chunkMatchTypes[chunk.ID]
searchResults = append(searchResults, s.buildSearchResult(chunk, knowledge, score, matchType))
addedChunkIDs[chunk.ID] = true
} else {
logger.Warnf(ctx, "Knowledge not found for chunk: %s, knowledge_id: %s", chunk.ID, chunk.KnowledgeID)
}
}
// Second pass: Add additional chunks (parent, nearby, relation) that weren't in original input
for chunkID, chunk := range chunkMap {
if addedChunkIDs[chunkID] || !s.isValidTextChunk(chunk) {
continue
}
@@ -959,6 +1064,7 @@ func (s *knowledgeBaseService) buildSearchResult(chunk *types.Chunk,
ImageInfo: chunk.ImageInfo,
KnowledgeFilename: knowledge.FileName,
KnowledgeSource: knowledge.Source,
ChunkMetadata: chunk.Metadata,
}
}

View File

@@ -800,7 +800,6 @@ func (s *sessionService) SearchKnowledge(ctx context.Context,
// Use specific event list, only including retrieval-related events, not LLM summarization
searchEvents := []types.EventType{
types.PREPROCESS_QUERY, // Preprocess query
types.CHUNK_SEARCH, // Vector search
types.CHUNK_RERANK, // Rerank search results
types.CHUNK_MERGE, // Merge search results

View File

@@ -140,10 +140,10 @@ func BuildContainer(container *dig.Container) *dig.Container {
must(container.Invoke(chatpipline.NewPluginChatCompletionStream))
must(container.Invoke(chatpipline.NewPluginStreamFilter))
must(container.Invoke(chatpipline.NewPluginFilterTopK))
must(container.Invoke(chatpipline.NewPluginPreprocess))
must(container.Invoke(chatpipline.NewPluginRewrite))
must(container.Invoke(chatpipline.NewPluginExtractEntity))
must(container.Invoke(chatpipline.NewPluginSearchEntity))
must(container.Invoke(chatpipline.NewPluginSearchParallel))
// HTTP handlers layer
must(container.Provide(handler.NewTenantHandler))

View File

@@ -5,7 +5,6 @@ package event
// QueryData represents query-related event data
type QueryData struct {
OriginalQuery string `json:"original_query"`
ProcessedQuery string `json:"processed_query,omitempty"`
RewrittenQuery string `json:"rewritten_query,omitempty"`
SessionID string `json:"session_id"`
UserID string `json:"user_id,omitempty"`

View File

@@ -0,0 +1,70 @@
package searchutil
import (
"crypto/md5"
"encoding/hex"
"strings"
)
// BuildContentSignature creates a normalized MD5 signature for content to detect duplicates.
// It normalizes the content by lowercasing, trimming whitespace, and collapsing multiple spaces.
func BuildContentSignature(content string) string {
c := strings.ToLower(strings.TrimSpace(content))
if c == "" {
return ""
}
// Normalize whitespace
c = strings.Join(strings.Fields(c), " ")
// Use MD5 hash of full content
hash := md5.Sum([]byte(c))
return hex.EncodeToString(hash[:])
}
// TokenizeSimple tokenizes text into a set of words (simple whitespace-based).
// Returns a map where keys are lowercase tokens with length > 1.
func TokenizeSimple(text string) map[string]struct{} {
text = strings.ToLower(text)
fields := strings.Fields(text)
set := make(map[string]struct{}, len(fields))
for _, f := range fields {
if len(f) > 1 {
set[f] = struct{}{}
}
}
return set
}
// Jaccard calculates Jaccard similarity between two token sets.
// Returns a value between 0 and 1, where 1 means identical sets.
func Jaccard(a, b map[string]struct{}) float64 {
if len(a) == 0 && len(b) == 0 {
return 0
}
// Calculate intersection
inter := 0
for k := range a {
if _, ok := b[k]; ok {
inter++
}
}
// Calculate union
union := len(a) + len(b) - inter
if union == 0 {
return 0
}
return float64(inter) / float64(union)
}
// ClampFloat clamps a float value to the specified range [minV, maxV].
func ClampFloat(v, minV, maxV float64) float64 {
if v < minV {
return minV
}
if v > maxV {
return maxV
}
return v
}

View File

@@ -5,9 +5,7 @@ package types
type ChatManage struct {
SessionID string `json:"session_id"` // Unique identifier for the chat session
Query string `json:"query,omitempty"` // Original user query
ProcessedQuery string `json:"processed_query,omitempty"` // Query after preprocessing
RewriteQuery string `json:"rewrite_query,omitempty"` // Query after rewriting for better retrieval
QueryIntent string `json:"query_intent,omitempty"` // Parsed intent: definition/howto/compare/qa/general
History []*History `json:"history,omitempty"` // Chat history for context
KnowledgeBaseID string `json:"knowledge_base_id"` // ID of the knowledge base to search against (deprecated, use KnowledgeBaseIDs)
@@ -60,9 +58,7 @@ func (c *ChatManage) Clone() *ChatManage {
return &ChatManage{
Query: c.Query,
ProcessedQuery: c.ProcessedQuery,
RewriteQuery: c.RewriteQuery,
QueryIntent: c.QueryIntent,
SessionID: c.SessionID,
KnowledgeBaseID: c.KnowledgeBaseID,
KnowledgeBaseIDs: knowledgeBaseIDs,
@@ -103,9 +99,9 @@ func (c *ChatManage) Clone() *ChatManage {
type EventType string
const (
PREPROCESS_QUERY EventType = "preprocess_query" // Query preprocessing stage
REWRITE_QUERY EventType = "rewrite_query" // Query rewriting for better retrieval
CHUNK_SEARCH EventType = "chunk_search" // Search for relevant chunks
CHUNK_SEARCH_PARALLEL EventType = "chunk_search_parallel" // Parallel search: chunks + entities
ENTITY_SEARCH EventType = "entity_search" // Search for relevant entities
CHUNK_RERANK EventType = "chunk_rerank" // Rerank search results
CHUNK_MERGE EventType = "chunk_merge" // Merge similar chunks
@@ -134,9 +130,7 @@ var Pipline = map[string][]EventType{
},
"rag_stream": { // Streaming Retrieval Augmented Generation
REWRITE_QUERY,
PREPROCESS_QUERY,
CHUNK_SEARCH,
ENTITY_SEARCH,
CHUNK_SEARCH_PARALLEL, // Parallel: CHUNK_SEARCH + ENTITY_SEARCH
CHUNK_RERANK,
CHUNK_MERGE,
FILTER_TOP_K,

View File

@@ -46,6 +46,9 @@ type SearchResult struct {
// Knowledge source
// Used to indicate the source of the knowledge, such as "url"
KnowledgeSource string `json:"knowledge_source"`
// ChunkMetadata stores chunk-level metadata (e.g., generated questions)
ChunkMetadata JSON `json:"chunk_metadata,omitempty"`
}
// SearchParams represents the search parameters