mirror of
https://github.com/Tencent/WeKnora.git
synced 2026-06-04 13:30:32 +08:00
feat: Implement parallel search functionality combining chunk and entity searches, enhancing retrieval efficiency and result accuracy
This commit is contained in:
6
go.mod
6
go.mod
@@ -29,7 +29,6 @@ require (
|
|||||||
github.com/spf13/viper v1.20.1
|
github.com/spf13/viper v1.20.1
|
||||||
github.com/stretchr/testify v1.11.1
|
github.com/stretchr/testify v1.11.1
|
||||||
github.com/tencentyun/cos-go-sdk-v5 v0.7.65
|
github.com/tencentyun/cos-go-sdk-v5 v0.7.65
|
||||||
github.com/xuri/excelize/v2 v2.10.0
|
|
||||||
github.com/yanyiwu/gojieba v1.4.5
|
github.com/yanyiwu/gojieba v1.4.5
|
||||||
go.opentelemetry.io/otel v1.37.0
|
go.opentelemetry.io/otel v1.37.0
|
||||||
go.opentelemetry.io/otel/exporters/otlp/otlptrace v1.37.0
|
go.opentelemetry.io/otel/exporters/otlp/otlptrace v1.37.0
|
||||||
@@ -106,8 +105,6 @@ require (
|
|||||||
github.com/pelletier/go-toml/v2 v2.2.3 // indirect
|
github.com/pelletier/go-toml/v2 v2.2.3 // indirect
|
||||||
github.com/pierrec/lz4/v4 v4.1.21 // indirect
|
github.com/pierrec/lz4/v4 v4.1.21 // indirect
|
||||||
github.com/pmezard/go-difflib v1.0.0 // indirect
|
github.com/pmezard/go-difflib v1.0.0 // indirect
|
||||||
github.com/richardlehane/mscfb v1.0.4 // indirect
|
|
||||||
github.com/richardlehane/msoleps v1.0.4 // indirect
|
|
||||||
github.com/rivo/uniseg v0.4.7 // indirect
|
github.com/rivo/uniseg v0.4.7 // indirect
|
||||||
github.com/robfig/cron/v3 v3.0.1 // indirect
|
github.com/robfig/cron/v3 v3.0.1 // indirect
|
||||||
github.com/rs/xid v1.6.0 // indirect
|
github.com/rs/xid v1.6.0 // indirect
|
||||||
@@ -117,12 +114,9 @@ require (
|
|||||||
github.com/spf13/cast v1.10.0 // indirect
|
github.com/spf13/cast v1.10.0 // indirect
|
||||||
github.com/spf13/pflag v1.0.6 // indirect
|
github.com/spf13/pflag v1.0.6 // indirect
|
||||||
github.com/subosito/gotenv v1.6.0 // indirect
|
github.com/subosito/gotenv v1.6.0 // indirect
|
||||||
github.com/tiendc/go-deepcopy v1.7.1 // indirect
|
|
||||||
github.com/twitchyliquid64/golang-asm v0.15.1 // indirect
|
github.com/twitchyliquid64/golang-asm v0.15.1 // indirect
|
||||||
github.com/ugorji/go/codec v1.2.12 // indirect
|
github.com/ugorji/go/codec v1.2.12 // indirect
|
||||||
github.com/wk8/go-ordered-map/v2 v2.1.8 // indirect
|
github.com/wk8/go-ordered-map/v2 v2.1.8 // indirect
|
||||||
github.com/xuri/efp v0.0.1 // indirect
|
|
||||||
github.com/xuri/nfp v0.0.2-0.20250530014748-2ddeb826f9a9 // indirect
|
|
||||||
github.com/yosida95/uritemplate/v3 v3.0.2 // indirect
|
github.com/yosida95/uritemplate/v3 v3.0.2 // indirect
|
||||||
go.opentelemetry.io/auto/sdk v1.1.0 // indirect
|
go.opentelemetry.io/auto/sdk v1.1.0 // indirect
|
||||||
go.opentelemetry.io/otel/metric v1.37.0 // indirect
|
go.opentelemetry.io/otel/metric v1.37.0 // indirect
|
||||||
|
|||||||
15
go.sum
15
go.sum
@@ -235,11 +235,6 @@ github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZb
|
|||||||
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
|
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
|
||||||
github.com/redis/go-redis/v9 v9.14.0 h1:u4tNCjXOyzfgeLN+vAZaW1xUooqWDqVEsZN0U01jfAE=
|
github.com/redis/go-redis/v9 v9.14.0 h1:u4tNCjXOyzfgeLN+vAZaW1xUooqWDqVEsZN0U01jfAE=
|
||||||
github.com/redis/go-redis/v9 v9.14.0/go.mod h1:huWgSWd8mW6+m0VPhJjSSQ+d6Nh1VICQ6Q5lHuCH/Iw=
|
github.com/redis/go-redis/v9 v9.14.0/go.mod h1:huWgSWd8mW6+m0VPhJjSSQ+d6Nh1VICQ6Q5lHuCH/Iw=
|
||||||
github.com/richardlehane/mscfb v1.0.4 h1:WULscsljNPConisD5hR0+OyZjwK46Pfyr6mPu5ZawpM=
|
|
||||||
github.com/richardlehane/mscfb v1.0.4/go.mod h1:YzVpcZg9czvAuhk9T+a3avCpcFPMUWm7gK3DypaEsUk=
|
|
||||||
github.com/richardlehane/msoleps v1.0.1/go.mod h1:BWev5JBpU9Ko2WAgmZEuiz4/u3ZYTKbjLycmwiWUfWg=
|
|
||||||
github.com/richardlehane/msoleps v1.0.4 h1:WuESlvhX3gH2IHcd8UqyCuFY5yiq/GR/yqaSM/9/g00=
|
|
||||||
github.com/richardlehane/msoleps v1.0.4/go.mod h1:BWev5JBpU9Ko2WAgmZEuiz4/u3ZYTKbjLycmwiWUfWg=
|
|
||||||
github.com/rivo/uniseg v0.2.0/go.mod h1:J6wj4VEh+S6ZtnVlnTBMWIodfgj8LQOQFoIToxlJtxc=
|
github.com/rivo/uniseg v0.2.0/go.mod h1:J6wj4VEh+S6ZtnVlnTBMWIodfgj8LQOQFoIToxlJtxc=
|
||||||
github.com/rivo/uniseg v0.4.7 h1:WUdvkW8uEhrYfLC4ZzdpI2ztxP1I582+49Oc5Mq64VQ=
|
github.com/rivo/uniseg v0.4.7 h1:WUdvkW8uEhrYfLC4ZzdpI2ztxP1I582+49Oc5Mq64VQ=
|
||||||
github.com/rivo/uniseg v0.4.7/go.mod h1:FN3SvrM+Zdj16jyLfmOkMNblXMcoc8DfTHruCPUcx88=
|
github.com/rivo/uniseg v0.4.7/go.mod h1:FN3SvrM+Zdj16jyLfmOkMNblXMcoc8DfTHruCPUcx88=
|
||||||
@@ -284,8 +279,6 @@ github.com/tencentcloud/tencentcloud-sdk-go/tencentcloud/common v1.0.563/go.mod
|
|||||||
github.com/tencentcloud/tencentcloud-sdk-go/tencentcloud/kms v1.0.563/go.mod h1:uom4Nvi9W+Qkom0exYiJ9VWJjXwyxtPYTkKkaLMlfE0=
|
github.com/tencentcloud/tencentcloud-sdk-go/tencentcloud/kms v1.0.563/go.mod h1:uom4Nvi9W+Qkom0exYiJ9VWJjXwyxtPYTkKkaLMlfE0=
|
||||||
github.com/tencentyun/cos-go-sdk-v5 v0.7.65 h1:+WBbfwThfZSbxpf1Dw6fyMwyzVtWBBExqfDJ5giiR2s=
|
github.com/tencentyun/cos-go-sdk-v5 v0.7.65 h1:+WBbfwThfZSbxpf1Dw6fyMwyzVtWBBExqfDJ5giiR2s=
|
||||||
github.com/tencentyun/cos-go-sdk-v5 v0.7.65/go.mod h1:8+hG+mQMuRP/OIS9d83syAvXvrMj9HhkND6Q1fLghw0=
|
github.com/tencentyun/cos-go-sdk-v5 v0.7.65/go.mod h1:8+hG+mQMuRP/OIS9d83syAvXvrMj9HhkND6Q1fLghw0=
|
||||||
github.com/tiendc/go-deepcopy v1.7.1 h1:LnubftI6nYaaMOcaz0LphzwraqN8jiWTwm416sitff4=
|
|
||||||
github.com/tiendc/go-deepcopy v1.7.1/go.mod h1:4bKjNC2r7boYOkD2IOuZpYjmlDdzjbpTRyCx+goBCJQ=
|
|
||||||
github.com/tmthrgd/go-hex v0.0.0-20190904060850-447a3041c3bc h1:9lRDQMhESg+zvGYmW5DyG0UqvY96Bu5QYsTLvCHdrgo=
|
github.com/tmthrgd/go-hex v0.0.0-20190904060850-447a3041c3bc h1:9lRDQMhESg+zvGYmW5DyG0UqvY96Bu5QYsTLvCHdrgo=
|
||||||
github.com/tmthrgd/go-hex v0.0.0-20190904060850-447a3041c3bc/go.mod h1:bciPuU6GHm1iF1pBvUfxfsH0Wmnc2VbpgvbI9ZWuIRs=
|
github.com/tmthrgd/go-hex v0.0.0-20190904060850-447a3041c3bc/go.mod h1:bciPuU6GHm1iF1pBvUfxfsH0Wmnc2VbpgvbI9ZWuIRs=
|
||||||
github.com/twitchyliquid64/golang-asm v0.15.1 h1:SU5vSMR7hnwNxj24w34ZyCi/FmDZTkS4MhqMhdFk5YI=
|
github.com/twitchyliquid64/golang-asm v0.15.1 h1:SU5vSMR7hnwNxj24w34ZyCi/FmDZTkS4MhqMhdFk5YI=
|
||||||
@@ -310,12 +303,6 @@ github.com/wk8/go-ordered-map/v2 v2.1.8 h1:5h/BUHu93oj4gIdvHHHGsScSTMijfx5PeYkE/
|
|||||||
github.com/wk8/go-ordered-map/v2 v2.1.8/go.mod h1:5nJHM5DyteebpVlHnWMV0rPz6Zp7+xBAnxjb1X5vnTw=
|
github.com/wk8/go-ordered-map/v2 v2.1.8/go.mod h1:5nJHM5DyteebpVlHnWMV0rPz6Zp7+xBAnxjb1X5vnTw=
|
||||||
github.com/x448/float16 v0.8.4 h1:qLwI1I70+NjRFUR3zs1JPUCgaCXSh3SW62uAKT1mSBM=
|
github.com/x448/float16 v0.8.4 h1:qLwI1I70+NjRFUR3zs1JPUCgaCXSh3SW62uAKT1mSBM=
|
||||||
github.com/x448/float16 v0.8.4/go.mod h1:14CWIYCyZA/cWjXOioeEpHeN/83MdbZDRQHoFcYsOfg=
|
github.com/x448/float16 v0.8.4/go.mod h1:14CWIYCyZA/cWjXOioeEpHeN/83MdbZDRQHoFcYsOfg=
|
||||||
github.com/xuri/efp v0.0.1 h1:fws5Rv3myXyYni8uwj2qKjVaRP30PdjeYe2Y6FDsCL8=
|
|
||||||
github.com/xuri/efp v0.0.1/go.mod h1:ybY/Jr0T0GTCnYjKqmdwxyxn2BQf2RcQIIvex5QldPI=
|
|
||||||
github.com/xuri/excelize/v2 v2.10.0 h1:8aKsP7JD39iKLc6dH5Tw3dgV3sPRh8uRVXu/fMstfW4=
|
|
||||||
github.com/xuri/excelize/v2 v2.10.0/go.mod h1:SC5TzhQkaOsTWpANfm+7bJCldzcnU/jrhqkTi/iBHBU=
|
|
||||||
github.com/xuri/nfp v0.0.2-0.20250530014748-2ddeb826f9a9 h1:+C0TIdyyYmzadGaL/HBLbf3WdLgC29pgyhTjAT/0nuE=
|
|
||||||
github.com/xuri/nfp v0.0.2-0.20250530014748-2ddeb826f9a9/go.mod h1:WwHg+CVyzlv/TX9xqBFXEZAuxOPxn2k1GNHwG41IIUQ=
|
|
||||||
github.com/yanyiwu/gojieba v1.4.5 h1:VyZogGtdFSnJbACHvDRvDreXPPVPCg8axKFUdblU/JI=
|
github.com/yanyiwu/gojieba v1.4.5 h1:VyZogGtdFSnJbACHvDRvDreXPPVPCg8axKFUdblU/JI=
|
||||||
github.com/yanyiwu/gojieba v1.4.5/go.mod h1:JUq4DddFVGdHXJHxxepxRmhrKlDpaBxR8O28v6fKYLY=
|
github.com/yanyiwu/gojieba v1.4.5/go.mod h1:JUq4DddFVGdHXJHxxepxRmhrKlDpaBxR8O28v6fKYLY=
|
||||||
github.com/yosida95/uritemplate/v3 v3.0.2 h1:Ed3Oyj9yrmi9087+NczuL5BwkIc4wvTb5zIM+UJPGz4=
|
github.com/yosida95/uritemplate/v3 v3.0.2 h1:Ed3Oyj9yrmi9087+NczuL5BwkIc4wvTb5zIM+UJPGz4=
|
||||||
@@ -359,8 +346,6 @@ golang.org/x/crypto v0.23.0/go.mod h1:CKFgDieR+mRhux2Lsu27y0fO304Db0wZe70UKqHu0v
|
|||||||
golang.org/x/crypto v0.31.0/go.mod h1:kDsLvtWBEx7MV9tJOj9bnXsPbxwJQ6csT/x4KIN4Ssk=
|
golang.org/x/crypto v0.31.0/go.mod h1:kDsLvtWBEx7MV9tJOj9bnXsPbxwJQ6csT/x4KIN4Ssk=
|
||||||
golang.org/x/crypto v0.43.0 h1:dduJYIi3A3KOfdGOHX8AVZ/jGiyPa3IbBozJ5kNuE04=
|
golang.org/x/crypto v0.43.0 h1:dduJYIi3A3KOfdGOHX8AVZ/jGiyPa3IbBozJ5kNuE04=
|
||||||
golang.org/x/crypto v0.43.0/go.mod h1:BFbav4mRNlXJL4wNeejLpWxB7wMbc79PdRGhWKncxR0=
|
golang.org/x/crypto v0.43.0/go.mod h1:BFbav4mRNlXJL4wNeejLpWxB7wMbc79PdRGhWKncxR0=
|
||||||
golang.org/x/image v0.25.0 h1:Y6uW6rH1y5y/LK1J8BPWZtr6yZ7hrsy6hFrXjgsc2fQ=
|
|
||||||
golang.org/x/image v0.25.0/go.mod h1:tCAmOEGthTtkalusGp1g3xa2gke8J6c2N565dTyl9Rs=
|
|
||||||
golang.org/x/mod v0.6.0-dev.0.20220419223038-86c51ed26bb4/go.mod h1:jJ57K6gSWd91VN4djpZkiMVwK6gcyfeH4XE8wZrZaV4=
|
golang.org/x/mod v0.6.0-dev.0.20220419223038-86c51ed26bb4/go.mod h1:jJ57K6gSWd91VN4djpZkiMVwK6gcyfeH4XE8wZrZaV4=
|
||||||
golang.org/x/mod v0.8.0/go.mod h1:iBbtSCu2XBx23ZKBPSOrRkjjQPZFPuis4dIYUhu/chs=
|
golang.org/x/mod v0.8.0/go.mod h1:iBbtSCu2XBx23ZKBPSOrRkjjQPZFPuis4dIYUhu/chs=
|
||||||
golang.org/x/mod v0.12.0/go.mod h1:iBbtSCu2XBx23ZKBPSOrRkjjQPZFPuis4dIYUhu/chs=
|
golang.org/x/mod v0.12.0/go.mod h1:iBbtSCu2XBx23ZKBPSOrRkjjQPZFPuis4dIYUhu/chs=
|
||||||
|
|||||||
@@ -8,6 +8,7 @@ import (
|
|||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
"github.com/Tencent/WeKnora/internal/logger"
|
"github.com/Tencent/WeKnora/internal/logger"
|
||||||
|
"github.com/Tencent/WeKnora/internal/searchutil"
|
||||||
"github.com/Tencent/WeKnora/internal/types"
|
"github.com/Tencent/WeKnora/internal/types"
|
||||||
"gorm.io/gorm"
|
"gorm.io/gorm"
|
||||||
)
|
)
|
||||||
@@ -546,17 +547,7 @@ func (t *GrepChunksTool) deduplicateChunks(ctx context.Context, results []chunkW
|
|||||||
|
|
||||||
// buildContentSignature creates a normalized signature for content to detect near-duplicates
|
// buildContentSignature creates a normalized signature for content to detect near-duplicates
|
||||||
func (t *GrepChunksTool) buildContentSignature(content string) string {
|
func (t *GrepChunksTool) buildContentSignature(content string) string {
|
||||||
c := strings.ToLower(strings.TrimSpace(content))
|
return searchutil.BuildContentSignature(content)
|
||||||
if c == "" {
|
|
||||||
return ""
|
|
||||||
}
|
|
||||||
// Normalize whitespace
|
|
||||||
c = strings.Join(strings.Fields(c), " ")
|
|
||||||
// Use first 128 characters as signature
|
|
||||||
if len(c) > 128 {
|
|
||||||
c = c[:128]
|
|
||||||
}
|
|
||||||
return c
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// scoreChunks calculates match scores for chunks based on pattern matches
|
// scoreChunks calculates match scores for chunks based on pattern matches
|
||||||
@@ -693,36 +684,10 @@ func (t *GrepChunksTool) applyMMR(
|
|||||||
|
|
||||||
// tokenizeSimple tokenizes text into a set of words (simple whitespace-based)
|
// tokenizeSimple tokenizes text into a set of words (simple whitespace-based)
|
||||||
func (t *GrepChunksTool) tokenizeSimple(text string) map[string]struct{} {
|
func (t *GrepChunksTool) tokenizeSimple(text string) map[string]struct{} {
|
||||||
text = strings.ToLower(text)
|
return searchutil.TokenizeSimple(text)
|
||||||
fields := strings.Fields(text)
|
|
||||||
set := make(map[string]struct{}, len(fields))
|
|
||||||
for _, f := range fields {
|
|
||||||
if len(f) > 1 {
|
|
||||||
set[f] = struct{}{}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return set
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// jaccard calculates Jaccard similarity between two token sets
|
// jaccard calculates Jaccard similarity between two token sets
|
||||||
func (t *GrepChunksTool) jaccard(a, b map[string]struct{}) float64 {
|
func (t *GrepChunksTool) jaccard(a, b map[string]struct{}) float64 {
|
||||||
if len(a) == 0 && len(b) == 0 {
|
return searchutil.Jaccard(a, b)
|
||||||
return 0
|
|
||||||
}
|
|
||||||
|
|
||||||
// Calculate intersection
|
|
||||||
inter := 0
|
|
||||||
for k := range a {
|
|
||||||
if _, ok := b[k]; ok {
|
|
||||||
inter++
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Calculate union
|
|
||||||
union := len(a) + len(b) - inter
|
|
||||||
if union == 0 {
|
|
||||||
return 0
|
|
||||||
}
|
|
||||||
|
|
||||||
return float64(inter) / float64(union)
|
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -264,15 +264,12 @@ func (t *KnowledgeSearchTool) Execute(ctx context.Context, args map[string]inter
|
|||||||
topK, vectorThreshold, keywordThreshold, kbTypeMap)
|
topK, vectorThreshold, keywordThreshold, kbTypeMap)
|
||||||
logger.Infof(ctx, "[Tool][KnowledgeSearch] Concurrent search completed: %d raw results", len(allResults))
|
logger.Infof(ctx, "[Tool][KnowledgeSearch] Concurrent search completed: %d raw results", len(allResults))
|
||||||
|
|
||||||
// Normalize keyword search results to ensure fair comparison across knowledge bases
|
// Note: HybridSearch now uses RRF (Reciprocal Rank Fusion) which produces normalized scores
|
||||||
logger.Debugf(ctx, "[Tool][KnowledgeSearch] Normalizing keyword search results...")
|
// RRF scores are in range [0, ~0.033] (max when rank=1 on both sides: 2/(60+1))
|
||||||
t.normalizeKeywordSearchResults(ctx, allResults)
|
// Threshold filtering is already done inside HybridSearch before RRF, so we skip it here
|
||||||
logger.Infof(ctx, "[Tool][KnowledgeSearch] After keyword normalization: %d results", len(allResults))
|
|
||||||
|
|
||||||
// Filter by threshold first
|
|
||||||
filteredResults := t.filterByThreshold(allResults, vectorThreshold, keywordThreshold)
|
|
||||||
// Deduplicate before reranking to reduce processing overhead
|
// Deduplicate before reranking to reduce processing overhead
|
||||||
deduplicatedBeforeRerank := t.deduplicateResults(filteredResults)
|
deduplicatedBeforeRerank := t.deduplicateResults(allResults)
|
||||||
|
|
||||||
// Apply ReRank if model is configured
|
// Apply ReRank if model is configured
|
||||||
// Prefer chatModel (LLM-based reranking) over rerankModel if both are available
|
// Prefer chatModel (LLM-based reranking) over rerankModel if both are available
|
||||||
@@ -286,6 +283,9 @@ func (t *KnowledgeSearchTool) Execute(ctx context.Context, args map[string]inter
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Variable to hold results through reranking and MMR stages
|
||||||
|
var filteredResults []*searchResultWithMeta
|
||||||
|
|
||||||
if t.chatModel != nil && len(deduplicatedBeforeRerank) > 0 && rerankQuery != "" {
|
if t.chatModel != nil && len(deduplicatedBeforeRerank) > 0 && rerankQuery != "" {
|
||||||
logger.Infof(
|
logger.Infof(
|
||||||
ctx,
|
ctx,
|
||||||
@@ -347,10 +347,9 @@ func (t *KnowledgeSearchTool) Execute(ctx context.Context, args map[string]inter
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Apply absolute minimum score filter to remove very low quality chunks
|
// Note: minScore filter is skipped because HybridSearch now uses RRF scores
|
||||||
logger.Debugf(ctx, "[Tool][KnowledgeSearch] Applying min_score filter (%.2f)...", minScore)
|
// RRF scores are in range [0, ~0.033], not [0, 1], so old thresholds don't apply
|
||||||
filteredResults = t.filterByMinScore(filteredResults, minScore)
|
// Threshold filtering is already done inside HybridSearch before RRF fusion
|
||||||
logger.Infof(ctx, "[Tool][KnowledgeSearch] After min_score filter: %d results", len(filteredResults))
|
|
||||||
|
|
||||||
// Final deduplication after rerank (in case rerank changed scores/order but duplicates remain)
|
// Final deduplication after rerank (in case rerank changed scores/order but duplicates remain)
|
||||||
logger.Debugf(ctx, "[Tool][KnowledgeSearch] Final deduplication after rerank...")
|
logger.Debugf(ctx, "[Tool][KnowledgeSearch] Final deduplication after rerank...")
|
||||||
@@ -465,45 +464,6 @@ func (t *KnowledgeSearchTool) concurrentSearch(
|
|||||||
return allResults
|
return allResults
|
||||||
}
|
}
|
||||||
|
|
||||||
// filterByThreshold filters results based on match type and threshold
|
|
||||||
// Special handling for history matches: uses lower threshold (reduced by 0.1, minimum 0.5)
|
|
||||||
func (t *KnowledgeSearchTool) filterByThreshold(
|
|
||||||
results []*searchResultWithMeta,
|
|
||||||
vectorThreshold, keywordThreshold float64,
|
|
||||||
) []*searchResultWithMeta {
|
|
||||||
filtered := make([]*searchResultWithMeta, 0)
|
|
||||||
for _, r := range results {
|
|
||||||
var threshold float64
|
|
||||||
|
|
||||||
// Special handling for history matches: use lower threshold
|
|
||||||
switch r.MatchType {
|
|
||||||
case types.MatchTypeHistory:
|
|
||||||
// Use the lower of the two thresholds, then reduce by 0.1 (minimum 0.5)
|
|
||||||
th := vectorThreshold
|
|
||||||
if keywordThreshold < th {
|
|
||||||
th = keywordThreshold
|
|
||||||
}
|
|
||||||
threshold = math.Max(th-0.1, 0.5)
|
|
||||||
case types.MatchTypeEmbedding:
|
|
||||||
threshold = vectorThreshold
|
|
||||||
case types.MatchTypeKeywords:
|
|
||||||
threshold = keywordThreshold
|
|
||||||
default:
|
|
||||||
// For other match types (graph, nearby chunk, etc.), use the lower threshold
|
|
||||||
threshold = vectorThreshold
|
|
||||||
if keywordThreshold < threshold {
|
|
||||||
threshold = keywordThreshold
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Check if result meets threshold
|
|
||||||
if r.Score >= threshold {
|
|
||||||
filtered = append(filtered, r)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return filtered
|
|
||||||
}
|
|
||||||
|
|
||||||
// rerankResults applies reranking to search results using LLM prompt scoring or rerank model
|
// rerankResults applies reranking to search results using LLM prompt scoring or rerank model
|
||||||
func (t *KnowledgeSearchTool) rerankResults(
|
func (t *KnowledgeSearchTool) rerankResults(
|
||||||
ctx context.Context,
|
ctx context.Context,
|
||||||
@@ -566,15 +526,13 @@ func (t *KnowledgeSearchTool) rerankResults(
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Apply composite scoring to reranked results
|
// Apply composite scoring to reranked results
|
||||||
// Get query intent from context if available (optional)
|
logger.Debugf(ctx, "[Tool][KnowledgeSearch] Applying composite scoring")
|
||||||
queryIntent := t.getQueryIntentFromContext(ctx)
|
|
||||||
logger.Debugf(ctx, "[Tool][KnowledgeSearch] Applying composite scoring with query_intent=%s", queryIntent)
|
|
||||||
|
|
||||||
// Store base scores before composite scoring
|
// Store base scores before composite scoring
|
||||||
for _, result := range rerankedNonFAQ {
|
for _, result := range rerankedNonFAQ {
|
||||||
baseScore := result.Score
|
baseScore := result.Score
|
||||||
// Apply composite score
|
// Apply composite score
|
||||||
result.Score = t.compositeScore(result, result.Score, baseScore, queryIntent)
|
result.Score = t.compositeScore(result, result.Score, baseScore)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Combine FAQ results (with original order) and reranked non-FAQ results
|
// Combine FAQ results (with original order) and reranked non-FAQ results
|
||||||
@@ -899,20 +857,6 @@ func (t *KnowledgeSearchTool) rerankWithModel(
|
|||||||
return reranked, nil
|
return reranked, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// filterByMinScore filters results by absolute minimum score
|
|
||||||
func (t *KnowledgeSearchTool) filterByMinScore(
|
|
||||||
results []*searchResultWithMeta,
|
|
||||||
minScore float64,
|
|
||||||
) []*searchResultWithMeta {
|
|
||||||
filtered := make([]*searchResultWithMeta, 0)
|
|
||||||
for _, r := range results {
|
|
||||||
if r.Score >= minScore {
|
|
||||||
filtered = append(filtered, r)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return filtered
|
|
||||||
}
|
|
||||||
|
|
||||||
// deduplicateResults removes duplicate chunks, keeping the highest score
|
// deduplicateResults removes duplicate chunks, keeping the highest score
|
||||||
// Uses multiple keys (ID, parent chunk ID, knowledge+index) and content signature for deduplication
|
// Uses multiple keys (ID, parent chunk ID, knowledge+index) and content signature for deduplication
|
||||||
func (t *KnowledgeSearchTool) deduplicateResults(results []*searchResultWithMeta) []*searchResultWithMeta {
|
func (t *KnowledgeSearchTool) deduplicateResults(results []*searchResultWithMeta) []*searchResultWithMeta {
|
||||||
@@ -984,17 +928,7 @@ func (t *KnowledgeSearchTool) deduplicateResults(results []*searchResultWithMeta
|
|||||||
|
|
||||||
// buildContentSignature creates a normalized signature for content to detect near-duplicates
|
// buildContentSignature creates a normalized signature for content to detect near-duplicates
|
||||||
func (t *KnowledgeSearchTool) buildContentSignature(content string) string {
|
func (t *KnowledgeSearchTool) buildContentSignature(content string) string {
|
||||||
c := strings.ToLower(strings.TrimSpace(content))
|
return searchutil.BuildContentSignature(content)
|
||||||
if c == "" {
|
|
||||||
return ""
|
|
||||||
}
|
|
||||||
// Normalize whitespace
|
|
||||||
c = strings.Join(strings.Fields(c), " ")
|
|
||||||
// Use first 128 characters as signature
|
|
||||||
if len(c) > 128 {
|
|
||||||
c = c[:128]
|
|
||||||
}
|
|
||||||
return c
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// formatOutput formats the search results for display
|
// formatOutput formats the search results for display
|
||||||
@@ -1215,47 +1149,6 @@ type chunkRange struct {
|
|||||||
end int
|
end int
|
||||||
}
|
}
|
||||||
|
|
||||||
// normalizeKeywordSearchResults normalizes keyword search result scores into [0,1] globally across all knowledge bases
|
|
||||||
// Improvements:
|
|
||||||
// 1. Uses robust normalization with percentile-based bounds to handle outliers
|
|
||||||
// 2. Handles edge cases: single result, no variance, negative scores
|
|
||||||
// 3. Global normalization ensures fair comparison across different knowledge bases
|
|
||||||
func (t *KnowledgeSearchTool) normalizeKeywordSearchResults(ctx context.Context, results []*searchResultWithMeta) {
|
|
||||||
searchutil.NormalizeKeywordScores[*searchResultWithMeta](
|
|
||||||
results,
|
|
||||||
func(r *searchResultWithMeta) bool {
|
|
||||||
return r.MatchType == types.MatchTypeKeywords
|
|
||||||
},
|
|
||||||
func(r *searchResultWithMeta) float64 {
|
|
||||||
return r.Score
|
|
||||||
},
|
|
||||||
func(r *searchResultWithMeta, score float64) {
|
|
||||||
r.Score = score
|
|
||||||
},
|
|
||||||
searchutil.KeywordScoreCallbacks{
|
|
||||||
OnNoVariance: func(count int, score float64) {
|
|
||||||
logger.Infof(
|
|
||||||
ctx,
|
|
||||||
"[Tool][KnowledgeSearch] Keyword scores have no variance, all set to 1.0: count=%d, score=%.3f",
|
|
||||||
count,
|
|
||||||
score,
|
|
||||||
)
|
|
||||||
},
|
|
||||||
OnNormalized: func(count int, rawMin, rawMax, normalizeMin, normalizeMax float64) {
|
|
||||||
logger.Infof(
|
|
||||||
ctx,
|
|
||||||
"[Tool][KnowledgeSearch] Normalized keyword scores: count=%d, raw_min=%.3f, raw_max=%.3f, normalize_min=%.3f, normalize_max=%.3f",
|
|
||||||
count,
|
|
||||||
rawMin,
|
|
||||||
rawMax,
|
|
||||||
normalizeMin,
|
|
||||||
normalizeMax,
|
|
||||||
)
|
|
||||||
},
|
|
||||||
},
|
|
||||||
)
|
|
||||||
}
|
|
||||||
|
|
||||||
// getEnrichedPassage 合并Content和ImageInfo的文本内容
|
// getEnrichedPassage 合并Content和ImageInfo的文本内容
|
||||||
func (t *KnowledgeSearchTool) getEnrichedPassage(ctx context.Context, result *types.SearchResult) string {
|
func (t *KnowledgeSearchTool) getEnrichedPassage(ctx context.Context, result *types.SearchResult) string {
|
||||||
if result.ImageInfo == "" {
|
if result.ImageInfo == "" {
|
||||||
@@ -1302,19 +1195,10 @@ func (t *KnowledgeSearchTool) getEnrichedPassage(ctx context.Context, result *ty
|
|||||||
return combinedText
|
return combinedText
|
||||||
}
|
}
|
||||||
|
|
||||||
// getQueryIntentFromContext attempts to extract query intent from context (optional)
|
|
||||||
func (t *KnowledgeSearchTool) getQueryIntentFromContext(ctx context.Context) string {
|
|
||||||
// Try to get query intent from context if available
|
|
||||||
// This is optional and may not always be present in agent tool context
|
|
||||||
// Return empty string if not available
|
|
||||||
return ""
|
|
||||||
}
|
|
||||||
|
|
||||||
// compositeScore calculates a composite score considering multiple factors
|
// compositeScore calculates a composite score considering multiple factors
|
||||||
func (t *KnowledgeSearchTool) compositeScore(
|
func (t *KnowledgeSearchTool) compositeScore(
|
||||||
result *searchResultWithMeta,
|
result *searchResultWithMeta,
|
||||||
modelScore, baseScore float64,
|
modelScore, baseScore float64,
|
||||||
queryIntent string,
|
|
||||||
) float64 {
|
) float64 {
|
||||||
// Source weight: web_search results get slightly lower weight
|
// Source weight: web_search results get slightly lower weight
|
||||||
sourceWeight := 1.0
|
sourceWeight := 1.0
|
||||||
@@ -1322,26 +1206,6 @@ func (t *KnowledgeSearchTool) compositeScore(
|
|||||||
sourceWeight = 0.95
|
sourceWeight = 0.95
|
||||||
}
|
}
|
||||||
|
|
||||||
// Intent boost: adjust score based on query intent and chunk characteristics
|
|
||||||
intentBoost := 1.0
|
|
||||||
if queryIntent != "" {
|
|
||||||
switch queryIntent {
|
|
||||||
case "definition":
|
|
||||||
// Boost summary chunks for definition queries
|
|
||||||
if result.ChunkType == string(types.ChunkTypeSummary) {
|
|
||||||
intentBoost = 1.05
|
|
||||||
}
|
|
||||||
case "howto":
|
|
||||||
// Boost longer chunks for howto queries
|
|
||||||
if result.EndAt-result.StartAt > 300 {
|
|
||||||
intentBoost = 1.03
|
|
||||||
}
|
|
||||||
case "compare":
|
|
||||||
// No boost for compare queries
|
|
||||||
intentBoost = 1.0
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Position prior: slightly favor chunks earlier in the document
|
// Position prior: slightly favor chunks earlier in the document
|
||||||
positionPrior := 1.0
|
positionPrior := 1.0
|
||||||
if result.StartAt >= 0 && result.EndAt > result.StartAt {
|
if result.StartAt >= 0 && result.EndAt > result.StartAt {
|
||||||
@@ -1352,7 +1216,6 @@ func (t *KnowledgeSearchTool) compositeScore(
|
|||||||
|
|
||||||
// Composite formula: weighted combination of model score, base score, and source weight
|
// Composite formula: weighted combination of model score, base score, and source weight
|
||||||
composite := 0.6*modelScore + 0.3*baseScore + 0.1*sourceWeight
|
composite := 0.6*modelScore + 0.3*baseScore + 0.1*sourceWeight
|
||||||
composite *= intentBoost
|
|
||||||
composite *= positionPrior
|
composite *= positionPrior
|
||||||
|
|
||||||
// Clamp to [0, 1]
|
// Clamp to [0, 1]
|
||||||
@@ -1368,13 +1231,7 @@ func (t *KnowledgeSearchTool) compositeScore(
|
|||||||
|
|
||||||
// clampFloat clamps a float value to the specified range
|
// clampFloat clamps a float value to the specified range
|
||||||
func (t *KnowledgeSearchTool) clampFloat(v, minV, maxV float64) float64 {
|
func (t *KnowledgeSearchTool) clampFloat(v, minV, maxV float64) float64 {
|
||||||
if v < minV {
|
return searchutil.ClampFloat(v, minV, maxV)
|
||||||
return minV
|
|
||||||
}
|
|
||||||
if v > maxV {
|
|
||||||
return maxV
|
|
||||||
}
|
|
||||||
return v
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// applyMMR applies Maximal Marginal Relevance algorithm to reduce redundancy
|
// applyMMR applies Maximal Marginal Relevance algorithm to reduce redundancy
|
||||||
@@ -1456,36 +1313,10 @@ func (t *KnowledgeSearchTool) applyMMR(
|
|||||||
|
|
||||||
// tokenizeSimple tokenizes text into a set of words (simple whitespace-based)
|
// tokenizeSimple tokenizes text into a set of words (simple whitespace-based)
|
||||||
func (t *KnowledgeSearchTool) tokenizeSimple(text string) map[string]struct{} {
|
func (t *KnowledgeSearchTool) tokenizeSimple(text string) map[string]struct{} {
|
||||||
text = strings.ToLower(text)
|
return searchutil.TokenizeSimple(text)
|
||||||
fields := strings.Fields(text)
|
|
||||||
set := make(map[string]struct{}, len(fields))
|
|
||||||
for _, f := range fields {
|
|
||||||
if len(f) > 1 {
|
|
||||||
set[f] = struct{}{}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return set
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// jaccard calculates Jaccard similarity between two token sets
|
// jaccard calculates Jaccard similarity between two token sets
|
||||||
func (t *KnowledgeSearchTool) jaccard(a, b map[string]struct{}) float64 {
|
func (t *KnowledgeSearchTool) jaccard(a, b map[string]struct{}) float64 {
|
||||||
if len(a) == 0 && len(b) == 0 {
|
return searchutil.Jaccard(a, b)
|
||||||
return 0
|
|
||||||
}
|
|
||||||
|
|
||||||
// Calculate intersection
|
|
||||||
inter := 0
|
|
||||||
for k := range a {
|
|
||||||
if _, ok := b[k]; ok {
|
|
||||||
inter++
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Calculate union
|
|
||||||
union := len(a) + len(b) - inter
|
|
||||||
if union == 0 {
|
|
||||||
return 0
|
|
||||||
}
|
|
||||||
|
|
||||||
return float64(inter) / float64(union)
|
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,166 +0,0 @@
|
|||||||
package chatpipline
|
|
||||||
|
|
||||||
import (
|
|
||||||
"context"
|
|
||||||
"encoding/json"
|
|
||||||
"regexp"
|
|
||||||
"strings"
|
|
||||||
|
|
||||||
"github.com/Tencent/WeKnora/internal/config"
|
|
||||||
"github.com/Tencent/WeKnora/internal/models/chat"
|
|
||||||
"github.com/Tencent/WeKnora/internal/types"
|
|
||||||
"github.com/Tencent/WeKnora/internal/types/interfaces"
|
|
||||||
)
|
|
||||||
|
|
||||||
// PluginPreprocess Query preprocessing plugin
|
|
||||||
type PluginPreprocess struct {
|
|
||||||
config *config.Config
|
|
||||||
modelService interfaces.ModelService
|
|
||||||
}
|
|
||||||
|
|
||||||
// Regular expressions for text cleaning
|
|
||||||
var (
|
|
||||||
multiSpaceRegex = regexp.MustCompile(`\s+`) // Multiple spaces
|
|
||||||
)
|
|
||||||
|
|
||||||
// NewPluginPreprocess Creates a new query preprocessing plugin
|
|
||||||
func NewPluginPreprocess(
|
|
||||||
eventManager *EventManager,
|
|
||||||
config *config.Config,
|
|
||||||
cleaner interfaces.ResourceCleaner,
|
|
||||||
modelService interfaces.ModelService,
|
|
||||||
) *PluginPreprocess {
|
|
||||||
res := &PluginPreprocess{
|
|
||||||
config: config,
|
|
||||||
modelService: modelService,
|
|
||||||
}
|
|
||||||
|
|
||||||
eventManager.Register(res)
|
|
||||||
return res
|
|
||||||
}
|
|
||||||
|
|
||||||
// ActivationEvents Register activation events
|
|
||||||
func (p *PluginPreprocess) ActivationEvents() []types.EventType {
|
|
||||||
return []types.EventType{types.PREPROCESS_QUERY}
|
|
||||||
}
|
|
||||||
|
|
||||||
// OnEvent Process events
|
|
||||||
func (p *PluginPreprocess) OnEvent(
|
|
||||||
ctx context.Context,
|
|
||||||
eventType types.EventType,
|
|
||||||
chatManage *types.ChatManage,
|
|
||||||
next func() *PluginError,
|
|
||||||
) *PluginError {
|
|
||||||
rawQuery := strings.TrimSpace(chatManage.RewriteQuery)
|
|
||||||
if rawQuery == "" {
|
|
||||||
return next()
|
|
||||||
}
|
|
||||||
|
|
||||||
pipelineInfo(ctx, "Preprocess", "input", map[string]interface{}{
|
|
||||||
"session_id": chatManage.SessionID,
|
|
||||||
"rewrite_query": rawQuery,
|
|
||||||
})
|
|
||||||
|
|
||||||
// Lightweight normalization: just collapse multiple spaces
|
|
||||||
processed := multiSpaceRegex.ReplaceAllString(rawQuery, " ")
|
|
||||||
processed = strings.TrimSpace(processed)
|
|
||||||
|
|
||||||
chatManage.ProcessedQuery = processed
|
|
||||||
chatManage.QueryIntent = p.detectIntentLLM(ctx, chatManage, processed)
|
|
||||||
|
|
||||||
pipelineInfo(ctx, "Preprocess", "output", map[string]interface{}{
|
|
||||||
"session_id": chatManage.SessionID,
|
|
||||||
"processed_query": processed,
|
|
||||||
"query_intent": chatManage.QueryIntent,
|
|
||||||
})
|
|
||||||
|
|
||||||
return next()
|
|
||||||
}
|
|
||||||
|
|
||||||
// intentResp is a response for intent detection
|
|
||||||
type intentResp struct {
|
|
||||||
Intent string `json:"intent"`
|
|
||||||
Confidence float64 `json:"confidence"`
|
|
||||||
}
|
|
||||||
|
|
||||||
// detectIntentLLM detects the intent of a query using an LLM
|
|
||||||
func (p *PluginPreprocess) detectIntentLLM(ctx context.Context, chatManage *types.ChatManage, text string) string {
|
|
||||||
if p.modelService == nil || chatManage.ChatModelID == "" {
|
|
||||||
pipelineWarn(
|
|
||||||
ctx,
|
|
||||||
"IntentDetect",
|
|
||||||
"skip",
|
|
||||||
map[string]interface{}{"reason": "no_model", "session_id": chatManage.SessionID},
|
|
||||||
)
|
|
||||||
return "general"
|
|
||||||
}
|
|
||||||
chatModel, err := p.modelService.GetChatModel(ctx, chatManage.ChatModelID)
|
|
||||||
if err != nil {
|
|
||||||
pipelineWarn(
|
|
||||||
ctx,
|
|
||||||
"IntentDetect",
|
|
||||||
"get_model_failed",
|
|
||||||
map[string]interface{}{"error": err.Error(), "model_id": chatManage.ChatModelID},
|
|
||||||
)
|
|
||||||
return "general"
|
|
||||||
}
|
|
||||||
pipelineInfo(
|
|
||||||
ctx,
|
|
||||||
"IntentDetect",
|
|
||||||
"start",
|
|
||||||
map[string]interface{}{"session_id": chatManage.SessionID, "model_id": chatManage.ChatModelID},
|
|
||||||
)
|
|
||||||
sys := "You are a query intent classifier. Classify the user's query into one of: definition, howto, compare, qa, general. Respond ONLY with a JSON object {\"intent\": \"...\", \"confidence\": 0.0 } inside a markdown fenced block."
|
|
||||||
usr := text
|
|
||||||
think := false
|
|
||||||
resp, err := chatModel.Chat(ctx, []chat.Message{
|
|
||||||
{Role: "system", Content: sys},
|
|
||||||
{Role: "user", Content: usr},
|
|
||||||
}, &chat.ChatOptions{Temperature: 0.0, MaxCompletionTokens: 64, Thinking: &think})
|
|
||||||
if err != nil || resp.Content == "" {
|
|
||||||
pipelineWarn(ctx, "IntentDetect", "model_call_failed", map[string]interface{}{"error": err})
|
|
||||||
return "general"
|
|
||||||
}
|
|
||||||
body := extractJSONBody(resp.Content)
|
|
||||||
var ir intentResp
|
|
||||||
if err := json.Unmarshal([]byte(body), &ir); err != nil {
|
|
||||||
pipelineWarn(ctx, "IntentDetect", "parse_failed", map[string]interface{}{"body": body, "error": err.Error()})
|
|
||||||
return "general"
|
|
||||||
}
|
|
||||||
pipelineInfo(
|
|
||||||
ctx,
|
|
||||||
"IntentDetect",
|
|
||||||
"result",
|
|
||||||
map[string]interface{}{"intent": ir.Intent, "confidence": ir.Confidence},
|
|
||||||
)
|
|
||||||
switch strings.ToLower(strings.TrimSpace(ir.Intent)) {
|
|
||||||
case "definition", "howto", "compare", "qa", "general":
|
|
||||||
return strings.ToLower(ir.Intent)
|
|
||||||
default:
|
|
||||||
return "general"
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// extractJSONBody extracts a JSON body from a string
|
|
||||||
func extractJSONBody(text string) string {
|
|
||||||
t := strings.TrimSpace(text)
|
|
||||||
// Try fenced block first
|
|
||||||
if i := strings.Index(t, "{"); i >= 0 {
|
|
||||||
j := strings.LastIndex(t, "}")
|
|
||||||
if j > i {
|
|
||||||
return t[i : j+1]
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return "{}"
|
|
||||||
}
|
|
||||||
|
|
||||||
// Close Releases resources
|
|
||||||
func (p *PluginPreprocess) Close() {
|
|
||||||
}
|
|
||||||
|
|
||||||
// ShutdownHandler Returns shutdown function
|
|
||||||
func (p *PluginPreprocess) ShutdownHandler() func() {
|
|
||||||
return func() {
|
|
||||||
p.Close()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
@@ -8,6 +8,7 @@ import (
|
|||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
"github.com/Tencent/WeKnora/internal/models/rerank"
|
"github.com/Tencent/WeKnora/internal/models/rerank"
|
||||||
|
"github.com/Tencent/WeKnora/internal/searchutil"
|
||||||
"github.com/Tencent/WeKnora/internal/types"
|
"github.com/Tencent/WeKnora/internal/types"
|
||||||
"github.com/Tencent/WeKnora/internal/types/interfaces"
|
"github.com/Tencent/WeKnora/internal/types/interfaces"
|
||||||
)
|
)
|
||||||
@@ -36,12 +37,11 @@ func (p *PluginRerank) OnEvent(ctx context.Context,
|
|||||||
eventType types.EventType, chatManage *types.ChatManage, next func() *PluginError,
|
eventType types.EventType, chatManage *types.ChatManage, next func() *PluginError,
|
||||||
) *PluginError {
|
) *PluginError {
|
||||||
pipelineInfo(ctx, "Rerank", "input", map[string]interface{}{
|
pipelineInfo(ctx, "Rerank", "input", map[string]interface{}{
|
||||||
"session_id": chatManage.SessionID,
|
"session_id": chatManage.SessionID,
|
||||||
"candidate_cnt": len(chatManage.SearchResult),
|
"candidate_cnt": len(chatManage.SearchResult),
|
||||||
"rerank_model": chatManage.RerankModelID,
|
"rerank_model": chatManage.RerankModelID,
|
||||||
"rerank_thresh": chatManage.RerankThreshold,
|
"rerank_thresh": chatManage.RerankThreshold,
|
||||||
"rewrite_query": chatManage.RewriteQuery,
|
"rewrite_query": chatManage.RewriteQuery,
|
||||||
"processed_query": chatManage.ProcessedQuery,
|
|
||||||
})
|
})
|
||||||
if len(chatManage.SearchResult) == 0 {
|
if len(chatManage.SearchResult) == 0 {
|
||||||
pipelineInfo(ctx, "Rerank", "skip", map[string]interface{}{
|
pipelineInfo(ctx, "Rerank", "skip", map[string]interface{}{
|
||||||
@@ -77,18 +77,40 @@ func (p *PluginRerank) OnEvent(ctx context.Context,
|
|||||||
passages = append(passages, passage)
|
passages = append(passages, passage)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Try reranking with different query variants in priority order
|
// Single rerank call with RewriteQuery, use threshold degradation if no results
|
||||||
|
originalThreshold := chatManage.RerankThreshold
|
||||||
rerankResp := p.rerank(ctx, chatManage, rerankModel, chatManage.RewriteQuery, passages)
|
rerankResp := p.rerank(ctx, chatManage, rerankModel, chatManage.RewriteQuery, passages)
|
||||||
if len(rerankResp) == 0 {
|
|
||||||
rerankResp = p.rerank(ctx, chatManage, rerankModel, chatManage.ProcessedQuery, passages)
|
// If no results and threshold is high enough, try with lower threshold
|
||||||
if len(rerankResp) == 0 {
|
if len(rerankResp) == 0 && originalThreshold > 0.3 {
|
||||||
rerankResp = p.rerank(ctx, chatManage, rerankModel, chatManage.Query, passages)
|
degradedThreshold := originalThreshold * 0.7
|
||||||
|
if degradedThreshold < 0.3 {
|
||||||
|
degradedThreshold = 0.3
|
||||||
}
|
}
|
||||||
|
pipelineInfo(ctx, "Rerank", "threshold_degrade", map[string]interface{}{
|
||||||
|
"original": originalThreshold,
|
||||||
|
"degraded": degradedThreshold,
|
||||||
|
})
|
||||||
|
chatManage.RerankThreshold = degradedThreshold
|
||||||
|
rerankResp = p.rerank(ctx, chatManage, rerankModel, chatManage.RewriteQuery, passages)
|
||||||
|
// Restore original threshold
|
||||||
|
chatManage.RerankThreshold = originalThreshold
|
||||||
}
|
}
|
||||||
|
|
||||||
pipelineInfo(ctx, "Rerank", "model_response", map[string]interface{}{
|
pipelineInfo(ctx, "Rerank", "model_response", map[string]interface{}{
|
||||||
"result_cnt": len(rerankResp),
|
"result_cnt": len(rerankResp),
|
||||||
})
|
})
|
||||||
|
|
||||||
|
// Log input scores before reranking for debugging
|
||||||
|
for i, sr := range chatManage.SearchResult {
|
||||||
|
pipelineInfo(ctx, "Rerank", "input_score", map[string]interface{}{
|
||||||
|
"index": i,
|
||||||
|
"chunk_id": sr.ID,
|
||||||
|
"score": fmt.Sprintf("%.4f", sr.Score),
|
||||||
|
"match_type": sr.MatchType,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
for i := range chatManage.SearchResult {
|
for i := range chatManage.SearchResult {
|
||||||
chatManage.SearchResult[i].Metadata = ensureMetadata(chatManage.SearchResult[i].Metadata)
|
chatManage.SearchResult[i].Metadata = ensureMetadata(chatManage.SearchResult[i].Metadata)
|
||||||
}
|
}
|
||||||
@@ -97,8 +119,15 @@ func (p *PluginRerank) OnEvent(ctx context.Context,
|
|||||||
sr := chatManage.SearchResult[rr.Index]
|
sr := chatManage.SearchResult[rr.Index]
|
||||||
base := sr.Score
|
base := sr.Score
|
||||||
sr.Metadata["base_score"] = fmt.Sprintf("%.4f", base)
|
sr.Metadata["base_score"] = fmt.Sprintf("%.4f", base)
|
||||||
sr.Score = rr.RelevanceScore
|
modelScore := rr.RelevanceScore
|
||||||
sr.Score = compositeScore(sr, rr.RelevanceScore, base, chatManage)
|
sr.Score = compositeScore(sr, modelScore, base)
|
||||||
|
pipelineInfo(ctx, "Rerank", "composite_calc", map[string]interface{}{
|
||||||
|
"chunk_id": sr.ID,
|
||||||
|
"base_score": fmt.Sprintf("%.4f", base),
|
||||||
|
"model_score": fmt.Sprintf("%.4f", modelScore),
|
||||||
|
"final_score": fmt.Sprintf("%.4f", sr.Score),
|
||||||
|
"match_type": sr.MatchType,
|
||||||
|
})
|
||||||
reranked = append(reranked, sr)
|
reranked = append(reranked, sr)
|
||||||
}
|
}
|
||||||
final := applyMMR(ctx, reranked, chatManage, min(len(reranked), max(1, chatManage.RerankTopK)), 0.7)
|
final := applyMMR(ctx, reranked, chatManage, min(len(reranked), max(1, chatManage.RerankTopK)), 0.7)
|
||||||
@@ -112,7 +141,6 @@ func (p *PluginRerank) OnEvent(ctx context.Context,
|
|||||||
"chunk_id": reranked[i].ID,
|
"chunk_id": reranked[i].ID,
|
||||||
"base_score": reranked[i].Metadata["base_score"],
|
"base_score": reranked[i].Metadata["base_score"],
|
||||||
"final_score": fmt.Sprintf("%.4f", reranked[i].Score),
|
"final_score": fmt.Sprintf("%.4f", reranked[i].Score),
|
||||||
"intent": chatManage.QueryIntent,
|
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -185,7 +213,7 @@ func ensureMetadata(m map[string]string) map[string]string {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// compositeScore calculates the composite score for a search result
|
// compositeScore calculates the composite score for a search result
|
||||||
func compositeScore(sr *types.SearchResult, modelScore, baseScore float64, chatManage *types.ChatManage) float64 {
|
func compositeScore(sr *types.SearchResult, modelScore, baseScore float64) float64 {
|
||||||
sourceWeight := 1.0
|
sourceWeight := 1.0
|
||||||
switch strings.ToLower(sr.KnowledgeSource) {
|
switch strings.ToLower(sr.KnowledgeSource) {
|
||||||
case "web_search":
|
case "web_search":
|
||||||
@@ -193,27 +221,11 @@ func compositeScore(sr *types.SearchResult, modelScore, baseScore float64, chatM
|
|||||||
default:
|
default:
|
||||||
sourceWeight = 1.0
|
sourceWeight = 1.0
|
||||||
}
|
}
|
||||||
intentBoost := 1.0
|
|
||||||
if chatManage.QueryIntent != "" {
|
|
||||||
switch chatManage.QueryIntent {
|
|
||||||
case "definition":
|
|
||||||
if sr.ChunkType == string(types.ChunkTypeSummary) {
|
|
||||||
intentBoost = 1.05
|
|
||||||
}
|
|
||||||
case "howto":
|
|
||||||
if sr.EndAt-sr.StartAt > 300 {
|
|
||||||
intentBoost = 1.03
|
|
||||||
}
|
|
||||||
case "compare":
|
|
||||||
intentBoost = 1.0
|
|
||||||
}
|
|
||||||
}
|
|
||||||
positionPrior := 1.0
|
positionPrior := 1.0
|
||||||
if sr.StartAt >= 0 {
|
if sr.StartAt >= 0 {
|
||||||
positionPrior += clampFloat(1.0-float64(sr.StartAt)/float64(sr.EndAt+1), -0.05, 0.05)
|
positionPrior += searchutil.ClampFloat(1.0-float64(sr.StartAt)/float64(sr.EndAt+1), -0.05, 0.05)
|
||||||
}
|
}
|
||||||
composite := 0.6*modelScore + 0.3*baseScore + 0.1*sourceWeight
|
composite := 0.6*modelScore + 0.3*baseScore + 0.1*sourceWeight
|
||||||
composite *= intentBoost
|
|
||||||
composite *= positionPrior
|
composite *= positionPrior
|
||||||
if composite < 0 {
|
if composite < 0 {
|
||||||
composite = 0
|
composite = 0
|
||||||
@@ -224,7 +236,7 @@ func compositeScore(sr *types.SearchResult, modelScore, baseScore float64, chatM
|
|||||||
return composite
|
return composite
|
||||||
}
|
}
|
||||||
|
|
||||||
// applyMMR applies the MMR algorithm to the search results
|
// applyMMR applies the MMR algorithm to the search results with pre-computed token sets
|
||||||
func applyMMR(
|
func applyMMR(
|
||||||
ctx context.Context,
|
ctx context.Context,
|
||||||
results []*types.SearchResult,
|
results []*types.SearchResult,
|
||||||
@@ -240,40 +252,60 @@ func applyMMR(
|
|||||||
"k": k,
|
"k": k,
|
||||||
"candidates": len(results),
|
"candidates": len(results),
|
||||||
})
|
})
|
||||||
selected := make([]*types.SearchResult, 0, k)
|
|
||||||
candidates := make([]*types.SearchResult, len(results))
|
// Pre-compute all token sets upfront (optimization)
|
||||||
copy(candidates, results)
|
allTokenSets := make([]map[string]struct{}, len(results))
|
||||||
tokenSets := make([]map[string]struct{}, len(candidates))
|
for i, r := range results {
|
||||||
for i, r := range candidates {
|
allTokenSets[i] = searchutil.TokenizeSimple(getEnrichedPassage(ctx, r))
|
||||||
tokenSets[i] = tokenizeSimple(getEnrichedPassage(ctx, r))
|
|
||||||
}
|
}
|
||||||
for len(selected) < k && len(candidates) > 0 {
|
|
||||||
bestIdx := 0
|
selected := make([]*types.SearchResult, 0, k)
|
||||||
|
selectedTokenSets := make([]map[string]struct{}, 0, k)
|
||||||
|
selectedIndices := make(map[int]struct{})
|
||||||
|
|
||||||
|
for len(selected) < k && len(selectedIndices) < len(results) {
|
||||||
|
bestIdx := -1
|
||||||
bestScore := -1.0
|
bestScore := -1.0
|
||||||
for i, r := range candidates {
|
|
||||||
|
for i, r := range results {
|
||||||
|
if _, isSelected := selectedIndices[i]; isSelected {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
relevance := r.Score
|
relevance := r.Score
|
||||||
redundancy := 0.0
|
redundancy := 0.0
|
||||||
for _, s := range selected {
|
|
||||||
redundancy = math.Max(redundancy, jaccard(tokenSets[i], tokenizeSimple(getEnrichedPassage(ctx, s))))
|
// Use pre-computed token sets for redundancy calculation
|
||||||
|
for _, selTokens := range selectedTokenSets {
|
||||||
|
sim := searchutil.Jaccard(allTokenSets[i], selTokens)
|
||||||
|
if sim > redundancy {
|
||||||
|
redundancy = sim
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
mmr := lambda*relevance - (1.0-lambda)*redundancy
|
mmr := lambda*relevance - (1.0-lambda)*redundancy
|
||||||
if mmr > bestScore {
|
if mmr > bestScore {
|
||||||
bestScore = mmr
|
bestScore = mmr
|
||||||
bestIdx = i
|
bestIdx = i
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
selected = append(selected, candidates[bestIdx])
|
|
||||||
candidates = append(candidates[:bestIdx], candidates[bestIdx+1:]...)
|
if bestIdx < 0 {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
|
||||||
|
selected = append(selected, results[bestIdx])
|
||||||
|
selectedTokenSets = append(selectedTokenSets, allTokenSets[bestIdx])
|
||||||
|
selectedIndices[bestIdx] = struct{}{}
|
||||||
}
|
}
|
||||||
// Compute average redundancy among selected
|
|
||||||
|
// Compute average redundancy among selected using pre-computed token sets
|
||||||
avgRed := 0.0
|
avgRed := 0.0
|
||||||
if len(selected) > 1 {
|
if len(selected) > 1 {
|
||||||
pairs := 0
|
pairs := 0
|
||||||
for i := 0; i < len(selected); i++ {
|
for i := 0; i < len(selectedTokenSets); i++ {
|
||||||
for j := i + 1; j < len(selected); j++ {
|
for j := i + 1; j < len(selectedTokenSets); j++ {
|
||||||
si := tokenizeSimple(getEnrichedPassage(ctx, selected[i]))
|
avgRed += searchutil.Jaccard(selectedTokenSets[i], selectedTokenSets[j])
|
||||||
sj := tokenizeSimple(getEnrichedPassage(ctx, selected[j]))
|
|
||||||
avgRed += jaccard(si, sj)
|
|
||||||
pairs++
|
pairs++
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -288,93 +320,59 @@ func applyMMR(
|
|||||||
return selected
|
return selected
|
||||||
}
|
}
|
||||||
|
|
||||||
// tokenizeSimple tokenizes a text into a set of tokens
|
// getEnrichedPassage 合并Content、ImageInfo和GeneratedQuestions的文本内容
|
||||||
func tokenizeSimple(text string) map[string]struct{} {
|
|
||||||
text = strings.ToLower(text)
|
|
||||||
fields := strings.Fields(text)
|
|
||||||
set := make(map[string]struct{}, len(fields))
|
|
||||||
for _, f := range fields {
|
|
||||||
if len(f) > 1 {
|
|
||||||
set[f] = struct{}{}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return set
|
|
||||||
}
|
|
||||||
|
|
||||||
// jaccard calculates the Jaccard similarity between two sets of tokens
|
|
||||||
func jaccard(a, b map[string]struct{}) float64 {
|
|
||||||
if len(a) == 0 && len(b) == 0 {
|
|
||||||
return 0
|
|
||||||
}
|
|
||||||
inter := 0
|
|
||||||
for k := range a {
|
|
||||||
if _, ok := b[k]; ok {
|
|
||||||
inter++
|
|
||||||
}
|
|
||||||
}
|
|
||||||
union := len(a) + len(b) - inter
|
|
||||||
if union == 0 {
|
|
||||||
return 0
|
|
||||||
}
|
|
||||||
return float64(inter) / float64(union)
|
|
||||||
}
|
|
||||||
|
|
||||||
// clampFloat clamps a float value between a minimum and maximum value
|
|
||||||
func clampFloat(v, minV, maxV float64) float64 {
|
|
||||||
if v < minV {
|
|
||||||
return minV
|
|
||||||
}
|
|
||||||
if v > maxV {
|
|
||||||
return maxV
|
|
||||||
}
|
|
||||||
return v
|
|
||||||
}
|
|
||||||
|
|
||||||
// getEnrichedPassage 合并Content和ImageInfo的文本内容
|
|
||||||
func getEnrichedPassage(ctx context.Context, result *types.SearchResult) string {
|
func getEnrichedPassage(ctx context.Context, result *types.SearchResult) string {
|
||||||
if result.ImageInfo == "" {
|
combinedText := result.Content
|
||||||
return result.Content
|
var enrichments []string
|
||||||
}
|
|
||||||
|
|
||||||
// 解析ImageInfo
|
// 解析ImageInfo
|
||||||
var imageInfos []types.ImageInfo
|
if result.ImageInfo != "" {
|
||||||
err := json.Unmarshal([]byte(result.ImageInfo), &imageInfos)
|
var imageInfos []types.ImageInfo
|
||||||
if err != nil {
|
err := json.Unmarshal([]byte(result.ImageInfo), &imageInfos)
|
||||||
pipelineWarn(ctx, "Rerank", "image_info_parse", map[string]interface{}{
|
if err != nil {
|
||||||
"error": err.Error(),
|
pipelineWarn(ctx, "Rerank", "image_info_parse", map[string]interface{}{
|
||||||
})
|
"error": err.Error(),
|
||||||
return result.Content
|
})
|
||||||
}
|
} else {
|
||||||
|
// 提取所有图片的描述和OCR文本
|
||||||
if len(imageInfos) == 0 {
|
for _, img := range imageInfos {
|
||||||
return result.Content
|
if img.Caption != "" {
|
||||||
}
|
enrichments = append(enrichments, fmt.Sprintf("图片描述: %s", img.Caption))
|
||||||
|
}
|
||||||
// 提取所有图片的描述和OCR文本
|
if img.OCRText != "" {
|
||||||
var imageTexts []string
|
enrichments = append(enrichments, fmt.Sprintf("图片文本: %s", img.OCRText))
|
||||||
for _, img := range imageInfos {
|
}
|
||||||
if img.Caption != "" {
|
}
|
||||||
imageTexts = append(imageTexts, fmt.Sprintf("图片描述: %s", img.Caption))
|
|
||||||
}
|
|
||||||
if img.OCRText != "" {
|
|
||||||
imageTexts = append(imageTexts, fmt.Sprintf("图片文本: %s", img.OCRText))
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if len(imageTexts) == 0 {
|
// 解析ChunkMetadata中的GeneratedQuestions
|
||||||
return result.Content
|
if len(result.ChunkMetadata) > 0 {
|
||||||
|
var docMeta types.DocumentChunkMetadata
|
||||||
|
err := json.Unmarshal(result.ChunkMetadata, &docMeta)
|
||||||
|
if err != nil {
|
||||||
|
pipelineWarn(ctx, "Rerank", "chunk_metadata_parse", map[string]interface{}{
|
||||||
|
"error": err.Error(),
|
||||||
|
})
|
||||||
|
} else if len(docMeta.GeneratedQuestions) > 0 {
|
||||||
|
enrichments = append(enrichments, fmt.Sprintf("相关问题: %s", strings.Join(docMeta.GeneratedQuestions, "; ")))
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// 组合内容和图片信息
|
if len(enrichments) == 0 {
|
||||||
combinedText := result.Content
|
return combinedText
|
||||||
|
}
|
||||||
|
|
||||||
|
// 组合内容和增强信息
|
||||||
if combinedText != "" {
|
if combinedText != "" {
|
||||||
combinedText += "\n\n"
|
combinedText += "\n\n"
|
||||||
}
|
}
|
||||||
combinedText += strings.Join(imageTexts, "\n")
|
combinedText += strings.Join(enrichments, "\n")
|
||||||
|
|
||||||
pipelineInfo(ctx, "Rerank", "image_info_merge", map[string]interface{}{
|
pipelineInfo(ctx, "Rerank", "passage_enrich", map[string]interface{}{
|
||||||
"content_len": len(result.Content),
|
"content_len": len(result.Content),
|
||||||
"image_len": len(strings.Join(imageTexts, "\n")),
|
"enrichment": strings.Join(enrichments, "\n"),
|
||||||
|
"enrichment_len": len(strings.Join(enrichments, "\n")),
|
||||||
})
|
})
|
||||||
|
|
||||||
return combinedText
|
return combinedText
|
||||||
|
|||||||
@@ -2,13 +2,14 @@ package chatpipline
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"encoding/json"
|
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"regexp"
|
||||||
"strings"
|
"strings"
|
||||||
"sync"
|
"sync"
|
||||||
|
"unicode"
|
||||||
|
|
||||||
"github.com/Tencent/WeKnora/internal/config"
|
"github.com/Tencent/WeKnora/internal/config"
|
||||||
"github.com/Tencent/WeKnora/internal/models/chat"
|
"github.com/Tencent/WeKnora/internal/logger"
|
||||||
"github.com/Tencent/WeKnora/internal/searchutil"
|
"github.com/Tencent/WeKnora/internal/searchutil"
|
||||||
"github.com/Tencent/WeKnora/internal/types"
|
"github.com/Tencent/WeKnora/internal/types"
|
||||||
"github.com/Tencent/WeKnora/internal/types/interfaces"
|
"github.com/Tencent/WeKnora/internal/types/interfaces"
|
||||||
@@ -18,7 +19,6 @@ import (
|
|||||||
type PluginSearch struct {
|
type PluginSearch struct {
|
||||||
knowledgeBaseService interfaces.KnowledgeBaseService
|
knowledgeBaseService interfaces.KnowledgeBaseService
|
||||||
knowledgeService interfaces.KnowledgeService
|
knowledgeService interfaces.KnowledgeService
|
||||||
modelService interfaces.ModelService
|
|
||||||
config *config.Config
|
config *config.Config
|
||||||
webSearchService interfaces.WebSearchService
|
webSearchService interfaces.WebSearchService
|
||||||
tenantService interfaces.TenantService
|
tenantService interfaces.TenantService
|
||||||
@@ -28,7 +28,6 @@ type PluginSearch struct {
|
|||||||
func NewPluginSearch(eventManager *EventManager,
|
func NewPluginSearch(eventManager *EventManager,
|
||||||
knowledgeBaseService interfaces.KnowledgeBaseService,
|
knowledgeBaseService interfaces.KnowledgeBaseService,
|
||||||
knowledgeService interfaces.KnowledgeService,
|
knowledgeService interfaces.KnowledgeService,
|
||||||
modelService interfaces.ModelService,
|
|
||||||
config *config.Config,
|
config *config.Config,
|
||||||
webSearchService interfaces.WebSearchService,
|
webSearchService interfaces.WebSearchService,
|
||||||
tenantService interfaces.TenantService,
|
tenantService interfaces.TenantService,
|
||||||
@@ -37,7 +36,6 @@ func NewPluginSearch(eventManager *EventManager,
|
|||||||
res := &PluginSearch{
|
res := &PluginSearch{
|
||||||
knowledgeBaseService: knowledgeBaseService,
|
knowledgeBaseService: knowledgeBaseService,
|
||||||
knowledgeService: knowledgeService,
|
knowledgeService: knowledgeService,
|
||||||
modelService: modelService,
|
|
||||||
config: config,
|
config: config,
|
||||||
webSearchService: webSearchService,
|
webSearchService: webSearchService,
|
||||||
tenantService: tenantService,
|
tenantService: tenantService,
|
||||||
@@ -75,12 +73,11 @@ func (p *PluginSearch) OnEvent(ctx context.Context,
|
|||||||
}
|
}
|
||||||
|
|
||||||
pipelineInfo(ctx, "Search", "input", map[string]interface{}{
|
pipelineInfo(ctx, "Search", "input", map[string]interface{}{
|
||||||
"session_id": chatManage.SessionID,
|
"session_id": chatManage.SessionID,
|
||||||
"rewrite_query": chatManage.RewriteQuery,
|
"rewrite_query": chatManage.RewriteQuery,
|
||||||
"processed_query": chatManage.ProcessedQuery,
|
"kb_ids": strings.Join(knowledgeBaseIDs, ","),
|
||||||
"kb_ids": strings.Join(knowledgeBaseIDs, ","),
|
"tenant_id": chatManage.TenantID,
|
||||||
"tenant_id": chatManage.TenantID,
|
"web_enabled": chatManage.WebSearchEnabled,
|
||||||
"web_enabled": chatManage.WebSearchEnabled,
|
|
||||||
})
|
})
|
||||||
|
|
||||||
// Run KB search and web search concurrently
|
// Run KB search and web search concurrently
|
||||||
@@ -121,6 +118,16 @@ func (p *PluginSearch) OnEvent(ctx context.Context,
|
|||||||
|
|
||||||
chatManage.SearchResult = allResults
|
chatManage.SearchResult = allResults
|
||||||
|
|
||||||
|
// Log all search results with scores before any processing
|
||||||
|
for i, r := range chatManage.SearchResult {
|
||||||
|
pipelineInfo(ctx, "Search", "result_score_before_normalize", map[string]interface{}{
|
||||||
|
"index": i,
|
||||||
|
"chunk_id": r.ID,
|
||||||
|
"score": fmt.Sprintf("%.4f", r.Score),
|
||||||
|
"match_type": r.MatchType,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
// If recall is low, attempt query expansion with keyword-focused search
|
// If recall is low, attempt query expansion with keyword-focused search
|
||||||
if chatManage.EnableQueryExpansion && len(chatManage.SearchResult) < max(1, chatManage.EmbeddingTopK/2) {
|
if chatManage.EnableQueryExpansion && len(chatManage.SearchResult) < max(1, chatManage.EmbeddingTopK/2) {
|
||||||
pipelineInfo(ctx, "Search", "recall_low", map[string]interface{}{
|
pipelineInfo(ctx, "Search", "recall_low", map[string]interface{}{
|
||||||
@@ -189,6 +196,7 @@ func (p *PluginSearch) OnEvent(ctx context.Context,
|
|||||||
}
|
}
|
||||||
wgExp.Wait()
|
wgExp.Wait()
|
||||||
if len(expResults) > 0 {
|
if len(expResults) > 0 {
|
||||||
|
// Scores already normalized in HybridSearch
|
||||||
pipelineInfo(ctx, "Search", "expansion_done", map[string]interface{}{
|
pipelineInfo(ctx, "Search", "expansion_done", map[string]interface{}{
|
||||||
"added": len(expResults),
|
"added": len(expResults),
|
||||||
})
|
})
|
||||||
@@ -215,6 +223,16 @@ func (p *PluginSearch) OnEvent(ctx context.Context,
|
|||||||
"after": len(chatManage.SearchResult),
|
"after": len(chatManage.SearchResult),
|
||||||
})
|
})
|
||||||
|
|
||||||
|
// Log final scores after all processing
|
||||||
|
for i, r := range chatManage.SearchResult {
|
||||||
|
pipelineInfo(ctx, "Search", "final_score", map[string]interface{}{
|
||||||
|
"index": i,
|
||||||
|
"chunk_id": r.ID,
|
||||||
|
"score": fmt.Sprintf("%.4f", r.Score),
|
||||||
|
"match_type": r.MatchType,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
// Return if we have results
|
// Return if we have results
|
||||||
if len(chatManage.SearchResult) != 0 {
|
if len(chatManage.SearchResult) != 0 {
|
||||||
pipelineInfo(ctx, "Search", "output", map[string]interface{}{
|
pipelineInfo(ctx, "Search", "output", map[string]interface{}{
|
||||||
@@ -250,32 +268,33 @@ func (p *PluginSearch) getSearchResultFromHistory(chatManage *types.ChatManage)
|
|||||||
|
|
||||||
func removeDuplicateResults(results []*types.SearchResult) []*types.SearchResult {
|
func removeDuplicateResults(results []*types.SearchResult) []*types.SearchResult {
|
||||||
seen := make(map[string]bool)
|
seen := make(map[string]bool)
|
||||||
contentSig := make(map[string]bool)
|
contentSig := make(map[string]string) // sig -> first chunk ID
|
||||||
var uniqueResults []*types.SearchResult
|
var uniqueResults []*types.SearchResult
|
||||||
for _, r := range results {
|
for _, r := range results {
|
||||||
keys := []string{r.ID}
|
keys := []string{r.ID}
|
||||||
if r.ParentChunkID != "" {
|
if r.ParentChunkID != "" {
|
||||||
keys = append(keys, "parent:"+r.ParentChunkID)
|
keys = append(keys, "parent:"+r.ParentChunkID)
|
||||||
}
|
}
|
||||||
if r.KnowledgeID != "" {
|
|
||||||
keys = append(keys, fmt.Sprintf("kb:%s#%d", r.KnowledgeID, r.ChunkIndex))
|
|
||||||
}
|
|
||||||
dup := false
|
dup := false
|
||||||
|
dupKey := ""
|
||||||
for _, k := range keys {
|
for _, k := range keys {
|
||||||
if seen[k] {
|
if seen[k] {
|
||||||
dup = true
|
dup = true
|
||||||
|
dupKey = k
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if dup {
|
if dup {
|
||||||
|
logger.Debugf(context.Background(), "Dedup: chunk %s removed due to key: %s", r.ID, dupKey)
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
sig := buildContentSignature(r.Content)
|
sig := buildContentSignature(r.Content)
|
||||||
if sig != "" {
|
if sig != "" {
|
||||||
if contentSig[sig] {
|
if firstChunk, exists := contentSig[sig]; exists {
|
||||||
|
logger.Debugf(context.Background(), "Dedup: chunk %s removed due to content signature (dup of %s, sig prefix: %.50s...)", r.ID, firstChunk, sig)
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
contentSig[sig] = true
|
contentSig[sig] = r.ID
|
||||||
}
|
}
|
||||||
for _, k := range keys {
|
for _, k := range keys {
|
||||||
seen[k] = true
|
seen[k] = true
|
||||||
@@ -286,24 +305,16 @@ func removeDuplicateResults(results []*types.SearchResult) []*types.SearchResult
|
|||||||
}
|
}
|
||||||
|
|
||||||
func buildContentSignature(content string) string {
|
func buildContentSignature(content string) string {
|
||||||
c := strings.ToLower(strings.TrimSpace(content))
|
return searchutil.BuildContentSignature(content)
|
||||||
if c == "" {
|
|
||||||
return ""
|
|
||||||
}
|
|
||||||
c = strings.Join(strings.Fields(c), " ")
|
|
||||||
if len(c) > 128 {
|
|
||||||
c = c[:128]
|
|
||||||
}
|
|
||||||
return c
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// searchKnowledgeBases performs KB searches for rewrite and processed queries across KB IDs
|
// searchKnowledgeBases performs KB searches across KB IDs using RewriteQuery only
|
||||||
func (p *PluginSearch) searchKnowledgeBases(
|
func (p *PluginSearch) searchKnowledgeBases(
|
||||||
ctx context.Context,
|
ctx context.Context,
|
||||||
knowledgeBaseIDs []string,
|
knowledgeBaseIDs []string,
|
||||||
chatManage *types.ChatManage,
|
chatManage *types.ChatManage,
|
||||||
) []*types.SearchResult {
|
) []*types.SearchResult {
|
||||||
// Build base params for rewrite query
|
// Build params for rewrite query
|
||||||
baseParams := types.SearchParams{
|
baseParams := types.SearchParams{
|
||||||
QueryText: strings.TrimSpace(chatManage.RewriteQuery),
|
QueryText: strings.TrimSpace(chatManage.RewriteQuery),
|
||||||
VectorThreshold: chatManage.VectorThreshold,
|
VectorThreshold: chatManage.VectorThreshold,
|
||||||
@@ -315,7 +326,7 @@ func (p *PluginSearch) searchKnowledgeBases(
|
|||||||
var mu sync.Mutex
|
var mu sync.Mutex
|
||||||
var results []*types.SearchResult
|
var results []*types.SearchResult
|
||||||
|
|
||||||
// Search with rewrite query
|
// Search with rewrite query only (removed duplicate ProcessedQuery search)
|
||||||
for _, kbID := range knowledgeBaseIDs {
|
for _, kbID := range knowledgeBaseIDs {
|
||||||
wg.Add(1)
|
wg.Add(1)
|
||||||
go func(knowledgeBaseID string) {
|
go func(knowledgeBaseID string) {
|
||||||
@@ -323,16 +334,14 @@ func (p *PluginSearch) searchKnowledgeBases(
|
|||||||
res, err := p.knowledgeBaseService.HybridSearch(ctx, knowledgeBaseID, baseParams)
|
res, err := p.knowledgeBaseService.HybridSearch(ctx, knowledgeBaseID, baseParams)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
pipelineWarn(ctx, "Search", "kb_search_error", map[string]interface{}{
|
pipelineWarn(ctx, "Search", "kb_search_error", map[string]interface{}{
|
||||||
"kb_id": knowledgeBaseID,
|
"kb_id": knowledgeBaseID,
|
||||||
"query": baseParams.QueryText,
|
"query": baseParams.QueryText,
|
||||||
"error": err.Error(),
|
"error": err.Error(),
|
||||||
"query_ty": "rewrite",
|
|
||||||
})
|
})
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
pipelineInfo(ctx, "Search", "kb_result", map[string]interface{}{
|
pipelineInfo(ctx, "Search", "kb_result", map[string]interface{}{
|
||||||
"kb_id": knowledgeBaseID,
|
"kb_id": knowledgeBaseID,
|
||||||
"query_ty": "rewrite",
|
|
||||||
"hit_count": len(res),
|
"hit_count": len(res),
|
||||||
})
|
})
|
||||||
mu.Lock()
|
mu.Lock()
|
||||||
@@ -343,45 +352,6 @@ func (p *PluginSearch) searchKnowledgeBases(
|
|||||||
|
|
||||||
wg.Wait()
|
wg.Wait()
|
||||||
|
|
||||||
// If processed query differs, search again
|
|
||||||
if chatManage.RewriteQuery != chatManage.ProcessedQuery {
|
|
||||||
paramsProcessed := baseParams
|
|
||||||
paramsProcessed.QueryText = strings.TrimSpace(chatManage.ProcessedQuery)
|
|
||||||
pipelineInfo(ctx, "Search", "processed_query_search", map[string]interface{}{
|
|
||||||
"query": paramsProcessed.QueryText,
|
|
||||||
})
|
|
||||||
|
|
||||||
wg = sync.WaitGroup{}
|
|
||||||
for _, kbID := range knowledgeBaseIDs {
|
|
||||||
wg.Add(1)
|
|
||||||
go func(knowledgeBaseID string) {
|
|
||||||
defer wg.Done()
|
|
||||||
res, err := p.knowledgeBaseService.HybridSearch(ctx, knowledgeBaseID, paramsProcessed)
|
|
||||||
if err != nil {
|
|
||||||
pipelineWarn(ctx, "Search", "kb_search_error", map[string]interface{}{
|
|
||||||
"kb_id": knowledgeBaseID,
|
|
||||||
"query": paramsProcessed.QueryText,
|
|
||||||
"error": err.Error(),
|
|
||||||
"query_ty": "processed",
|
|
||||||
})
|
|
||||||
return
|
|
||||||
}
|
|
||||||
pipelineInfo(ctx, "Search", "kb_result", map[string]interface{}{
|
|
||||||
"kb_id": knowledgeBaseID,
|
|
||||||
"query_ty": "processed",
|
|
||||||
"hit_count": len(res),
|
|
||||||
})
|
|
||||||
mu.Lock()
|
|
||||||
results = append(results, res...)
|
|
||||||
mu.Unlock()
|
|
||||||
}(kbID)
|
|
||||||
}
|
|
||||||
wg.Wait()
|
|
||||||
}
|
|
||||||
|
|
||||||
// Normalize keyword retriever scores after collecting all results from multiple knowledge bases
|
|
||||||
normalizeKeywordSearchResults(ctx, results)
|
|
||||||
|
|
||||||
pipelineInfo(ctx, "Search", "kb_result_summary", map[string]interface{}{
|
pipelineInfo(ctx, "Search", "kb_result_summary", map[string]interface{}{
|
||||||
"total_hits": len(results),
|
"total_hits": len(results),
|
||||||
})
|
})
|
||||||
@@ -413,11 +383,8 @@ func (p *PluginSearch) searchWebIfEnabled(ctx context.Context, chatManage *types
|
|||||||
})
|
})
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
// Build questions (rewrite + processed if different)
|
// Build questions using RewriteQuery only
|
||||||
questions := []string{strings.TrimSpace(chatManage.RewriteQuery)}
|
questions := []string{strings.TrimSpace(chatManage.RewriteQuery)}
|
||||||
if chatManage.ProcessedQuery != "" && chatManage.ProcessedQuery != chatManage.RewriteQuery {
|
|
||||||
questions = append(questions, strings.TrimSpace(chatManage.ProcessedQuery))
|
|
||||||
}
|
|
||||||
// Load session-scoped temp KB state from Redis using SessionService
|
// Load session-scoped temp KB state from Redis using SessionService
|
||||||
tempKBID, seen, ids := p.sessionService.GetWebSearchTempKBState(ctx, chatManage.SessionID)
|
tempKBID, seen, ids := p.sessionService.GetWebSearchTempKBState(ctx, chatManage.SessionID)
|
||||||
compressed, kbID, newSeen, newIDs, err := p.webSearchService.CompressWithRAG(
|
compressed, kbID, newSeen, newIDs, err := p.webSearchService.CompressWithRAG(
|
||||||
@@ -440,119 +407,155 @@ func (p *PluginSearch) searchWebIfEnabled(ctx context.Context, chatManage *types
|
|||||||
return res
|
return res
|
||||||
}
|
}
|
||||||
|
|
||||||
// expandQueries generates paraphrases and synonyms using chat model to improve keyword recall
|
// expandQueries generates query variants locally without LLM to improve keyword recall
|
||||||
|
// Uses simple techniques: word reordering, stopword removal, key phrase extraction
|
||||||
func (p *PluginSearch) expandQueries(ctx context.Context, chatManage *types.ChatManage) []string {
|
func (p *PluginSearch) expandQueries(ctx context.Context, chatManage *types.ChatManage) []string {
|
||||||
if p.modelService == nil || chatManage.ChatModelID == "" {
|
query := strings.TrimSpace(chatManage.RewriteQuery)
|
||||||
pipelineWarn(ctx, "Search", "expansion_skip", map[string]interface{}{
|
if query == "" {
|
||||||
"reason": "no_model",
|
|
||||||
})
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
model, err := p.modelService.GetChatModel(ctx, chatManage.ChatModelID)
|
|
||||||
if err != nil {
|
expansions := make([]string, 0, 5)
|
||||||
pipelineWarn(ctx, "Search", "expansion_get_model_failed", map[string]interface{}{
|
seen := make(map[string]struct{})
|
||||||
"error": err.Error(),
|
seen[strings.ToLower(query)] = struct{}{}
|
||||||
})
|
if q := strings.ToLower(chatManage.Query); q != "" {
|
||||||
return nil
|
seen[q] = struct{}{}
|
||||||
}
|
}
|
||||||
sys := "Generate up to 5 diverse paraphrases or keyword variants for the user query to improve keyword-based search recall. Respond ONLY with a JSON array of strings inside a fenced code block."
|
|
||||||
usr := chatManage.RewriteQuery
|
addIfNew := func(s string) {
|
||||||
think := false
|
s = strings.TrimSpace(s)
|
||||||
resp, err := model.Chat(ctx, []chat.Message{
|
if s == "" || len(s) < 3 {
|
||||||
{Role: "system", Content: sys},
|
return
|
||||||
{Role: "user", Content: usr},
|
|
||||||
}, &chat.ChatOptions{Temperature: 0.2, MaxCompletionTokens: 200, Thinking: &think})
|
|
||||||
if err != nil || resp.Content == "" {
|
|
||||||
pipelineWarn(ctx, "Search", "expansion_model_call_failed", map[string]interface{}{
|
|
||||||
"error": err,
|
|
||||||
})
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
body := extractJSONBlock(resp.Content)
|
|
||||||
var arr []string
|
|
||||||
if err := json.Unmarshal([]byte(body), &arr); err != nil || len(arr) == 0 {
|
|
||||||
// Fallback: split lines
|
|
||||||
lines := strings.Split(resp.Content, "\n")
|
|
||||||
for _, l := range lines {
|
|
||||||
l = strings.TrimSpace(l)
|
|
||||||
if l != "" {
|
|
||||||
arr = append(arr, l)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
uniq := make(map[string]struct{})
|
|
||||||
base := []string{chatManage.Query, chatManage.RewriteQuery, chatManage.ProcessedQuery}
|
|
||||||
for _, b := range base {
|
|
||||||
if s := strings.TrimSpace(b); s != "" {
|
|
||||||
uniq[strings.ToLower(s)] = struct{}{}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
expansions := make([]string, 0, len(arr))
|
|
||||||
for _, a := range arr {
|
|
||||||
s := strings.TrimSpace(a)
|
|
||||||
if s == "" {
|
|
||||||
continue
|
|
||||||
}
|
}
|
||||||
key := strings.ToLower(s)
|
key := strings.ToLower(s)
|
||||||
if _, ok := uniq[key]; ok {
|
if _, ok := seen[key]; ok {
|
||||||
continue
|
return
|
||||||
}
|
}
|
||||||
uniq[key] = struct{}{}
|
seen[key] = struct{}{}
|
||||||
expansions = append(expansions, s)
|
expansions = append(expansions, s)
|
||||||
if len(expansions) >= 5 {
|
}
|
||||||
break
|
|
||||||
|
// 1. Remove common stopwords and create keyword-only variant
|
||||||
|
keywords := extractKeywords(query)
|
||||||
|
if len(keywords) >= 2 {
|
||||||
|
addIfNew(strings.Join(keywords, " "))
|
||||||
|
}
|
||||||
|
|
||||||
|
// 2. Extract quoted phrases or key segments
|
||||||
|
phrases := extractPhrases(query)
|
||||||
|
for _, phrase := range phrases {
|
||||||
|
addIfNew(phrase)
|
||||||
|
}
|
||||||
|
|
||||||
|
// 3. Split by common delimiters and use longest segment
|
||||||
|
segments := splitByDelimiters(query)
|
||||||
|
for _, seg := range segments {
|
||||||
|
if len(seg) > 5 {
|
||||||
|
addIfNew(seg)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
pipelineInfo(ctx, "Search", "expansion_result", map[string]interface{}{
|
|
||||||
|
// 4. Remove question words (什么/如何/怎么/为什么/哪个 etc.)
|
||||||
|
cleaned := removeQuestionWords(query)
|
||||||
|
if cleaned != query {
|
||||||
|
addIfNew(cleaned)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Limit to 5 expansions
|
||||||
|
if len(expansions) > 5 {
|
||||||
|
expansions = expansions[:5]
|
||||||
|
}
|
||||||
|
|
||||||
|
pipelineInfo(ctx, "Search", "local_expansion_result", map[string]interface{}{
|
||||||
"variants": len(expansions),
|
"variants": len(expansions),
|
||||||
})
|
})
|
||||||
return expansions
|
return expansions
|
||||||
}
|
}
|
||||||
|
|
||||||
func extractJSONBlock(text string) string {
|
// Common Chinese and English stopwords
|
||||||
t := strings.TrimSpace(text)
|
var stopwords = map[string]struct{}{
|
||||||
if i := strings.Index(t, "["); i >= 0 {
|
"的": {}, "是": {}, "在": {}, "了": {}, "和": {}, "与": {}, "或": {},
|
||||||
j := strings.LastIndex(t, "]")
|
"a": {}, "an": {}, "the": {}, "is": {}, "are": {}, "was": {}, "were": {},
|
||||||
if j > i {
|
"be": {}, "been": {}, "being": {}, "have": {}, "has": {}, "had": {},
|
||||||
return t[i : j+1]
|
"do": {}, "does": {}, "did": {}, "will": {}, "would": {}, "could": {},
|
||||||
}
|
"should": {}, "may": {}, "might": {}, "must": {}, "can": {},
|
||||||
}
|
"to": {}, "of": {}, "in": {}, "for": {}, "on": {}, "with": {}, "at": {},
|
||||||
return "[]"
|
"by": {}, "from": {}, "as": {}, "into": {}, "through": {}, "about": {},
|
||||||
|
"what": {}, "how": {}, "why": {}, "when": {}, "where": {}, "which": {},
|
||||||
|
"who": {}, "whom": {}, "whose": {},
|
||||||
}
|
}
|
||||||
|
|
||||||
// normalizeKeywordSearchResults normalizes keyword search result scores into [0,1] globally across all knowledge bases
|
// Question words in Chinese
|
||||||
// Improvements:
|
var questionWords = regexp.MustCompile(`^(什么是|什么|如何|怎么|怎样|为什么|为何|哪个|哪些|谁|何时|何地|请问|请告诉我|帮我|我想知道|我想了解)`)
|
||||||
// 1. Uses robust normalization with percentile-based bounds to handle outliers
|
|
||||||
// 2. Handles edge cases: single result, no variance, negative scores
|
func extractKeywords(text string) []string {
|
||||||
// 3. Global normalization ensures fair comparison across different knowledge bases
|
words := tokenize(text)
|
||||||
func normalizeKeywordSearchResults(ctx context.Context, results []*types.SearchResult) {
|
keywords := make([]string, 0, len(words))
|
||||||
searchutil.NormalizeKeywordScores[*types.SearchResult](
|
for _, w := range words {
|
||||||
results,
|
lower := strings.ToLower(w)
|
||||||
func(r *types.SearchResult) bool {
|
if _, isStop := stopwords[lower]; !isStop && len(w) > 1 {
|
||||||
return r.MatchType == types.MatchTypeKeywords
|
keywords = append(keywords, w)
|
||||||
},
|
}
|
||||||
func(r *types.SearchResult) float64 {
|
}
|
||||||
return r.Score
|
return keywords
|
||||||
},
|
}
|
||||||
func(r *types.SearchResult, score float64) {
|
|
||||||
r.Score = score
|
func extractPhrases(text string) []string {
|
||||||
},
|
// Extract quoted content
|
||||||
searchutil.KeywordScoreCallbacks{
|
var phrases []string
|
||||||
OnNoVariance: func(count int, score float64) {
|
re := regexp.MustCompile(`["'"'「」『』]([^"'"'「」『』]+)["'"'「」『』]`)
|
||||||
pipelineInfo(ctx, "Search", "keyword_scores_no_variance", map[string]interface{}{
|
matches := re.FindAllStringSubmatch(text, -1)
|
||||||
"count": count,
|
for _, m := range matches {
|
||||||
"score": score,
|
if len(m) > 1 && len(m[1]) > 2 {
|
||||||
})
|
phrases = append(phrases, m[1])
|
||||||
},
|
}
|
||||||
OnNormalized: func(count int, rawMin, rawMax, normalizeMin, normalizeMax float64) {
|
}
|
||||||
pipelineInfo(ctx, "Search", "normalize_keyword_scores", map[string]interface{}{
|
return phrases
|
||||||
"count": count,
|
}
|
||||||
"raw_min": rawMin,
|
|
||||||
"raw_max": rawMax,
|
func splitByDelimiters(text string) []string {
|
||||||
"normalize_min": normalizeMin,
|
// Split by common delimiters
|
||||||
"normalize_max": normalizeMax,
|
re := regexp.MustCompile(`[,,;;、。!?!?\s]+`)
|
||||||
})
|
parts := re.Split(text, -1)
|
||||||
},
|
var result []string
|
||||||
},
|
for _, p := range parts {
|
||||||
)
|
p = strings.TrimSpace(p)
|
||||||
|
if p != "" {
|
||||||
|
result = append(result, p)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return result
|
||||||
|
}
|
||||||
|
|
||||||
|
func removeQuestionWords(text string) string {
|
||||||
|
return strings.TrimSpace(questionWords.ReplaceAllString(text, ""))
|
||||||
|
}
|
||||||
|
|
||||||
|
func tokenize(text string) []string {
|
||||||
|
var tokens []string
|
||||||
|
var current strings.Builder
|
||||||
|
|
||||||
|
for _, r := range text {
|
||||||
|
if unicode.IsLetter(r) || unicode.IsDigit(r) {
|
||||||
|
current.WriteRune(r)
|
||||||
|
} else if unicode.Is(unicode.Han, r) {
|
||||||
|
// Flush current token
|
||||||
|
if current.Len() > 0 {
|
||||||
|
tokens = append(tokens, current.String())
|
||||||
|
current.Reset()
|
||||||
|
}
|
||||||
|
// Chinese character as single token
|
||||||
|
tokens = append(tokens, string(r))
|
||||||
|
} else {
|
||||||
|
// Delimiter
|
||||||
|
if current.Len() > 0 {
|
||||||
|
tokens = append(tokens, current.String())
|
||||||
|
current.Reset()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if current.Len() > 0 {
|
||||||
|
tokens = append(tokens, current.String())
|
||||||
|
}
|
||||||
|
return tokens
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -190,5 +190,6 @@ func chunk2SearchResult(chunk *types.Chunk, knowledge *types.Knowledge) *types.S
|
|||||||
ImageInfo: chunk.ImageInfo,
|
ImageInfo: chunk.ImageInfo,
|
||||||
KnowledgeFilename: knowledge.FileName,
|
KnowledgeFilename: knowledge.FileName,
|
||||||
KnowledgeSource: knowledge.Source,
|
KnowledgeSource: knowledge.Source,
|
||||||
|
ChunkMetadata: chunk.Metadata,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
181
internal/application/service/chat_pipline/search_parallel.go
Normal file
181
internal/application/service/chat_pipline/search_parallel.go
Normal file
@@ -0,0 +1,181 @@
|
|||||||
|
package chatpipline
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"sync"
|
||||||
|
|
||||||
|
"github.com/Tencent/WeKnora/internal/config"
|
||||||
|
"github.com/Tencent/WeKnora/internal/logger"
|
||||||
|
"github.com/Tencent/WeKnora/internal/types"
|
||||||
|
"github.com/Tencent/WeKnora/internal/types/interfaces"
|
||||||
|
)
|
||||||
|
|
||||||
|
// PluginSearchParallel implements parallel search functionality combining chunk search and entity search
|
||||||
|
type PluginSearchParallel struct {
|
||||||
|
// Chunk search dependencies
|
||||||
|
knowledgeBaseService interfaces.KnowledgeBaseService
|
||||||
|
knowledgeService interfaces.KnowledgeService
|
||||||
|
config *config.Config
|
||||||
|
webSearchService interfaces.WebSearchService
|
||||||
|
tenantService interfaces.TenantService
|
||||||
|
sessionService interfaces.SessionService
|
||||||
|
|
||||||
|
// Entity search dependencies
|
||||||
|
graphRepo interfaces.RetrieveGraphRepository
|
||||||
|
chunkRepo interfaces.ChunkRepository
|
||||||
|
knowledgeRepo interfaces.KnowledgeRepository
|
||||||
|
|
||||||
|
// Internal plugins
|
||||||
|
searchPlugin *PluginSearch
|
||||||
|
searchEntityPlugin *PluginSearchEntity
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewPluginSearchParallel creates a new parallel search plugin
|
||||||
|
func NewPluginSearchParallel(
|
||||||
|
eventManager *EventManager,
|
||||||
|
knowledgeBaseService interfaces.KnowledgeBaseService,
|
||||||
|
knowledgeService interfaces.KnowledgeService,
|
||||||
|
config *config.Config,
|
||||||
|
webSearchService interfaces.WebSearchService,
|
||||||
|
tenantService interfaces.TenantService,
|
||||||
|
sessionService interfaces.SessionService,
|
||||||
|
graphRepository interfaces.RetrieveGraphRepository,
|
||||||
|
chunkRepository interfaces.ChunkRepository,
|
||||||
|
knowledgeRepository interfaces.KnowledgeRepository,
|
||||||
|
) *PluginSearchParallel {
|
||||||
|
// Create internal plugins without registering them
|
||||||
|
searchPlugin := &PluginSearch{
|
||||||
|
knowledgeBaseService: knowledgeBaseService,
|
||||||
|
knowledgeService: knowledgeService,
|
||||||
|
config: config,
|
||||||
|
webSearchService: webSearchService,
|
||||||
|
tenantService: tenantService,
|
||||||
|
sessionService: sessionService,
|
||||||
|
}
|
||||||
|
|
||||||
|
searchEntityPlugin := &PluginSearchEntity{
|
||||||
|
graphRepo: graphRepository,
|
||||||
|
chunkRepo: chunkRepository,
|
||||||
|
knowledgeRepo: knowledgeRepository,
|
||||||
|
}
|
||||||
|
|
||||||
|
res := &PluginSearchParallel{
|
||||||
|
knowledgeBaseService: knowledgeBaseService,
|
||||||
|
knowledgeService: knowledgeService,
|
||||||
|
config: config,
|
||||||
|
webSearchService: webSearchService,
|
||||||
|
tenantService: tenantService,
|
||||||
|
sessionService: sessionService,
|
||||||
|
graphRepo: graphRepository,
|
||||||
|
chunkRepo: chunkRepository,
|
||||||
|
knowledgeRepo: knowledgeRepository,
|
||||||
|
searchPlugin: searchPlugin,
|
||||||
|
searchEntityPlugin: searchEntityPlugin,
|
||||||
|
}
|
||||||
|
eventManager.Register(res)
|
||||||
|
return res
|
||||||
|
}
|
||||||
|
|
||||||
|
// ActivationEvents returns the event types this plugin handles
|
||||||
|
func (p *PluginSearchParallel) ActivationEvents() []types.EventType {
|
||||||
|
return []types.EventType{types.CHUNK_SEARCH_PARALLEL}
|
||||||
|
}
|
||||||
|
|
||||||
|
// OnEvent handles parallel search events - runs chunk search and entity search concurrently
|
||||||
|
func (p *PluginSearchParallel) OnEvent(ctx context.Context,
|
||||||
|
eventType types.EventType, chatManage *types.ChatManage, next func() *PluginError,
|
||||||
|
) *PluginError {
|
||||||
|
pipelineInfo(ctx, "SearchParallel", "start", map[string]interface{}{
|
||||||
|
"session_id": chatManage.SessionID,
|
||||||
|
"has_entities": len(chatManage.Entity) > 0,
|
||||||
|
"rewrite_query": chatManage.RewriteQuery,
|
||||||
|
})
|
||||||
|
|
||||||
|
var wg sync.WaitGroup
|
||||||
|
var mu sync.Mutex
|
||||||
|
var chunkSearchErr *PluginError
|
||||||
|
var entitySearchErr *PluginError
|
||||||
|
|
||||||
|
// Use separate ChatManage copies to avoid concurrent write conflicts
|
||||||
|
chunkChatManage := *chatManage
|
||||||
|
chunkChatManage.SearchResult = nil
|
||||||
|
|
||||||
|
entityChatManage := *chatManage
|
||||||
|
entityChatManage.SearchResult = nil
|
||||||
|
|
||||||
|
// Run chunk search and entity search in parallel
|
||||||
|
wg.Add(2)
|
||||||
|
|
||||||
|
// Goroutine 1: Chunk Search
|
||||||
|
go func() {
|
||||||
|
defer wg.Done()
|
||||||
|
err := p.searchPlugin.OnEvent(ctx, types.CHUNK_SEARCH, &chunkChatManage, func() *PluginError {
|
||||||
|
return nil
|
||||||
|
})
|
||||||
|
if err != nil && err != ErrSearchNothing {
|
||||||
|
mu.Lock()
|
||||||
|
chunkSearchErr = err
|
||||||
|
mu.Unlock()
|
||||||
|
}
|
||||||
|
pipelineInfo(ctx, "SearchParallel", "chunk_search_done", map[string]interface{}{
|
||||||
|
"result_count": len(chunkChatManage.SearchResult),
|
||||||
|
"has_error": err != nil && err != ErrSearchNothing,
|
||||||
|
})
|
||||||
|
}()
|
||||||
|
|
||||||
|
// Goroutine 2: Entity Search (only if entities are available)
|
||||||
|
go func() {
|
||||||
|
defer wg.Done()
|
||||||
|
if len(chatManage.Entity) == 0 {
|
||||||
|
pipelineInfo(ctx, "SearchParallel", "entity_search_skip", map[string]interface{}{
|
||||||
|
"reason": "no_entities",
|
||||||
|
})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
err := p.searchEntityPlugin.OnEvent(ctx, types.ENTITY_SEARCH, &entityChatManage, func() *PluginError {
|
||||||
|
return nil
|
||||||
|
})
|
||||||
|
if err != nil && err != ErrSearchNothing {
|
||||||
|
mu.Lock()
|
||||||
|
entitySearchErr = err
|
||||||
|
mu.Unlock()
|
||||||
|
}
|
||||||
|
pipelineInfo(ctx, "SearchParallel", "entity_search_done", map[string]interface{}{
|
||||||
|
"result_count": len(entityChatManage.SearchResult),
|
||||||
|
"has_error": err != nil && err != ErrSearchNothing,
|
||||||
|
})
|
||||||
|
}()
|
||||||
|
|
||||||
|
wg.Wait()
|
||||||
|
|
||||||
|
// Merge results from both searches (no concurrent access now)
|
||||||
|
chatManage.SearchResult = append(chunkChatManage.SearchResult, entityChatManage.SearchResult...)
|
||||||
|
chatManage.SearchResult = removeDuplicateResults(chatManage.SearchResult)
|
||||||
|
|
||||||
|
// Log any errors but don't fail the pipeline if at least one search succeeded
|
||||||
|
if chunkSearchErr != nil {
|
||||||
|
logger.Warnf(ctx, "[SearchParallel] Chunk search error: %v", chunkSearchErr.Err)
|
||||||
|
}
|
||||||
|
if entitySearchErr != nil {
|
||||||
|
logger.Warnf(ctx, "[SearchParallel] Entity search error: %v", entitySearchErr.Err)
|
||||||
|
}
|
||||||
|
|
||||||
|
pipelineInfo(ctx, "SearchParallel", "complete", map[string]interface{}{
|
||||||
|
"session_id": chatManage.SessionID,
|
||||||
|
"chunk_results": len(chunkChatManage.SearchResult),
|
||||||
|
"entity_results": len(entityChatManage.SearchResult),
|
||||||
|
"total_results": len(chatManage.SearchResult),
|
||||||
|
"chunk_search_error": chunkSearchErr != nil,
|
||||||
|
"entity_search_error": entitySearchErr != nil,
|
||||||
|
})
|
||||||
|
|
||||||
|
// Return error only if both searches failed and we have no results
|
||||||
|
if len(chatManage.SearchResult) == 0 {
|
||||||
|
if chunkSearchErr != nil {
|
||||||
|
return chunkSearchErr
|
||||||
|
}
|
||||||
|
return ErrSearchNothing
|
||||||
|
}
|
||||||
|
|
||||||
|
return next()
|
||||||
|
}
|
||||||
@@ -34,7 +34,7 @@ func (p *PluginTracing) ActivationEvents() []types.EventType {
|
|||||||
types.CHAT_COMPLETION_STREAM,
|
types.CHAT_COMPLETION_STREAM,
|
||||||
types.FILTER_TOP_K,
|
types.FILTER_TOP_K,
|
||||||
types.REWRITE_QUERY,
|
types.REWRITE_QUERY,
|
||||||
types.PREPROCESS_QUERY,
|
types.CHUNK_SEARCH_PARALLEL,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -69,8 +69,8 @@ func (p *PluginTracing) OnEvent(ctx context.Context,
|
|||||||
return p.FilterTopK(ctx, eventType, chatManage, next)
|
return p.FilterTopK(ctx, eventType, chatManage, next)
|
||||||
case types.REWRITE_QUERY:
|
case types.REWRITE_QUERY:
|
||||||
return p.RewriteQuery(ctx, eventType, chatManage, next)
|
return p.RewriteQuery(ctx, eventType, chatManage, next)
|
||||||
case types.PREPROCESS_QUERY:
|
case types.CHUNK_SEARCH_PARALLEL:
|
||||||
return p.PreprocessQuery(ctx, eventType, chatManage, next)
|
return p.SearchParallel(ctx, eventType, chatManage, next)
|
||||||
}
|
}
|
||||||
return next()
|
return next()
|
||||||
}
|
}
|
||||||
@@ -95,7 +95,6 @@ func (p *PluginTracing) Search(ctx context.Context,
|
|||||||
}
|
}
|
||||||
span.SetAttributes(
|
span.SetAttributes(
|
||||||
attribute.String("hybrid_search", string(searchResultJson)),
|
attribute.String("hybrid_search", string(searchResultJson)),
|
||||||
attribute.String("processed_query", chatManage.ProcessedQuery),
|
|
||||||
attribute.Int("search_unique_count", len(unique)),
|
attribute.Int("search_unique_count", len(unique)),
|
||||||
)
|
)
|
||||||
return err
|
return err
|
||||||
@@ -119,7 +118,6 @@ func (p *PluginTracing) Rerank(ctx context.Context,
|
|||||||
span.SetAttributes(
|
span.SetAttributes(
|
||||||
attribute.Int("rerank_resp_count", len(chatManage.RerankResult)),
|
attribute.Int("rerank_resp_count", len(chatManage.RerankResult)),
|
||||||
attribute.String("rerank_resp_results", string(resultJson)),
|
attribute.String("rerank_resp_results", string(resultJson)),
|
||||||
attribute.String("query_intent", chatManage.QueryIntent),
|
|
||||||
)
|
)
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
@@ -266,22 +264,20 @@ func (p *PluginTracing) RewriteQuery(ctx context.Context,
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
// PreprocessQuery traces query preprocessing operations
|
// SearchParallel traces parallel search operations (chunk + entity)
|
||||||
func (p *PluginTracing) PreprocessQuery(ctx context.Context,
|
func (p *PluginTracing) SearchParallel(ctx context.Context,
|
||||||
eventType types.EventType, chatManage *types.ChatManage, next func() *PluginError,
|
eventType types.EventType, chatManage *types.ChatManage, next func() *PluginError,
|
||||||
) *PluginError {
|
) *PluginError {
|
||||||
_, span := tracing.ContextWithSpan(ctx, "PluginTracing.PreprocessQuery")
|
_, span := tracing.ContextWithSpan(ctx, "PluginTracing.SearchParallel")
|
||||||
defer span.End()
|
defer span.End()
|
||||||
|
|
||||||
span.SetAttributes(
|
span.SetAttributes(
|
||||||
attribute.String("query", chatManage.Query),
|
attribute.String("query", chatManage.Query),
|
||||||
|
attribute.String("rewrite_query", chatManage.RewriteQuery),
|
||||||
|
attribute.Int("entity_count", len(chatManage.Entity)),
|
||||||
)
|
)
|
||||||
|
|
||||||
err := next()
|
err := next()
|
||||||
|
|
||||||
span.SetAttributes(
|
span.SetAttributes(
|
||||||
attribute.String("processed_query", chatManage.ProcessedQuery),
|
attribute.Int("search_result_count", len(chatManage.SearchResult)),
|
||||||
)
|
)
|
||||||
|
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -9,7 +9,6 @@ import (
|
|||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/Tencent/WeKnora/internal/application/service/retriever"
|
"github.com/Tencent/WeKnora/internal/application/service/retriever"
|
||||||
"github.com/Tencent/WeKnora/internal/common"
|
|
||||||
"github.com/Tencent/WeKnora/internal/logger"
|
"github.com/Tencent/WeKnora/internal/logger"
|
||||||
"github.com/Tencent/WeKnora/internal/models/embedding"
|
"github.com/Tencent/WeKnora/internal/models/embedding"
|
||||||
"github.com/Tencent/WeKnora/internal/types"
|
"github.com/Tencent/WeKnora/internal/types"
|
||||||
@@ -523,35 +522,113 @@ func (s *knowledgeBaseService) HybridSearch(ctx context.Context,
|
|||||||
|
|
||||||
// Collect all results from different retrievers and deduplicate by chunk ID
|
// Collect all results from different retrievers and deduplicate by chunk ID
|
||||||
logger.Infof(ctx, "Processing retrieval results")
|
logger.Infof(ctx, "Processing retrieval results")
|
||||||
matchResults := []*types.IndexWithScore{}
|
|
||||||
|
// Separate results by retriever type for RRF fusion
|
||||||
|
var vectorResults []*types.IndexWithScore
|
||||||
|
var keywordResults []*types.IndexWithScore
|
||||||
for _, retrieveResult := range retrieveResults {
|
for _, retrieveResult := range retrieveResults {
|
||||||
logger.Infof(ctx, "Retrieval results, engine: %v, retriever: %v, count: %v",
|
logger.Infof(ctx, "Retrieval results, engine: %v, retriever: %v, count: %v",
|
||||||
retrieveResult.RetrieverEngineType,
|
retrieveResult.RetrieverEngineType,
|
||||||
retrieveResult.RetrieverType,
|
retrieveResult.RetrieverType,
|
||||||
len(retrieveResult.Results),
|
len(retrieveResult.Results),
|
||||||
)
|
)
|
||||||
matchResults = append(matchResults, retrieveResult.Results...)
|
if retrieveResult.RetrieverType == types.VectorRetrieverType {
|
||||||
|
vectorResults = append(vectorResults, retrieveResult.Results...)
|
||||||
|
} else {
|
||||||
|
keywordResults = append(keywordResults, retrieveResult.Results...)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Early return if no results
|
// Early return if no results
|
||||||
if len(matchResults) == 0 {
|
if len(vectorResults) == 0 && len(keywordResults) == 0 {
|
||||||
logger.Info(ctx, "No search results found")
|
logger.Info(ctx, "No search results found")
|
||||||
return nil, nil
|
return nil, nil
|
||||||
}
|
}
|
||||||
logger.Infof(ctx, "Result count before deduplication: %d", len(matchResults))
|
logger.Infof(ctx, "Result count before RRF fusion: vector=%d, keyword=%d", len(vectorResults), len(keywordResults))
|
||||||
|
|
||||||
// First, try standard deduplication
|
// Use RRF (Reciprocal Rank Fusion) to merge results
|
||||||
deduplicatedChunks := common.DeduplicateWithScore(
|
// RRF score = sum(1 / (k + rank)) for each retriever where the chunk appears
|
||||||
func(r *types.IndexWithScore) string { return r.ChunkID },
|
// k=60 is a common choice that works well in practice
|
||||||
matchResults...)
|
const rrfK = 60
|
||||||
logger.Infof(ctx, "Result count after deduplication: %d", len(deduplicatedChunks))
|
|
||||||
|
// Build rank maps for each retriever (already sorted by score from retriever)
|
||||||
|
vectorRanks := make(map[string]int)
|
||||||
|
for i, r := range vectorResults {
|
||||||
|
if _, exists := vectorRanks[r.ChunkID]; !exists {
|
||||||
|
vectorRanks[r.ChunkID] = i + 1 // 1-indexed rank
|
||||||
|
}
|
||||||
|
}
|
||||||
|
keywordRanks := make(map[string]int)
|
||||||
|
for i, r := range keywordResults {
|
||||||
|
if _, exists := keywordRanks[r.ChunkID]; !exists {
|
||||||
|
keywordRanks[r.ChunkID] = i + 1 // 1-indexed rank
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Collect all unique chunks and compute RRF scores
|
||||||
|
chunkInfoMap := make(map[string]*types.IndexWithScore)
|
||||||
|
rrfScores := make(map[string]float64)
|
||||||
|
|
||||||
|
// Process vector results
|
||||||
|
for _, r := range vectorResults {
|
||||||
|
if _, exists := chunkInfoMap[r.ChunkID]; !exists {
|
||||||
|
chunkInfoMap[r.ChunkID] = r
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// Process keyword results
|
||||||
|
for _, r := range keywordResults {
|
||||||
|
if _, exists := chunkInfoMap[r.ChunkID]; !exists {
|
||||||
|
chunkInfoMap[r.ChunkID] = r
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Compute RRF scores
|
||||||
|
for chunkID := range chunkInfoMap {
|
||||||
|
rrfScore := 0.0
|
||||||
|
if rank, ok := vectorRanks[chunkID]; ok {
|
||||||
|
rrfScore += 1.0 / float64(rrfK+rank)
|
||||||
|
}
|
||||||
|
if rank, ok := keywordRanks[chunkID]; ok {
|
||||||
|
rrfScore += 1.0 / float64(rrfK+rank)
|
||||||
|
}
|
||||||
|
rrfScores[chunkID] = rrfScore
|
||||||
|
}
|
||||||
|
|
||||||
|
// Convert to slice and sort by RRF score
|
||||||
|
deduplicatedChunks := make([]*types.IndexWithScore, 0, len(chunkInfoMap))
|
||||||
|
for chunkID, info := range chunkInfoMap {
|
||||||
|
// Store RRF score in the Score field for downstream processing
|
||||||
|
info.Score = rrfScores[chunkID]
|
||||||
|
deduplicatedChunks = append(deduplicatedChunks, info)
|
||||||
|
}
|
||||||
|
slices.SortFunc(deduplicatedChunks, func(a, b *types.IndexWithScore) int {
|
||||||
|
if a.Score > b.Score {
|
||||||
|
return -1
|
||||||
|
} else if a.Score < b.Score {
|
||||||
|
return 1
|
||||||
|
}
|
||||||
|
return 0
|
||||||
|
})
|
||||||
|
|
||||||
|
logger.Infof(ctx, "Result count after RRF fusion: %d", len(deduplicatedChunks))
|
||||||
|
|
||||||
|
// Log top results after RRF fusion for debugging
|
||||||
|
for i, chunk := range deduplicatedChunks {
|
||||||
|
if i < 15 {
|
||||||
|
vRank, vOk := vectorRanks[chunk.ChunkID]
|
||||||
|
kRank, kOk := keywordRanks[chunk.ChunkID]
|
||||||
|
logger.Debugf(ctx, "RRF rank %d: chunk_id=%s, rrf_score=%.6f, vector_rank=%v(%v), keyword_rank=%v(%v)",
|
||||||
|
i, chunk.ChunkID, chunk.Score, vRank, vOk, kRank, kOk)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
kb.EnsureDefaults()
|
kb.EnsureDefaults()
|
||||||
|
|
||||||
// Check if we need iterative retrieval for FAQ with separate indexing
|
// Check if we need iterative retrieval for FAQ with separate indexing
|
||||||
// Only use iterative retrieval if we don't have enough unique chunks after first deduplication
|
// Only use iterative retrieval if we don't have enough unique chunks after first deduplication
|
||||||
|
totalRetrieved := len(vectorResults) + len(keywordResults)
|
||||||
needsIterativeRetrieval := len(deduplicatedChunks) < params.MatchCount &&
|
needsIterativeRetrieval := len(deduplicatedChunks) < params.MatchCount &&
|
||||||
kb.Type == types.KnowledgeBaseTypeFAQ && len(matchResults) == matchCount
|
kb.Type == types.KnowledgeBaseTypeFAQ && totalRetrieved == matchCount*2
|
||||||
if needsIterativeRetrieval {
|
if needsIterativeRetrieval {
|
||||||
logger.Info(ctx, "Not enough unique chunks, using iterative retrieval for FAQ")
|
logger.Info(ctx, "Not enough unique chunks, using iterative retrieval for FAQ")
|
||||||
// Use iterative retrieval to get more unique chunks (with negative question filtering inside)
|
// Use iterative retrieval to get more unique chunks (with negative question filtering inside)
|
||||||
@@ -891,10 +968,38 @@ func (s *knowledgeBaseService) processSearchResults(ctx context.Context,
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Build final search results
|
// Build final search results - preserve original order from input chunks
|
||||||
var searchResults []*types.SearchResult
|
var searchResults []*types.SearchResult
|
||||||
for chunkID, chunk := range chunkMap {
|
addedChunkIDs := make(map[string]bool)
|
||||||
|
|
||||||
|
// First pass: Add results in the original order from input chunks
|
||||||
|
for _, inputChunk := range chunks {
|
||||||
|
chunk, exists := chunkMap[inputChunk.ChunkID]
|
||||||
|
if !exists {
|
||||||
|
logger.Debugf(ctx, "Chunk not found in chunkMap: %s", inputChunk.ChunkID)
|
||||||
|
continue
|
||||||
|
}
|
||||||
if !s.isValidTextChunk(chunk) {
|
if !s.isValidTextChunk(chunk) {
|
||||||
|
logger.Debugf(ctx, "Chunk is not valid text chunk: %s, type: %s", chunk.ID, chunk.ChunkType)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if addedChunkIDs[chunk.ID] {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
score := chunkScores[chunk.ID]
|
||||||
|
if knowledge, ok := knowledgeMap[chunk.KnowledgeID]; ok {
|
||||||
|
matchType := chunkMatchTypes[chunk.ID]
|
||||||
|
searchResults = append(searchResults, s.buildSearchResult(chunk, knowledge, score, matchType))
|
||||||
|
addedChunkIDs[chunk.ID] = true
|
||||||
|
} else {
|
||||||
|
logger.Warnf(ctx, "Knowledge not found for chunk: %s, knowledge_id: %s", chunk.ID, chunk.KnowledgeID)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Second pass: Add additional chunks (parent, nearby, relation) that weren't in original input
|
||||||
|
for chunkID, chunk := range chunkMap {
|
||||||
|
if addedChunkIDs[chunkID] || !s.isValidTextChunk(chunk) {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -959,6 +1064,7 @@ func (s *knowledgeBaseService) buildSearchResult(chunk *types.Chunk,
|
|||||||
ImageInfo: chunk.ImageInfo,
|
ImageInfo: chunk.ImageInfo,
|
||||||
KnowledgeFilename: knowledge.FileName,
|
KnowledgeFilename: knowledge.FileName,
|
||||||
KnowledgeSource: knowledge.Source,
|
KnowledgeSource: knowledge.Source,
|
||||||
|
ChunkMetadata: chunk.Metadata,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -800,11 +800,10 @@ func (s *sessionService) SearchKnowledge(ctx context.Context,
|
|||||||
|
|
||||||
// Use specific event list, only including retrieval-related events, not LLM summarization
|
// Use specific event list, only including retrieval-related events, not LLM summarization
|
||||||
searchEvents := []types.EventType{
|
searchEvents := []types.EventType{
|
||||||
types.PREPROCESS_QUERY, // Preprocess query
|
types.CHUNK_SEARCH, // Vector search
|
||||||
types.CHUNK_SEARCH, // Vector search
|
types.CHUNK_RERANK, // Rerank search results
|
||||||
types.CHUNK_RERANK, // Rerank search results
|
types.CHUNK_MERGE, // Merge search results
|
||||||
types.CHUNK_MERGE, // Merge search results
|
types.FILTER_TOP_K, // Filter top K results
|
||||||
types.FILTER_TOP_K, // Filter top K results
|
|
||||||
}
|
}
|
||||||
|
|
||||||
ctx, span := tracing.ContextWithSpan(ctx, "SessionService.SearchKnowledge")
|
ctx, span := tracing.ContextWithSpan(ctx, "SessionService.SearchKnowledge")
|
||||||
|
|||||||
@@ -140,10 +140,10 @@ func BuildContainer(container *dig.Container) *dig.Container {
|
|||||||
must(container.Invoke(chatpipline.NewPluginChatCompletionStream))
|
must(container.Invoke(chatpipline.NewPluginChatCompletionStream))
|
||||||
must(container.Invoke(chatpipline.NewPluginStreamFilter))
|
must(container.Invoke(chatpipline.NewPluginStreamFilter))
|
||||||
must(container.Invoke(chatpipline.NewPluginFilterTopK))
|
must(container.Invoke(chatpipline.NewPluginFilterTopK))
|
||||||
must(container.Invoke(chatpipline.NewPluginPreprocess))
|
|
||||||
must(container.Invoke(chatpipline.NewPluginRewrite))
|
must(container.Invoke(chatpipline.NewPluginRewrite))
|
||||||
must(container.Invoke(chatpipline.NewPluginExtractEntity))
|
must(container.Invoke(chatpipline.NewPluginExtractEntity))
|
||||||
must(container.Invoke(chatpipline.NewPluginSearchEntity))
|
must(container.Invoke(chatpipline.NewPluginSearchEntity))
|
||||||
|
must(container.Invoke(chatpipline.NewPluginSearchParallel))
|
||||||
|
|
||||||
// HTTP handlers layer
|
// HTTP handlers layer
|
||||||
must(container.Provide(handler.NewTenantHandler))
|
must(container.Provide(handler.NewTenantHandler))
|
||||||
|
|||||||
@@ -5,7 +5,6 @@ package event
|
|||||||
// QueryData represents query-related event data
|
// QueryData represents query-related event data
|
||||||
type QueryData struct {
|
type QueryData struct {
|
||||||
OriginalQuery string `json:"original_query"`
|
OriginalQuery string `json:"original_query"`
|
||||||
ProcessedQuery string `json:"processed_query,omitempty"`
|
|
||||||
RewrittenQuery string `json:"rewritten_query,omitempty"`
|
RewrittenQuery string `json:"rewritten_query,omitempty"`
|
||||||
SessionID string `json:"session_id"`
|
SessionID string `json:"session_id"`
|
||||||
UserID string `json:"user_id,omitempty"`
|
UserID string `json:"user_id,omitempty"`
|
||||||
|
|||||||
70
internal/searchutil/textutil.go
Normal file
70
internal/searchutil/textutil.go
Normal file
@@ -0,0 +1,70 @@
|
|||||||
|
package searchutil
|
||||||
|
|
||||||
|
import (
|
||||||
|
"crypto/md5"
|
||||||
|
"encoding/hex"
|
||||||
|
"strings"
|
||||||
|
)
|
||||||
|
|
||||||
|
// BuildContentSignature creates a normalized MD5 signature for content to detect duplicates.
|
||||||
|
// It normalizes the content by lowercasing, trimming whitespace, and collapsing multiple spaces.
|
||||||
|
func BuildContentSignature(content string) string {
|
||||||
|
c := strings.ToLower(strings.TrimSpace(content))
|
||||||
|
if c == "" {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
// Normalize whitespace
|
||||||
|
c = strings.Join(strings.Fields(c), " ")
|
||||||
|
// Use MD5 hash of full content
|
||||||
|
hash := md5.Sum([]byte(c))
|
||||||
|
return hex.EncodeToString(hash[:])
|
||||||
|
}
|
||||||
|
|
||||||
|
// TokenizeSimple tokenizes text into a set of words (simple whitespace-based).
|
||||||
|
// Returns a map where keys are lowercase tokens with length > 1.
|
||||||
|
func TokenizeSimple(text string) map[string]struct{} {
|
||||||
|
text = strings.ToLower(text)
|
||||||
|
fields := strings.Fields(text)
|
||||||
|
set := make(map[string]struct{}, len(fields))
|
||||||
|
for _, f := range fields {
|
||||||
|
if len(f) > 1 {
|
||||||
|
set[f] = struct{}{}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return set
|
||||||
|
}
|
||||||
|
|
||||||
|
// Jaccard calculates Jaccard similarity between two token sets.
|
||||||
|
// Returns a value between 0 and 1, where 1 means identical sets.
|
||||||
|
func Jaccard(a, b map[string]struct{}) float64 {
|
||||||
|
if len(a) == 0 && len(b) == 0 {
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
|
||||||
|
// Calculate intersection
|
||||||
|
inter := 0
|
||||||
|
for k := range a {
|
||||||
|
if _, ok := b[k]; ok {
|
||||||
|
inter++
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Calculate union
|
||||||
|
union := len(a) + len(b) - inter
|
||||||
|
if union == 0 {
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
|
||||||
|
return float64(inter) / float64(union)
|
||||||
|
}
|
||||||
|
|
||||||
|
// ClampFloat clamps a float value to the specified range [minV, maxV].
|
||||||
|
func ClampFloat(v, minV, maxV float64) float64 {
|
||||||
|
if v < minV {
|
||||||
|
return minV
|
||||||
|
}
|
||||||
|
if v > maxV {
|
||||||
|
return maxV
|
||||||
|
}
|
||||||
|
return v
|
||||||
|
}
|
||||||
@@ -3,12 +3,10 @@ package types
|
|||||||
// ChatManage represents the configuration and state for a chat session
|
// ChatManage represents the configuration and state for a chat session
|
||||||
// including query processing, search parameters, and model configurations
|
// including query processing, search parameters, and model configurations
|
||||||
type ChatManage struct {
|
type ChatManage struct {
|
||||||
SessionID string `json:"session_id"` // Unique identifier for the chat session
|
SessionID string `json:"session_id"` // Unique identifier for the chat session
|
||||||
Query string `json:"query,omitempty"` // Original user query
|
Query string `json:"query,omitempty"` // Original user query
|
||||||
ProcessedQuery string `json:"processed_query,omitempty"` // Query after preprocessing
|
RewriteQuery string `json:"rewrite_query,omitempty"` // Query after rewriting for better retrieval
|
||||||
RewriteQuery string `json:"rewrite_query,omitempty"` // Query after rewriting for better retrieval
|
History []*History `json:"history,omitempty"` // Chat history for context
|
||||||
QueryIntent string `json:"query_intent,omitempty"` // Parsed intent: definition/howto/compare/qa/general
|
|
||||||
History []*History `json:"history,omitempty"` // Chat history for context
|
|
||||||
|
|
||||||
KnowledgeBaseID string `json:"knowledge_base_id"` // ID of the knowledge base to search against (deprecated, use KnowledgeBaseIDs)
|
KnowledgeBaseID string `json:"knowledge_base_id"` // ID of the knowledge base to search against (deprecated, use KnowledgeBaseIDs)
|
||||||
KnowledgeBaseIDs []string `json:"knowledge_base_ids"` // IDs of knowledge bases to search (multi-KB support)
|
KnowledgeBaseIDs []string `json:"knowledge_base_ids"` // IDs of knowledge bases to search (multi-KB support)
|
||||||
@@ -60,9 +58,7 @@ func (c *ChatManage) Clone() *ChatManage {
|
|||||||
|
|
||||||
return &ChatManage{
|
return &ChatManage{
|
||||||
Query: c.Query,
|
Query: c.Query,
|
||||||
ProcessedQuery: c.ProcessedQuery,
|
|
||||||
RewriteQuery: c.RewriteQuery,
|
RewriteQuery: c.RewriteQuery,
|
||||||
QueryIntent: c.QueryIntent,
|
|
||||||
SessionID: c.SessionID,
|
SessionID: c.SessionID,
|
||||||
KnowledgeBaseID: c.KnowledgeBaseID,
|
KnowledgeBaseID: c.KnowledgeBaseID,
|
||||||
KnowledgeBaseIDs: knowledgeBaseIDs,
|
KnowledgeBaseIDs: knowledgeBaseIDs,
|
||||||
@@ -103,9 +99,9 @@ func (c *ChatManage) Clone() *ChatManage {
|
|||||||
type EventType string
|
type EventType string
|
||||||
|
|
||||||
const (
|
const (
|
||||||
PREPROCESS_QUERY EventType = "preprocess_query" // Query preprocessing stage
|
|
||||||
REWRITE_QUERY EventType = "rewrite_query" // Query rewriting for better retrieval
|
REWRITE_QUERY EventType = "rewrite_query" // Query rewriting for better retrieval
|
||||||
CHUNK_SEARCH EventType = "chunk_search" // Search for relevant chunks
|
CHUNK_SEARCH EventType = "chunk_search" // Search for relevant chunks
|
||||||
|
CHUNK_SEARCH_PARALLEL EventType = "chunk_search_parallel" // Parallel search: chunks + entities
|
||||||
ENTITY_SEARCH EventType = "entity_search" // Search for relevant entities
|
ENTITY_SEARCH EventType = "entity_search" // Search for relevant entities
|
||||||
CHUNK_RERANK EventType = "chunk_rerank" // Rerank search results
|
CHUNK_RERANK EventType = "chunk_rerank" // Rerank search results
|
||||||
CHUNK_MERGE EventType = "chunk_merge" // Merge similar chunks
|
CHUNK_MERGE EventType = "chunk_merge" // Merge similar chunks
|
||||||
@@ -134,9 +130,7 @@ var Pipline = map[string][]EventType{
|
|||||||
},
|
},
|
||||||
"rag_stream": { // Streaming Retrieval Augmented Generation
|
"rag_stream": { // Streaming Retrieval Augmented Generation
|
||||||
REWRITE_QUERY,
|
REWRITE_QUERY,
|
||||||
PREPROCESS_QUERY,
|
CHUNK_SEARCH_PARALLEL, // Parallel: CHUNK_SEARCH + ENTITY_SEARCH
|
||||||
CHUNK_SEARCH,
|
|
||||||
ENTITY_SEARCH,
|
|
||||||
CHUNK_RERANK,
|
CHUNK_RERANK,
|
||||||
CHUNK_MERGE,
|
CHUNK_MERGE,
|
||||||
FILTER_TOP_K,
|
FILTER_TOP_K,
|
||||||
|
|||||||
@@ -46,6 +46,9 @@ type SearchResult struct {
|
|||||||
// Knowledge source
|
// Knowledge source
|
||||||
// Used to indicate the source of the knowledge, such as "url"
|
// Used to indicate the source of the knowledge, such as "url"
|
||||||
KnowledgeSource string `json:"knowledge_source"`
|
KnowledgeSource string `json:"knowledge_source"`
|
||||||
|
|
||||||
|
// ChunkMetadata stores chunk-level metadata (e.g., generated questions)
|
||||||
|
ChunkMetadata JSON `json:"chunk_metadata,omitempty"`
|
||||||
}
|
}
|
||||||
|
|
||||||
// SearchParams represents the search parameters
|
// SearchParams represents the search parameters
|
||||||
|
|||||||
Reference in New Issue
Block a user