mirror of
https://github.com/Tencent/WeKnora.git
synced 2026-06-04 13:30:32 +08:00
Multi-KB hybrid search now groups KBs by their bound VectorStore (partition key (storeID, owner_tenant_id)), retrieves in parallel via errgroup with a SetLimit(4) cap and a per-group timeout (MULTI_STORE_RETRIEVE_TIMEOUT_SEC, default 30s), and merges results. When the collected results span more than one engine type, an EngineAwareNormalizer rescales vector scores to [0, 1]; keyword (BM25) scores pass through to the existing RRF fusion. Single-group calls take the fast path with zero fan-out overhead, preserving today's behavior for deployments where every KB has vector_store_id = NULL. Embedding-model consistency is now enforced explicitly via ResolveEmbeddingModelKeys. Multi-KB searches across KBs whose resolved model identities differ return BadRequest instead of silently producing incomparable scores. Cross-tenant Organization-shared KBs are preserved by partitioning on KB.TenantID so the factory's ownership lookup runs against the source tenant. Foreign-tenant KB UUIDs injected via the request body are rejected via kbShareService.HasTenantKBPermission (Plan 3 of #1303, 3-D capped) before any retrieval; rejected scopes surface as 404 to avoid leaking foreign KB existence. Service-layer typed AppErrors (ErrVectorStoreBindingInvalid 2200 / ErrVectorStoreUnavailable 2201) are mapped from PR2 sentinel hierarchy and preserved end-to-end: the iterative FAQ path returns them rather than swallowing, and the HybridSearch handler routes typed AppErrors to the client unchanged instead of downgrading to 500. Part of #993 (Phase 2: Per-KB VectorStore Binding). Phase 2 roadmap item: PR 4 (Multi-store fan-out search). Depends on #994, #1310, #1372.
388 lines
15 KiB
Go
388 lines
15 KiB
Go
package service
|
|
|
|
import (
|
|
"context"
|
|
|
|
"github.com/Tencent/WeKnora/internal/application/service/retriever"
|
|
apperrors "github.com/Tencent/WeKnora/internal/errors"
|
|
"github.com/Tencent/WeKnora/internal/logger"
|
|
"github.com/Tencent/WeKnora/internal/models/embedding"
|
|
"github.com/Tencent/WeKnora/internal/tracing/langfuse"
|
|
"github.com/Tencent/WeKnora/internal/types"
|
|
secutils "github.com/Tencent/WeKnora/internal/utils"
|
|
)
|
|
|
|
// GetQueryEmbedding computes the query embedding using the embedding model
|
|
// associated with the given knowledge base. Callers can pre-compute and reuse
|
|
// the result across multiple KBs that share the same embedding model to avoid
|
|
// redundant embedding API calls.
|
|
func (s *knowledgeBaseService) GetQueryEmbedding(ctx context.Context, kbID string, queryText string) ([]float32, error) {
|
|
kb, err := s.repo.GetKnowledgeBaseByID(ctx, kbID)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
currentTenantID := types.MustTenantIDFromContext(ctx)
|
|
var embeddingModel embedding.Embedder
|
|
|
|
if kb.TenantID != currentTenantID {
|
|
embeddingModel, err = s.modelService.GetEmbeddingModelForTenant(ctx, kb.EmbeddingModelID, kb.TenantID)
|
|
} else {
|
|
embeddingModel, err = s.modelService.GetEmbeddingModel(ctx, kb.EmbeddingModelID)
|
|
}
|
|
if err != nil {
|
|
logger.Errorf(ctx, "GetQueryEmbedding: failed to get embedding model %s: %v", kb.EmbeddingModelID, err)
|
|
return nil, err
|
|
}
|
|
|
|
return embeddingModel.Embed(ctx, queryText)
|
|
}
|
|
|
|
// ResolveEmbeddingModelKeys resolves embedding model IDs to their actual model
|
|
// identity key (name + endpoint). KBs using the same underlying model across
|
|
// different tenants will share the same key, enabling optimal grouping.
|
|
func (s *knowledgeBaseService) ResolveEmbeddingModelKeys(ctx context.Context, kbs []*types.KnowledgeBase) map[string]string {
|
|
type modelRef struct {
|
|
ModelID string
|
|
TenantID uint64
|
|
}
|
|
|
|
// Deduplicate model references
|
|
uniqueRefs := make(map[modelRef]struct{})
|
|
kbRefs := make(map[string]modelRef, len(kbs))
|
|
for _, kb := range kbs {
|
|
ref := modelRef{ModelID: kb.EmbeddingModelID, TenantID: kb.TenantID}
|
|
uniqueRefs[ref] = struct{}{}
|
|
kbRefs[kb.ID] = ref
|
|
}
|
|
|
|
// Resolve each unique (modelID, tenantID) to a model identity key
|
|
resolvedKeys := make(map[modelRef]string, len(uniqueRefs))
|
|
for ref := range uniqueRefs {
|
|
tenantCtx := context.WithValue(ctx, types.TenantIDContextKey, ref.TenantID)
|
|
model, err := s.modelService.GetModelByID(tenantCtx, ref.ModelID)
|
|
if err != nil || model == nil {
|
|
logger.Warnf(ctx, "ResolveEmbeddingModelKeys: cannot resolve model %s for tenant %d: %v", ref.ModelID, ref.TenantID, err)
|
|
resolvedKeys[ref] = ref.ModelID
|
|
continue
|
|
}
|
|
resolvedKeys[ref] = model.Name + "|" + model.Parameters.BaseURL
|
|
}
|
|
|
|
result := make(map[string]string, len(kbs))
|
|
for _, kb := range kbs {
|
|
result[kb.ID] = resolvedKeys[kbRefs[kb.ID]]
|
|
}
|
|
return result
|
|
}
|
|
|
|
// HybridSearch performs hybrid search, including vector retrieval and keyword retrieval.
|
|
//
|
|
// id is the "primary" knowledge base ID used to resolve the embedding model and
|
|
// determine the KB type (e.g. FAQ). When params.KnowledgeBaseIDs is set, those
|
|
// IDs are used for the actual retrieval scope instead of id alone, allowing a
|
|
// single call to span multiple KBs that share the same embedding model. In that
|
|
// case id should be any one of those KBs (typically the first) so that its
|
|
// embedding model and type configuration are used for the search.
|
|
func (s *knowledgeBaseService) HybridSearch(ctx context.Context,
|
|
id string,
|
|
params types.SearchParams,
|
|
) ([]*types.SearchResult, error) {
|
|
// Determine the set of KB IDs to search.
|
|
searchKBIDs := params.KnowledgeBaseIDs
|
|
if len(searchKBIDs) == 0 {
|
|
searchKBIDs = []string{id}
|
|
}
|
|
|
|
// QueryText is user-controlled; sanitize before logging to prevent
|
|
// CR/LF/tab log injection. Matches the handler-layer sanitization at
|
|
// handler/knowledgebase.go.
|
|
logger.Infof(ctx, "Hybrid search parameters, knowledge base IDs: %v, query text: %s",
|
|
searchKBIDs, secutils.SanitizeForLog(params.QueryText))
|
|
|
|
tenantInfo, _ := types.TenantInfoFromContext(ctx)
|
|
requestTenantID := types.MustTenantIDFromContext(ctx)
|
|
|
|
// Batch-load every KB in scope. Required for store grouping,
|
|
// embedding-model consistency validation, and FAQ type detection.
|
|
// GetKnowledgeBaseByIDs is intentionally tenant-agnostic at the
|
|
// repository layer so that Organization-shared KBs (owned by a
|
|
// different tenant) can be loaded here; authorization for each
|
|
// returned row is enforced explicitly below.
|
|
kbs, err := s.repo.GetKnowledgeBaseByIDs(ctx, searchKBIDs)
|
|
if err != nil {
|
|
logger.ErrorWithFields(ctx, err, map[string]interface{}{
|
|
"knowledge_base_ids": searchKBIDs,
|
|
})
|
|
return nil, err
|
|
}
|
|
if len(kbs) == 0 {
|
|
return nil, apperrors.NewNotFoundError("knowledge base not found")
|
|
}
|
|
|
|
// Authorize every KB the caller asked for. Same-tenant KBs are
|
|
// always accessible; foreign-tenant KBs (Organization-shared) must
|
|
// pass an explicit per-KB permission check. Without this guard, a
|
|
// caller could pass arbitrary KB UUIDs in params.KnowledgeBaseIDs
|
|
// and reach foreign tenants' bound vector stores via the per-group
|
|
// engine resolution downstream.
|
|
if err := s.authorizeKBAccess(ctx, kbs, requestTenantID); err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
// Explicit embedding-model consistency check. Multi-KB searches that
|
|
// span different embedding spaces would otherwise silently produce
|
|
// meaningless cross-model scores. Same-model wiki/graph KBs are
|
|
// tolerated — see validateSameEmbeddingModel for the carve-out.
|
|
if err := s.validateSameEmbeddingModel(ctx, kbs); err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
// Resolve the primary KB — embedding model + FAQ type come from this
|
|
// one. Miss → 404 (no kbs[0] fallback; a silent pivot to an arbitrary
|
|
// KB would hide caller bugs and reveal foreign KB metadata).
|
|
kb := pickPrimary(kbs, id)
|
|
if kb == nil {
|
|
return nil, apperrors.NewNotFoundError("knowledge base not found")
|
|
}
|
|
|
|
// Over-retrieval (existing rule, preserved): 5x per-KB matchCount,
|
|
// floor of 50, capped at 500 across the whole search.
|
|
matchCount := max(params.MatchCount*5, 50) * len(searchKBIDs)
|
|
if matchCount > 500 {
|
|
matchCount = 500
|
|
}
|
|
|
|
// Compute the query embedding once before fan-out and propagate via
|
|
// params.QueryEmbedding. Without this, each storeGroup's
|
|
// buildRetrievalParams would re-embed the same query text — for N
|
|
// stores that means N API calls of identical input.
|
|
//
|
|
// Skip when params already carries an embedding (e.g. the agent
|
|
// pre-computed it) or when the primary KB has no vector indexing
|
|
// configured.
|
|
if len(params.QueryEmbedding) == 0 &&
|
|
kb.IsVectorEnabled() && kb.EmbeddingModelID != "" &&
|
|
!params.DisableVectorMatch {
|
|
emb, embErr := s.GetQueryEmbedding(ctx, kb.ID, params.QueryText)
|
|
if embErr != nil {
|
|
return nil, embErr
|
|
}
|
|
params.QueryEmbedding = emb
|
|
}
|
|
|
|
// Group KBs by (storeID, owner tenant), resolve the bound engine for
|
|
// each group, and build the per-group base RetrieveParams once.
|
|
groups, err := s.resolveStoreGroups(ctx, kb, kbs, params, matchCount)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
if len(groups) == 0 || allBaseParamsEmpty(groups) {
|
|
// Wiki-only / graph-only fan-out: every KB is non-retrievable.
|
|
// Preserve the existing "return empty rather than error" contract
|
|
// so agent tools that combine multiple KB scopes degrade gracefully.
|
|
logger.Infof(ctx, "No retrievable indexing pipelines across %d KBs", len(kbs))
|
|
return nil, nil
|
|
}
|
|
|
|
// Execute retrieval with fan-out + score normalization (multi-store
|
|
// only) and a langfuse span around the entire retrieve step.
|
|
logger.Infof(ctx, "Starting multi-store retrieval, group count: %d", len(groups))
|
|
retrieveCtx, retrieveSpan := langfuse.GetManager().StartSpan(ctx, langfuse.SpanOptions{
|
|
Name: "retrieve",
|
|
Input: map[string]interface{}{
|
|
"kb_ids": searchKBIDs,
|
|
"group_count": len(groups),
|
|
"match_count": matchCount,
|
|
},
|
|
})
|
|
retrieveResults, err := s.retrieveFromStores(retrieveCtx, groups, retriever.EngineAwareNormalizer{})
|
|
retrieveSpan.Finish(map[string]interface{}{
|
|
"result_count": totalHits(retrieveResults),
|
|
}, nil, err)
|
|
if err != nil {
|
|
logger.ErrorWithFields(ctx, err, map[string]interface{}{
|
|
"knowledge_base_ids": searchKBIDs,
|
|
"query_text": params.QueryText,
|
|
})
|
|
return nil, err
|
|
}
|
|
|
|
// Separate and fuse retrieval results.
|
|
vectorResults, keywordResults := classifyRetrievalResults(ctx, retrieveResults)
|
|
if len(vectorResults) == 0 && len(keywordResults) == 0 {
|
|
logger.Info(ctx, "No search results found")
|
|
return nil, nil
|
|
}
|
|
logger.Infof(ctx, "Result count before fusion: vector=%d, keyword=%d",
|
|
len(vectorResults), len(keywordResults))
|
|
|
|
var retrievalCfg *types.RetrievalConfig
|
|
if tenantInfo != nil {
|
|
retrievalCfg = tenantInfo.RetrievalConfig
|
|
}
|
|
deduplicatedChunks := fuseOrDeduplicate(ctx, vectorResults, keywordResults, retrievalCfg)
|
|
|
|
kb.EnsureDefaults()
|
|
|
|
// FAQ-specific post-processing now operates on storeGroups so the
|
|
// iterative TopK growth applies uniformly across the fan-out. An
|
|
// AppError from inside the iterative fan-out path (e.g. a per-group
|
|
// timeout surfaced as ErrVectorStoreUnavailable) must surface to the
|
|
// caller rather than be silently converted to a truncated chunk list.
|
|
deduplicatedChunks, err = s.applyFAQPostProcessing(
|
|
ctx, kb, deduplicatedChunks, vectorResults, groups, params, matchCount)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
if len(deduplicatedChunks) > params.MatchCount {
|
|
deduplicatedChunks = deduplicatedChunks[:params.MatchCount]
|
|
}
|
|
|
|
return s.processSearchResults(ctx, deduplicatedChunks, params.SkipContextEnrichment)
|
|
}
|
|
|
|
// pickPrimary returns the KB whose ID matches id, or nil if id is not in
|
|
// scope. Callers map a nil result to NotFound; there is intentionally no
|
|
// kbs[0] fallback because it would mask caller bugs and could leak an
|
|
// unintended KB's embedding-model identity into the search path.
|
|
//
|
|
// The primary KB drives the embedding model and FAQ-type decisions for
|
|
// buildRetrievalParams. If the caller selects a wiki-only / graph-only
|
|
// KB as primary, the multi-KB search is implicitly demoted to
|
|
// keyword-only retrieval — vector retrieval is skipped because
|
|
// primary.IsVectorEnabled() is false. Callers that mix vector-enabled
|
|
// and non-vector KBs should pass a vector-enabled KB as id.
|
|
func pickPrimary(kbs []*types.KnowledgeBase, id string) *types.KnowledgeBase {
|
|
for _, kb := range kbs {
|
|
if kb.ID == id {
|
|
return kb
|
|
}
|
|
}
|
|
return nil
|
|
}
|
|
|
|
// allBaseParamsEmpty reports whether every store group has an empty
|
|
// BaseParams slice. True only when every KB in scope is wiki-only or
|
|
// graph-only with neither vector nor keyword indexing — HybridSearch then
|
|
// returns nil so callers that combine searchable + non-searchable KBs
|
|
// (agent tools, chat pipeline) degrade gracefully.
|
|
func allBaseParamsEmpty(groups []*storeGroup) bool {
|
|
for _, g := range groups {
|
|
if len(g.BaseParams) > 0 {
|
|
return false
|
|
}
|
|
}
|
|
return true
|
|
}
|
|
|
|
// totalHits counts the IndexWithScore entries across a slice of retrieve
|
|
// results. Used only for langfuse span metadata.
|
|
func totalHits(rrs []*types.RetrieveResult) int {
|
|
n := 0
|
|
for _, r := range rrs {
|
|
n += len(r.Results)
|
|
}
|
|
return n
|
|
}
|
|
|
|
// buildRetrievalParams constructs the vector and keyword retrieval parameters
|
|
// based on the knowledge base type, engine capabilities, and search params.
|
|
func (s *knowledgeBaseService) buildRetrievalParams(
|
|
ctx context.Context,
|
|
retrieveEngine *retriever.CompositeRetrieveEngine,
|
|
kb *types.KnowledgeBase,
|
|
params types.SearchParams,
|
|
searchKBIDs []string,
|
|
matchCount int,
|
|
) ([]types.RetrieveParams, error) {
|
|
currentTenantID := types.MustTenantIDFromContext(ctx)
|
|
var retrieveParams []types.RetrieveParams
|
|
|
|
// Respect the KB's IndexingStrategy: a KB that does not have vector
|
|
// indexing enabled (e.g. wiki-only or graph-only KBs) has no embeddings
|
|
// to retrieve from, and typically has no EmbeddingModelID configured
|
|
// either. Skipping vector retrieval for such KBs avoids spurious
|
|
// "model ID cannot be empty" errors when an agent's retrieval scope
|
|
// happens to include them (e.g. KBSelectionMode=all picking up a
|
|
// wiki-only KB).
|
|
vectorIndexed := kb.IsVectorEnabled() && kb.EmbeddingModelID != ""
|
|
|
|
// Add vector retrieval params if supported
|
|
if retrieveEngine.SupportRetriever(types.VectorRetrieverType) && !params.DisableVectorMatch && vectorIndexed {
|
|
logger.Info(ctx, "Vector retrieval supported, preparing vector retrieval parameters")
|
|
|
|
var queryEmbedding []float32
|
|
|
|
if len(params.QueryEmbedding) > 0 {
|
|
queryEmbedding = params.QueryEmbedding
|
|
logger.Infof(ctx, "Using pre-computed query embedding, vector length: %d", len(queryEmbedding))
|
|
} else {
|
|
logger.Infof(ctx, "Getting embedding model, model ID: %s", kb.EmbeddingModelID)
|
|
|
|
// Check if this is a cross-tenant shared knowledge base
|
|
// For shared KB, we must use the source tenant's embedding model to ensure vector compatibility
|
|
var embeddingModel embedding.Embedder
|
|
var err error
|
|
if kb.TenantID != currentTenantID {
|
|
logger.Infof(ctx, "Cross-tenant knowledge base detected, using source tenant's embedding model. KB tenant: %d, current tenant: %d", kb.TenantID, currentTenantID)
|
|
embeddingModel, err = s.modelService.GetEmbeddingModelForTenant(ctx, kb.EmbeddingModelID, kb.TenantID)
|
|
} else {
|
|
embeddingModel, err = s.modelService.GetEmbeddingModel(ctx, kb.EmbeddingModelID)
|
|
}
|
|
|
|
if err != nil {
|
|
logger.Errorf(ctx, "Failed to get embedding model, model ID: %s, error: %v", kb.EmbeddingModelID, err)
|
|
return nil, err
|
|
}
|
|
logger.Infof(ctx, "Embedding model retrieved: %v", embeddingModel)
|
|
|
|
logger.Info(ctx, "Starting to generate query embedding")
|
|
queryEmbedding, err = embeddingModel.Embed(ctx, params.QueryText)
|
|
if err != nil {
|
|
logger.Errorf(ctx, "Failed to embed query text, query text: %s, error: %v", params.QueryText, err)
|
|
return nil, err
|
|
}
|
|
logger.Infof(ctx, "Query embedding generated successfully, embedding vector length: %d", len(queryEmbedding))
|
|
}
|
|
|
|
vectorParams := types.RetrieveParams{
|
|
Query: params.QueryText,
|
|
Embedding: queryEmbedding,
|
|
KnowledgeBaseIDs: searchKBIDs,
|
|
TopK: matchCount,
|
|
Threshold: params.VectorThreshold,
|
|
RetrieverType: types.VectorRetrieverType,
|
|
KnowledgeIDs: params.KnowledgeIDs,
|
|
TagIDs: params.TagIDs,
|
|
}
|
|
|
|
// For FAQ knowledge base, use FAQ index
|
|
if kb.Type == types.KnowledgeBaseTypeFAQ {
|
|
vectorParams.KnowledgeType = types.KnowledgeTypeFAQ
|
|
}
|
|
|
|
retrieveParams = append(retrieveParams, vectorParams)
|
|
logger.Info(ctx, "Vector retrieval parameters setup completed")
|
|
}
|
|
|
|
// Add keyword retrieval params if supported, KB has keyword indexing, and not FAQ
|
|
if retrieveEngine.SupportRetriever(types.KeywordsRetrieverType) && !params.DisableKeywordsMatch &&
|
|
kb.IsKeywordEnabled() && kb.Type != types.KnowledgeBaseTypeFAQ {
|
|
logger.Info(ctx, "Keyword retrieval supported, preparing keyword retrieval parameters")
|
|
retrieveParams = append(retrieveParams, types.RetrieveParams{
|
|
Query: params.QueryText,
|
|
KnowledgeBaseIDs: searchKBIDs,
|
|
TopK: matchCount,
|
|
Threshold: params.KeywordThreshold,
|
|
RetrieverType: types.KeywordsRetrieverType,
|
|
KnowledgeIDs: params.KnowledgeIDs,
|
|
TagIDs: params.TagIDs,
|
|
})
|
|
logger.Info(ctx, "Keyword retrieval parameters setup completed")
|
|
}
|
|
|
|
return retrieveParams, nil
|
|
}
|