feat: 支持知识库选择模式和标题生成模型配置

- 移除对rerankModel的强制检查,支持无知识库场景
- 优化知识库解析逻辑,优先处理用户明确指定的知识库
This commit is contained in:
wizardchen
2025-12-29 14:35:13 +08:00
parent 6bc19ab082
commit 6c8e261a5f
6 changed files with 92 additions and 44 deletions

Binary file not shown.

Before

Width:  |  Height:  |  Size: 14 MiB

After

Width:  |  Height:  |  Size: 15 MiB

View File

@@ -83,9 +83,8 @@ func (s *agentService) CreateAgentEngine(
return nil, fmt.Errorf("chat model is nil after initialization")
}
if rerankModel == nil {
return nil, fmt.Errorf("rerank model is nil after initialization")
}
// Note: rerankModel can be nil when no knowledge bases are configured
// The registerTools function will filter out knowledge-related tools in this case
// Create tool registry
toolRegistry := tools.NewToolRegistry(s.knowledgeService, s.chunkService, s.db)

View File

@@ -218,8 +218,9 @@ func (s *sessionService) DeleteSession(ctx context.Context, id string) error {
}
// GenerateTitle generates a title for the current conversation content
// modelID: optional model ID to use for title generation (if empty, uses first available KnowledgeQA model)
func (s *sessionService) GenerateTitle(ctx context.Context,
session *types.Session, messages []types.Message,
session *types.Session, messages []types.Message, modelID string,
) (string, error) {
if session == nil {
logger.Error(ctx, "Failed to generate title: session cannot be empty")
@@ -256,28 +257,29 @@ func (s *sessionService) GenerateTitle(ctx context.Context,
return "", errors.New("no user message found")
}
// Get chat model, find an available model
modelID := ""
// Try to get an available KnowledgeQA model
models, err := s.modelService.ListModels(ctx)
if err != nil {
logger.ErrorWithFields(ctx, err, nil)
return "", fmt.Errorf("failed to list models: %w", err)
}
// Find first available KnowledgeQA model
for _, model := range models {
if model == nil {
continue
}
if model.Type == types.ModelTypeKnowledgeQA {
modelID = model.ID
logger.Infof(ctx, "Using first available KnowledgeQA model: %s", modelID)
break
}
}
// Use provided modelID, or fallback to first available KnowledgeQA model
if modelID == "" {
logger.Error(ctx, "No KnowledgeQA model found")
return "", errors.New("no KnowledgeQA model available for title generation")
models, err := s.modelService.ListModels(ctx)
if err != nil {
logger.ErrorWithFields(ctx, err, nil)
return "", fmt.Errorf("failed to list models: %w", err)
}
for _, model := range models {
if model == nil {
continue
}
if model.Type == types.ModelTypeKnowledgeQA {
modelID = model.ID
logger.Infof(ctx, "Using first available KnowledgeQA model for title: %s", modelID)
break
}
}
if modelID == "" {
logger.Error(ctx, "No KnowledgeQA model found")
return "", errors.New("no KnowledgeQA model available for title generation")
}
} else {
logger.Infof(ctx, "Using specified model for title generation: %s", modelID)
}
chatModel, err := s.modelService.GetChatModel(ctx, modelID)
@@ -324,10 +326,12 @@ func (s *sessionService) GenerateTitle(ctx context.Context,
// GenerateTitleAsync generates a title for the session asynchronously
// This method clones the session and generates the title in a goroutine
// It emits an event when the title is generated
// modelID: optional model ID to use for title generation (if empty, uses first available KnowledgeQA model)
func (s *sessionService) GenerateTitleAsync(
ctx context.Context,
session *types.Session,
userQuery string,
modelID string,
eventBus *event.EventBus,
) {
// Extract values from context before cloning
@@ -356,7 +360,7 @@ func (s *sessionService) GenerateTitleAsync(
},
}
title, err := s.GenerateTitle(bgCtx, session, messages)
title, err := s.GenerateTitle(bgCtx, session, messages, modelID)
if err != nil {
logger.ErrorWithFields(bgCtx, err, map[string]interface{}{
"session_id": session.ID,
@@ -410,10 +414,9 @@ func (s *sessionService) KnowledgeQA(
// Use custom agent's knowledge bases only if request didn't specify any
// When user explicitly @mentions a knowledge base or document, only search those
if len(knowledgeBaseIDs) == 0 && len(knowledgeIDs) == 0 && customAgent != nil && len(customAgent.Config.KnowledgeBases) > 0 {
knowledgeBaseIDs = customAgent.Config.KnowledgeBases
logger.Infof(ctx, "No knowledge bases specified in request, using custom agent's knowledge bases: %v", knowledgeBaseIDs)
} else if len(knowledgeBaseIDs) > 0 || len(knowledgeIDs) > 0 {
if len(knowledgeBaseIDs) == 0 && len(knowledgeIDs) == 0 {
knowledgeBaseIDs = s.resolveKnowledgeBasesFromAgent(ctx, customAgent)
} else {
logger.Infof(ctx, "Using request-specified targets (ignoring agent config): kbs=%v, docs=%v", knowledgeBaseIDs, knowledgeIDs)
}
@@ -776,6 +779,48 @@ func (s *sessionService) selectChatModelID(
return "", errors.New("no chat model ID available: no knowledge bases configured and no available models")
}
// resolveKnowledgeBasesFromAgent resolves knowledge base IDs based on agent's KBSelectionMode
// Returns the resolved knowledge base IDs based on the selection mode:
// - "all": fetches all knowledge bases for the tenant
// - "selected": uses the explicitly configured knowledge bases
// - "none": returns empty slice
// - default: falls back to configured knowledge bases for backward compatibility
func (s *sessionService) resolveKnowledgeBasesFromAgent(
ctx context.Context,
customAgent *types.CustomAgent,
) []string {
if customAgent == nil {
return nil
}
switch customAgent.Config.KBSelectionMode {
case "all":
allKBs, err := s.knowledgeBaseService.ListKnowledgeBases(ctx)
if err != nil {
logger.Warnf(ctx, "Failed to list all knowledge bases: %v", err)
return nil
}
kbIDs := make([]string, 0, len(allKBs))
for _, kb := range allKBs {
kbIDs = append(kbIDs, kb.ID)
}
logger.Infof(ctx, "KBSelectionMode=all: loaded %d knowledge bases", len(kbIDs))
return kbIDs
case "selected":
logger.Infof(ctx, "KBSelectionMode=selected: using %d configured knowledge bases", len(customAgent.Config.KnowledgeBases))
return customAgent.Config.KnowledgeBases
case "none":
logger.Infof(ctx, "KBSelectionMode=none: no knowledge bases configured")
return nil
default:
// Default to "selected" behavior for backward compatibility
if len(customAgent.Config.KnowledgeBases) > 0 {
logger.Infof(ctx, "KBSelectionMode not set: using %d configured knowledge bases", len(customAgent.Config.KnowledgeBases))
}
return customAgent.Config.KnowledgeBases
}
}
// buildSearchTargets computes the unified search targets from knowledgeBaseIDs and knowledgeIDs
// This is called once at the request entry point to avoid repeated queries later in the pipeline
// Logic:
@@ -1055,27 +1100,24 @@ func (s *sessionService) AgentQA(
WebSearchMaxResults: customAgent.Config.WebSearchMaxResults,
MultiTurnEnabled: customAgent.Config.MultiTurnEnabled,
HistoryTurns: customAgent.Config.HistoryTurns,
KnowledgeBases: customAgent.Config.KnowledgeBases,
MCPSelectionMode: customAgent.Config.MCPSelectionMode,
MCPServices: customAgent.Config.MCPServices,
}
// Handle knowledge bases: request-level @ mentions take priority over agent config
// When user explicitly @mentions a knowledge base or document, only search those
// Resolve knowledge bases: request-level @ mentions take priority over agent config
if len(knowledgeBaseIDs) > 0 || len(knowledgeIDs) > 0 {
// User explicitly specified via @ mention, use only those (clear agent's default KBs)
// User explicitly specified via @ mention
if len(knowledgeBaseIDs) > 0 {
agentConfig.KnowledgeBases = knowledgeBaseIDs
logger.Infof(ctx, "Using request-specified knowledge bases (ignoring agent config): %v", knowledgeBaseIDs)
} else {
// User only specified documents, clear the default KBs
agentConfig.KnowledgeBases = nil
logger.Infof(ctx, "User specified documents only, clearing agent's default knowledge bases")
logger.Infof(ctx, "Using request-specified knowledge bases: %v", knowledgeBaseIDs)
}
if len(knowledgeIDs) > 0 {
agentConfig.KnowledgeIDs = knowledgeIDs
logger.Infof(ctx, "Using request-specified knowledge IDs: %v", knowledgeIDs)
}
} else {
// Use agent's configured knowledge bases based on KBSelectionMode
agentConfig.KnowledgeBases = s.resolveKnowledgeBasesFromAgent(ctx, customAgent)
}
// Use custom agent's allowed tools if specified, otherwise use defaults

View File

@@ -147,8 +147,13 @@ func (h *Handler) setupSSEStream(reqCtx *qaRequestContext) *sseStreamContext {
// Generate title if needed
if reqCtx.session.Title == "" {
logger.Infof(reqCtx.ctx, "Session has no title, starting async title generation, session ID: %s", reqCtx.sessionID)
h.sessionService.GenerateTitleAsync(asyncCtx, reqCtx.session, reqCtx.query, eventBus)
// Use the same model as the conversation for title generation
modelID := ""
if reqCtx.customAgent != nil && reqCtx.customAgent.Config.ModelID != "" {
modelID = reqCtx.customAgent.Config.ModelID
}
logger.Infof(reqCtx.ctx, "Session has no title, starting async title generation, session ID: %s, model: %s", reqCtx.sessionID, modelID)
h.sessionService.GenerateTitleAsync(asyncCtx, reqCtx.session, reqCtx.query, modelID, eventBus)
}
return streamCtx

View File

@@ -52,7 +52,7 @@ func (h *Handler) GenerateTitle(c *gin.Context) {
// Call service to generate title
logger.Infof(ctx, "Generating session title, session ID: %s, message count: %d", sessionID, len(request.Messages))
title, err := h.sessionService.GenerateTitle(ctx, session, request.Messages)
title, err := h.sessionService.GenerateTitle(ctx, session, request.Messages, "")
if err != nil {
logger.ErrorWithFields(ctx, err, nil)
c.Error(errors.NewInternalServerError(err.Error()))

View File

@@ -22,10 +22,12 @@ type SessionService interface {
// DeleteSession deletes a session
DeleteSession(ctx context.Context, id string) error
// GenerateTitle generates a title for the current conversation
GenerateTitle(ctx context.Context, session *types.Session, messages []types.Message) (string, error)
// modelID: optional model ID to use for title generation (if empty, uses first available KnowledgeQA model)
GenerateTitle(ctx context.Context, session *types.Session, messages []types.Message, modelID string) (string, error)
// GenerateTitleAsync generates a title for the session asynchronously
// It emits an event when the title is generated
GenerateTitleAsync(ctx context.Context, session *types.Session, userQuery string, eventBus *event.EventBus)
// modelID: optional model ID to use for title generation (if empty, uses first available KnowledgeQA model)
GenerateTitleAsync(ctx context.Context, session *types.Session, userQuery string, modelID string, eventBus *event.EventBus)
// KnowledgeQA performs knowledge-based question answering
// knowledgeBaseIDs: list of knowledge base IDs to search (supports multi-KB)
// knowledgeIDs: list of specific knowledge (file) IDs to search