Files
WeKnora/internal/application/service/knowledge.go
wizardchen 59d2f6b54c feat(graph): enhance graph extraction prompt rendering and language context handling
- Added a new method to render graph extraction prompts with shared placeholders for language.
- Updated entity and relationship extraction methods to utilize the new rendering function, improving prompt customization.
- Enhanced question generation task to include language context, ensuring prompts are generated in the correct language.
- Improved error handling for empty prompt configurations in question generation, enhancing robustness.
2026-03-25 22:08:29 +08:00

9071 lines
302 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
package service
import (
"context"
"crypto/md5"
"encoding/hex"
"encoding/json"
"errors"
"fmt"
"io"
"mime/multipart"
"net/http"
"net/url"
"os"
"path"
"regexp"
"runtime"
"slices"
"sort"
"strings"
"sync"
"time"
filesvc "github.com/Tencent/WeKnora/internal/application/service/file"
"github.com/Tencent/WeKnora/internal/application/service/retriever"
"github.com/Tencent/WeKnora/internal/config"
werrors "github.com/Tencent/WeKnora/internal/errors"
"github.com/Tencent/WeKnora/internal/infrastructure/chunker"
"github.com/Tencent/WeKnora/internal/infrastructure/docparser"
"github.com/Tencent/WeKnora/internal/logger"
"github.com/Tencent/WeKnora/internal/models/chat"
"github.com/Tencent/WeKnora/internal/models/embedding"
"github.com/Tencent/WeKnora/internal/tracing"
"github.com/Tencent/WeKnora/internal/types"
"github.com/Tencent/WeKnora/internal/types/interfaces"
secutils "github.com/Tencent/WeKnora/internal/utils"
"github.com/google/uuid"
"github.com/hibiken/asynq"
"github.com/redis/go-redis/v9"
"go.opentelemetry.io/otel/attribute"
"golang.org/x/sync/errgroup"
)
// Error definitions for knowledge service operations
var (
// ErrInvalidFileType is returned when an unsupported file type is provided
ErrInvalidFileType = errors.New("unsupported file type")
// ErrInvalidURL is returned when an invalid URL is provided
ErrInvalidURL = errors.New("invalid URL")
// ErrChunkNotFound is returned when a requested chunk cannot be found
ErrChunkNotFound = errors.New("chunk not found")
// ErrDuplicateFile is returned when trying to add a file that already exists
ErrDuplicateFile = errors.New("file already exists")
// ErrDuplicateURL is returned when trying to add a URL that already exists
ErrDuplicateURL = errors.New("URL already exists")
// ErrImageNotParse is returned when trying to update image information without enabling multimodel
ErrImageNotParse = errors.New("image not parse without enable multimodel")
)
// knowledgeService implements the knowledge service interface
// service 实现知识服务接口
type knowledgeService struct {
config *config.Config
retrieveEngine interfaces.RetrieveEngineRegistry
repo interfaces.KnowledgeRepository
kbService interfaces.KnowledgeBaseService
tenantRepo interfaces.TenantRepository
documentReader interfaces.DocumentReader
chunkService interfaces.ChunkService
chunkRepo interfaces.ChunkRepository
tagRepo interfaces.KnowledgeTagRepository
tagService interfaces.KnowledgeTagService
fileSvc interfaces.FileService
modelService interfaces.ModelService
task interfaces.TaskEnqueuer
graphEngine interfaces.RetrieveGraphRepository
redisClient *redis.Client
kbShareService interfaces.KBShareService
imageResolver *docparser.ImageResolver
}
const (
manualContentMaxLength = 200000
manualFileExtension = ".md"
faqImportBatchSize = 50 // 每批处理的FAQ条目数
)
// NewKnowledgeService creates a new knowledge service instance
func NewKnowledgeService(
config *config.Config,
repo interfaces.KnowledgeRepository,
documentReader interfaces.DocumentReader,
kbService interfaces.KnowledgeBaseService,
tenantRepo interfaces.TenantRepository,
chunkService interfaces.ChunkService,
chunkRepo interfaces.ChunkRepository,
tagRepo interfaces.KnowledgeTagRepository,
tagService interfaces.KnowledgeTagService,
fileSvc interfaces.FileService,
modelService interfaces.ModelService,
task interfaces.TaskEnqueuer,
graphEngine interfaces.RetrieveGraphRepository,
retrieveEngine interfaces.RetrieveEngineRegistry,
redisClient *redis.Client,
kbShareService interfaces.KBShareService,
imageResolver *docparser.ImageResolver,
) (interfaces.KnowledgeService, error) {
return &knowledgeService{
config: config,
repo: repo,
kbService: kbService,
tenantRepo: tenantRepo,
documentReader: documentReader,
chunkService: chunkService,
chunkRepo: chunkRepo,
tagRepo: tagRepo,
tagService: tagService,
fileSvc: fileSvc,
modelService: modelService,
task: task,
graphEngine: graphEngine,
retrieveEngine: retrieveEngine,
redisClient: redisClient,
kbShareService: kbShareService,
imageResolver: imageResolver,
}, nil
}
// getParserEngineOverridesFromContext returns parser engine overrides from tenant in context (e.g. MinerU endpoint, API key).
// Used when building document ReadRequest so UI-configured values take precedence over env.
func (s *knowledgeService) getParserEngineOverridesFromContext(ctx context.Context) map[string]string {
if v := ctx.Value(types.TenantInfoContextKey); v != nil {
if tenant, ok := v.(*types.Tenant); ok && tenant != nil {
return tenant.ParserEngineConfig.ToOverridesMap()
}
}
return nil
}
// GetRepository gets the knowledge repository
// Parameters:
// - ctx: Context with authentication and request information
//
// Returns:
// - interfaces.KnowledgeRepository: Knowledge repository
func (s *knowledgeService) GetRepository() interfaces.KnowledgeRepository {
return s.repo
}
// isKnowledgeDeleting checks if a knowledge entry is being deleted.
// This is used to prevent async tasks from conflicting with deletion operations.
func (s *knowledgeService) isKnowledgeDeleting(ctx context.Context, tenantID uint64, knowledgeID string) bool {
knowledge, err := s.repo.GetKnowledgeByID(ctx, tenantID, knowledgeID)
if err != nil {
// If we can't find the knowledge, assume it's deleted
logger.Warnf(ctx, "Failed to check knowledge deletion status (assuming deleted): %v", err)
return true
}
if knowledge == nil {
return true
}
return knowledge.ParseStatus == types.ParseStatusDeleting
}
// checkStorageEngineConfigured verifies that the knowledge base has a storage engine configured
// (either at the KB level or via the tenant default). Returns an error if no storage engine is found.
func checkStorageEngineConfigured(ctx context.Context, kb *types.KnowledgeBase) error {
provider := kb.GetStorageProvider()
if provider == "" {
tenant, _ := ctx.Value(types.TenantInfoContextKey).(*types.Tenant)
if tenant != nil && tenant.StorageEngineConfig != nil {
provider = strings.ToLower(strings.TrimSpace(tenant.StorageEngineConfig.DefaultProvider))
}
}
if provider == "" {
return werrors.NewBadRequestError("请先为知识库选择存储引擎,再上传内容。请前往知识库设置页面进行配置。")
}
return nil
}
// CreateKnowledgeFromFile creates a knowledge entry from an uploaded file
func (s *knowledgeService) CreateKnowledgeFromFile(ctx context.Context,
kbID string, file *multipart.FileHeader, metadata map[string]string, enableMultimodel *bool, customFileName string, tagID string,
) (*types.Knowledge, error) {
logger.Info(ctx, "Start creating knowledge from file")
// Use custom filename if provided, otherwise use original filename
fileName := file.Filename
if customFileName != "" {
fileName = customFileName
logger.Infof(ctx, "Using custom filename: %s (original: %s)", customFileName, file.Filename)
}
logger.Infof(ctx, "Knowledge base ID: %s, file: %s", kbID, fileName)
// Get knowledge base configuration
logger.Info(ctx, "Getting knowledge base configuration")
kb, err := s.kbService.GetKnowledgeBaseByID(ctx, kbID)
if err != nil {
logger.Errorf(ctx, "Failed to get knowledge base: %v", err)
return nil, err
}
if err := checkStorageEngineConfigured(ctx, kb); err != nil {
return nil, err
}
// 检查多模态配置完整性 - 只在图片文件时校验
if !IsImageType(getFileType(fileName)) {
logger.Info(ctx, "Non-image file with multimodal enabled, skipping COS/VLM validation")
} else {
// 解析有效 provider优先 KB 级别(新字段 > 旧字段),其次租户默认
provider := kb.GetStorageProvider()
tenant, _ := ctx.Value(types.TenantInfoContextKey).(*types.Tenant)
if provider == "" && tenant != nil && tenant.StorageEngineConfig != nil {
provider = strings.ToLower(strings.TrimSpace(tenant.StorageEngineConfig.DefaultProvider))
}
// 根据 provider 校验租户级存储引擎配置
switch provider {
case "cos":
if tenant == nil || tenant.StorageEngineConfig == nil || tenant.StorageEngineConfig.COS == nil ||
tenant.StorageEngineConfig.COS.SecretID == "" || tenant.StorageEngineConfig.COS.SecretKey == "" ||
tenant.StorageEngineConfig.COS.Region == "" || tenant.StorageEngineConfig.COS.BucketName == "" {
logger.Error(ctx, "COS configuration incomplete for image multimodal processing")
return nil, werrors.NewBadRequestError("上传图片文件需要完整的对象存储配置信息, 请前往知识库存储设置或系统设置页面进行补全")
}
case "minio":
ok := false
if tenant != nil && tenant.StorageEngineConfig != nil && tenant.StorageEngineConfig.MinIO != nil {
m := tenant.StorageEngineConfig.MinIO
if m.Mode == "remote" {
ok = m.Endpoint != "" && m.AccessKeyID != "" && m.SecretAccessKey != "" && m.BucketName != ""
} else {
ok = os.Getenv("MINIO_ENDPOINT") != "" && os.Getenv("MINIO_ACCESS_KEY_ID") != "" &&
os.Getenv("MINIO_SECRET_ACCESS_KEY") != "" &&
(m.BucketName != "" || os.Getenv("MINIO_BUCKET_NAME") != "")
}
}
if !ok {
logger.Error(ctx, "MinIO configuration incomplete for image multimodal processing")
return nil, werrors.NewBadRequestError("上传图片文件需要完整的对象存储配置信息, 请前往知识库存储设置或系统设置页面进行补全")
}
}
// 检查VLM配置
if !kb.VLMConfig.Enabled || kb.VLMConfig.ModelID == "" {
logger.Error(ctx, "VLM model is not configured")
return nil, werrors.NewBadRequestError("上传图片文件需要设置VLM模型")
}
logger.Info(ctx, "Image multimodal configuration validation passed")
}
// Validate file type
logger.Infof(ctx, "Checking file type: %s", fileName)
if !isValidFileType(fileName) {
logger.Error(ctx, "Invalid file type")
return nil, ErrInvalidFileType
}
// Calculate file hash for deduplication
logger.Info(ctx, "Calculating file hash")
hash, err := calculateFileHash(file)
if err != nil {
logger.Errorf(ctx, "Failed to calculate file hash: %v", err)
return nil, err
}
// Check if file already exists
tenantID := ctx.Value(types.TenantIDContextKey).(uint64)
logger.Infof(ctx, "Checking if file exists, tenant ID: %d", tenantID)
exists, existingKnowledge, err := s.repo.CheckKnowledgeExists(ctx, tenantID, kbID, &types.KnowledgeCheckParams{
Type: "file",
FileName: fileName,
FileSize: file.Size,
FileHash: hash,
})
if err != nil {
logger.Errorf(ctx, "Failed to check knowledge existence: %v", err)
return nil, err
}
if exists {
logger.Infof(ctx, "File already exists: %s", fileName)
// Update creation time for existing knowledge
if err := s.repo.UpdateKnowledgeColumn(ctx, existingKnowledge.ID, "created_at", time.Now()); err != nil {
logger.Errorf(ctx, "Failed to update existing knowledge: %v", err)
return nil, err
}
return existingKnowledge, types.NewDuplicateFileError(existingKnowledge)
}
// Check storage quota
tenantInfo := ctx.Value(types.TenantInfoContextKey).(*types.Tenant)
if tenantInfo.StorageQuota > 0 && tenantInfo.StorageUsed >= tenantInfo.StorageQuota {
logger.Error(ctx, "Storage quota exceeded")
return nil, types.NewStorageQuotaExceededError()
}
// Convert metadata to JSON format if provided
var metadataJSON types.JSON
if metadata != nil {
metadataBytes, err := json.Marshal(metadata)
if err != nil {
logger.Errorf(ctx, "Failed to marshal metadata: %v", err)
return nil, err
}
metadataJSON = types.JSON(metadataBytes)
}
// 验证文件名安全性
safeFilename, isValid := secutils.ValidateInput(fileName)
if !isValid {
logger.Errorf(ctx, "Invalid filename: %s", fileName)
return nil, werrors.NewValidationError("文件名包含非法字符")
}
// Create knowledge record
logger.Info(ctx, "Creating knowledge record")
knowledge := &types.Knowledge{
TenantID: tenantID,
KnowledgeBaseID: kbID,
TagID: tagID, // 设置分类ID用于知识分类管理
Type: "file",
Title: safeFilename,
FileName: safeFilename,
FileType: getFileType(safeFilename),
FileSize: file.Size,
FileHash: hash,
ParseStatus: "pending",
EnableStatus: "disabled",
CreatedAt: time.Now(),
UpdatedAt: time.Now(),
EmbeddingModelID: kb.EmbeddingModelID,
Metadata: metadataJSON,
}
// Save knowledge record to database
logger.Info(ctx, "Saving knowledge record to database")
if err := s.repo.CreateKnowledge(ctx, knowledge); err != nil {
logger.Errorf(ctx, "Failed to create knowledge record, ID: %s, error: %v", knowledge.ID, err)
return nil, err
}
// Save the file to storage (use KB-level storage engine if configured)
logger.Infof(ctx, "Saving file, knowledge ID: %s", knowledge.ID)
filePath, err := s.resolveFileService(ctx, kb).SaveFile(ctx, file, knowledge.TenantID, knowledge.ID)
if err != nil {
logger.Errorf(ctx, "Failed to save file, knowledge ID: %s, error: %v", knowledge.ID, err)
return nil, err
}
knowledge.FilePath = filePath
// Update knowledge record with file path
logger.Info(ctx, "Updating knowledge record with file path")
if err := s.repo.UpdateKnowledge(ctx, knowledge); err != nil {
logger.Errorf(ctx, "Failed to update knowledge with file path, ID: %s, error: %v", knowledge.ID, err)
return nil, err
}
// Enqueue document processing task to Asynq
logger.Info(ctx, "Enqueuing document processing task to Asynq")
enableMultimodelValue := false
if enableMultimodel != nil {
enableMultimodelValue = *enableMultimodel
} else {
enableMultimodelValue = kb.IsMultimodalEnabled()
}
// Check question generation config
enableQuestionGeneration := false
questionCount := 3 // default
if kb.QuestionGenerationConfig != nil && kb.QuestionGenerationConfig.Enabled {
enableQuestionGeneration = true
if kb.QuestionGenerationConfig.QuestionCount > 0 {
questionCount = kb.QuestionGenerationConfig.QuestionCount
}
}
taskPayload := types.DocumentProcessPayload{
TenantID: tenantID,
KnowledgeID: knowledge.ID,
KnowledgeBaseID: kbID,
FilePath: filePath,
FileName: safeFilename,
FileType: getFileType(safeFilename),
EnableMultimodel: enableMultimodelValue,
EnableQuestionGeneration: enableQuestionGeneration,
QuestionCount: questionCount,
}
payloadBytes, err := json.Marshal(taskPayload)
if err != nil {
logger.Errorf(ctx, "Failed to marshal document process task payload: %v", err)
// 即使入队失败也返回knowledge因为文件已保存
return knowledge, nil
}
task := asynq.NewTask(types.TypeDocumentProcess, payloadBytes, asynq.Queue("default"), asynq.MaxRetry(3))
info, err := s.task.Enqueue(task)
if err != nil {
logger.Errorf(ctx, "Failed to enqueue document process task: %v", err)
// 即使入队失败也返回knowledge因为文件已保存
return knowledge, nil
}
logger.Infof(
ctx,
"Enqueued document process task: id=%s queue=%s knowledge_id=%s",
info.ID,
info.Queue,
knowledge.ID,
)
if slices.Contains([]string{"csv", "xlsx", "xls"}, getFileType(safeFilename)) {
NewDataTableSummaryTask(ctx, s.task, tenantID, knowledge.ID, kb.SummaryModelID, kb.EmbeddingModelID)
}
logger.Infof(ctx, "Knowledge from file created successfully, ID: %s", knowledge.ID)
return knowledge, nil
}
// CreateKnowledgeFromURL creates a knowledge entry from a URL source
// tagID is optional - when provided, the knowledge will be assigned to the specified tag/category.
// isFileURL reports whether the given URL should be treated as a direct file download.
// Priority: URL path has a known file extension first, then fall back to user-provided fileName/fileType hints.
func isFileURL(rawURL, fileName, fileType string) bool {
u, err := url.Parse(rawURL)
if err == nil {
ext := strings.ToLower(strings.TrimPrefix(path.Ext(u.Path), "."))
if ext != "" && allowedFileURLExtensions[ext] {
return true
}
}
// Fall back to user-provided hints
return fileName != "" || fileType != ""
}
func (s *knowledgeService) CreateKnowledgeFromURL(ctx context.Context,
kbID string, rawURL string, fileName string, fileType string, enableMultimodel *bool, title string, tagID string,
) (*types.Knowledge, error) {
logger.Info(ctx, "Start creating knowledge from URL")
logger.Infof(ctx, "Knowledge base ID: %s, URL: %s", kbID, rawURL)
// Route to file_url logic when the URL points to a downloadable file
if isFileURL(rawURL, fileName, fileType) {
return s.createKnowledgeFromFileURL(ctx, kbID, rawURL, fileName, fileType, enableMultimodel, title, tagID)
}
url := rawURL
// Get knowledge base configuration
logger.Info(ctx, "Getting knowledge base configuration")
kb, err := s.kbService.GetKnowledgeBaseByID(ctx, kbID)
if err != nil {
logger.Errorf(ctx, "Failed to get knowledge base: %v", err)
return nil, err
}
if err := checkStorageEngineConfigured(ctx, kb); err != nil {
return nil, err
}
// Validate URL format and security
logger.Info(ctx, "Validating URL")
if !isValidURL(url) || !secutils.IsValidURL(url) {
logger.Error(ctx, "Invalid or unsafe URL format")
return nil, ErrInvalidURL
}
// SSRF protection: validate URL is safe to fetch
if safe, reason := secutils.IsSSRFSafeURL(url); !safe {
logger.Errorf(ctx, "URL rejected for SSRF protection: %s, reason: %s", url, reason)
return nil, ErrInvalidURL
}
// Check if URL already exists in the knowledge base
tenantID := ctx.Value(types.TenantIDContextKey).(uint64)
logger.Infof(ctx, "Checking if URL exists, tenant ID: %d", tenantID)
fileHash := calculateStr(url)
exists, existingKnowledge, err := s.repo.CheckKnowledgeExists(ctx, tenantID, kbID, &types.KnowledgeCheckParams{
Type: "url",
URL: url,
FileHash: fileHash,
})
if err != nil {
logger.Errorf(ctx, "Failed to check knowledge existence: %v", err)
return nil, err
}
if exists {
logger.Infof(ctx, "URL already exists: %s", url)
// Update creation time for existing knowledge
existingKnowledge.CreatedAt = time.Now()
existingKnowledge.UpdatedAt = time.Now()
if err := s.repo.UpdateKnowledge(ctx, existingKnowledge); err != nil {
logger.Errorf(ctx, "Failed to update existing knowledge: %v", err)
return nil, err
}
return existingKnowledge, types.NewDuplicateURLError(existingKnowledge)
}
// Check storage quota
tenantInfo := ctx.Value(types.TenantInfoContextKey).(*types.Tenant)
if tenantInfo.StorageQuota > 0 && tenantInfo.StorageUsed >= tenantInfo.StorageQuota {
logger.Error(ctx, "Storage quota exceeded")
return nil, types.NewStorageQuotaExceededError()
}
// Create knowledge record
logger.Info(ctx, "Creating knowledge record")
knowledge := &types.Knowledge{
ID: uuid.New().String(),
TenantID: tenantID,
KnowledgeBaseID: kbID,
Type: "url",
Title: title,
Source: url,
FileHash: fileHash,
ParseStatus: "pending",
EnableStatus: "disabled",
CreatedAt: time.Now(),
UpdatedAt: time.Now(),
EmbeddingModelID: kb.EmbeddingModelID,
TagID: tagID, // 设置分类ID用于知识分类管理
}
// Save knowledge record
logger.Infof(ctx, "Saving knowledge record to database, ID: %s", knowledge.ID)
if err := s.repo.CreateKnowledge(ctx, knowledge); err != nil {
logger.Errorf(ctx, "Failed to create knowledge record: %v", err)
return nil, err
}
// Enqueue URL processing task to Asynq
logger.Info(ctx, "Enqueuing URL processing task to Asynq")
enableMultimodelValue := false
if enableMultimodel != nil {
enableMultimodelValue = *enableMultimodel
} else {
enableMultimodelValue = kb.IsMultimodalEnabled()
}
// Check question generation config
enableQuestionGeneration := false
questionCount := 3 // default
if kb.QuestionGenerationConfig != nil && kb.QuestionGenerationConfig.Enabled {
enableQuestionGeneration = true
if kb.QuestionGenerationConfig.QuestionCount > 0 {
questionCount = kb.QuestionGenerationConfig.QuestionCount
}
}
taskPayload := types.DocumentProcessPayload{
TenantID: tenantID,
KnowledgeID: knowledge.ID,
KnowledgeBaseID: kbID,
URL: url,
EnableMultimodel: enableMultimodelValue,
EnableQuestionGeneration: enableQuestionGeneration,
QuestionCount: questionCount,
}
payloadBytes, err := json.Marshal(taskPayload)
if err != nil {
logger.Errorf(ctx, "Failed to marshal URL process task payload: %v", err)
return knowledge, nil
}
task := asynq.NewTask(types.TypeDocumentProcess, payloadBytes, asynq.Queue("default"), asynq.MaxRetry(3))
info, err := s.task.Enqueue(task)
if err != nil {
logger.Errorf(ctx, "Failed to enqueue URL process task: %v", err)
return knowledge, nil
}
logger.Infof(ctx, "Enqueued URL process task: id=%s queue=%s knowledge_id=%s", info.ID, info.Queue, knowledge.ID)
logger.Infof(ctx, "Knowledge from URL created successfully, ID: %s", knowledge.ID)
return knowledge, nil
}
// allowedFileURLExtensions defines the supported file extensions for file URL import
var allowedFileURLExtensions = map[string]bool{
"txt": true,
"md": true,
"pdf": true,
"docx": true,
"doc": true,
}
// maxFileURLSize is the maximum allowed file size for file URL import (10MB)
const maxFileURLSize = 10 * 1024 * 1024
// extractFileNameFromURL extracts the filename from a URL path
func extractFileNameFromURL(rawURL string) string {
u, err := url.Parse(rawURL)
if err != nil {
return ""
}
base := path.Base(u.Path)
if base == "." || base == "/" {
return ""
}
return base
}
// extractFileNameFromContentDisposition extracts filename from Content-Disposition header
func extractFileNameFromContentDisposition(header string) string {
// e.g. attachment; filename="document.pdf" or filename*=UTF-8''document.pdf
for _, part := range strings.Split(header, ";") {
part = strings.TrimSpace(part)
if strings.HasPrefix(strings.ToLower(part), "filename=") {
name := strings.TrimPrefix(part, "filename=")
name = strings.TrimPrefix(part[len("filename="):], "")
name = strings.Trim(name, `"'`)
if name != "" {
return name
}
}
}
return ""
}
// createKnowledgeFromFileURL is the internal implementation for file URL knowledge creation.
// Called by CreateKnowledgeFromURL when the URL is detected as a direct file download.
func (s *knowledgeService) createKnowledgeFromFileURL(
ctx context.Context,
kbID string,
fileURL string,
fileName string,
fileType string,
enableMultimodel *bool,
title string,
tagID string,
) (*types.Knowledge, error) {
logger.Info(ctx, "Start creating knowledge from file URL")
logger.Infof(ctx, "Knowledge base ID: %s, file URL: %s", kbID, fileURL)
// Get knowledge base configuration
kb, err := s.kbService.GetKnowledgeBaseByID(ctx, kbID)
if err != nil {
logger.Errorf(ctx, "Failed to get knowledge base: %v", err)
return nil, err
}
if err := checkStorageEngineConfigured(ctx, kb); err != nil {
return nil, err
}
// Validate URL format and security (static check only, no HEAD request)
if !isValidURL(fileURL) || !secutils.IsValidURL(fileURL) {
logger.Error(ctx, "Invalid or unsafe file URL format")
return nil, ErrInvalidURL
}
if safe, reason := secutils.IsSSRFSafeURL(fileURL); !safe {
logger.Errorf(ctx, "File URL rejected for SSRF protection: %s, reason: %s", fileURL, reason)
return nil, ErrInvalidURL
}
// Resolve fileName: user-provided > extracted from URL path
if fileName == "" {
fileName = extractFileNameFromURL(fileURL)
}
// Resolve fileType: user-provided > inferred from fileName
if fileType == "" && fileName != "" {
fileType = getFileType(fileName)
}
// Validate file extension against whitelist (if we can determine it)
if fileType != "" {
if !allowedFileURLExtensions[strings.ToLower(fileType)] {
logger.Errorf(ctx, "Unsupported file type for file URL import: %s", fileType)
return nil, werrors.NewBadRequestError(fmt.Sprintf("不支持的文件类型: %s仅支持 txt, md, pdf, docx, doc", fileType))
}
}
// Use title as display name if fileName is still empty
displayName := fileName
if displayName == "" {
displayName = title
}
if displayName == "" {
// Fallback: use last segment of URL
displayName = extractFileNameFromURL(fileURL)
}
if displayName == "" {
displayName = fileURL
}
// Check for duplicate (by URL hash)
tenantID := ctx.Value(types.TenantIDContextKey).(uint64)
fileHash := calculateStr(fileURL)
exists, existingKnowledge, err := s.repo.CheckKnowledgeExists(ctx, tenantID, kbID, &types.KnowledgeCheckParams{
Type: "file_url",
URL: fileURL,
FileHash: fileHash,
})
if err != nil {
logger.Errorf(ctx, "Failed to check knowledge existence: %v", err)
return nil, err
}
if exists {
logger.Infof(ctx, "File URL already exists: %s", fileURL)
existingKnowledge.CreatedAt = time.Now()
existingKnowledge.UpdatedAt = time.Now()
if err := s.repo.UpdateKnowledge(ctx, existingKnowledge); err != nil {
logger.Errorf(ctx, "Failed to update existing knowledge: %v", err)
return nil, err
}
return existingKnowledge, types.NewDuplicateURLError(existingKnowledge)
}
// Check storage quota
tenantInfo := ctx.Value(types.TenantInfoContextKey).(*types.Tenant)
if tenantInfo.StorageQuota > 0 && tenantInfo.StorageUsed >= tenantInfo.StorageQuota {
logger.Error(ctx, "Storage quota exceeded")
return nil, types.NewStorageQuotaExceededError()
}
// Create knowledge record
knowledge := &types.Knowledge{
ID: uuid.New().String(),
TenantID: tenantID,
KnowledgeBaseID: kbID,
Type: "file_url",
Title: title,
FileName: displayName,
FileType: fileType,
Source: fileURL,
FileHash: fileHash,
ParseStatus: "pending",
EnableStatus: "disabled",
CreatedAt: time.Now(),
UpdatedAt: time.Now(),
EmbeddingModelID: kb.EmbeddingModelID,
TagID: tagID,
}
if knowledge.Title == "" {
knowledge.Title = displayName
}
if err := s.repo.CreateKnowledge(ctx, knowledge); err != nil {
logger.Errorf(ctx, "Failed to create knowledge record: %v", err)
return nil, err
}
// Build async task payload
enableMultimodelValue := false
if enableMultimodel != nil {
enableMultimodelValue = *enableMultimodel
} else {
enableMultimodelValue = kb.IsMultimodalEnabled()
}
enableQuestionGeneration := false
questionCount := 3
if kb.QuestionGenerationConfig != nil && kb.QuestionGenerationConfig.Enabled {
enableQuestionGeneration = true
if kb.QuestionGenerationConfig.QuestionCount > 0 {
questionCount = kb.QuestionGenerationConfig.QuestionCount
}
}
taskPayload := types.DocumentProcessPayload{
TenantID: tenantID,
KnowledgeID: knowledge.ID,
KnowledgeBaseID: kbID,
FileURL: fileURL,
FileName: fileName,
FileType: fileType,
EnableMultimodel: enableMultimodelValue,
EnableQuestionGeneration: enableQuestionGeneration,
QuestionCount: questionCount,
}
payloadBytes, err := json.Marshal(taskPayload)
if err != nil {
logger.Errorf(ctx, "Failed to marshal file URL process task payload: %v", err)
return knowledge, nil
}
task := asynq.NewTask(types.TypeDocumentProcess, payloadBytes, asynq.Queue("default"))
info, err := s.task.Enqueue(task)
if err != nil {
logger.Errorf(ctx, "Failed to enqueue file URL process task: %v", err)
return knowledge, nil
}
logger.Infof(ctx, "Enqueued file URL process task: id=%s queue=%s knowledge_id=%s", info.ID, info.Queue, knowledge.ID)
logger.Infof(ctx, "Knowledge from file URL created successfully, ID: %s", knowledge.ID)
return knowledge, nil
}
// CreateKnowledgeFromPassage creates a knowledge entry from text passages
func (s *knowledgeService) CreateKnowledgeFromPassage(ctx context.Context,
kbID string, passage []string,
) (*types.Knowledge, error) {
return s.createKnowledgeFromPassageInternal(ctx, kbID, passage, false)
}
// CreateKnowledgeFromPassageSync creates a knowledge entry from text passages and waits for indexing to complete.
func (s *knowledgeService) CreateKnowledgeFromPassageSync(ctx context.Context,
kbID string, passage []string,
) (*types.Knowledge, error) {
return s.createKnowledgeFromPassageInternal(ctx, kbID, passage, true)
}
// CreateKnowledgeFromManual creates or saves manual Markdown knowledge content.
func (s *knowledgeService) CreateKnowledgeFromManual(ctx context.Context,
kbID string, payload *types.ManualKnowledgePayload,
) (*types.Knowledge, error) {
logger.Info(ctx, "Start creating manual knowledge entry")
if payload == nil {
return nil, werrors.NewBadRequestError("请求内容不能为空")
}
cleanContent := secutils.CleanMarkdown(payload.Content)
if strings.TrimSpace(cleanContent) == "" {
return nil, werrors.NewValidationError("内容不能为空")
}
if len([]rune(cleanContent)) > manualContentMaxLength {
return nil, werrors.NewValidationError(fmt.Sprintf("内容长度超出限制(最多%d个字符", manualContentMaxLength))
}
safeTitle, ok := secutils.ValidateInput(payload.Title)
if !ok {
return nil, werrors.NewValidationError("标题包含非法字符或超出长度限制")
}
status := strings.ToLower(strings.TrimSpace(payload.Status))
if status == "" {
status = types.ManualKnowledgeStatusDraft
}
if status != types.ManualKnowledgeStatusDraft && status != types.ManualKnowledgeStatusPublish {
return nil, werrors.NewValidationError("状态仅支持 draft 或 publish")
}
kb, err := s.kbService.GetKnowledgeBaseByID(ctx, kbID)
if err != nil {
logger.Errorf(ctx, "Failed to get knowledge base: %v", err)
return nil, err
}
if err := checkStorageEngineConfigured(ctx, kb); err != nil {
return nil, err
}
tenantID := ctx.Value(types.TenantIDContextKey).(uint64)
now := time.Now()
title := safeTitle
if title == "" {
title = fmt.Sprintf("Knowledge-%s", now.Format("20060102-150405"))
}
fileName := ensureManualFileName(title)
meta := types.NewManualKnowledgeMetadata(cleanContent, status, 1)
knowledge := &types.Knowledge{
TenantID: tenantID,
KnowledgeBaseID: kbID,
Type: types.KnowledgeTypeManual,
Title: title,
Description: "",
Source: types.KnowledgeTypeManual,
ParseStatus: types.ManualKnowledgeStatusDraft,
EnableStatus: "disabled",
CreatedAt: now,
UpdatedAt: now,
EmbeddingModelID: kb.EmbeddingModelID,
FileName: fileName,
FileType: types.KnowledgeTypeManual,
TagID: payload.TagID, // 设置分类ID用于知识分类管理
}
if err := knowledge.SetManualMetadata(meta); err != nil {
logger.Errorf(ctx, "Failed to set manual metadata: %v", err)
return nil, err
}
knowledge.EnsureManualDefaults()
if status == types.ManualKnowledgeStatusPublish {
knowledge.ParseStatus = "pending"
}
if err := s.repo.CreateKnowledge(ctx, knowledge); err != nil {
logger.Errorf(ctx, "Failed to create manual knowledge record: %v", err)
return nil, err
}
if status == types.ManualKnowledgeStatusPublish {
logger.Infof(ctx, "Manual knowledge created, enqueuing async processing task, ID: %s", knowledge.ID)
if err := s.enqueueManualProcessing(ctx, knowledge, cleanContent, false); err != nil {
logger.Errorf(ctx, "Failed to enqueue manual processing task for new knowledge: %v", err)
// Non-fatal: mark as failed so user can retry
knowledge.ParseStatus = "failed"
knowledge.ErrorMessage = "Failed to enqueue processing task"
s.repo.UpdateKnowledge(ctx, knowledge)
}
}
return knowledge, nil
}
// createKnowledgeFromPassageInternal consolidates the common logic for creating knowledge from passages.
// When syncMode is true, chunk processing is performed synchronously; otherwise, it's processed asynchronously.
func (s *knowledgeService) createKnowledgeFromPassageInternal(ctx context.Context,
kbID string, passage []string, syncMode bool,
) (*types.Knowledge, error) {
if syncMode {
logger.Info(ctx, "Start creating knowledge from passage (sync)")
} else {
logger.Info(ctx, "Start creating knowledge from passage")
}
logger.Infof(ctx, "Knowledge base ID: %s, passage count: %d", kbID, len(passage))
// 验证段落内容安全性
safePassages := make([]string, 0, len(passage))
for i, p := range passage {
safePassage, isValid := secutils.ValidateInput(p)
if !isValid {
logger.Errorf(ctx, "Invalid passage content at index %d", i)
return nil, werrors.NewValidationError(fmt.Sprintf("段落 %d 包含非法内容", i+1))
}
safePassages = append(safePassages, safePassage)
}
// Get knowledge base configuration
logger.Info(ctx, "Getting knowledge base configuration")
kb, err := s.kbService.GetKnowledgeBaseByID(ctx, kbID)
if err != nil {
logger.Errorf(ctx, "Failed to get knowledge base: %v", err)
return nil, err
}
// Create knowledge record
if syncMode {
logger.Info(ctx, "Creating knowledge record (sync)")
} else {
logger.Info(ctx, "Creating knowledge record")
}
knowledge := &types.Knowledge{
ID: uuid.New().String(),
TenantID: ctx.Value(types.TenantIDContextKey).(uint64),
KnowledgeBaseID: kbID,
Type: "passage",
ParseStatus: "pending",
EnableStatus: "disabled",
CreatedAt: time.Now(),
UpdatedAt: time.Now(),
EmbeddingModelID: kb.EmbeddingModelID,
}
// Save knowledge record
logger.Infof(ctx, "Saving knowledge record to database, ID: %s", knowledge.ID)
if err := s.repo.CreateKnowledge(ctx, knowledge); err != nil {
logger.Errorf(ctx, "Failed to create knowledge record: %v", err)
return nil, err
}
// Process passages
if syncMode {
logger.Info(ctx, "Processing passage synchronously")
s.processDocumentFromPassage(ctx, kb, knowledge, safePassages)
logger.Infof(ctx, "Knowledge from passage created successfully (sync), ID: %s", knowledge.ID)
} else {
// Enqueue passage processing task to Asynq
logger.Info(ctx, "Enqueuing passage processing task to Asynq")
tenantID := ctx.Value(types.TenantIDContextKey).(uint64)
// Check question generation config
enableQuestionGeneration := false
questionCount := 3 // default
if kb.QuestionGenerationConfig != nil && kb.QuestionGenerationConfig.Enabled {
enableQuestionGeneration = true
if kb.QuestionGenerationConfig.QuestionCount > 0 {
questionCount = kb.QuestionGenerationConfig.QuestionCount
}
}
taskPayload := types.DocumentProcessPayload{
TenantID: tenantID,
KnowledgeID: knowledge.ID,
KnowledgeBaseID: kbID,
Passages: safePassages,
EnableMultimodel: false, // 文本段落不支持多模态
EnableQuestionGeneration: enableQuestionGeneration,
QuestionCount: questionCount,
}
payloadBytes, err := json.Marshal(taskPayload)
if err != nil {
logger.Errorf(ctx, "Failed to marshal passage process task payload: %v", err)
// 即使入队失败也返回knowledge
return knowledge, nil
}
task := asynq.NewTask(types.TypeDocumentProcess, payloadBytes, asynq.Queue("default"), asynq.MaxRetry(3))
info, err := s.task.Enqueue(task)
if err != nil {
logger.Errorf(ctx, "Failed to enqueue passage process task: %v", err)
return knowledge, nil
}
logger.Infof(ctx, "Enqueued passage process task: id=%s queue=%s knowledge_id=%s", info.ID, info.Queue, knowledge.ID)
logger.Infof(ctx, "Knowledge from passage created successfully, ID: %s", knowledge.ID)
}
return knowledge, nil
}
// GetKnowledgeByID retrieves a knowledge entry by its ID
func (s *knowledgeService) GetKnowledgeByID(ctx context.Context, id string) (*types.Knowledge, error) {
tenantID := ctx.Value(types.TenantIDContextKey).(uint64)
knowledge, err := s.repo.GetKnowledgeByID(ctx, tenantID, id)
if err != nil {
logger.ErrorWithFields(ctx, err, map[string]interface{}{
"knowledge_id": id,
"tenant_id": tenantID,
})
return nil, err
}
logger.Infof(ctx, "Knowledge retrieved successfully, ID: %s, type: %s", knowledge.ID, knowledge.Type)
return knowledge, nil
}
// GetKnowledgeByIDOnly retrieves knowledge by ID without tenant filter (for permission resolution).
func (s *knowledgeService) GetKnowledgeByIDOnly(ctx context.Context, id string) (*types.Knowledge, error) {
return s.repo.GetKnowledgeByIDOnly(ctx, id)
}
// ListKnowledgeByKnowledgeBaseID returns all knowledge entries in a knowledge base
func (s *knowledgeService) ListKnowledgeByKnowledgeBaseID(ctx context.Context,
kbID string,
) ([]*types.Knowledge, error) {
return s.repo.ListKnowledgeByKnowledgeBaseID(ctx, ctx.Value(types.TenantIDContextKey).(uint64), kbID)
}
// ListPagedKnowledgeByKnowledgeBaseID returns paginated knowledge entries in a knowledge base
func (s *knowledgeService) ListPagedKnowledgeByKnowledgeBaseID(ctx context.Context,
kbID string, page *types.Pagination, tagID string, keyword string, fileType string,
) (*types.PageResult, error) {
knowledges, total, err := s.repo.ListPagedKnowledgeByKnowledgeBaseID(ctx,
ctx.Value(types.TenantIDContextKey).(uint64), kbID, page, tagID, keyword, fileType)
if err != nil {
return nil, err
}
return types.NewPageResult(total, page, knowledges), nil
}
// DeleteKnowledge deletes a knowledge entry and all related resources
func (s *knowledgeService) DeleteKnowledge(ctx context.Context, id string) error {
// Get the knowledge entry
knowledge, err := s.repo.GetKnowledgeByID(ctx, ctx.Value(types.TenantIDContextKey).(uint64), id)
if err != nil {
return err
}
// Mark as deleting first to prevent async task conflicts
// This ensures that any running async tasks will detect the deletion and abort
originalStatus := knowledge.ParseStatus
knowledge.ParseStatus = types.ParseStatusDeleting
knowledge.UpdatedAt = time.Now()
if err := s.repo.UpdateKnowledge(ctx, knowledge); err != nil {
logger.GetLogger(ctx).WithField("error", err).Errorf("DeleteKnowledge failed to mark as deleting")
// Continue with deletion even if marking fails
} else {
logger.Infof(ctx, "Marked knowledge %s as deleting (previous status: %s)", id, originalStatus)
}
// Resolve file service for this KB before spawning goroutines
kb, _ := s.kbService.GetKnowledgeBaseByID(ctx, knowledge.KnowledgeBaseID)
kbFileSvc := s.resolveFileService(ctx, kb)
wg := errgroup.Group{}
// Delete knowledge embeddings from vector store
wg.Go(func() error {
tenantInfo := ctx.Value(types.TenantInfoContextKey).(*types.Tenant)
retrieveEngine, err := retriever.NewCompositeRetrieveEngine(
s.retrieveEngine,
tenantInfo.GetEffectiveEngines(),
)
if err != nil {
logger.GetLogger(ctx).WithField("error", err).Errorf("DeleteKnowledge delete knowledge embedding failed")
return err
}
embeddingModel, err := s.modelService.GetEmbeddingModel(ctx, knowledge.EmbeddingModelID)
if err != nil {
logger.GetLogger(ctx).WithField("error", err).Errorf("DeleteKnowledge delete knowledge embedding failed")
return err
}
if err := retrieveEngine.DeleteByKnowledgeIDList(ctx, []string{knowledge.ID}, embeddingModel.GetDimensions(), knowledge.Type); err != nil {
logger.GetLogger(ctx).WithField("error", err).Errorf("DeleteKnowledge delete knowledge embedding failed")
return err
}
return nil
})
// Delete all chunks associated with this knowledge
wg.Go(func() error {
if err := s.chunkService.DeleteChunksByKnowledgeID(ctx, knowledge.ID); err != nil {
logger.GetLogger(ctx).WithField("error", err).Errorf("DeleteKnowledge delete chunks failed")
return err
}
return nil
})
// Delete the physical file if it exists
wg.Go(func() error {
if knowledge.FilePath != "" {
if err := kbFileSvc.DeleteFile(ctx, knowledge.FilePath); err != nil {
logger.GetLogger(ctx).WithField("error", err).Errorf("DeleteKnowledge delete file failed")
}
}
tenantInfo := ctx.Value(types.TenantInfoContextKey).(*types.Tenant)
tenantInfo.StorageUsed -= knowledge.StorageSize
if err := s.tenantRepo.AdjustStorageUsed(ctx, tenantInfo.ID, -knowledge.StorageSize); err != nil {
logger.GetLogger(ctx).WithField("error", err).Errorf("DeleteKnowledge update tenant storage used failed")
}
return nil
})
// Delete the knowledge graph
wg.Go(func() error {
namespace := types.NameSpace{KnowledgeBase: knowledge.KnowledgeBaseID, Knowledge: knowledge.ID}
if err := s.graphEngine.DelGraph(ctx, []types.NameSpace{namespace}); err != nil {
logger.GetLogger(ctx).WithField("error", err).Errorf("DeleteKnowledge delete knowledge graph failed")
return err
}
return nil
})
if err = wg.Wait(); err != nil {
return err
}
// Delete the knowledge entry itself from the database
return s.repo.DeleteKnowledge(ctx, ctx.Value(types.TenantIDContextKey).(uint64), id)
}
// DeleteKnowledgeList deletes a knowledge entry and all related resources
func (s *knowledgeService) DeleteKnowledgeList(ctx context.Context, ids []string) error {
if len(ids) == 0 {
return nil
}
// 1. Get the knowledge entry
tenantInfo := ctx.Value(types.TenantInfoContextKey).(*types.Tenant)
knowledgeList, err := s.repo.GetKnowledgeBatch(ctx, tenantInfo.ID, ids)
if err != nil {
return err
}
// Mark all as deleting first to prevent async task conflicts
for _, knowledge := range knowledgeList {
knowledge.ParseStatus = types.ParseStatusDeleting
knowledge.UpdatedAt = time.Now()
if err := s.repo.UpdateKnowledge(ctx, knowledge); err != nil {
logger.GetLogger(ctx).WithField("error", err).WithField("knowledge_id", knowledge.ID).
Errorf("DeleteKnowledgeList failed to mark as deleting")
// Continue with deletion even if marking fails
}
}
logger.Infof(ctx, "Marked %d knowledge entries as deleting", len(knowledgeList))
// Pre-resolve file services per KB so goroutines don't need DB access
kbFileServices := make(map[string]interfaces.FileService)
for _, knowledge := range knowledgeList {
if _, ok := kbFileServices[knowledge.KnowledgeBaseID]; !ok {
kb, _ := s.kbService.GetKnowledgeBaseByID(ctx, knowledge.KnowledgeBaseID)
kbFileServices[knowledge.KnowledgeBaseID] = s.resolveFileService(ctx, kb)
}
}
wg := errgroup.Group{}
// 2. Delete knowledge embeddings from vector store
wg.Go(func() error {
tenantInfo := ctx.Value(types.TenantInfoContextKey).(*types.Tenant)
retrieveEngine, err := retriever.NewCompositeRetrieveEngine(
s.retrieveEngine,
tenantInfo.GetEffectiveEngines(),
)
if err != nil {
logger.GetLogger(ctx).WithField("error", err).Errorf("DeleteKnowledge delete knowledge embedding failed")
return err
}
// Group by EmbeddingModelID and Type
type groupKey struct {
EmbeddingModelID string
Type string
}
group := map[groupKey][]string{}
for _, knowledge := range knowledgeList {
key := groupKey{EmbeddingModelID: knowledge.EmbeddingModelID, Type: knowledge.Type}
group[key] = append(group[key], knowledge.ID)
}
for key, knowledgeIDs := range group {
embeddingModel, err := s.modelService.GetEmbeddingModel(ctx, key.EmbeddingModelID)
if err != nil {
logger.GetLogger(ctx).WithField("error", err).Errorf("DeleteKnowledge get embedding model failed")
return err
}
if err := retrieveEngine.DeleteByKnowledgeIDList(ctx, knowledgeIDs, embeddingModel.GetDimensions(), key.Type); err != nil {
logger.GetLogger(ctx).
WithField("error", err).
Errorf("DeleteKnowledge delete knowledge embedding failed")
return err
}
}
return nil
})
// 3. Delete all chunks associated with this knowledge
wg.Go(func() error {
if err := s.chunkService.DeleteByKnowledgeList(ctx, ids); err != nil {
logger.GetLogger(ctx).WithField("error", err).Errorf("DeleteKnowledge delete chunks failed")
return err
}
return nil
})
// 4. Delete the physical file if it exists
wg.Go(func() error {
storageAdjust := int64(0)
for _, knowledge := range knowledgeList {
if knowledge.FilePath != "" {
fSvc := kbFileServices[knowledge.KnowledgeBaseID]
if err := fSvc.DeleteFile(ctx, knowledge.FilePath); err != nil {
logger.GetLogger(ctx).WithField("error", err).Errorf("DeleteKnowledge delete file failed")
}
}
storageAdjust -= knowledge.StorageSize
}
tenantInfo.StorageUsed += storageAdjust
if err := s.tenantRepo.AdjustStorageUsed(ctx, tenantInfo.ID, storageAdjust); err != nil {
logger.GetLogger(ctx).WithField("error", err).Errorf("DeleteKnowledge update tenant storage used failed")
}
return nil
})
// Delete the knowledge graph
wg.Go(func() error {
namespaces := []types.NameSpace{}
for _, knowledge := range knowledgeList {
namespaces = append(
namespaces,
types.NameSpace{KnowledgeBase: knowledge.KnowledgeBaseID, Knowledge: knowledge.ID},
)
}
if err := s.graphEngine.DelGraph(ctx, namespaces); err != nil {
logger.GetLogger(ctx).WithField("error", err).Errorf("DeleteKnowledge delete knowledge graph failed")
return err
}
return nil
})
if err = wg.Wait(); err != nil {
return err
}
// 5. Delete the knowledge entry itself from the database
return s.repo.DeleteKnowledgeList(ctx, tenantInfo.ID, ids)
}
func (s *knowledgeService) cloneKnowledge(
ctx context.Context,
src *types.Knowledge,
targetKB *types.KnowledgeBase,
) (err error) {
if src.ParseStatus != "completed" {
logger.GetLogger(ctx).WithField("knowledge_id", src.ID).Errorf("MoveKnowledge parse status is not completed")
return nil
}
tenantInfo := ctx.Value(types.TenantInfoContextKey).(*types.Tenant)
dst := &types.Knowledge{
ID: uuid.New().String(),
TenantID: targetKB.TenantID,
KnowledgeBaseID: targetKB.ID,
Type: src.Type,
Title: src.Title,
Description: src.Description,
Source: src.Source,
ParseStatus: "processing",
EnableStatus: "disabled",
EmbeddingModelID: targetKB.EmbeddingModelID,
FileName: src.FileName,
FileType: src.FileType,
FileSize: src.FileSize,
FileHash: src.FileHash,
FilePath: src.FilePath,
StorageSize: src.StorageSize,
Metadata: src.Metadata,
}
defer func() {
if err != nil {
dst.ParseStatus = "failed"
dst.ErrorMessage = err.Error()
_ = s.repo.UpdateKnowledge(ctx, dst)
logger.GetLogger(ctx).WithField("error", err).Errorf("MoveKnowledge failed to move knowledge")
} else {
dst.ParseStatus = "completed"
dst.EnableStatus = "enabled"
_ = s.repo.UpdateKnowledge(ctx, dst)
logger.GetLogger(ctx).WithField("knowledge_id", dst.ID).Infof("MoveKnowledge move knowledge successfully")
}
}()
if err = s.repo.CreateKnowledge(ctx, dst); err != nil {
logger.GetLogger(ctx).WithField("error", err).Errorf("MoveKnowledge create knowledge failed")
return
}
tenantInfo.StorageUsed += dst.StorageSize
if err = s.tenantRepo.AdjustStorageUsed(ctx, tenantInfo.ID, dst.StorageSize); err != nil {
logger.GetLogger(ctx).WithField("error", err).Errorf("MoveKnowledge update tenant storage used failed")
return
}
if err = s.CloneChunk(ctx, src, dst); err != nil {
logger.GetLogger(ctx).WithField("knowledge_id", dst.ID).
WithField("error", err).Errorf("MoveKnowledge move chunks failed")
return
}
return
}
// processDocumentFromPassage handles asynchronous processing of text passages
func (s *knowledgeService) processDocumentFromPassage(ctx context.Context,
kb *types.KnowledgeBase, knowledge *types.Knowledge, passage []string,
) {
// Update status to processing
knowledge.ParseStatus = "processing"
knowledge.UpdatedAt = time.Now()
if err := s.repo.UpdateKnowledge(ctx, knowledge); err != nil {
return
}
// Convert passages to chunks
chunks := make([]types.ParsedChunk, 0, len(passage))
start, end := 0, 0
for i, p := range passage {
if p == "" {
continue
}
end += len([]rune(p))
chunks = append(chunks, types.ParsedChunk{
Content: p,
Seq: i,
Start: start,
End: end,
})
start = end
}
// Process and store chunks
s.processChunks(ctx, kb, knowledge, chunks)
}
// ProcessChunksOptions contains options for processing chunks
type ProcessChunksOptions struct {
EnableQuestionGeneration bool
QuestionCount int
EnableMultimodel bool
StoredImages []docparser.StoredImage
// ParentChunks holds parent chunk data when parent-child chunking is enabled.
// When set, the chunks passed to processChunks are child chunks, and each
// child's ParentIndex references an entry in this slice.
ParentChunks []types.ParsedParentChunk
}
// buildParentChildConfigs derives parent and child SplitterConfig from ChunkingConfig.
// The base config (already validated with defaults) is used for separators.
func buildParentChildConfigs(cc types.ChunkingConfig, base chunker.SplitterConfig) (parent, child chunker.SplitterConfig) {
parentSize := cc.ParentChunkSize
if parentSize <= 0 {
parentSize = 4096
}
childSize := cc.ChildChunkSize
if childSize <= 0 {
childSize = 384
}
parent = chunker.SplitterConfig{
ChunkSize: parentSize,
ChunkOverlap: base.ChunkOverlap, // reuse configured overlap for parents
Separators: base.Separators,
}
child = chunker.SplitterConfig{
ChunkSize: childSize,
ChunkOverlap: childSize / 5, // ~20% overlap for child chunks
Separators: base.Separators,
}
return
}
// processChunks processes chunks and creates embeddings for knowledge content
func (s *knowledgeService) processChunks(ctx context.Context,
kb *types.KnowledgeBase, knowledge *types.Knowledge, chunks []types.ParsedChunk,
opts ...ProcessChunksOptions,
) {
// Get options
var options ProcessChunksOptions
if len(opts) > 0 {
options = opts[0]
}
ctx, span := tracing.ContextWithSpan(ctx, "knowledgeService.processChunks")
defer span.End()
span.SetAttributes(
attribute.Int("tenant_id", int(knowledge.TenantID)),
attribute.String("knowledge_base_id", knowledge.KnowledgeBaseID),
attribute.String("knowledge_id", knowledge.ID),
attribute.String("embedding_model_id", kb.EmbeddingModelID),
attribute.Int("chunk_count", len(chunks)),
)
// Check if knowledge is being deleted before processing
if s.isKnowledgeDeleting(ctx, knowledge.TenantID, knowledge.ID) {
logger.Infof(ctx, "Knowledge is being deleted, aborting chunk processing: %s", knowledge.ID)
span.AddEvent("aborted: knowledge is being deleted")
return
}
// Get embedding model for vectorization
embeddingModel, err := s.modelService.GetEmbeddingModel(ctx, kb.EmbeddingModelID)
if err != nil {
logger.GetLogger(ctx).WithField("error", err).Errorf("processChunks get embedding model failed")
span.RecordError(err)
return
}
// 幂等性处理清理旧的chunks和索引数据避免重复数据
logger.Infof(ctx, "Cleaning up existing chunks and index data for knowledge: %s", knowledge.ID)
// 删除旧的chunks
if err := s.chunkService.DeleteChunksByKnowledgeID(ctx, knowledge.ID); err != nil {
logger.Warnf(ctx, "Failed to delete existing chunks (may not exist): %v", err)
// 不返回错误,继续处理(可能没有旧数据)
}
// 删除旧的索引数据
tenantInfo := ctx.Value(types.TenantInfoContextKey).(*types.Tenant)
retrieveEngine, err := retriever.NewCompositeRetrieveEngine(s.retrieveEngine, tenantInfo.GetEffectiveEngines())
if err == nil {
if err := retrieveEngine.DeleteByKnowledgeIDList(ctx, []string{knowledge.ID}, embeddingModel.GetDimensions(), knowledge.Type); err != nil {
logger.Warnf(ctx, "Failed to delete existing index data (may not exist): %v", err)
// 不返回错误,继续处理(可能没有旧数据)
} else {
logger.Infof(ctx, "Successfully deleted existing index data for knowledge: %s", knowledge.ID)
}
}
// 删除知识图谱数据(如果存在)
namespace := types.NameSpace{KnowledgeBase: knowledge.KnowledgeBaseID, Knowledge: knowledge.ID}
if err := s.graphEngine.DelGraph(ctx, []types.NameSpace{namespace}); err != nil {
logger.Warnf(ctx, "Failed to delete existing graph data (may not exist): %v", err)
// 不返回错误,继续处理
}
logger.Infof(ctx, "Cleanup completed, starting to process new chunks")
// ========== DocReader 解析结果日志 ==========
logger.Infof(ctx, "[DocReader] ========== 解析结果概览 ==========")
logger.Infof(ctx, "[DocReader] 知识ID: %s, 知识库ID: %s", knowledge.ID, knowledge.KnowledgeBaseID)
logger.Infof(ctx, "[DocReader] 总Chunk数量: %d", len(chunks))
// 统计图片信息
totalImages := 0
chunksWithImages := 0
for _, chunkData := range chunks {
if len(chunkData.Images) > 0 {
chunksWithImages++
totalImages += len(chunkData.Images)
}
}
logger.Infof(ctx, "[DocReader] 包含图片的Chunk数: %d, 总图片数: %d", chunksWithImages, totalImages)
// 打印每个Chunk的详细信息
for idx, chunkData := range chunks {
contentPreview := chunkData.Content
if len(contentPreview) > 200 {
contentPreview = contentPreview[:200] + "..."
}
logger.Infof(ctx, "[DocReader] Chunk #%d (seq=%d): 内容长度=%d, 图片数=%d, 范围=[%d-%d]",
idx, chunkData.Seq, len(chunkData.Content), len(chunkData.Images), chunkData.Start, chunkData.End)
logger.Debugf(ctx, "[DocReader] Chunk #%d 内容预览: %s", idx, contentPreview)
// 打印图片详细信息
for imgIdx, img := range chunkData.Images {
logger.Infof(ctx, "[DocReader] 图片 #%d: URL=%s", imgIdx, img.URL)
logger.Infof(ctx, "[DocReader] 图片 #%d: OriginalURL=%s", imgIdx, img.OriginalURL)
if img.Caption != "" {
captionPreview := img.Caption
if len(captionPreview) > 100 {
captionPreview = captionPreview[:100] + "..."
}
logger.Infof(ctx, "[DocReader] 图片 #%d: Caption=%s", imgIdx, captionPreview)
}
if img.OCRText != "" {
ocrPreview := img.OCRText
if len(ocrPreview) > 100 {
ocrPreview = ocrPreview[:100] + "..."
}
logger.Infof(ctx, "[DocReader] 图片 #%d: OCRText=%s", imgIdx, ocrPreview)
}
logger.Infof(ctx, "[DocReader] 图片 #%d: 位置=[%d-%d]", imgIdx, img.Start, img.End)
}
}
logger.Infof(ctx, "[DocReader] ========== 解析结果概览结束 ==========")
// Create chunk objects from proto chunks
maxSeq := 0
// 统计图片相关的子Chunk数量用于扩展insertChunks的容量
imageChunkCount := 0
for _, chunkData := range chunks {
if len(chunkData.Images) > 0 {
// 为每个图片的OCR和Caption分别创建一个Chunk
imageChunkCount += len(chunkData.Images) * 2
}
if int(chunkData.Seq) > maxSeq {
maxSeq = int(chunkData.Seq)
}
}
// === Parent-Child Chunking: create parent chunks first ===
hasParentChild := len(options.ParentChunks) > 0
var parentDBChunks []*types.Chunk // indexed by ParsedParentChunk position
if hasParentChild {
parentDBChunks = make([]*types.Chunk, len(options.ParentChunks))
for i, pc := range options.ParentChunks {
parentDBChunks[i] = &types.Chunk{
ID: uuid.New().String(),
TenantID: knowledge.TenantID,
KnowledgeID: knowledge.ID,
KnowledgeBaseID: knowledge.KnowledgeBaseID,
Content: pc.Content,
ChunkIndex: pc.Seq,
IsEnabled: true,
CreatedAt: time.Now(),
UpdatedAt: time.Now(),
StartAt: pc.Start,
EndAt: pc.End,
ChunkType: types.ChunkTypeParentText,
}
}
// Set prev/next links for parent chunks
for i := range parentDBChunks {
if i > 0 {
parentDBChunks[i-1].NextChunkID = parentDBChunks[i].ID
parentDBChunks[i].PreChunkID = parentDBChunks[i-1].ID
}
}
logger.Infof(ctx, "Created %d parent chunks for parent-child strategy", len(parentDBChunks))
}
// 重新分配容量考虑图片相关的Chunk + parent chunks
parentCount := len(options.ParentChunks)
insertChunks := make([]*types.Chunk, 0, len(chunks)+imageChunkCount+parentCount)
// Add parent chunks first (they go into DB but NOT into the vector index)
if hasParentChild {
insertChunks = append(insertChunks, parentDBChunks...)
}
for idx, chunkData := range chunks {
if strings.TrimSpace(chunkData.Content) == "" {
continue
}
// 创建主文本Chunk
textChunk := &types.Chunk{
ID: uuid.New().String(),
TenantID: knowledge.TenantID,
KnowledgeID: knowledge.ID,
KnowledgeBaseID: knowledge.KnowledgeBaseID,
Content: chunkData.Content,
ChunkIndex: int(chunkData.Seq),
IsEnabled: true,
CreatedAt: time.Now(),
UpdatedAt: time.Now(),
StartAt: int(chunkData.Start),
EndAt: int(chunkData.End),
ChunkType: types.ChunkTypeText,
}
// Wire up ParentChunkID for child chunks
if hasParentChild && chunkData.ParentIndex >= 0 && chunkData.ParentIndex < len(parentDBChunks) {
textChunk.ParentChunkID = parentDBChunks[chunkData.ParentIndex].ID
}
chunks[idx].ChunkID = textChunk.ID
insertChunks = append(insertChunks, textChunk)
}
// Sort chunks by index for proper ordering
sort.Slice(insertChunks, func(i, j int) bool {
return insertChunks[i].ChunkIndex < insertChunks[j].ChunkIndex
})
// 仅为文本类型的Chunk设置前后关系child chunks only, parents already linked above
textChunks := make([]*types.Chunk, 0, len(chunks))
for _, chunk := range insertChunks {
if chunk.ChunkType == types.ChunkTypeText && chunk.ParentChunkID != "" {
// This is a child chunk in parent-child mode
textChunks = append(textChunks, chunk)
} else if chunk.ChunkType == types.ChunkTypeText && !hasParentChild {
// Normal flat chunk (no parent-child mode)
textChunks = append(textChunks, chunk)
}
}
// 设置文本Chunk之间的前后关系 (skip if parent-child, children don't need prev/next links)
if !hasParentChild {
for i, chunk := range textChunks {
if i > 0 {
textChunks[i-1].NextChunkID = chunk.ID
}
if i < len(textChunks)-1 {
textChunks[i+1].PreChunkID = chunk.ID
}
}
}
// Create index information — only for child/flat chunks, NOT parent chunks.
// Parent chunks are stored for context retrieval but do not need vector embeddings.
// Prepend the document title to improve semantic alignment between
// question-style queries and statement-style chunk content.
indexInfoList := make([]*types.IndexInfo, 0, len(textChunks))
titlePrefix := ""
if t := strings.TrimSpace(knowledge.Title); t != "" {
titlePrefix = t + "\n"
}
for _, chunk := range textChunks {
indexContent := titlePrefix + chunk.Content
indexInfoList = append(indexInfoList, &types.IndexInfo{
Content: indexContent,
SourceID: chunk.ID,
SourceType: types.ChunkSourceType,
ChunkID: chunk.ID,
KnowledgeID: knowledge.ID,
KnowledgeBaseID: knowledge.KnowledgeBaseID,
IsEnabled: true,
})
}
// Initialize retrieval engine
// Calculate storage size required for embeddings
span.AddEvent("estimate storage size")
totalStorageSize := retrieveEngine.EstimateStorageSize(ctx, embeddingModel, indexInfoList)
if tenantInfo.StorageQuota > 0 {
// Re-fetch tenant storage information
tenantInfo, err = s.tenantRepo.GetTenantByID(ctx, tenantInfo.ID)
if err != nil {
knowledge.ParseStatus = types.ParseStatusFailed
knowledge.ErrorMessage = err.Error()
knowledge.UpdatedAt = time.Now()
s.repo.UpdateKnowledge(ctx, knowledge)
span.RecordError(err)
return
}
// Check if there's enough storage quota available
if tenantInfo.StorageUsed+totalStorageSize > tenantInfo.StorageQuota {
knowledge.ParseStatus = types.ParseStatusFailed
knowledge.ErrorMessage = "存储空间不足"
knowledge.UpdatedAt = time.Now()
s.repo.UpdateKnowledge(ctx, knowledge)
span.RecordError(errors.New("storage quota exceeded"))
return
}
}
// Check again if knowledge is being deleted before writing to database
if s.isKnowledgeDeleting(ctx, knowledge.TenantID, knowledge.ID) {
logger.Infof(ctx, "Knowledge is being deleted, aborting before saving chunks: %s", knowledge.ID)
span.AddEvent("aborted: knowledge is being deleted before saving")
return
}
// Save chunks to database
span.AddEvent("create chunks")
if err := s.chunkService.CreateChunks(ctx, insertChunks); err != nil {
knowledge.ParseStatus = types.ParseStatusFailed
knowledge.ErrorMessage = err.Error()
knowledge.UpdatedAt = time.Now()
s.repo.UpdateKnowledge(ctx, knowledge)
span.RecordError(err)
return
}
// Check again before batch indexing (this is a heavy operation)
if s.isKnowledgeDeleting(ctx, knowledge.TenantID, knowledge.ID) {
logger.Infof(ctx, "Knowledge is being deleted, cleaning up and aborting before indexing: %s", knowledge.ID)
// Clean up the chunks we just created
if err := s.chunkService.DeleteChunksByKnowledgeID(ctx, knowledge.ID); err != nil {
logger.Warnf(ctx, "Failed to cleanup chunks after deletion detected: %v", err)
}
span.AddEvent("aborted: knowledge is being deleted before indexing")
return
}
span.AddEvent("batch index")
err = retrieveEngine.BatchIndex(ctx, embeddingModel, indexInfoList)
if err != nil {
knowledge.ParseStatus = types.ParseStatusFailed
knowledge.ErrorMessage = err.Error()
knowledge.UpdatedAt = time.Now()
s.repo.UpdateKnowledge(ctx, knowledge)
// delete failed chunks
if err := s.chunkService.DeleteChunksByKnowledgeID(ctx, knowledge.ID); err != nil {
logger.Errorf(ctx, "Delete chunks failed: %v", err)
}
// delete index
if err := retrieveEngine.DeleteByKnowledgeIDList(
ctx, []string{knowledge.ID}, embeddingModel.GetDimensions(), kb.Type,
); err != nil {
logger.Errorf(ctx, "Delete index failed: %v", err)
}
span.RecordError(err)
return
}
logger.GetLogger(ctx).Infof("processChunks batch index successfully, with %d index", len(indexInfoList))
logger.Infof(ctx, "processChunks create relationship rag task")
if kb.ExtractConfig != nil && kb.ExtractConfig.Enabled {
for _, chunk := range textChunks {
err := NewChunkExtractTask(ctx, s.task, chunk.TenantID, chunk.ID, kb.SummaryModelID)
if err != nil {
logger.GetLogger(ctx).WithField("error", err).Errorf("processChunks create chunk extract task failed")
span.RecordError(err)
}
}
}
// Final check before marking as completed - if deleted during processing, don't update status
if s.isKnowledgeDeleting(ctx, knowledge.TenantID, knowledge.ID) {
logger.Infof(ctx, "Knowledge was deleted during processing, skipping completion update: %s", knowledge.ID)
// Clean up the data we just created since the knowledge is being deleted
if err := s.chunkService.DeleteChunksByKnowledgeID(ctx, knowledge.ID); err != nil {
logger.Warnf(ctx, "Failed to cleanup chunks after deletion detected: %v", err)
}
if err := retrieveEngine.DeleteByKnowledgeIDList(ctx, []string{knowledge.ID}, embeddingModel.GetDimensions(), kb.Type); err != nil {
logger.Warnf(ctx, "Failed to cleanup index after deletion detected: %v", err)
}
span.AddEvent("aborted: knowledge was deleted during processing")
return
}
// Skip summary/question generation for image-type knowledge — the text chunk
// is just a markdown image reference, so LLM summary would be useless.
// The multimodal task will provide a caption as the description instead.
isImage := IsImageType(knowledge.FileType)
pendingMultimodal := isImage && options.EnableMultimodel && len(options.StoredImages) > 0
// For image files with pending multimodal processing, keep "processing" status
// so the frontend waits until the description is ready before showing "completed".
if pendingMultimodal {
knowledge.ParseStatus = types.ParseStatusProcessing
} else {
knowledge.ParseStatus = types.ParseStatusCompleted
}
knowledge.EnableStatus = "enabled"
knowledge.StorageSize = totalStorageSize
now := time.Now()
knowledge.ProcessedAt = &now
knowledge.UpdatedAt = now
// Set summary status based on whether summary generation will be triggered
if len(textChunks) > 0 && !isImage {
knowledge.SummaryStatus = types.SummaryStatusPending
} else {
knowledge.SummaryStatus = types.SummaryStatusNone
}
if err := s.repo.UpdateKnowledge(ctx, knowledge); err != nil {
logger.GetLogger(ctx).WithField("error", err).Errorf("processChunks update knowledge failed")
}
// Enqueue question generation task if enabled (async, non-blocking)
if options.EnableQuestionGeneration && len(textChunks) > 0 && !isImage {
questionCount := options.QuestionCount
if questionCount <= 0 {
questionCount = 3
}
if questionCount > 10 {
questionCount = 10
}
s.enqueueQuestionGenerationTask(ctx, knowledge.KnowledgeBaseID, knowledge.ID, questionCount)
}
// Enqueue summary generation task (async, non-blocking)
if len(textChunks) > 0 && !isImage {
s.enqueueSummaryGenerationTask(ctx, knowledge.KnowledgeBaseID, knowledge.ID)
}
// Enqueue multimodal tasks for images (async, non-blocking)
if options.EnableMultimodel && len(options.StoredImages) > 0 {
s.enqueueImageMultimodalTasks(ctx, knowledge, kb, options.StoredImages, chunks)
}
// Update tenant's storage usage
tenantInfo.StorageUsed += totalStorageSize
if err := s.tenantRepo.AdjustStorageUsed(ctx, tenantInfo.ID, totalStorageSize); err != nil {
logger.GetLogger(ctx).WithField("error", err).Errorf("processChunks update tenant storage used failed")
}
logger.GetLogger(ctx).Infof("processChunks successfully")
}
// GetSummary generates a summary for knowledge content using an AI model
func (s *knowledgeService) getSummary(ctx context.Context,
summaryModel chat.Chat, knowledge *types.Knowledge, chunks []*types.Chunk,
) (string, error) {
// Get knowledge info from the first chunk
if len(chunks) == 0 {
return "", fmt.Errorf("no chunks provided for summary generation")
}
// concat chunk contents
chunkContents := ""
allImageInfos := make([]*types.ImageInfo, 0)
// then, sort chunks by StartAt
sortedChunks := make([]*types.Chunk, len(chunks))
copy(sortedChunks, chunks)
sort.Slice(sortedChunks, func(i, j int) bool {
return sortedChunks[i].StartAt < sortedChunks[j].StartAt
})
// concat chunk contents and collect image infos
for _, chunk := range sortedChunks {
if chunk.EndAt > 4096 {
break
}
// Ensure we don't slice beyond the current content length
runes := []rune(chunkContents)
if chunk.StartAt <= len(runes) {
chunkContents = string(runes[:chunk.StartAt]) + chunk.Content
} else {
// If StartAt is beyond current content, just append
chunkContents = chunkContents + chunk.Content
}
if chunk.ImageInfo != "" {
var images []*types.ImageInfo
if err := json.Unmarshal([]byte(chunk.ImageInfo), &images); err == nil {
allImageInfos = append(allImageInfos, images...)
}
}
}
// remove markdown image syntax
re := regexp.MustCompile(`!\[[^\]]*\]\([^)]+\)`)
chunkContents = re.ReplaceAllString(chunkContents, "")
// collect all image infos
if len(allImageInfos) > 0 {
// add image infos to chunk contents
var imageAnnotations string
for _, img := range allImageInfos {
if img.Caption != "" {
imageAnnotations += fmt.Sprintf("\n[Image Description: %s]", img.Caption)
}
if img.OCRText != "" {
imageAnnotations += fmt.Sprintf("\n[Image OCR Text: %s]", img.OCRText)
}
}
// concat chunk contents and image annotations
chunkContents = chunkContents + imageAnnotations
}
if len(chunkContents) < 300 {
return chunkContents, nil
}
// Prepare content with metadata for summary generation
contentWithMetadata := chunkContents
// Add knowledge metadata if available
if knowledge != nil {
metadataIntro := fmt.Sprintf("Document Type: %s\nFile Name: %s\n", knowledge.FileType, knowledge.FileName)
// Add additional metadata if available
if knowledge.Type != "" {
metadataIntro += fmt.Sprintf("Knowledge Type: %s\n", knowledge.Type)
}
// Prepend metadata to content
contentWithMetadata = metadataIntro + "\nContent:\n" + contentWithMetadata
}
// Generate summary using AI model
summaryPrompt := types.RenderPromptPlaceholders(s.config.Conversation.GenerateSummaryPrompt, types.PlaceholderValues{
"language": types.LanguageNameFromContext(ctx),
})
thinking := false
summary, err := summaryModel.Chat(ctx, []chat.Message{
{
Role: "system",
Content: summaryPrompt,
},
{
Role: "user",
Content: contentWithMetadata,
},
}, &chat.ChatOptions{
Temperature: 0.3,
MaxTokens: 1024,
Thinking: &thinking,
})
if err != nil {
logger.GetLogger(ctx).WithField("error", err).Errorf("GetSummary failed")
return "", err
}
logger.GetLogger(ctx).WithField("summary", summary.Content).Infof("GetSummary success")
return summary.Content, nil
}
// enqueueQuestionGenerationTask enqueues an async task for question generation
func (s *knowledgeService) enqueueQuestionGenerationTask(ctx context.Context,
kbID, knowledgeID string, questionCount int,
) {
tenantID := ctx.Value(types.TenantIDContextKey).(uint64)
lang, _ := types.LanguageFromContext(ctx)
payload := types.QuestionGenerationPayload{
TenantID: tenantID,
KnowledgeBaseID: kbID,
KnowledgeID: knowledgeID,
QuestionCount: questionCount,
Language: lang,
}
payloadBytes, err := json.Marshal(payload)
if err != nil {
logger.Errorf(ctx, "Failed to marshal question generation payload: %v", err)
return
}
task := asynq.NewTask(types.TypeQuestionGeneration, payloadBytes, asynq.Queue("low"), asynq.MaxRetry(3))
info, err := s.task.Enqueue(task)
if err != nil {
logger.Errorf(ctx, "Failed to enqueue question generation task: %v", err)
return
}
logger.Infof(ctx, "Enqueued question generation task: %s for knowledge: %s", info.ID, knowledgeID)
}
// enqueueSummaryGenerationTask enqueues an async task for summary generation
func (s *knowledgeService) enqueueSummaryGenerationTask(ctx context.Context,
kbID, knowledgeID string,
) {
tenantID := ctx.Value(types.TenantIDContextKey).(uint64)
lang, _ := types.LanguageFromContext(ctx)
payload := types.SummaryGenerationPayload{
TenantID: tenantID,
KnowledgeBaseID: kbID,
KnowledgeID: knowledgeID,
Language: lang,
}
payloadBytes, err := json.Marshal(payload)
if err != nil {
logger.Errorf(ctx, "Failed to marshal summary generation payload: %v", err)
return
}
task := asynq.NewTask(types.TypeSummaryGeneration, payloadBytes, asynq.Queue("low"), asynq.MaxRetry(3))
info, err := s.task.Enqueue(task)
if err != nil {
logger.Errorf(ctx, "Failed to enqueue summary generation task: %v", err)
return
}
logger.Infof(ctx, "Enqueued summary generation task: %s for knowledge: %s", info.ID, knowledgeID)
}
// ProcessSummaryGeneration handles async summary generation task
func (s *knowledgeService) ProcessSummaryGeneration(ctx context.Context, t *asynq.Task) error {
var payload types.SummaryGenerationPayload
if err := json.Unmarshal(t.Payload(), &payload); err != nil {
logger.Errorf(ctx, "Failed to unmarshal summary generation payload: %v", err)
return nil // Don't retry on unmarshal error
}
logger.Infof(ctx, "Processing summary generation for knowledge: %s", payload.KnowledgeID)
// Set tenant and language context
ctx = context.WithValue(ctx, types.TenantIDContextKey, payload.TenantID)
if payload.Language != "" {
ctx = context.WithValue(ctx, types.LanguageContextKey, payload.Language)
}
// Get knowledge base
kb, err := s.kbService.GetKnowledgeBaseByID(ctx, payload.KnowledgeBaseID)
if err != nil {
logger.Errorf(ctx, "Failed to get knowledge base: %v", err)
return nil
}
if kb.SummaryModelID == "" {
logger.Warn(ctx, "Knowledge base summary model ID is empty, skipping summary generation")
return nil
}
// Get knowledge
knowledge, err := s.repo.GetKnowledgeByID(ctx, payload.TenantID, payload.KnowledgeID)
if err != nil {
logger.Errorf(ctx, "Failed to get knowledge: %v", err)
return nil
}
// Update summary status to processing
knowledge.SummaryStatus = types.SummaryStatusProcessing
knowledge.UpdatedAt = time.Now()
if err := s.repo.UpdateKnowledge(ctx, knowledge); err != nil {
logger.Warnf(ctx, "Failed to update summary status to processing: %v", err)
}
// Helper function to mark summary as failed
markSummaryFailed := func() {
knowledge.SummaryStatus = types.SummaryStatusFailed
knowledge.UpdatedAt = time.Now()
if err := s.repo.UpdateKnowledge(ctx, knowledge); err != nil {
logger.Warnf(ctx, "Failed to update summary status to failed: %v", err)
}
}
// Get text chunks for this knowledge
chunks, err := s.chunkService.ListChunksByKnowledgeID(ctx, payload.KnowledgeID)
if err != nil {
logger.Errorf(ctx, "Failed to get chunks: %v", err)
markSummaryFailed()
return nil
}
// Filter text chunks only
textChunks := make([]*types.Chunk, 0)
for _, chunk := range chunks {
if chunk.ChunkType == types.ChunkTypeText {
textChunks = append(textChunks, chunk)
}
}
if len(textChunks) == 0 {
logger.Infof(ctx, "No text chunks found for knowledge: %s", payload.KnowledgeID)
// Mark as completed since there's nothing to summarize
knowledge.SummaryStatus = types.SummaryStatusCompleted
knowledge.UpdatedAt = time.Now()
s.repo.UpdateKnowledge(ctx, knowledge)
return nil
}
// Sort chunks by ChunkIndex for proper ordering
sort.Slice(textChunks, func(i, j int) bool {
return textChunks[i].ChunkIndex < textChunks[j].ChunkIndex
})
// Initialize chat model for summary
chatModel, err := s.modelService.GetChatModel(ctx, kb.SummaryModelID)
if err != nil {
logger.Errorf(ctx, "Failed to get chat model: %v", err)
markSummaryFailed()
return fmt.Errorf("failed to get chat model: %w", err)
}
// Generate summary
summary, err := s.getSummary(ctx, chatModel, knowledge, textChunks)
if err != nil {
logger.Errorf(ctx, "Failed to generate summary for knowledge %s: %v", payload.KnowledgeID, err)
// Use first chunk content as fallback
if len(textChunks) > 0 {
summary = textChunks[0].Content
if len(summary) > 500 {
summary = summary[:500]
}
}
}
// Update knowledge description
knowledge.Description = summary
knowledge.SummaryStatus = types.SummaryStatusCompleted
knowledge.UpdatedAt = time.Now()
if err := s.repo.UpdateKnowledge(ctx, knowledge); err != nil {
logger.Errorf(ctx, "Failed to update knowledge description: %v", err)
return fmt.Errorf("failed to update knowledge: %w", err)
}
// Create summary chunk and index it
if strings.TrimSpace(summary) != "" {
// Get max chunk index
maxChunkIndex := 0
for _, chunk := range chunks {
if chunk.ChunkIndex > maxChunkIndex {
maxChunkIndex = chunk.ChunkIndex
}
}
summaryChunk := &types.Chunk{
ID: uuid.New().String(),
TenantID: knowledge.TenantID,
KnowledgeID: knowledge.ID,
KnowledgeBaseID: knowledge.KnowledgeBaseID,
Content: fmt.Sprintf("# Document\n%s\n\n# Summary\n%s", knowledge.FileName, summary),
ChunkIndex: maxChunkIndex + 1,
IsEnabled: true,
CreatedAt: time.Now(),
UpdatedAt: time.Now(),
StartAt: 0,
EndAt: 0,
ChunkType: types.ChunkTypeSummary,
ParentChunkID: textChunks[0].ID,
}
// Save summary chunk
if err := s.chunkService.CreateChunks(ctx, []*types.Chunk{summaryChunk}); err != nil {
logger.Errorf(ctx, "Failed to create summary chunk: %v", err)
return fmt.Errorf("failed to create summary chunk: %w", err)
}
// Index summary chunk
tenantInfo, err := s.tenantRepo.GetTenantByID(ctx, payload.TenantID)
if err != nil {
logger.Errorf(ctx, "Failed to get tenant info: %v", err)
return fmt.Errorf("failed to get tenant info: %w", err)
}
ctx = context.WithValue(ctx, types.TenantInfoContextKey, tenantInfo)
retrieveEngine, err := retriever.NewCompositeRetrieveEngine(s.retrieveEngine, tenantInfo.GetEffectiveEngines())
if err != nil {
logger.Errorf(ctx, "Failed to init retrieve engine: %v", err)
return fmt.Errorf("failed to init retrieve engine: %w", err)
}
embeddingModel, err := s.modelService.GetEmbeddingModel(ctx, kb.EmbeddingModelID)
if err != nil {
logger.Errorf(ctx, "Failed to get embedding model: %v", err)
return fmt.Errorf("failed to get embedding model: %w", err)
}
indexInfo := []*types.IndexInfo{{
Content: summaryChunk.Content,
SourceID: summaryChunk.ID,
SourceType: types.ChunkSourceType,
ChunkID: summaryChunk.ID,
KnowledgeID: knowledge.ID,
KnowledgeBaseID: knowledge.KnowledgeBaseID,
IsEnabled: true,
}}
if err := retrieveEngine.BatchIndex(ctx, embeddingModel, indexInfo); err != nil {
logger.Errorf(ctx, "Failed to index summary chunk: %v", err)
return fmt.Errorf("failed to index summary chunk: %w", err)
}
logger.Infof(ctx, "Successfully created and indexed summary chunk for knowledge: %s", payload.KnowledgeID)
}
logger.Infof(ctx, "Successfully generated summary for knowledge: %s", payload.KnowledgeID)
return nil
}
// ProcessQuestionGeneration handles async question generation task
func (s *knowledgeService) ProcessQuestionGeneration(ctx context.Context, t *asynq.Task) error {
ctx, span := tracing.ContextWithSpan(ctx, "knowledgeService.ProcessQuestionGeneration")
defer span.End()
var payload types.QuestionGenerationPayload
if err := json.Unmarshal(t.Payload(), &payload); err != nil {
logger.Errorf(ctx, "Failed to unmarshal question generation payload: %v", err)
return nil // Don't retry on unmarshal error
}
logger.Infof(ctx, "Processing question generation for knowledge: %s", payload.KnowledgeID)
// Set tenant context
ctx = context.WithValue(ctx, types.TenantIDContextKey, payload.TenantID)
if payload.Language != "" {
ctx = context.WithValue(ctx, types.LanguageContextKey, payload.Language)
}
if strings.TrimSpace(s.config.Conversation.GenerateQuestionsPrompt) == "" {
logger.Errorf(ctx, "GenerateQuestionsPrompt is empty: configure conversation.generate_questions_prompt_id")
return fmt.Errorf("generate questions prompt not configured")
}
// Get knowledge base
kb, err := s.kbService.GetKnowledgeBaseByID(ctx, payload.KnowledgeBaseID)
if err != nil {
logger.Errorf(ctx, "Failed to get knowledge base: %v", err)
return nil
}
// Get knowledge
knowledge, err := s.repo.GetKnowledgeByID(ctx, payload.TenantID, payload.KnowledgeID)
if err != nil {
logger.Errorf(ctx, "Failed to get knowledge: %v", err)
return nil
}
// Get text chunks for this knowledge
chunks, err := s.chunkService.ListChunksByKnowledgeID(ctx, payload.KnowledgeID)
if err != nil {
logger.Errorf(ctx, "Failed to get chunks: %v", err)
return nil
}
// Filter text chunks only
textChunks := make([]*types.Chunk, 0)
for _, chunk := range chunks {
if chunk.ChunkType == types.ChunkTypeText {
textChunks = append(textChunks, chunk)
}
}
if len(textChunks) == 0 {
logger.Infof(ctx, "No text chunks found for knowledge: %s", payload.KnowledgeID)
return nil
}
// Sort chunks by StartAt for context building
sort.Slice(textChunks, func(i, j int) bool {
return textChunks[i].StartAt < textChunks[j].StartAt
})
// Initialize chat model
chatModel, err := s.modelService.GetChatModel(ctx, kb.SummaryModelID)
if err != nil {
logger.Errorf(ctx, "Failed to get chat model: %v", err)
return fmt.Errorf("failed to get chat model: %w", err)
}
// Initialize embedding model and retrieval engine
embeddingModel, err := s.modelService.GetEmbeddingModel(ctx, kb.EmbeddingModelID)
if err != nil {
logger.Errorf(ctx, "Failed to get embedding model: %v", err)
return fmt.Errorf("failed to get embedding model: %w", err)
}
tenantInfo, err := s.tenantRepo.GetTenantByID(ctx, payload.TenantID)
if err != nil {
logger.Errorf(ctx, "Failed to get tenant info: %v", err)
return fmt.Errorf("failed to get tenant info: %w", err)
}
ctx = context.WithValue(ctx, types.TenantInfoContextKey, tenantInfo)
retrieveEngine, err := retriever.NewCompositeRetrieveEngine(s.retrieveEngine, tenantInfo.GetEffectiveEngines())
if err != nil {
logger.Errorf(ctx, "Failed to init retrieve engine: %v", err)
return fmt.Errorf("failed to init retrieve engine: %w", err)
}
questionCount := payload.QuestionCount
if questionCount <= 0 {
questionCount = 3
}
if questionCount > 10 {
questionCount = 10
}
// Generate questions for each chunk with context
var indexInfoList []*types.IndexInfo
for i, chunk := range textChunks {
// Build context from adjacent chunks
var prevContent, nextContent string
if i > 0 {
prevContent = textChunks[i-1].Content
// Limit context size
if len(prevContent) > 500 {
prevContent = prevContent[len(prevContent)-500:]
}
}
if i < len(textChunks)-1 {
nextContent = textChunks[i+1].Content
// Limit context size
if len(nextContent) > 500 {
nextContent = nextContent[:500]
}
}
questions, err := s.generateQuestionsWithContext(ctx, chatModel, chunk.Content, prevContent, nextContent, knowledge.Title, questionCount)
if err != nil {
logger.Warnf(ctx, "Failed to generate questions for chunk %s: %v", chunk.ID, err)
continue
}
if len(questions) == 0 {
continue
}
// Update chunk metadata with unique IDs for each question
generatedQuestions := make([]types.GeneratedQuestion, len(questions))
for j, question := range questions {
questionID := fmt.Sprintf("q%d", time.Now().UnixNano()+int64(j))
generatedQuestions[j] = types.GeneratedQuestion{
ID: questionID,
Question: question,
}
}
meta := &types.DocumentChunkMetadata{
GeneratedQuestions: generatedQuestions,
}
if err := chunk.SetDocumentMetadata(meta); err != nil {
logger.Warnf(ctx, "Failed to set document metadata for chunk %s: %v", chunk.ID, err)
continue
}
// Update chunk in database
if err := s.chunkService.UpdateChunk(ctx, chunk); err != nil {
logger.Warnf(ctx, "Failed to update chunk %s: %v", chunk.ID, err)
continue
}
// Create index entries for generated questions
for _, gq := range generatedQuestions {
sourceID := fmt.Sprintf("%s-%s", chunk.ID, gq.ID)
indexInfoList = append(indexInfoList, &types.IndexInfo{
Content: gq.Question,
SourceID: sourceID,
SourceType: types.ChunkSourceType,
ChunkID: chunk.ID,
KnowledgeID: knowledge.ID,
KnowledgeBaseID: knowledge.KnowledgeBaseID,
IsEnabled: true,
})
}
logger.Debugf(ctx, "Generated %d questions for chunk %s", len(questions), chunk.ID)
}
// Index generated questions
if len(indexInfoList) > 0 {
if err := retrieveEngine.BatchIndex(ctx, embeddingModel, indexInfoList); err != nil {
logger.Errorf(ctx, "Failed to index generated questions: %v", err)
return fmt.Errorf("failed to index questions: %w", err)
}
logger.Infof(ctx, "Successfully indexed %d generated questions for knowledge: %s", len(indexInfoList), payload.KnowledgeID)
}
return nil
}
// generateQuestionsWithContext generates questions for a chunk with surrounding context
func (s *knowledgeService) generateQuestionsWithContext(ctx context.Context,
chatModel chat.Chat, content, prevContent, nextContent, docName string, questionCount int,
) ([]string, error) {
if content == "" || questionCount <= 0 {
return nil, nil
}
prompt := strings.TrimSpace(s.config.Conversation.GenerateQuestionsPrompt)
if prompt == "" {
return nil, fmt.Errorf("generate questions prompt not configured")
}
// Build context section
var contextSection string
if prevContent != "" || nextContent != "" {
contextSection = "## Context Information (for reference only, to help understand the main content)\n"
if prevContent != "" {
contextSection += fmt.Sprintf("[Preceding Context] %s\n", prevContent)
}
if nextContent != "" {
contextSection += fmt.Sprintf("[Following Context] %s\n", nextContent)
}
contextSection += "\n"
}
langName := types.LanguageNameFromContext(ctx)
prompt = types.RenderPromptPlaceholders(prompt, types.PlaceholderValues{
"question_count": fmt.Sprintf("%d", questionCount),
"content": content,
"context": contextSection,
"doc_name": docName,
"language": langName,
})
thinking := false
response, err := chatModel.Chat(ctx, []chat.Message{
{
Role: "user",
Content: prompt,
},
}, &chat.ChatOptions{
Temperature: 0.7,
MaxTokens: 512,
Thinking: &thinking,
})
if err != nil {
return nil, fmt.Errorf("failed to generate questions: %w", err)
}
// Parse response
lines := strings.Split(response.Content, "\n")
questions := make([]string, 0, questionCount)
for _, line := range lines {
line = strings.TrimSpace(line)
if line == "" {
continue
}
line = strings.TrimLeft(line, "0123456789.-*) ")
line = strings.TrimSpace(line)
if line != "" && len(line) > 5 {
questions = append(questions, line)
if len(questions) >= questionCount {
break
}
}
}
return questions, nil
}
// GetKnowledgeFile retrieves the physical file associated with a knowledge entry
func (s *knowledgeService) GetKnowledgeFile(ctx context.Context, id string) (io.ReadCloser, string, error) {
// Get knowledge record
tenantID := ctx.Value(types.TenantIDContextKey).(uint64)
knowledge, err := s.repo.GetKnowledgeByID(ctx, tenantID, id)
if err != nil {
return nil, "", err
}
// Manual knowledge stores content in Metadata — stream it directly as a .md file.
if knowledge.IsManual() {
meta, err := knowledge.ManualMetadata()
if err != nil {
return nil, "", err
}
// ManualMetadata returns (nil, nil) when Metadata column is empty; treat as empty content.
content := ""
if meta != nil {
content = meta.Content
}
filename := sanitizeManualDownloadFilename(knowledge.Title)
return io.NopCloser(strings.NewReader(content)), filename, nil
}
// Resolve KB-level file service with FilePath fallback protection
kb, _ := s.kbService.GetKnowledgeBaseByID(ctx, knowledge.KnowledgeBaseID)
file, err := s.resolveFileServiceForPath(ctx, kb, knowledge.FilePath).GetFile(ctx, knowledge.FilePath)
if err != nil {
return nil, "", err
}
return file, knowledge.FileName, nil
}
func (s *knowledgeService) UpdateKnowledge(ctx context.Context, knowledge *types.Knowledge) error {
record, err := s.repo.GetKnowledgeByID(ctx, ctx.Value(types.TenantIDContextKey).(uint64), knowledge.ID)
if err != nil {
logger.Errorf(ctx, "Failed to get knowledge record: %v", err)
return err
}
// if need other fields update, please add here
if knowledge.Title != "" {
record.Title = knowledge.Title
}
if knowledge.Description != "" {
record.Description = knowledge.Description
}
// Update knowledge record in the repository
if err := s.repo.UpdateKnowledge(ctx, record); err != nil {
logger.Errorf(ctx, "Failed to update knowledge: %v", err)
return err
}
logger.Infof(ctx, "Knowledge updated successfully, ID: %s", knowledge.ID)
return nil
}
// UpdateManualKnowledge updates manual Markdown knowledge content.
// For publish status, the heavy operations (cleanup old indexes, re-chunking,
// re-embedding) are offloaded to an Asynq task so the HTTP response returns quickly.
func (s *knowledgeService) UpdateManualKnowledge(ctx context.Context,
knowledgeID string, payload *types.ManualKnowledgePayload,
) (*types.Knowledge, error) {
logger.Info(ctx, "Start updating manual knowledge entry")
if payload == nil {
return nil, werrors.NewBadRequestError("请求内容不能为空")
}
cleanContent := secutils.CleanMarkdown(payload.Content)
if strings.TrimSpace(cleanContent) == "" {
return nil, werrors.NewValidationError("内容不能为空")
}
if len([]rune(cleanContent)) > manualContentMaxLength {
return nil, werrors.NewValidationError(fmt.Sprintf("内容长度超出限制(最多%d个字符", manualContentMaxLength))
}
safeTitle, ok := secutils.ValidateInput(payload.Title)
if !ok {
return nil, werrors.NewValidationError("标题包含非法字符或超出长度限制")
}
status := strings.ToLower(strings.TrimSpace(payload.Status))
if status == "" {
status = types.ManualKnowledgeStatusDraft
}
if status != types.ManualKnowledgeStatusDraft && status != types.ManualKnowledgeStatusPublish {
return nil, werrors.NewValidationError("状态仅支持 draft 或 publish")
}
tenantID := ctx.Value(types.TenantIDContextKey).(uint64)
existing, err := s.repo.GetKnowledgeByID(ctx, tenantID, knowledgeID)
if err != nil {
logger.Errorf(ctx, "Failed to load knowledge: %v", err)
return nil, err
}
if !existing.IsManual() {
return nil, werrors.NewBadRequestError("仅支持手工知识的在线编辑")
}
kb, err := s.kbService.GetKnowledgeBaseByID(ctx, existing.KnowledgeBaseID)
if err != nil {
logger.Errorf(ctx, "Failed to get knowledge base for manual update: %v", err)
return nil, err
}
var version int
if meta, err := existing.ManualMetadata(); err == nil && meta != nil {
version = meta.Version + 1
} else {
version = 1
}
meta := types.NewManualKnowledgeMetadata(cleanContent, status, version)
if err := existing.SetManualMetadata(meta); err != nil {
logger.Errorf(ctx, "Failed to set manual metadata during update: %v", err)
return nil, err
}
if safeTitle != "" {
existing.Title = safeTitle
} else if existing.Title == "" {
existing.Title = fmt.Sprintf("手工知识-%s", time.Now().Format("20060102-150405"))
}
existing.FileName = ensureManualFileName(existing.Title)
existing.FileType = types.KnowledgeTypeManual
existing.Type = types.KnowledgeTypeManual
existing.Source = types.KnowledgeTypeManual
existing.EnableStatus = "disabled"
existing.UpdatedAt = time.Now()
existing.EmbeddingModelID = kb.EmbeddingModelID
if status == types.ManualKnowledgeStatusDraft {
existing.ParseStatus = types.ManualKnowledgeStatusDraft
existing.Description = ""
existing.ProcessedAt = nil
if err := s.repo.UpdateKnowledge(ctx, existing); err != nil {
logger.Errorf(ctx, "Failed to persist manual draft: %v", err)
return nil, err
}
return existing, nil
}
// Publish: persist pending status and enqueue async task for cleanup + re-indexing
existing.ParseStatus = "pending"
existing.Description = ""
existing.ProcessedAt = nil
if err := s.repo.UpdateKnowledge(ctx, existing); err != nil {
logger.Errorf(ctx, "Failed to persist manual knowledge before indexing: %v", err)
return nil, err
}
logger.Infof(ctx, "Manual knowledge updated, enqueuing async processing task, ID: %s", existing.ID)
if err := s.enqueueManualProcessing(ctx, existing, cleanContent, true); err != nil {
logger.Errorf(ctx, "Failed to enqueue manual processing task: %v", err)
// Non-fatal: mark as failed so user can retry
existing.ParseStatus = "failed"
existing.ErrorMessage = "Failed to enqueue processing task"
s.repo.UpdateKnowledge(ctx, existing)
return nil, werrors.NewInternalServerError("Failed to submit processing task")
}
return existing, nil
}
// enqueueManualProcessing enqueues a manual:process Asynq task for async cleanup + re-indexing.
func (s *knowledgeService) enqueueManualProcessing(ctx context.Context,
knowledge *types.Knowledge, content string, needCleanup bool,
) error {
requestID, _ := types.RequestIDFromContext(ctx)
payload := types.ManualProcessPayload{
RequestId: requestID,
TenantID: knowledge.TenantID,
KnowledgeID: knowledge.ID,
KnowledgeBaseID: knowledge.KnowledgeBaseID,
Content: content,
NeedCleanup: needCleanup,
}
payloadBytes, err := json.Marshal(payload)
if err != nil {
return fmt.Errorf("failed to marshal manual process payload: %w", err)
}
task := asynq.NewTask(types.TypeManualProcess, payloadBytes, asynq.Queue("default"), asynq.MaxRetry(3))
info, err := s.task.Enqueue(task)
if err != nil {
return fmt.Errorf("failed to enqueue manual process task: %w", err)
}
logger.Infof(ctx, "Enqueued manual process task: knowledge_id=%s, asynq_id=%s", knowledge.ID, info.ID)
return nil
}
// ReparseKnowledge deletes existing document content and re-parses the knowledge asynchronously.
// This method reuses the logic from UpdateManualKnowledge for resource cleanup and async parsing.
func (s *knowledgeService) ReparseKnowledge(ctx context.Context, knowledgeID string) (*types.Knowledge, error) {
logger.Info(ctx, "Start re-parsing knowledge")
tenantID := ctx.Value(types.TenantIDContextKey).(uint64)
existing, err := s.repo.GetKnowledgeByID(ctx, tenantID, knowledgeID)
if err != nil {
logger.Errorf(ctx, "Failed to load knowledge: %v", err)
return nil, err
}
// Get knowledge base configuration
kb, err := s.kbService.GetKnowledgeBaseByID(ctx, existing.KnowledgeBaseID)
if err != nil {
logger.Errorf(ctx, "Failed to get knowledge base for reparse: %v", err)
return nil, err
}
// For manual knowledge, use async manual processing (cleanup + re-indexing in worker)
if existing.IsManual() {
meta, metaErr := existing.ManualMetadata()
if metaErr != nil || meta == nil {
logger.Errorf(ctx, "Failed to get manual metadata for reparse: %v", metaErr)
return nil, werrors.NewBadRequestError("无法获取手工知识内容")
}
existing.ParseStatus = "pending"
existing.EnableStatus = "disabled"
existing.Description = ""
existing.ProcessedAt = nil
existing.EmbeddingModelID = kb.EmbeddingModelID
if err := s.repo.UpdateKnowledge(ctx, existing); err != nil {
logger.Errorf(ctx, "Failed to update knowledge status before reparse: %v", err)
return nil, err
}
if err := s.enqueueManualProcessing(ctx, existing, meta.Content, true); err != nil {
logger.Errorf(ctx, "Failed to enqueue manual reparse task: %v", err)
existing.ParseStatus = "failed"
existing.ErrorMessage = "Failed to enqueue processing task"
s.repo.UpdateKnowledge(ctx, existing)
}
return existing, nil
}
// For non-manual knowledge, cleanup synchronously then enqueue document processing
logger.Infof(ctx, "Cleaning up existing resources for knowledge: %s", knowledgeID)
if err := s.cleanupKnowledgeResources(ctx, existing); err != nil {
logger.ErrorWithFields(ctx, err, map[string]interface{}{
"knowledge_id": knowledgeID,
})
return nil, err
}
// Step 2: Update knowledge status and metadata
existing.ParseStatus = "pending"
existing.EnableStatus = "disabled"
existing.Description = ""
existing.ProcessedAt = nil
existing.EmbeddingModelID = kb.EmbeddingModelID
if err := s.repo.UpdateKnowledge(ctx, existing); err != nil {
logger.Errorf(ctx, "Failed to update knowledge status before reparse: %v", err)
return nil, err
}
// Step 3: Trigger async re-parsing based on knowledge type
logger.Infof(ctx, "Knowledge status updated, scheduling async reparse, ID: %s, Type: %s", existing.ID, existing.Type)
// For file-based knowledge, enqueue document processing task
if existing.FilePath != "" {
tenantID := ctx.Value(types.TenantIDContextKey).(uint64)
// Determine multimodal setting
enableMultimodel := kb.IsMultimodalEnabled()
// Check question generation config
enableQuestionGeneration := false
questionCount := 3 // default
if kb.QuestionGenerationConfig != nil && kb.QuestionGenerationConfig.Enabled {
enableQuestionGeneration = true
if kb.QuestionGenerationConfig.QuestionCount > 0 {
questionCount = kb.QuestionGenerationConfig.QuestionCount
}
}
taskPayload := types.DocumentProcessPayload{
TenantID: tenantID,
KnowledgeID: existing.ID,
KnowledgeBaseID: existing.KnowledgeBaseID,
FilePath: existing.FilePath,
FileName: existing.FileName,
FileType: getFileType(existing.FileName),
EnableMultimodel: enableMultimodel,
EnableQuestionGeneration: enableQuestionGeneration,
QuestionCount: questionCount,
}
payloadBytes, err := json.Marshal(taskPayload)
if err != nil {
logger.Errorf(ctx, "Failed to marshal reparse task payload: %v", err)
return existing, nil
}
task := asynq.NewTask(types.TypeDocumentProcess, payloadBytes, asynq.Queue("default"), asynq.MaxRetry(3))
info, err := s.task.Enqueue(task)
if err != nil {
logger.Errorf(ctx, "Failed to enqueue reparse task: %v", err)
return existing, nil
}
logger.Infof(ctx, "Enqueued reparse task: id=%s queue=%s knowledge_id=%s", info.ID, info.Queue, existing.ID)
// For data tables (csv, xlsx, xls), also enqueue summary task
if slices.Contains([]string{"csv", "xlsx", "xls"}, getFileType(existing.FileName)) {
NewDataTableSummaryTask(ctx, s.task, tenantID, existing.ID, kb.SummaryModelID, kb.EmbeddingModelID)
}
return existing, nil
}
// For file-URL-based knowledge, enqueue document processing task with FileURL field
if existing.Type == "file_url" && existing.Source != "" {
tenantID := ctx.Value(types.TenantIDContextKey).(uint64)
enableMultimodel := kb.IsMultimodalEnabled()
// Check question generation config
enableQuestionGeneration := false
questionCount := 3
if kb.QuestionGenerationConfig != nil && kb.QuestionGenerationConfig.Enabled {
enableQuestionGeneration = true
if kb.QuestionGenerationConfig.QuestionCount > 0 {
questionCount = kb.QuestionGenerationConfig.QuestionCount
}
}
taskPayload := types.DocumentProcessPayload{
TenantID: tenantID,
KnowledgeID: existing.ID,
KnowledgeBaseID: existing.KnowledgeBaseID,
FileURL: existing.Source,
FileName: existing.FileName,
FileType: existing.FileType,
EnableMultimodel: enableMultimodel,
EnableQuestionGeneration: enableQuestionGeneration,
QuestionCount: questionCount,
}
payloadBytes, err := json.Marshal(taskPayload)
if err != nil {
logger.Errorf(ctx, "Failed to marshal file URL reparse task payload: %v", err)
return existing, nil
}
task := asynq.NewTask(types.TypeDocumentProcess, payloadBytes, asynq.Queue("default"))
info, err := s.task.Enqueue(task)
if err != nil {
logger.Errorf(ctx, "Failed to enqueue file URL reparse task: %v", err)
return existing, nil
}
logger.Infof(ctx, "Enqueued file URL reparse task: id=%s queue=%s knowledge_id=%s", info.ID, info.Queue, existing.ID)
return existing, nil
}
// For URL-based knowledge, enqueue URL processing task
if existing.Type == "url" && existing.Source != "" {
tenantID := ctx.Value(types.TenantIDContextKey).(uint64)
enableMultimodel := kb.IsMultimodalEnabled()
// Check question generation config
enableQuestionGeneration := false
questionCount := 3
if kb.QuestionGenerationConfig != nil && kb.QuestionGenerationConfig.Enabled {
enableQuestionGeneration = true
if kb.QuestionGenerationConfig.QuestionCount > 0 {
questionCount = kb.QuestionGenerationConfig.QuestionCount
}
}
taskPayload := types.DocumentProcessPayload{
TenantID: tenantID,
KnowledgeID: existing.ID,
KnowledgeBaseID: existing.KnowledgeBaseID,
URL: existing.Source,
EnableMultimodel: enableMultimodel,
EnableQuestionGeneration: enableQuestionGeneration,
QuestionCount: questionCount,
}
payloadBytes, err := json.Marshal(taskPayload)
if err != nil {
logger.Errorf(ctx, "Failed to marshal URL reparse task payload: %v", err)
return existing, nil
}
task := asynq.NewTask(types.TypeDocumentProcess, payloadBytes, asynq.Queue("default"), asynq.MaxRetry(3))
info, err := s.task.Enqueue(task)
if err != nil {
logger.Errorf(ctx, "Failed to enqueue URL reparse task: %v", err)
return existing, nil
}
logger.Infof(ctx, "Enqueued URL reparse task: id=%s queue=%s knowledge_id=%s", info.ID, info.Queue, existing.ID)
return existing, nil
}
logger.Warnf(ctx, "Knowledge %s has no parseable content (no file, URL, or manual content)", knowledgeID)
return existing, nil
}
// isValidFileType checks if a file type is supported
func isValidFileType(filename string) bool {
switch strings.ToLower(getFileType(filename)) {
case "pdf", "txt", "docx", "doc", "md", "markdown", "png", "jpg", "jpeg", "gif", "csv", "xlsx", "xls", "pptx", "ppt":
return true
default:
return false
}
}
// getFileType extracts the file extension from a filename
func getFileType(filename string) string {
ext := strings.Split(filename, ".")
if len(ext) < 2 {
return "unknown"
}
return ext[len(ext)-1]
}
// isValidURL verifies if a URL is valid
// isValidURL 检查URL是否有效
func isValidURL(url string) bool {
if strings.HasPrefix(url, "http://") || strings.HasPrefix(url, "https://") {
return true
}
return false
}
// GetKnowledgeBatch retrieves multiple knowledge entries by their IDs
func (s *knowledgeService) GetKnowledgeBatch(ctx context.Context,
tenantID uint64, ids []string,
) ([]*types.Knowledge, error) {
if len(ids) == 0 {
return nil, nil
}
return s.repo.GetKnowledgeBatch(ctx, tenantID, ids)
}
// GetKnowledgeBatchWithSharedAccess retrieves knowledge by IDs, including items from shared KBs the user has access to.
// Used when building search targets so that @mentioned files from shared KBs are included.
func (s *knowledgeService) GetKnowledgeBatchWithSharedAccess(ctx context.Context,
tenantID uint64, ids []string,
) ([]*types.Knowledge, error) {
if len(ids) == 0 {
return nil, nil
}
ownList, err := s.repo.GetKnowledgeBatch(ctx, tenantID, ids)
if err != nil {
return nil, err
}
foundSet := make(map[string]bool)
for _, k := range ownList {
if k != nil {
foundSet[k.ID] = true
}
}
userIDVal := ctx.Value(types.UserIDContextKey)
if userIDVal == nil {
return ownList, nil
}
userID, ok := userIDVal.(string)
if !ok || userID == "" {
return ownList, nil
}
for _, id := range ids {
if foundSet[id] {
continue
}
k, err := s.repo.GetKnowledgeByIDOnly(ctx, id)
if err != nil || k == nil || k.KnowledgeBaseID == "" {
continue
}
hasPermission, err := s.kbShareService.HasKBPermission(ctx, k.KnowledgeBaseID, userID, types.OrgRoleViewer)
if err != nil || !hasPermission {
continue
}
foundSet[k.ID] = true
ownList = append(ownList, k)
}
return ownList, nil
}
// calculateFileHash calculates MD5 hash of a file
func calculateFileHash(file *multipart.FileHeader) (string, error) {
f, err := file.Open()
if err != nil {
return "", err
}
defer f.Close()
h := md5.New()
if _, err := io.Copy(h, f); err != nil {
return "", err
}
// Reset file pointer for subsequent operations
if _, err := f.Seek(0, 0); err != nil {
return "", err
}
return hex.EncodeToString(h.Sum(nil)), nil
}
func calculateStr(strList ...string) string {
h := md5.New()
input := strings.Join(strList, "")
h.Write([]byte(input))
return hex.EncodeToString(h.Sum(nil))
}
func (s *knowledgeService) CloneKnowledgeBase(ctx context.Context, srcID, dstID string) error {
srcKB, dstKB, err := s.kbService.CopyKnowledgeBase(ctx, srcID, dstID)
if err != nil {
logger.Errorf(ctx, "Failed to copy knowledge base: %v", err)
return err
}
addKnowledge, err := s.repo.AminusB(ctx, srcKB.TenantID, srcKB.ID, dstKB.TenantID, dstKB.ID)
if err != nil {
logger.Errorf(ctx, "Failed to get knowledge: %v", err)
return err
}
delKnowledge, err := s.repo.AminusB(ctx, dstKB.TenantID, dstKB.ID, srcKB.TenantID, srcKB.ID)
if err != nil {
logger.Errorf(ctx, "Failed to get knowledge: %v", err)
return err
}
logger.Infof(ctx, "Knowledge after update to add: %d, delete: %d", len(addKnowledge), len(delKnowledge))
batch := 10
g, gctx := errgroup.WithContext(ctx)
for ids := range slices.Chunk(delKnowledge, batch) {
g.Go(func() error {
err := s.DeleteKnowledgeList(gctx, ids)
if err != nil {
logger.Errorf(gctx, "delete partial knowledge %v: %v", ids, err)
return err
}
return nil
})
}
err = g.Wait()
if err != nil {
logger.Errorf(ctx, "delete total knowledge %d: %v", len(delKnowledge), err)
return err
}
// Copy context out of auto-stop task
g, gctx = errgroup.WithContext(ctx)
g.SetLimit(batch)
for _, knowledge := range addKnowledge {
g.Go(func() error {
srcKn, err := s.repo.GetKnowledgeByID(gctx, srcKB.TenantID, knowledge)
if err != nil {
logger.Errorf(gctx, "get knowledge %s: %v", knowledge, err)
return err
}
err = s.cloneKnowledge(gctx, srcKn, dstKB)
if err != nil {
logger.Errorf(gctx, "clone knowledge %s: %v", knowledge, err)
return err
}
return nil
})
}
err = g.Wait()
if err != nil {
logger.Errorf(ctx, "add total knowledge %d: %v", len(addKnowledge), err)
return err
}
return nil
}
func (s *knowledgeService) updateChunkVector(ctx context.Context, kbID string, chunks []*types.Chunk) error {
// Get embedding model from knowledge base
sourceKB, err := s.kbService.GetKnowledgeBaseByID(ctx, kbID)
if err != nil {
return err
}
embeddingModel, err := s.modelService.GetEmbeddingModel(ctx, sourceKB.EmbeddingModelID)
if err != nil {
return err
}
// Initialize composite retrieve engine from tenant configuration
indexInfo := make([]*types.IndexInfo, 0, len(chunks))
ids := make([]string, 0, len(chunks))
for _, chunk := range chunks {
if chunk.KnowledgeBaseID != kbID {
logger.Warnf(ctx, "Knowledge base ID mismatch: %s != %s", chunk.KnowledgeBaseID, kbID)
continue
}
indexInfo = append(indexInfo, &types.IndexInfo{
Content: chunk.Content,
SourceID: chunk.ID,
SourceType: types.ChunkSourceType,
ChunkID: chunk.ID,
KnowledgeID: chunk.KnowledgeID,
KnowledgeBaseID: chunk.KnowledgeBaseID,
IsEnabled: true,
})
ids = append(ids, chunk.ID)
}
tenantInfo := ctx.Value(types.TenantInfoContextKey).(*types.Tenant)
retrieveEngine, err := retriever.NewCompositeRetrieveEngine(s.retrieveEngine, tenantInfo.GetEffectiveEngines())
if err != nil {
return err
}
// Delete old vector representation of the chunk
err = retrieveEngine.DeleteByChunkIDList(ctx, ids, embeddingModel.GetDimensions(), sourceKB.Type)
if err != nil {
return err
}
// Index updated chunk content with new vector representation
err = retrieveEngine.BatchIndex(ctx, embeddingModel, indexInfo)
if err != nil {
return err
}
return nil
}
func (s *knowledgeService) UpdateImageInfo(
ctx context.Context,
knowledgeID string,
chunkID string,
imageInfo string,
) error {
var images []*types.ImageInfo
if err := json.Unmarshal([]byte(imageInfo), &images); err != nil {
logger.Errorf(ctx, "Failed to unmarshal image info: %v", err)
return err
}
if len(images) != 1 {
logger.Warnf(ctx, "Expected exactly one image info, got %d", len(images))
return nil
}
image := images[0]
// Retrieve all chunks with the given parent chunk ID
chunk, err := s.chunkService.GetChunkByID(ctx, chunkID)
if err != nil {
logger.Errorf(ctx, "Failed to get chunk: %v", err)
return err
}
chunk.ImageInfo = imageInfo
tenantID := ctx.Value(types.TenantIDContextKey).(uint64)
chunkChildren, err := s.chunkService.ListChunkByParentID(ctx, tenantID, chunkID)
if err != nil {
logger.ErrorWithFields(ctx, err, map[string]interface{}{
"parent_chunk_id": chunkID,
"tenant_id": tenantID,
})
return err
}
logger.Infof(ctx, "Found %d chunks with parent chunk ID: %s", len(chunkChildren), chunkID)
// Iterate through each chunk and update its content based on the image information
updateChunk := []*types.Chunk{chunk}
var addChunk []*types.Chunk
// Track whether we've found OCR and caption child chunks for this image
hasOCRChunk := false
hasCaptionChunk := false
for i, child := range chunkChildren {
// Skip chunks that are not image types
var cImageInfo []*types.ImageInfo
err = json.Unmarshal([]byte(child.ImageInfo), &cImageInfo)
if err != nil {
logger.Warnf(ctx, "Failed to unmarshal image %s info: %v", child.ID, err)
continue
}
if len(cImageInfo) == 0 {
continue
}
if cImageInfo[0].OriginalURL != image.OriginalURL {
logger.Warnf(ctx, "Skipping chunk ID: %s, image URL mismatch: %s != %s",
child.ID, cImageInfo[0].OriginalURL, image.OriginalURL)
continue
}
// Mark that we've found chunks for this image
switch child.ChunkType {
case types.ChunkTypeImageCaption:
hasCaptionChunk = true
// Update caption if it has changed
if image.Caption != cImageInfo[0].Caption {
child.Content = image.Caption
child.ImageInfo = imageInfo
updateChunk = append(updateChunk, chunkChildren[i])
}
case types.ChunkTypeImageOCR:
hasOCRChunk = true
// Update OCR if it has changed
if image.OCRText != cImageInfo[0].OCRText {
child.Content = image.OCRText
child.ImageInfo = imageInfo
updateChunk = append(updateChunk, chunkChildren[i])
}
}
}
// Create a new caption chunk if it doesn't exist and we have caption data
if !hasCaptionChunk && image.Caption != "" {
captionChunk := &types.Chunk{
ID: uuid.New().String(),
TenantID: tenantID,
KnowledgeID: chunk.KnowledgeID,
KnowledgeBaseID: chunk.KnowledgeBaseID,
Content: image.Caption,
ChunkType: types.ChunkTypeImageCaption,
ParentChunkID: chunk.ID,
ImageInfo: imageInfo,
}
addChunk = append(addChunk, captionChunk)
logger.Infof(ctx, "Created new caption chunk ID: %s for image URL: %s", captionChunk.ID, image.OriginalURL)
}
// Create a new OCR chunk if it doesn't exist and we have OCR data
if !hasOCRChunk && image.OCRText != "" {
ocrChunk := &types.Chunk{
ID: uuid.New().String(),
TenantID: tenantID,
KnowledgeID: chunk.KnowledgeID,
KnowledgeBaseID: chunk.KnowledgeBaseID,
Content: image.OCRText,
ChunkType: types.ChunkTypeImageOCR,
ParentChunkID: chunk.ID,
ImageInfo: imageInfo,
}
addChunk = append(addChunk, ocrChunk)
logger.Infof(ctx, "Created new OCR chunk ID: %s for image URL: %s", ocrChunk.ID, image.OriginalURL)
}
logger.Infof(ctx, "Updated %d chunks out of %d total chunks", len(updateChunk), len(chunkChildren)+1)
if len(addChunk) > 0 {
err := s.chunkService.CreateChunks(ctx, addChunk)
if err != nil {
logger.ErrorWithFields(ctx, err, map[string]interface{}{
"add_chunk_size": len(addChunk),
})
return err
}
}
// Update the chunks
for _, c := range updateChunk {
err := s.chunkService.UpdateChunk(ctx, c)
if err != nil {
logger.ErrorWithFields(ctx, err, map[string]interface{}{
"chunk_id": c.ID,
"knowledge_id": c.KnowledgeID,
})
return err
}
}
// Update the chunk vector
err = s.updateChunkVector(ctx, chunk.KnowledgeBaseID, append(updateChunk, addChunk...))
if err != nil {
logger.ErrorWithFields(ctx, err, map[string]interface{}{
"chunk_id": chunk.ID,
"knowledge_id": chunk.KnowledgeID,
})
return err
}
// Update the knowledge file hash
knowledge, err := s.repo.GetKnowledgeByID(ctx, tenantID, knowledgeID)
if err != nil {
logger.Errorf(ctx, "Failed to get knowledge: %v", err)
return err
}
fileHash := calculateStr(knowledgeID, knowledge.FileHash, imageInfo)
knowledge.FileHash = fileHash
err = s.repo.UpdateKnowledge(ctx, knowledge)
if err != nil {
logger.Warnf(ctx, "Failed to update knowledge file hash: %v", err)
}
logger.Infof(ctx, "Updated chunk successfully, chunk ID: %s, knowledge ID: %s", chunk.ID, chunk.KnowledgeID)
return nil
}
// CloneChunk clone chunks from one knowledge to another
// This method transfers a chunk from a source knowledge document to a target knowledge document
// It handles the creation of new chunks in the target knowledge and updates the vector database accordingly
// Parameters:
// - ctx: Context with authentication and request information
// - src: Source knowledge document containing the chunk to move
// - dst: Target knowledge document where the chunk will be moved
//
// Returns:
// - error: Any error encountered during the move operation
//
// This method handles the chunk transfer logic, including creating new chunks in the target knowledge
// and updating the vector database representation of the moved chunks.
// It also ensures that the chunk's relationships (like pre and next chunk IDs) are maintained
// by mapping the source chunk IDs to the new target chunk IDs.
func (s *knowledgeService) CloneChunk(ctx context.Context, src, dst *types.Knowledge) error {
chunkPage := 1
chunkPageSize := 100
srcTodst := map[string]string{}
tagIDMapping := map[string]string{} // srcTagID -> dstTagID
targetChunks := make([]*types.Chunk, 0, 10)
chunkType := []types.ChunkType{
types.ChunkTypeText, types.ChunkTypeParentText, types.ChunkTypeSummary,
types.ChunkTypeImageCaption, types.ChunkTypeImageOCR,
}
for {
sourceChunks, _, err := s.chunkRepo.ListPagedChunksByKnowledgeID(ctx,
src.TenantID,
src.ID,
&types.Pagination{
Page: chunkPage,
PageSize: chunkPageSize,
},
chunkType,
"",
"",
"",
"",
"",
)
chunkPage++
if err != nil {
return err
}
if len(sourceChunks) == 0 {
break
}
now := time.Now()
for _, sourceChunk := range sourceChunks {
// Map TagID to target knowledge base
targetTagID := ""
if sourceChunk.TagID != "" {
if mappedTagID, ok := tagIDMapping[sourceChunk.TagID]; ok {
targetTagID = mappedTagID
} else {
// Try to find or create the tag in target knowledge base
targetTagID = s.getOrCreateTagInTarget(ctx, src.TenantID, dst.TenantID, dst.KnowledgeBaseID, sourceChunk.TagID, tagIDMapping)
}
}
targetChunk := &types.Chunk{
ID: uuid.New().String(),
TenantID: dst.TenantID,
KnowledgeID: dst.ID,
KnowledgeBaseID: dst.KnowledgeBaseID,
TagID: targetTagID,
Content: sourceChunk.Content,
ChunkIndex: sourceChunk.ChunkIndex,
IsEnabled: sourceChunk.IsEnabled,
Flags: sourceChunk.Flags,
Status: sourceChunk.Status,
StartAt: sourceChunk.StartAt,
EndAt: sourceChunk.EndAt,
PreChunkID: sourceChunk.PreChunkID,
NextChunkID: sourceChunk.NextChunkID,
ChunkType: sourceChunk.ChunkType,
ParentChunkID: sourceChunk.ParentChunkID,
Metadata: sourceChunk.Metadata,
ContentHash: sourceChunk.ContentHash,
ImageInfo: sourceChunk.ImageInfo,
CreatedAt: now,
UpdatedAt: now,
}
targetChunks = append(targetChunks, targetChunk)
srcTodst[sourceChunk.ID] = targetChunk.ID
}
}
for _, targetChunk := range targetChunks {
if val, ok := srcTodst[targetChunk.PreChunkID]; ok {
targetChunk.PreChunkID = val
} else {
targetChunk.PreChunkID = ""
}
if val, ok := srcTodst[targetChunk.NextChunkID]; ok {
targetChunk.NextChunkID = val
} else {
targetChunk.NextChunkID = ""
}
if val, ok := srcTodst[targetChunk.ParentChunkID]; ok {
targetChunk.ParentChunkID = val
} else {
targetChunk.ParentChunkID = ""
}
}
for chunks := range slices.Chunk(targetChunks, chunkPageSize) {
err := s.chunkRepo.CreateChunks(ctx, chunks)
if err != nil {
return err
}
}
tenantInfo := ctx.Value(types.TenantInfoContextKey).(*types.Tenant)
retrieveEngine, err := retriever.NewCompositeRetrieveEngine(s.retrieveEngine, tenantInfo.GetEffectiveEngines())
if err != nil {
return err
}
embeddingModel, err := s.modelService.GetEmbeddingModel(ctx, dst.EmbeddingModelID)
if err != nil {
return err
}
if err := retrieveEngine.CopyIndices(ctx, src.KnowledgeBaseID, dst.KnowledgeBaseID,
map[string]string{src.ID: dst.ID},
srcTodst,
embeddingModel.GetDimensions(),
dst.Type,
); err != nil {
return err
}
return nil
}
// ListFAQEntries lists FAQ entries under a FAQ knowledge base.
func (s *knowledgeService) ListFAQEntries(ctx context.Context,
kbID string, page *types.Pagination, tagSeqID int64, keyword string, searchField string, sortOrder string,
) (*types.PageResult, error) {
if page == nil {
page = &types.Pagination{}
}
keyword = strings.TrimSpace(keyword)
kb, err := s.validateFAQKnowledgeBase(ctx, kbID)
if err != nil {
return nil, err
}
// Check if this is a shared knowledge base access
tenantID := ctx.Value(types.TenantIDContextKey).(uint64)
effectiveTenantID := tenantID
// If the kb belongs to a different tenant, check for shared access
if kb.TenantID != tenantID {
// Get user ID from context
userIDVal := ctx.Value(types.UserIDContextKey)
if userIDVal == nil {
return nil, werrors.NewForbiddenError("无权访问该知识库")
}
userID := userIDVal.(string)
// Check if user has at least viewer permission through organization sharing
hasPermission, err := s.kbShareService.HasKBPermission(ctx, kbID, userID, types.OrgRoleViewer)
if err != nil || !hasPermission {
return nil, werrors.NewForbiddenError("无权访问该知识库")
}
// Use the source tenant ID for data access
sourceTenantID, err := s.kbShareService.GetKBSourceTenant(ctx, kbID)
if err != nil {
return nil, werrors.NewForbiddenError("无权访问该知识库")
}
effectiveTenantID = sourceTenantID
}
faqKnowledge, err := s.findFAQKnowledge(ctx, effectiveTenantID, kb.ID)
if err != nil {
return nil, err
}
if faqKnowledge == nil {
return types.NewPageResult(0, page, []*types.FAQEntry{}), nil
}
// Convert tagSeqID to tagID (UUID)
var tagID string
if tagSeqID > 0 {
tag, err := s.tagRepo.GetBySeqID(ctx, effectiveTenantID, tagSeqID)
if err != nil {
return nil, werrors.NewNotFoundError("标签不存在")
}
tagID = tag.ID
}
chunkType := []types.ChunkType{types.ChunkTypeFAQ}
chunks, total, err := s.chunkRepo.ListPagedChunksByKnowledgeID(
ctx, effectiveTenantID, faqKnowledge.ID, page, chunkType, tagID, keyword, searchField, sortOrder, types.KnowledgeTypeFAQ,
)
if err != nil {
return nil, err
}
// Build tag ID to name and seq_id mapping for all unique tag IDs (batch query)
tagNameMap := make(map[string]string)
tagSeqIDMap := make(map[string]int64)
tagIDs := make([]string, 0)
tagIDSet := make(map[string]struct{})
for _, chunk := range chunks {
if chunk.TagID != "" {
if _, exists := tagIDSet[chunk.TagID]; !exists {
tagIDSet[chunk.TagID] = struct{}{}
tagIDs = append(tagIDs, chunk.TagID)
}
}
}
if len(tagIDs) > 0 {
tags, err := s.tagRepo.GetByIDs(ctx, effectiveTenantID, tagIDs)
if err == nil {
for _, tag := range tags {
tagNameMap[tag.ID] = tag.Name
tagSeqIDMap[tag.ID] = tag.SeqID
}
}
}
kb.EnsureDefaults()
entries := make([]*types.FAQEntry, 0, len(chunks))
for _, chunk := range chunks {
entry, err := s.chunkToFAQEntry(chunk, kb, tagSeqIDMap)
if err != nil {
return nil, err
}
// Set tag name from mapping
if chunk.TagID != "" {
entry.TagName = tagNameMap[chunk.TagID]
}
entries = append(entries, entry)
}
return types.NewPageResult(total, page, entries), nil
}
// UpsertFAQEntries imports or appends FAQ entries asynchronously.
// Returns task ID (UUID) for tracking import progress.
func (s *knowledgeService) UpsertFAQEntries(ctx context.Context,
kbID string, payload *types.FAQBatchUpsertPayload,
) (string, error) {
if payload == nil || len(payload.Entries) == 0 {
return "", werrors.NewBadRequestError("FAQ 条目不能为空")
}
if payload.Mode == "" {
payload.Mode = types.FAQBatchModeAppend
}
if payload.Mode != types.FAQBatchModeAppend && payload.Mode != types.FAQBatchModeReplace {
return "", werrors.NewBadRequestError("模式仅支持 append 或 replace")
}
// 验证知识库是否存在且有效
kb, err := s.validateFAQKnowledgeBase(ctx, kbID)
if err != nil {
return "", err
}
tenantID := ctx.Value(types.TenantIDContextKey).(uint64)
// 使用传入的TaskID如果没传则生成增强的TaskID
taskID := payload.TaskID
if taskID == "" {
taskID = secutils.GenerateTaskID("faq_import", tenantID, kbID)
}
var knowledgeID string
// 检查是否有正在进行的导入任务通过Redis
runningTaskID, err := s.getRunningFAQImportTaskID(ctx, kbID)
if err != nil {
logger.Errorf(ctx, "Failed to check running import task: %v", err)
// 检查失败不影响导入,继续执行
} else if runningTaskID != "" {
logger.Warnf(ctx, "Import task already running for KB %s: %s", kbID, runningTaskID)
return "", werrors.NewBadRequestError(fmt.Sprintf("该知识库已有导入任务正在进行中任务ID: %s请等待完成后再试", runningTaskID))
}
// 确保 FAQ knowledge 存在
faqKnowledge, err := s.ensureFAQKnowledge(ctx, tenantID, kb)
if err != nil {
return "", fmt.Errorf("failed to ensure FAQ knowledge: %w", err)
}
knowledgeID = faqKnowledge.ID
// 记录任务入队时间
enqueuedAt := time.Now().Unix()
// 设置 KB 的运行中任务信息
if err := s.setRunningFAQImportInfo(ctx, kbID, &runningFAQImportInfo{
TaskID: taskID,
EnqueuedAt: enqueuedAt,
}); err != nil {
logger.Errorf(ctx, "Failed to set running FAQ import task info: %v", err)
// 不影响任务执行,继续
}
// 初始化导入任务状态到Redis
progress := &types.FAQImportProgress{
TaskID: taskID,
KBID: kbID,
KnowledgeID: knowledgeID,
Status: types.FAQImportStatusPending,
Progress: 0,
Total: len(payload.Entries),
Processed: 0,
SuccessCount: 0,
FailedCount: 0,
FailedEntries: make([]types.FAQFailedEntry, 0),
Message: "任务已创建,等待处理",
CreatedAt: time.Now().Unix(),
UpdatedAt: time.Now().Unix(),
DryRun: payload.DryRun,
}
if err := s.saveFAQImportProgress(ctx, progress); err != nil {
logger.Errorf(ctx, "Failed to initialize FAQ import task status: %v", err)
return "", fmt.Errorf("failed to initialize task: %w", err)
}
logger.Infof(ctx, "FAQ import task initialized: %s, kb_id: %s, total entries: %d, dry_run: %v",
taskID, kbID, len(payload.Entries), payload.DryRun)
// Enqueue FAQ import task to Asynq
logger.Info(ctx, "Enqueuing FAQ import task to Asynq")
// 构建任务 payload
taskPayload := types.FAQImportPayload{
TenantID: tenantID,
TaskID: taskID,
KBID: kbID,
KnowledgeID: knowledgeID,
Mode: payload.Mode,
DryRun: payload.DryRun,
EnqueuedAt: enqueuedAt,
}
// 阈值:超过 200 条或序列化后超过 50KB 时使用对象存储
const (
entryCountThreshold = 200
payloadSizeThreshold = 50 * 1024 // 50KB
)
entryCount := len(payload.Entries)
if entryCount > entryCountThreshold {
// 数据量较大,上传到对象存储
entriesData, err := json.Marshal(payload.Entries)
if err != nil {
logger.Errorf(ctx, "Failed to marshal FAQ entries: %v", err)
return "", fmt.Errorf("failed to marshal entries: %w", err)
}
logger.Infof(ctx, "FAQ entries size: %d bytes, uploading to object storage", len(entriesData))
// 上传到私有桶(主桶),任务处理完成后清理
fileName := fmt.Sprintf("faq_import_entries_%s_%d.json", taskID, enqueuedAt)
entriesURL, err := s.fileSvc.SaveBytes(ctx, entriesData, tenantID, fileName, false)
if err != nil {
logger.Errorf(ctx, "Failed to upload FAQ entries to object storage: %v", err)
return "", fmt.Errorf("failed to upload entries: %w", err)
}
logger.Infof(ctx, "FAQ entries uploaded to: %s", entriesURL)
taskPayload.EntriesURL = entriesURL
taskPayload.EntryCount = entryCount
} else {
// 数据量较小,直接存储在 payload 中
taskPayload.Entries = payload.Entries
}
payloadBytes, err := json.Marshal(taskPayload)
if err != nil {
logger.Errorf(ctx, "Failed to marshal FAQ import task payload: %v", err)
return "", fmt.Errorf("failed to marshal task payload: %w", err)
}
// 再次检查 payload 大小
if len(payloadBytes) > payloadSizeThreshold && taskPayload.EntriesURL == "" {
// payload 太大但还没上传,现在上传
entriesData, _ := json.Marshal(payload.Entries)
fileName := fmt.Sprintf("faq_import_entries_%s_%d.json", taskID, enqueuedAt)
entriesURL, err := s.fileSvc.SaveBytes(ctx, entriesData, tenantID, fileName, false)
if err != nil {
logger.Errorf(ctx, "Failed to upload FAQ entries to object storage: %v", err)
return "", fmt.Errorf("failed to upload entries: %w", err)
}
logger.Infof(ctx, "FAQ entries uploaded to (size exceeded): %s", entriesURL)
taskPayload.Entries = nil
taskPayload.EntriesURL = entriesURL
taskPayload.EntryCount = entryCount
payloadBytes, _ = json.Marshal(taskPayload)
}
logger.Infof(ctx, "FAQ import task payload size: %d bytes", len(payloadBytes))
maxRetry := 5
if payload.DryRun {
maxRetry = 3 // dry run 重试次数少一些
}
// 使用 taskID:enqueuedAt 作为 asynq 的唯一任务标识
// 这样同一个用户 TaskID 的不同次提交不会冲突
asynqTaskID := fmt.Sprintf("%s:%d", taskID, enqueuedAt)
task := asynq.NewTask(
types.TypeFAQImport,
payloadBytes,
asynq.TaskID(asynqTaskID),
asynq.Queue("default"),
asynq.MaxRetry(maxRetry),
)
info, err := s.task.Enqueue(task)
if err != nil {
logger.Errorf(ctx, "Failed to enqueue FAQ import task: %v", err)
return "", fmt.Errorf("failed to enqueue task: %w", err)
}
logger.Infof(ctx, "Enqueued FAQ import task: id=%s queue=%s task_id=%s dry_run=%v", info.ID, info.Queue, taskID, payload.DryRun)
return taskID, nil
}
// generateFailedEntriesCSV 生成失败条目的 CSV 文件并上传
func (s *knowledgeService) generateFailedEntriesCSV(ctx context.Context,
tenantID uint64, taskID string, failedEntries []types.FAQFailedEntry,
) (string, error) {
// 生成 CSV 内容
var buf strings.Builder
// 写入 BOM 以支持 Excel 正确识别 UTF-8
buf.WriteString("\xEF\xBB\xBF")
// 写入表头
buf.WriteString("错误原因,分类(必填),问题(必填),相似问题(选填-多个用##分隔),反例问题(选填-多个用##分隔),机器人回答(必填-多个用##分隔),是否全部回复(选填-默认FALSE),是否停用(选填-默认FALSE)\n")
// 写入数据行
for _, entry := range failedEntries {
// CSV 转义:如果内容包含逗号、引号或换行,需要用引号包裹并转义内部引号
reason := csvEscape(entry.Reason)
tagName := csvEscape(entry.TagName)
standardQ := csvEscape(entry.StandardQuestion)
similarQs := ""
if len(entry.SimilarQuestions) > 0 {
similarQs = csvEscape(strings.Join(entry.SimilarQuestions, "##"))
}
negativeQs := ""
if len(entry.NegativeQuestions) > 0 {
negativeQs = csvEscape(strings.Join(entry.NegativeQuestions, "##"))
}
answers := ""
if len(entry.Answers) > 0 {
answers = csvEscape(strings.Join(entry.Answers, "##"))
}
answerAll := "false"
if entry.AnswerAll {
answerAll = "true"
}
isDisabled := "false"
if entry.IsDisabled {
isDisabled = "true"
}
buf.WriteString(fmt.Sprintf("%s,%s,%s,%s,%s,%s,%s,%s\n",
reason, tagName, standardQ, similarQs, negativeQs, answers, answerAll, isDisabled))
}
// 上传 CSV 文件到临时存储(会自动过期)
fileName := fmt.Sprintf("faq_dryrun_failed_%s.csv", taskID)
filePath, err := s.fileSvc.SaveBytes(ctx, []byte(buf.String()), tenantID, fileName, true)
if err != nil {
return "", fmt.Errorf("failed to save CSV file: %w", err)
}
// 获取下载 URL
fileURL, err := s.fileSvc.GetFileURL(ctx, filePath)
if err != nil {
return "", fmt.Errorf("failed to get file URL: %w", err)
}
logger.Infof(ctx, "Generated failed entries CSV: %s, entries: %d", fileURL, len(failedEntries))
return fileURL, nil
}
// csvEscape 转义 CSV 字段
func csvEscape(s string) string {
if strings.ContainsAny(s, ",\"\n\r") {
// 将内部引号替换为两个引号,并用引号包裹整个字段
return "\"" + strings.ReplaceAll(s, "\"", "\"\"") + "\""
}
return s
}
// saveFAQImportResultToDatabase 保存FAQ导入结果统计到数据库
func (s *knowledgeService) saveFAQImportResultToDatabase(ctx context.Context,
payload *types.FAQImportPayload, progress *types.FAQImportProgress, originalTotalEntries int,
) error {
// 获取FAQ知识库实例
tenantID := ctx.Value(types.TenantIDContextKey).(uint64)
knowledge, err := s.repo.GetKnowledgeByID(ctx, tenantID, payload.KnowledgeID)
if err != nil {
return fmt.Errorf("failed to get FAQ knowledge: %w", err)
}
// 计算跳过的条目数(总数 - 成功 - 失败)
skippedCount := originalTotalEntries - progress.SuccessCount - progress.FailedCount
if skippedCount < 0 {
skippedCount = 0
}
// 创建导入结果统计
importResult := &types.FAQImportResult{
TotalEntries: originalTotalEntries,
SuccessCount: progress.SuccessCount,
FailedCount: progress.FailedCount,
SkippedCount: skippedCount,
ImportMode: payload.Mode,
ImportedAt: time.Now(),
TaskID: payload.TaskID,
ProcessingTime: time.Now().Unix() - progress.CreatedAt, // 处理耗时(秒)
DisplayStatus: "open", // 新导入的结果默认显示
}
// 如果有失败条目且提供了下载URL设置失败URL
if progress.FailedCount > 0 && progress.FailedEntriesURL != "" {
importResult.FailedEntriesURL = progress.FailedEntriesURL
}
// 设置导入结果到Knowledge的metadata中
if err := knowledge.SetLastFAQImportResult(importResult); err != nil {
return fmt.Errorf("failed to set FAQ import result: %w", err)
}
// 更新数据库
if err := s.repo.UpdateKnowledge(ctx, knowledge); err != nil {
return fmt.Errorf("failed to update knowledge with import result: %w", err)
}
logger.Infof(ctx, "Saved FAQ import result to database: knowledge_id=%s, task_id=%s, total=%d, success=%d, failed=%d, skipped=%d",
payload.KnowledgeID, payload.TaskID, originalTotalEntries, progress.SuccessCount, progress.FailedCount, skippedCount)
return nil
}
// buildFAQFailedEntry 构建 FAQFailedEntry
func buildFAQFailedEntry(idx int, reason string, entry *types.FAQEntryPayload) types.FAQFailedEntry {
answerAll := false
if entry.AnswerStrategy != nil && *entry.AnswerStrategy == types.AnswerStrategyAll {
answerAll = true
}
isDisabled := false
if entry.IsEnabled != nil && !*entry.IsEnabled {
isDisabled = true
}
return types.FAQFailedEntry{
Index: idx,
Reason: reason,
TagName: entry.TagName,
StandardQuestion: strings.TrimSpace(entry.StandardQuestion),
SimilarQuestions: entry.SimilarQuestions,
NegativeQuestions: entry.NegativeQuestions,
Answers: entry.Answers,
AnswerAll: answerAll,
IsDisabled: isDisabled,
}
}
// executeFAQDryRunValidation 执行 FAQ dry run 验证,返回通过验证的条目索引
func (s *knowledgeService) executeFAQDryRunValidation(ctx context.Context,
payload *types.FAQImportPayload, progress *types.FAQImportProgress,
) []int {
entries := payload.Entries
// 用于记录已通过基本验证和重复检查的条目索引,后续进行安全检查
validEntryIndices := make([]int, 0, len(entries))
// 根据模式选择不同的验证逻辑
if payload.Mode == types.FAQBatchModeAppend {
validEntryIndices = s.validateEntriesForAppendModeWithProgress(ctx, payload.TenantID, payload.KBID, entries, progress)
} else {
validEntryIndices = s.validateEntriesForReplaceModeWithProgress(ctx, entries, progress)
}
return validEntryIndices
}
// validateEntriesForAppendModeWithProgress 验证 Append 模式下的条目(带进度更新)
// 注意:验证阶段不更新 Processed只有实际导入时才更新
func (s *knowledgeService) validateEntriesForAppendModeWithProgress(ctx context.Context,
tenantID uint64, kbID string, entries []types.FAQEntryPayload, progress *types.FAQImportProgress,
) []int {
validIndices := make([]int, 0, len(entries))
// 查询知识库中已有的所有FAQ chunks的metadata
existingChunks, err := s.chunkRepo.ListAllFAQChunksWithMetadataByKnowledgeBaseID(ctx, tenantID, kbID)
if err != nil {
logger.Warnf(ctx, "Failed to list existing FAQ chunks for dry run: %v", err)
// 无法获取已有数据时,仅做批次内验证
}
// 构建已存在的标准问和相似问集合
existingQuestions := make(map[string]bool)
for _, chunk := range existingChunks {
meta, err := chunk.FAQMetadata()
if err != nil || meta == nil {
continue
}
if meta.StandardQuestion != "" {
existingQuestions[meta.StandardQuestion] = true
}
for _, q := range meta.SimilarQuestions {
if q != "" {
existingQuestions[q] = true
}
}
}
// 构建当前批次的标准问和相似问集合(用于批次内去重)
batchQuestions := make(map[string]int) // value 为首次出现的索引
for i, entry := range entries {
// 验证条目基本格式
if err := validateFAQEntryPayloadBasic(&entry); err != nil {
progress.FailedCount++
progress.FailedEntries = append(progress.FailedEntries, buildFAQFailedEntry(i, err.Error(), &entry))
continue
}
standardQ := strings.TrimSpace(entry.StandardQuestion)
// 检查标准问是否与已有知识库重复
if existingQuestions[standardQ] {
progress.FailedCount++
progress.FailedEntries = append(progress.FailedEntries, buildFAQFailedEntry(i, "标准问与知识库中已有问题重复", &entry))
continue
}
// 检查标准问是否与同批次重复
if firstIdx, exists := batchQuestions[standardQ]; exists {
progress.FailedCount++
progress.FailedEntries = append(progress.FailedEntries, buildFAQFailedEntry(i, fmt.Sprintf("标准问与批次内第 %d 条重复", firstIdx+1), &entry))
continue
}
// 检查相似问是否有重复
hasDuplicate := false
for _, q := range entry.SimilarQuestions {
q = strings.TrimSpace(q)
if q == "" {
continue
}
if existingQuestions[q] {
progress.FailedCount++
progress.FailedEntries = append(progress.FailedEntries, buildFAQFailedEntry(i, fmt.Sprintf("相似问 \"%s\" 与知识库中已有问题重复", q), &entry))
hasDuplicate = true
break
}
if firstIdx, exists := batchQuestions[q]; exists {
progress.FailedCount++
progress.FailedEntries = append(progress.FailedEntries, buildFAQFailedEntry(i, fmt.Sprintf("相似问 \"%s\" 与批次内第 %d 条重复", q, firstIdx+1), &entry))
hasDuplicate = true
break
}
}
if hasDuplicate {
continue
}
// 将当前条目的标准问和相似问加入批次集合
batchQuestions[standardQ] = i
for _, q := range entry.SimilarQuestions {
q = strings.TrimSpace(q)
if q != "" {
batchQuestions[q] = i
}
}
// 记录通过验证的条目索引
validIndices = append(validIndices, i)
// 定期更新进度消息(验证阶段不更新 Processed
if (i+1)%100 == 0 {
progress.Message = fmt.Sprintf("正在验证条目 %d/%d...", i+1, len(entries))
progress.UpdatedAt = time.Now().Unix()
if err := s.saveFAQImportProgress(ctx, progress); err != nil {
logger.Warnf(ctx, "Failed to update FAQ dry run progress: %v", err)
}
}
}
return validIndices
}
// validateEntriesForReplaceModeWithProgress 验证 Replace 模式下的条目(带进度更新)
// 注意:验证阶段不更新 Processed只有实际导入时才更新
func (s *knowledgeService) validateEntriesForReplaceModeWithProgress(ctx context.Context,
entries []types.FAQEntryPayload, progress *types.FAQImportProgress,
) []int {
validIndices := make([]int, 0, len(entries))
// Replace 模式下只检查批次内重复
batchQuestions := make(map[string]int) // value 为首次出现的索引
for i, entry := range entries {
// 验证条目基本格式
if err := validateFAQEntryPayloadBasic(&entry); err != nil {
progress.FailedCount++
progress.FailedEntries = append(progress.FailedEntries, buildFAQFailedEntry(i, err.Error(), &entry))
continue
}
standardQ := strings.TrimSpace(entry.StandardQuestion)
// 检查标准问是否与同批次重复
if firstIdx, exists := batchQuestions[standardQ]; exists {
progress.FailedCount++
progress.FailedEntries = append(progress.FailedEntries, buildFAQFailedEntry(i, fmt.Sprintf("标准问与批次内第 %d 条重复", firstIdx+1), &entry))
continue
}
// 检查相似问是否有重复
hasDuplicate := false
for _, q := range entry.SimilarQuestions {
q = strings.TrimSpace(q)
if q == "" {
continue
}
if firstIdx, exists := batchQuestions[q]; exists {
progress.FailedCount++
progress.FailedEntries = append(progress.FailedEntries, buildFAQFailedEntry(i, fmt.Sprintf("相似问 \"%s\" 与批次内第 %d 条重复", q, firstIdx+1), &entry))
hasDuplicate = true
break
}
}
if hasDuplicate {
continue
}
// 将当前条目的标准问和相似问加入批次集合
batchQuestions[standardQ] = i
for _, q := range entry.SimilarQuestions {
q = strings.TrimSpace(q)
if q != "" {
batchQuestions[q] = i
}
}
// 记录通过验证的条目索引
validIndices = append(validIndices, i)
// 定期更新进度消息(验证阶段不更新 Processed
if (i+1)%100 == 0 {
progress.Message = fmt.Sprintf("正在验证条目 %d/%d...", i+1, len(entries))
progress.UpdatedAt = time.Now().Unix()
if err := s.saveFAQImportProgress(ctx, progress); err != nil {
logger.Warnf(ctx, "Failed to update FAQ dry run progress: %v", err)
}
}
}
return validIndices
}
// validateFAQEntryPayloadBasic 验证 FAQ 条目的基本格式
func validateFAQEntryPayloadBasic(entry *types.FAQEntryPayload) error {
if entry == nil {
return fmt.Errorf("条目不能为空")
}
standardQ := strings.TrimSpace(entry.StandardQuestion)
if standardQ == "" {
return fmt.Errorf("标准问不能为空")
}
if len(entry.Answers) == 0 {
return fmt.Errorf("答案不能为空")
}
hasValidAnswer := false
for _, a := range entry.Answers {
if strings.TrimSpace(a) != "" {
hasValidAnswer = true
break
}
}
if !hasValidAnswer {
return fmt.Errorf("答案不能全为空")
}
return nil
}
// calculateAppendOperations 计算Append模式下需要处理的条目跳过已存在且内容相同的条目
// 同时过滤掉标准问或相似问与同批次或已有知识库中重复的条目
func (s *knowledgeService) calculateAppendOperations(ctx context.Context,
tenantID uint64, kbID string, entries []types.FAQEntryPayload,
) ([]types.FAQEntryPayload, int, error) {
if len(entries) == 0 {
return []types.FAQEntryPayload{}, 0, nil
}
// 1. 查询知识库中已有的所有FAQ chunks的metadata
existingChunks, err := s.chunkRepo.ListAllFAQChunksWithMetadataByKnowledgeBaseID(ctx, tenantID, kbID)
if err != nil {
return nil, 0, fmt.Errorf("failed to list existing FAQ chunks: %w", err)
}
// 2. 构建已存在的标准问和相似问集合
existingQuestions := make(map[string]bool)
for _, chunk := range existingChunks {
meta, err := chunk.FAQMetadata()
if err != nil || meta == nil {
continue
}
// 添加标准问
if meta.StandardQuestion != "" {
existingQuestions[meta.StandardQuestion] = true
}
// 添加相似问
for _, q := range meta.SimilarQuestions {
if q != "" {
existingQuestions[q] = true
}
}
}
// 3. 构建当前批次的标准问和相似问集合(用于批次内去重)
batchQuestions := make(map[string]bool)
entriesToProcess := make([]types.FAQEntryPayload, 0, len(entries))
skippedCount := 0
for _, entry := range entries {
meta, err := sanitizeFAQEntryPayload(&entry)
if err != nil {
// 跳过无效条目
skippedCount++
logger.Warnf(ctx, "Skipping invalid FAQ entry: %v", err)
continue
}
// 检查标准问是否重复(与已有或同批次)
if existingQuestions[meta.StandardQuestion] || batchQuestions[meta.StandardQuestion] {
skippedCount++
logger.Infof(ctx, "Skipping FAQ entry with duplicate standard question: %s", meta.StandardQuestion)
continue
}
// 检查相似问是否有重复(与已有或同批次)
hasDuplicateSimilar := false
for _, q := range meta.SimilarQuestions {
if existingQuestions[q] || batchQuestions[q] {
hasDuplicateSimilar = true
logger.Infof(ctx, "Skipping FAQ entry with duplicate similar question: %s (standard: %s)", q, meta.StandardQuestion)
break
}
}
if hasDuplicateSimilar {
skippedCount++
continue
}
// 将当前条目的标准问和相似问加入批次集合
batchQuestions[meta.StandardQuestion] = true
for _, q := range meta.SimilarQuestions {
batchQuestions[q] = true
}
entriesToProcess = append(entriesToProcess, entry)
}
return entriesToProcess, skippedCount, nil
}
// calculateReplaceOperations 计算Replace模式下需要删除、创建、更新的条目
// 同时过滤掉同批次内标准问或相似问重复的条目
func (s *knowledgeService) calculateReplaceOperations(ctx context.Context,
tenantID uint64, knowledgeID string, newEntries []types.FAQEntryPayload,
) ([]types.FAQEntryPayload, []*types.Chunk, int, error) {
// 获取 kbID 用于解析 tag
var kbID string
if len(newEntries) > 0 {
// 从 knowledgeID 获取 kbID
knowledge, err := s.repo.GetKnowledgeByID(ctx, tenantID, knowledgeID)
if err != nil {
return nil, nil, 0, fmt.Errorf("failed to get knowledge: %w", err)
}
if knowledge != nil {
kbID = knowledge.KnowledgeBaseID
}
}
// 计算所有新条目的 content hash并同时构建 hash 到 entry 的映射
type entryWithHash struct {
entry types.FAQEntryPayload
hash string
meta *types.FAQChunkMetadata
}
entriesWithHash := make([]entryWithHash, 0, len(newEntries))
newHashSet := make(map[string]bool)
// 用于批次内标准问和相似问去重
batchQuestions := make(map[string]bool)
batchSkippedCount := 0
for _, entry := range newEntries {
meta, err := sanitizeFAQEntryPayload(&entry)
if err != nil {
batchSkippedCount++
logger.Warnf(ctx, "Skipping invalid FAQ entry in replace mode: %v", err)
continue
}
// 检查标准问是否在同批次中重复
if batchQuestions[meta.StandardQuestion] {
batchSkippedCount++
logger.Infof(ctx, "Skipping FAQ entry with duplicate standard question in batch: %s", meta.StandardQuestion)
continue
}
// 检查相似问是否在同批次中重复
hasDuplicateSimilar := false
for _, q := range meta.SimilarQuestions {
if batchQuestions[q] {
hasDuplicateSimilar = true
logger.Infof(ctx, "Skipping FAQ entry with duplicate similar question in batch: %s (standard: %s)", q, meta.StandardQuestion)
break
}
}
if hasDuplicateSimilar {
batchSkippedCount++
continue
}
// 将当前条目的标准问和相似问加入批次集合
batchQuestions[meta.StandardQuestion] = true
for _, q := range meta.SimilarQuestions {
batchQuestions[q] = true
}
hash := types.CalculateFAQContentHash(meta)
if hash != "" {
entriesWithHash = append(entriesWithHash, entryWithHash{entry: entry, hash: hash, meta: meta})
newHashSet[hash] = true
}
}
// 查询所有已存在的chunks
allExistingChunks, err := s.chunkRepo.ListAllFAQChunksByKnowledgeID(ctx, tenantID, knowledgeID)
if err != nil {
return nil, nil, 0, fmt.Errorf("failed to list existing chunks: %w", err)
}
// 在内存中过滤出匹配新条目hash的chunks并构建map
existingHashMap := make(map[string]*types.Chunk)
for _, chunk := range allExistingChunks {
if chunk.ContentHash != "" && newHashSet[chunk.ContentHash] {
existingHashMap[chunk.ContentHash] = chunk
}
}
// 计算需要删除的chunks数据库中有但新批次中没有的或hash不匹配的
chunksToDelete := make([]*types.Chunk, 0)
for _, chunk := range allExistingChunks {
if chunk.ContentHash == "" {
// 如果没有hash需要删除可能是旧数据
chunksToDelete = append(chunksToDelete, chunk)
} else if !newHashSet[chunk.ContentHash] {
// hash不在新条目中需要删除
chunksToDelete = append(chunksToDelete, chunk)
}
}
// 计算需要创建的条目利用已经计算好的hash避免重复计算
entriesToProcess := make([]types.FAQEntryPayload, 0, len(entriesWithHash))
skippedCount := batchSkippedCount
for _, ewh := range entriesWithHash {
existingChunk := existingHashMap[ewh.hash]
if existingChunk != nil {
// hash 匹配,检查 tag 是否变化
newTagID, err := s.resolveTagID(ctx, kbID, &ewh.entry)
if err != nil {
logger.Warnf(ctx, "Failed to resolve tag for entry, treating as new: %v", err)
entriesToProcess = append(entriesToProcess, ewh.entry)
continue
}
if existingChunk.TagID != newTagID {
// tag 变化了,需要删除旧的并创建新的
logger.Infof(ctx, "FAQ entry tag changed from %s to %s, will update", existingChunk.TagID, newTagID)
chunksToDelete = append(chunksToDelete, existingChunk)
entriesToProcess = append(entriesToProcess, ewh.entry)
} else {
// hash 和 tag 都相同,跳过
skippedCount++
}
continue
}
// hash不匹配或不存在需要创建
entriesToProcess = append(entriesToProcess, ewh.entry)
}
return entriesToProcess, chunksToDelete, skippedCount, nil
}
// executeFAQImport 执行实际的FAQ导入逻辑
func (s *knowledgeService) executeFAQImport(ctx context.Context, taskID string, kbID string,
payload *types.FAQBatchUpsertPayload, tenantID uint64, processedCount int,
progress *types.FAQImportProgress,
) (err error) {
// 保存知识库和embedding模型信息用于清理索引
var kb *types.KnowledgeBase
var embeddingModel embedding.Embedder
totalEntries := len(payload.Entries) + processedCount
// Recovery机制如果发生任何错误或panic回滚所有已创建的chunks和索引数据
defer func() {
// 捕获panic
if r := recover(); r != nil {
buf := make([]byte, 8192)
n := runtime.Stack(buf, false)
stack := string(buf[:n])
logger.Errorf(ctx, "FAQ import task %s panicked: %v\n%s", taskID, r, stack)
err = fmt.Errorf("panic during FAQ import: %v", r)
}
}()
kb, err = s.validateFAQKnowledgeBase(ctx, kbID)
if err != nil {
return err
}
kb.EnsureDefaults()
// 获取embedding模型用于后续清理索引
embeddingModel, err = s.modelService.GetEmbeddingModel(ctx, kb.EmbeddingModelID)
if err != nil {
return fmt.Errorf("failed to get embedding model: %w", err)
}
faqKnowledge, err := s.ensureFAQKnowledge(ctx, tenantID, kb)
if err != nil {
return err
}
// 获取索引模式
indexMode := types.FAQIndexModeQuestionOnly
if kb.FAQConfig != nil && kb.FAQConfig.IndexMode != "" {
indexMode = kb.FAQConfig.IndexMode
}
// 增量更新逻辑:计算需要处理的条目
var entriesToProcess []types.FAQEntryPayload
var chunksToDelete []*types.Chunk
var skippedCount int
if payload.Mode == types.FAQBatchModeReplace {
// Replace模式计算需要删除、创建、更新的条目
entriesToProcess, chunksToDelete, skippedCount, err = s.calculateReplaceOperations(
ctx,
tenantID,
faqKnowledge.ID,
payload.Entries,
)
if err != nil {
return fmt.Errorf("failed to calculate replace operations: %w", err)
}
// 删除需要删除的chunks包括需要更新的旧chunks
if len(chunksToDelete) > 0 {
chunkIDsToDelete := make([]string, 0, len(chunksToDelete))
for _, chunk := range chunksToDelete {
chunkIDsToDelete = append(chunkIDsToDelete, chunk.ID)
}
if err := s.chunkRepo.DeleteChunks(ctx, tenantID, chunkIDsToDelete); err != nil {
return fmt.Errorf("failed to delete chunks: %w", err)
}
// 删除索引
if err := s.deleteFAQChunkVectors(ctx, kb, faqKnowledge, chunksToDelete); err != nil {
return fmt.Errorf("failed to delete chunk vectors: %w", err)
}
logger.Infof(ctx, "FAQ import task %s: deleted %d chunks (including updates)", taskID, len(chunksToDelete))
}
} else {
// Append模式查询已存在的条目跳过未变化的
entriesToProcess, skippedCount, err = s.calculateAppendOperations(ctx, tenantID, kb.ID, payload.Entries)
if err != nil {
return fmt.Errorf("failed to calculate append operations: %w", err)
}
}
logger.Infof(
ctx,
"FAQ import task %s: total entries: %d, to process: %d, skipped: %d",
taskID,
len(payload.Entries),
len(entriesToProcess),
skippedCount,
)
// 如果没有需要处理的条目,直接返回
if len(entriesToProcess) == 0 {
logger.Infof(ctx, "FAQ import task %s: no entries to process, all skipped", taskID)
return nil
}
// 分批处理需要创建的条目
remainingEntries := len(entriesToProcess)
totalStartTime := time.Now()
actualProcessed := skippedCount + processedCount
logger.Infof(
ctx,
"FAQ import task %s: starting batch processing, remaining entries: %d, total entries: %d, batch size: %d",
taskID,
remainingEntries,
totalEntries,
faqImportBatchSize,
)
for i := 0; i < remainingEntries; i += faqImportBatchSize {
batchStartTime := time.Now()
end := i + faqImportBatchSize
if end > remainingEntries {
end = remainingEntries
}
batch := entriesToProcess[i:end]
logger.Infof(ctx, "FAQ import task %s: processing batch %d-%d (%d entries)", taskID, i+1, end, len(batch))
// 构建chunks
buildStartTime := time.Now()
chunks := make([]*types.Chunk, 0, len(batch))
chunkIds := make([]string, 0, len(batch))
for idx, entry := range batch {
meta, err := sanitizeFAQEntryPayload(&entry)
if err != nil {
logger.ErrorWithFields(ctx, err, map[string]interface{}{
"entry": entry,
"task_id": taskID,
})
return fmt.Errorf("failed to sanitize entry at index %d: %w", i+idx, err)
}
// 解析 TagID
tagID, err := s.resolveTagID(ctx, kbID, &entry)
if err != nil {
logger.ErrorWithFields(ctx, err, map[string]interface{}{
"entry": entry,
"task_id": taskID,
})
return fmt.Errorf("failed to resolve tag for entry at index %d: %w", i+idx, err)
}
isEnabled := true
if entry.IsEnabled != nil {
isEnabled = *entry.IsEnabled
}
// ChunkIndex计算startChunkIndex + (i+idx) + initialProcessed
chunk := &types.Chunk{
ID: uuid.New().String(),
TenantID: tenantID,
KnowledgeID: faqKnowledge.ID,
KnowledgeBaseID: kb.ID,
Content: buildFAQChunkContent(meta, indexMode),
// ChunkIndex: 0,
IsEnabled: isEnabled,
ChunkType: types.ChunkTypeFAQ,
TagID: tagID, // 使用解析后的 TagID
Status: int(types.ChunkStatusStored), // store but not indexed
}
// 如果指定了 ID用于数据迁移设置 SeqID
if entry.ID != nil && *entry.ID > 0 {
chunk.SeqID = *entry.ID
}
if err := chunk.SetFAQMetadata(meta); err != nil {
return fmt.Errorf("failed to set FAQ metadata: %w", err)
}
chunks = append(chunks, chunk)
chunkIds = append(chunkIds, chunk.ID)
}
buildDuration := time.Since(buildStartTime)
logger.Debugf(ctx, "FAQ import task %s: batch %d-%d built %d chunks in %v, chunk IDs: %v",
taskID, i+1, end, len(chunks), buildDuration, chunkIds)
// 创建chunks
createStartTime := time.Now()
if err := s.chunkService.CreateChunks(ctx, chunks); err != nil {
return fmt.Errorf("failed to create chunks: %w", err)
}
createDuration := time.Since(createStartTime)
logger.Infof(
ctx,
"FAQ import task %s: batch %d-%d created %d chunks in %v",
taskID,
i+1,
end,
len(chunks),
createDuration,
)
// 索引chunks
indexStartTime := time.Now()
// 注意如果索引失败defer中的recovery机制会自动回滚已创建的chunks和索引数据
if err := s.indexFAQChunks(ctx, kb, faqKnowledge, chunks, embeddingModel, true, false); err != nil {
return fmt.Errorf("failed to index chunks: %w", err)
}
indexDuration := time.Since(indexStartTime)
logger.Infof(
ctx,
"FAQ import task %s: batch %d-%d indexed %d chunks in %v",
taskID,
i+1,
end,
len(chunks),
indexDuration,
)
// 更新chunks的Status为已索引
chunksToUpdate := make([]*types.Chunk, 0, len(chunks))
for _, chunk := range chunks {
chunk.Status = int(types.ChunkStatusIndexed) // indexed
chunksToUpdate = append(chunksToUpdate, chunk)
}
if err := s.chunkService.UpdateChunks(ctx, chunksToUpdate); err != nil {
return fmt.Errorf("failed to update chunks status: %w", err)
}
// 收集成功条目信息
for idx, chunk := range chunks {
entryIdx := i + idx + processedCount // 原始条目索引
meta, _ := chunk.FAQMetadata()
standardQ := ""
if meta != nil {
standardQ = meta.StandardQuestion
}
// 获取 tag info
var tagID int64
tagName := ""
if chunk.TagID != "" {
if tag, err := s.tagRepo.GetByID(ctx, tenantID, chunk.TagID); err == nil && tag != nil {
tagID = tag.SeqID
tagName = tag.Name
}
}
progress.SuccessEntries = append(progress.SuccessEntries, types.FAQSuccessEntry{
Index: entryIdx,
SeqID: chunk.SeqID,
TagID: tagID,
TagName: tagName,
StandardQuestion: standardQ,
})
}
actualProcessed += len(batch)
// 更新任务进度
progress := int(float64(actualProcessed) / float64(totalEntries) * 100)
if err := s.updateFAQImportProgressStatus(ctx, taskID, types.FAQImportStatusProcessing, progress, totalEntries, actualProcessed, fmt.Sprintf("正在处理第 %d/%d 条", actualProcessed, totalEntries), ""); err != nil {
logger.Errorf(ctx, "Failed to update task progress: %v", err)
}
batchDuration := time.Since(batchStartTime)
logger.Infof(
ctx,
"FAQ import task %s: batch %d-%d completed in %v (build: %v, create: %v, index: %v), total progress: %d/%d (%d%%)",
taskID,
i+1,
end,
batchDuration,
buildDuration,
createDuration,
indexDuration,
actualProcessed,
totalEntries,
progress,
)
}
totalDuration := time.Since(totalStartTime)
logger.Infof(
ctx,
"FAQ import task %s: all batches completed, processed: %d entries (skipped: %d) in %v, avg: %v per entry",
taskID,
actualProcessed,
skippedCount,
totalDuration,
totalDuration/time.Duration(actualProcessed),
)
return nil
}
// CreateFAQEntry creates a single FAQ entry synchronously.
func (s *knowledgeService) CreateFAQEntry(ctx context.Context,
kbID string, payload *types.FAQEntryPayload,
) (*types.FAQEntry, error) {
if payload == nil {
return nil, werrors.NewBadRequestError("请求体不能为空")
}
kb, err := s.validateFAQKnowledgeBase(ctx, kbID)
if err != nil {
return nil, err
}
kb.EnsureDefaults()
tenantID := ctx.Value(types.TenantIDContextKey).(uint64)
// 验证并清理输入
meta, err := sanitizeFAQEntryPayload(payload)
if err != nil {
return nil, err
}
// 解析 TagID
tagID, err := s.resolveTagID(ctx, kbID, payload)
if err != nil {
return nil, err
}
// 检查标准问和相似问是否与其他条目重复
if err := s.checkFAQQuestionDuplicate(ctx, tenantID, kb.ID, "", meta); err != nil {
return nil, err
}
// 确保FAQ Knowledge存在
faqKnowledge, err := s.ensureFAQKnowledge(ctx, tenantID, kb)
if err != nil {
return nil, fmt.Errorf("failed to ensure FAQ knowledge: %w", err)
}
// 获取索引模式
indexMode := types.FAQIndexModeQuestionOnly
if kb.FAQConfig != nil && kb.FAQConfig.IndexMode != "" {
indexMode = kb.FAQConfig.IndexMode
}
// 获取embedding模型
embeddingModel, err := s.modelService.GetEmbeddingModel(ctx, kb.EmbeddingModelID)
if err != nil {
return nil, fmt.Errorf("failed to get embedding model: %w", err)
}
// 创建chunk
isEnabled := true
if payload.IsEnabled != nil {
isEnabled = *payload.IsEnabled
}
// 默认可推荐
flags := types.ChunkFlagRecommended
if payload.IsRecommended != nil && !*payload.IsRecommended {
flags = 0
}
chunk := &types.Chunk{
ID: uuid.New().String(),
TenantID: tenantID,
KnowledgeID: faqKnowledge.ID,
KnowledgeBaseID: kb.ID,
Content: buildFAQChunkContent(meta, indexMode),
IsEnabled: isEnabled,
Flags: flags,
ChunkType: types.ChunkTypeFAQ,
TagID: tagID, // 使用解析后的 TagID
Status: int(types.ChunkStatusStored),
}
// 如果指定了 ID用于数据迁移设置 SeqID
if payload.ID != nil && *payload.ID > 0 {
chunk.SeqID = *payload.ID
}
if err := chunk.SetFAQMetadata(meta); err != nil {
return nil, fmt.Errorf("failed to set FAQ metadata: %w", err)
}
// 保存chunk
if err := s.chunkService.CreateChunks(ctx, []*types.Chunk{chunk}); err != nil {
return nil, fmt.Errorf("failed to create chunk: %w", err)
}
// 索引chunk
if err := s.indexFAQChunks(ctx, kb, faqKnowledge, []*types.Chunk{chunk}, embeddingModel, true, false); err != nil {
// 如果索引失败删除已创建的chunk
_ = s.chunkService.DeleteChunk(ctx, chunk.ID)
return nil, fmt.Errorf("failed to index chunk: %w", err)
}
// 更新chunk状态为已索引
chunk.Status = int(types.ChunkStatusIndexed)
if err := s.chunkService.UpdateChunk(ctx, chunk); err != nil {
return nil, fmt.Errorf("failed to update chunk status: %w", err)
}
// Build tag seq_id map for conversion
tagSeqIDMap := make(map[string]int64)
if chunk.TagID != "" {
tag, tagErr := s.tagRepo.GetByID(ctx, tenantID, chunk.TagID)
if tagErr == nil && tag != nil {
tagSeqIDMap[tag.ID] = tag.SeqID
}
}
// 转换为FAQEntry返回
entry, err := s.chunkToFAQEntry(chunk, kb, tagSeqIDMap)
if err != nil {
return nil, err
}
// 查询TagName
if chunk.TagID != "" {
tag, tagErr := s.tagRepo.GetByID(ctx, tenantID, chunk.TagID)
if tagErr == nil && tag != nil {
entry.TagName = tag.Name
}
}
return entry, nil
}
// GetFAQEntry retrieves a single FAQ entry by seq_id.
func (s *knowledgeService) GetFAQEntry(ctx context.Context,
kbID string, entrySeqID int64,
) (*types.FAQEntry, error) {
if entrySeqID <= 0 {
return nil, werrors.NewBadRequestError("条目ID不能为空")
}
kb, err := s.validateFAQKnowledgeBase(ctx, kbID)
if err != nil {
return nil, err
}
kb.EnsureDefaults()
tenantID := ctx.Value(types.TenantIDContextKey).(uint64)
// 获取chunk by seq_id
chunk, err := s.chunkRepo.GetChunkBySeqID(ctx, tenantID, entrySeqID)
if err != nil {
return nil, werrors.NewNotFoundError("FAQ条目不存在")
}
// 验证chunk属于当前知识库
if chunk.KnowledgeBaseID != kb.ID || chunk.TenantID != tenantID {
return nil, werrors.NewNotFoundError("FAQ条目不存在")
}
// 验证是FAQ类型
if chunk.ChunkType != types.ChunkTypeFAQ {
return nil, werrors.NewNotFoundError("FAQ条目不存在")
}
// Build tag seq_id map for conversion
tagSeqIDMap := make(map[string]int64)
if chunk.TagID != "" {
tag, tagErr := s.tagRepo.GetByID(ctx, tenantID, chunk.TagID)
if tagErr == nil && tag != nil {
tagSeqIDMap[tag.ID] = tag.SeqID
}
}
// 转换为FAQEntry返回
entry, err := s.chunkToFAQEntry(chunk, kb, tagSeqIDMap)
if err != nil {
return nil, err
}
// 查询TagName
if chunk.TagID != "" {
tag, tagErr := s.tagRepo.GetByID(ctx, tenantID, chunk.TagID)
if tagErr == nil && tag != nil {
entry.TagName = tag.Name
}
}
return entry, nil
}
// UpdateFAQEntry updates a single FAQ entry.
func (s *knowledgeService) UpdateFAQEntry(ctx context.Context,
kbID string, entrySeqID int64, payload *types.FAQEntryPayload,
) (*types.FAQEntry, error) {
if payload == nil {
return nil, werrors.NewBadRequestError("请求体不能为空")
}
kb, err := s.validateFAQKnowledgeBase(ctx, kbID)
if err != nil {
return nil, err
}
kb.EnsureDefaults()
tenantID := ctx.Value(types.TenantIDContextKey).(uint64)
chunk, err := s.chunkRepo.GetChunkBySeqID(ctx, tenantID, entrySeqID)
if err != nil {
return nil, werrors.NewNotFoundError("FAQ条目不存在")
}
if chunk.KnowledgeBaseID != kb.ID {
return nil, werrors.NewForbiddenError("无权操作该 FAQ 条目")
}
if chunk.ChunkType != types.ChunkTypeFAQ {
return nil, werrors.NewBadRequestError("仅支持更新 FAQ 条目")
}
meta, err := sanitizeFAQEntryPayload(payload)
if err != nil {
return nil, err
}
// 检查标准问和相似问是否与其他条目重复
if err := s.checkFAQQuestionDuplicate(ctx, tenantID, kb.ID, chunk.ID, meta); err != nil {
return nil, err
}
// 获取旧的相似问列表,用于增量更新
var oldSimilarQuestions []string
var oldStandardQuestion string
var oldAnswers []string
questionIndexMode := types.FAQQuestionIndexModeCombined
if kb.FAQConfig != nil && kb.FAQConfig.QuestionIndexMode != "" {
questionIndexMode = kb.FAQConfig.QuestionIndexMode
}
if existing, err := chunk.FAQMetadata(); err == nil && existing != nil {
meta.Version = existing.Version + 1
// 保存旧的内容用于增量比较
if questionIndexMode == types.FAQQuestionIndexModeSeparate {
oldSimilarQuestions = existing.SimilarQuestions
oldStandardQuestion = existing.StandardQuestion
oldAnswers = existing.Answers
}
}
if err := chunk.SetFAQMetadata(meta); err != nil {
return nil, err
}
// 获取索引模式
indexMode := types.FAQIndexModeQuestionOnly
if kb.FAQConfig != nil && kb.FAQConfig.IndexMode != "" {
indexMode = kb.FAQConfig.IndexMode
}
chunk.Content = buildFAQChunkContent(meta, indexMode)
// Convert tag seq_id to UUID
if payload.TagID > 0 {
tag, tagErr := s.tagRepo.GetBySeqID(ctx, tenantID, payload.TagID)
if tagErr != nil {
return nil, werrors.NewNotFoundError("标签不存在")
}
chunk.TagID = tag.ID
} else {
chunk.TagID = ""
}
if payload.IsEnabled != nil {
chunk.IsEnabled = *payload.IsEnabled
}
// 处理推荐状态
if payload.IsRecommended != nil {
if *payload.IsRecommended {
chunk.Flags = chunk.Flags.SetFlag(types.ChunkFlagRecommended)
} else {
chunk.Flags = chunk.Flags.ClearFlag(types.ChunkFlagRecommended)
}
}
chunk.UpdatedAt = time.Now()
if err := s.chunkService.UpdateChunk(ctx, chunk); err != nil {
return nil, err
}
// Note: We don't need to call BatchUpdateChunkEnabledStatus here because
// indexFAQChunks will delete old vectors and re-insert with the latest chunk data
// (including the updated is_enabled status). Calling both would cause version conflicts.
faqKnowledge, err := s.repo.GetKnowledgeByID(ctx, tenantID, chunk.KnowledgeID)
if err != nil {
return nil, err
}
embeddingModel, err := s.modelService.GetEmbeddingModel(ctx, kb.EmbeddingModelID)
if err != nil {
return nil, err
}
// 增量索引优化:只对变化的内容进行索引操作
if questionIndexMode == types.FAQQuestionIndexModeSeparate && len(oldSimilarQuestions) > 0 {
// 分别索引模式下的增量更新
if err := s.incrementalIndexFAQEntry(ctx, kb, faqKnowledge, chunk, embeddingModel,
oldStandardQuestion, oldSimilarQuestions, oldAnswers, meta); err != nil {
return nil, err
}
} else {
// Combined 模式或首次创建,使用全量索引
// 增量删除:只删除被移除的相似问索引
oldSimilarQuestionCount := len(oldSimilarQuestions)
newSimilarQuestionCount := len(meta.SimilarQuestions)
if questionIndexMode == types.FAQQuestionIndexModeSeparate && oldSimilarQuestionCount > newSimilarQuestionCount {
tenantInfo := ctx.Value(types.TenantInfoContextKey).(*types.Tenant)
retrieveEngine, engineErr := retriever.NewCompositeRetrieveEngine(s.retrieveEngine, tenantInfo.GetEffectiveEngines())
if engineErr == nil {
sourceIDsToDelete := make([]string, 0, oldSimilarQuestionCount-newSimilarQuestionCount)
for i := newSimilarQuestionCount; i < oldSimilarQuestionCount; i++ {
sourceIDsToDelete = append(sourceIDsToDelete, fmt.Sprintf("%s-%d", chunk.ID, i))
}
if len(sourceIDsToDelete) > 0 {
logger.Debugf(ctx, "UpdateFAQEntry: incremental delete %d obsolete source IDs", len(sourceIDsToDelete))
if delErr := retrieveEngine.DeleteBySourceIDList(ctx, sourceIDsToDelete, embeddingModel.GetDimensions(), types.KnowledgeTypeFAQ); delErr != nil {
logger.Warnf(ctx, "UpdateFAQEntry: failed to delete obsolete source IDs: %v", delErr)
}
}
}
}
// 使用 needDelete=false因为 EFPutDocument 会自动覆盖相同 SourceID 的文档
if err := s.indexFAQChunks(ctx, kb, faqKnowledge, []*types.Chunk{chunk}, embeddingModel, false, false); err != nil {
return nil, err
}
}
// Build tag seq_id map for conversion
tagSeqIDMap := make(map[string]int64)
if chunk.TagID != "" {
tag, tagErr := s.tagRepo.GetByID(ctx, tenantID, chunk.TagID)
if tagErr == nil && tag != nil {
tagSeqIDMap[tag.ID] = tag.SeqID
}
}
// 转换为FAQEntry返回
entry, err := s.chunkToFAQEntry(chunk, kb, tagSeqIDMap)
if err != nil {
return nil, err
}
// 查询TagName
if chunk.TagID != "" {
tag, tagErr := s.tagRepo.GetByID(ctx, tenantID, chunk.TagID)
if tagErr == nil && tag != nil {
entry.TagName = tag.Name
}
}
return entry, nil
}
// AddSimilarQuestions adds similar questions to a FAQ entry.
// This will append the new questions to the existing similar questions list.
func (s *knowledgeService) AddSimilarQuestions(ctx context.Context,
kbID string, entrySeqID int64, questions []string,
) (*types.FAQEntry, error) {
if len(questions) == 0 {
return nil, werrors.NewBadRequestError("相似问列表不能为空")
}
kb, err := s.validateFAQKnowledgeBase(ctx, kbID)
if err != nil {
return nil, err
}
kb.EnsureDefaults()
tenantID := ctx.Value(types.TenantIDContextKey).(uint64)
// Get existing FAQ entry
chunk, err := s.chunkRepo.GetChunkBySeqID(ctx, tenantID, entrySeqID)
if err != nil {
return nil, werrors.NewNotFoundError("FAQ条目不存在")
}
if chunk.KnowledgeBaseID != kb.ID {
return nil, werrors.NewForbiddenError("无权操作该 FAQ 条目")
}
if chunk.ChunkType != types.ChunkTypeFAQ {
return nil, werrors.NewBadRequestError("仅支持更新 FAQ 条目")
}
// Get existing metadata
meta, err := chunk.FAQMetadata()
if err != nil || meta == nil {
return nil, werrors.NewBadRequestError("获取 FAQ 元数据失败")
}
// Deduplicate and sanitize new questions
existingSet := make(map[string]struct{})
for _, q := range meta.SimilarQuestions {
existingSet[q] = struct{}{}
}
// Also add standard question to prevent duplicates
existingSet[meta.StandardQuestion] = struct{}{}
newQuestions := make([]string, 0, len(questions))
for _, q := range questions {
q = strings.TrimSpace(q)
if q == "" {
continue
}
if _, exists := existingSet[q]; exists {
continue
}
existingSet[q] = struct{}{}
newQuestions = append(newQuestions, q)
}
if len(newQuestions) == 0 {
// No new questions to add, return current entry
tagSeqIDMap := make(map[string]int64)
if chunk.TagID != "" {
tag, tagErr := s.tagRepo.GetByID(ctx, tenantID, chunk.TagID)
if tagErr == nil && tag != nil {
tagSeqIDMap[tag.ID] = tag.SeqID
}
}
return s.chunkToFAQEntry(chunk, kb, tagSeqIDMap)
}
// Check for duplicates with other entries
tempMeta := &types.FAQChunkMetadata{
StandardQuestion: meta.StandardQuestion,
SimilarQuestions: append(meta.SimilarQuestions, newQuestions...),
}
if err := s.checkFAQQuestionDuplicate(ctx, tenantID, kb.ID, chunk.ID, tempMeta); err != nil {
return nil, err
}
// Update metadata
oldSimilarQuestions := meta.SimilarQuestions
meta.SimilarQuestions = append(meta.SimilarQuestions, newQuestions...)
meta.Version++
if err := chunk.SetFAQMetadata(meta); err != nil {
return nil, err
}
// Update chunk content
indexMode := types.FAQIndexModeQuestionOnly
if kb.FAQConfig != nil && kb.FAQConfig.IndexMode != "" {
indexMode = kb.FAQConfig.IndexMode
}
chunk.Content = buildFAQChunkContent(meta, indexMode)
chunk.UpdatedAt = time.Now()
if err := s.chunkService.UpdateChunk(ctx, chunk); err != nil {
return nil, err
}
// Index new similar questions
faqKnowledge, err := s.repo.GetKnowledgeByID(ctx, tenantID, chunk.KnowledgeID)
if err != nil {
return nil, err
}
embeddingModel, err := s.modelService.GetEmbeddingModel(ctx, kb.EmbeddingModelID)
if err != nil {
return nil, err
}
questionIndexMode := types.FAQQuestionIndexModeCombined
if kb.FAQConfig != nil && kb.FAQConfig.QuestionIndexMode != "" {
questionIndexMode = kb.FAQConfig.QuestionIndexMode
}
if questionIndexMode == types.FAQQuestionIndexModeSeparate {
// Only index the new similar questions
if err := s.incrementalIndexFAQEntry(ctx, kb, faqKnowledge, chunk, embeddingModel,
meta.StandardQuestion, oldSimilarQuestions, meta.Answers, meta); err != nil {
return nil, err
}
} else {
// Combined mode, re-index the whole entry
if err := s.indexFAQChunks(ctx, kb, faqKnowledge, []*types.Chunk{chunk}, embeddingModel, false, false); err != nil {
return nil, err
}
}
// Build response
tagSeqIDMap := make(map[string]int64)
if chunk.TagID != "" {
tag, tagErr := s.tagRepo.GetByID(ctx, tenantID, chunk.TagID)
if tagErr == nil && tag != nil {
tagSeqIDMap[tag.ID] = tag.SeqID
}
}
entry, err := s.chunkToFAQEntry(chunk, kb, tagSeqIDMap)
if err != nil {
return nil, err
}
if chunk.TagID != "" {
tag, tagErr := s.tagRepo.GetByID(ctx, tenantID, chunk.TagID)
if tagErr == nil && tag != nil {
entry.TagName = tag.Name
}
}
return entry, nil
}
// UpdateFAQEntryStatus updates enable status for a FAQ entry.
func (s *knowledgeService) UpdateFAQEntryStatus(ctx context.Context,
kbID string, entryID string, isEnabled bool,
) error {
kb, err := s.validateFAQKnowledgeBase(ctx, kbID)
if err != nil {
return err
}
tenantID := ctx.Value(types.TenantIDContextKey).(uint64)
chunk, err := s.chunkRepo.GetChunkByID(ctx, tenantID, entryID)
if err != nil {
return err
}
if chunk.KnowledgeBaseID != kb.ID || chunk.ChunkType != types.ChunkTypeFAQ {
return werrors.NewBadRequestError("仅支持更新 FAQ 条目")
}
if chunk.IsEnabled == isEnabled {
return nil
}
chunk.IsEnabled = isEnabled
chunk.UpdatedAt = time.Now()
if err := s.chunkService.UpdateChunk(ctx, chunk); err != nil {
return err
}
// Sync update to retriever engines
chunkStatusMap := map[string]bool{chunk.ID: isEnabled}
tenantInfo := ctx.Value(types.TenantInfoContextKey).(*types.Tenant)
retrieveEngine, err := retriever.NewCompositeRetrieveEngine(s.retrieveEngine, tenantInfo.GetEffectiveEngines())
if err != nil {
return err
}
if err := retrieveEngine.BatchUpdateChunkEnabledStatus(ctx, chunkStatusMap); err != nil {
return err
}
return nil
}
// UpdateFAQEntryFieldsBatch updates multiple fields for FAQ entries in batch.
// This is the unified API for batch updating FAQ entry fields.
// Supports two modes:
// 1. By entry seq_id: use ByID field
// 2. By Tag seq_id: use ByTag field to apply the same update to all entries under a tag
func (s *knowledgeService) UpdateFAQEntryFieldsBatch(ctx context.Context,
kbID string, req *types.FAQEntryFieldsBatchUpdate,
) error {
if req == nil || (len(req.ByID) == 0 && len(req.ByTag) == 0) {
return nil
}
kb, err := s.validateFAQKnowledgeBase(ctx, kbID)
if err != nil {
return err
}
tenantID := ctx.Value(types.TenantIDContextKey).(uint64)
enabledUpdates := make(map[string]bool)
tagUpdates := make(map[string]string)
// Convert exclude seq_ids to UUIDs
excludeUUIDs := make([]string, 0, len(req.ExcludeIDs))
if len(req.ExcludeIDs) > 0 {
excludeChunks, err := s.chunkRepo.ListChunksBySeqID(ctx, tenantID, req.ExcludeIDs)
if err == nil {
for _, c := range excludeChunks {
excludeUUIDs = append(excludeUUIDs, c.ID)
}
}
}
// Handle ByTag updates first (by tag seq_id)
if len(req.ByTag) > 0 {
for tagSeqID, update := range req.ByTag {
// Convert tag seq_id to UUID
tag, err := s.tagRepo.GetBySeqID(ctx, tenantID, tagSeqID)
if err != nil {
return werrors.NewNotFoundError(fmt.Sprintf("标签 %d 不存在", tagSeqID))
}
var setFlags, clearFlags types.ChunkFlags
// Handle IsRecommended
if update.IsRecommended != nil {
if *update.IsRecommended {
setFlags = types.ChunkFlagRecommended
} else {
clearFlags = types.ChunkFlagRecommended
}
}
// Convert new tag seq_id to UUID if provided
var newTagUUID *string
if update.TagID != nil {
if *update.TagID > 0 {
newTag, err := s.tagRepo.GetBySeqID(ctx, tenantID, *update.TagID)
if err != nil {
return werrors.NewNotFoundError(fmt.Sprintf("标签 %d 不存在", *update.TagID))
}
newTagUUID = &newTag.ID
} else {
emptyStr := ""
newTagUUID = &emptyStr
}
}
// Update all chunks with this tag
affectedIDs, err := s.chunkRepo.UpdateChunkFieldsByTagID(
ctx, tenantID, kb.ID, tag.ID,
update.IsEnabled, setFlags, clearFlags, newTagUUID, excludeUUIDs,
)
if err != nil {
return err
}
// Collect affected IDs for retriever sync
if len(affectedIDs) > 0 {
if update.IsEnabled != nil {
for _, id := range affectedIDs {
enabledUpdates[id] = *update.IsEnabled
}
}
if newTagUUID != nil {
for _, id := range affectedIDs {
tagUpdates[id] = *newTagUUID
}
}
}
}
}
// Handle ByID updates (by entry seq_id)
if len(req.ByID) > 0 {
entrySeqIDs := make([]int64, 0, len(req.ByID))
for entrySeqID := range req.ByID {
entrySeqIDs = append(entrySeqIDs, entrySeqID)
}
chunks, err := s.chunkRepo.ListChunksBySeqID(ctx, tenantID, entrySeqIDs)
if err != nil {
return err
}
// Build chunk seq_id to chunk map
chunkBySeqID := make(map[int64]*types.Chunk)
for _, chunk := range chunks {
chunkBySeqID[chunk.SeqID] = chunk
}
setFlags := make(map[string]types.ChunkFlags)
clearFlags := make(map[string]types.ChunkFlags)
chunksToUpdate := make([]*types.Chunk, 0)
for entrySeqID, update := range req.ByID {
chunk, exists := chunkBySeqID[entrySeqID]
if !exists {
continue
}
if chunk.KnowledgeBaseID != kb.ID || chunk.ChunkType != types.ChunkTypeFAQ {
continue
}
needUpdate := false
// Handle IsEnabled
if update.IsEnabled != nil && chunk.IsEnabled != *update.IsEnabled {
chunk.IsEnabled = *update.IsEnabled
enabledUpdates[chunk.ID] = *update.IsEnabled
needUpdate = true
}
// Handle IsRecommended (via Flags)
if update.IsRecommended != nil {
currentRecommended := chunk.Flags.HasFlag(types.ChunkFlagRecommended)
if currentRecommended != *update.IsRecommended {
if *update.IsRecommended {
setFlags[chunk.ID] = types.ChunkFlagRecommended
} else {
clearFlags[chunk.ID] = types.ChunkFlagRecommended
}
}
}
// Handle TagID (convert seq_id to UUID)
if update.TagID != nil {
var newTagID string
if *update.TagID > 0 {
newTag, err := s.tagRepo.GetBySeqID(ctx, tenantID, *update.TagID)
if err != nil {
return werrors.NewNotFoundError(fmt.Sprintf("标签 %d 不存在", *update.TagID))
}
newTagID = newTag.ID
}
if chunk.TagID != newTagID {
chunk.TagID = newTagID
tagUpdates[chunk.ID] = newTagID
needUpdate = true
}
}
if needUpdate {
chunk.UpdatedAt = time.Now()
chunksToUpdate = append(chunksToUpdate, chunk)
}
}
// Batch update chunks (for IsEnabled and TagID)
if len(chunksToUpdate) > 0 {
if err := s.chunkRepo.UpdateChunks(ctx, chunksToUpdate); err != nil {
return err
}
}
// Batch update flags (for IsRecommended)
if len(setFlags) > 0 || len(clearFlags) > 0 {
if err := s.chunkRepo.UpdateChunkFlagsBatch(ctx, tenantID, kb.ID, setFlags, clearFlags); err != nil {
return err
}
}
}
// Sync to retriever engines
if len(enabledUpdates) > 0 || len(tagUpdates) > 0 {
tenantInfo := ctx.Value(types.TenantInfoContextKey).(*types.Tenant)
retrieveEngine, err := retriever.NewCompositeRetrieveEngine(
s.retrieveEngine,
tenantInfo.GetEffectiveEngines(),
)
if err != nil {
return err
}
if len(enabledUpdates) > 0 {
if err := retrieveEngine.BatchUpdateChunkEnabledStatus(ctx, enabledUpdates); err != nil {
return err
}
}
if len(tagUpdates) > 0 {
if err := retrieveEngine.BatchUpdateChunkTagID(ctx, tagUpdates); err != nil {
return err
}
}
}
return nil
}
// UpdateKnowledgeTag updates the tag assigned to a knowledge document.
func (s *knowledgeService) UpdateKnowledgeTag(ctx context.Context, knowledgeID string, tagID *string) error {
tenantID := ctx.Value(types.TenantIDContextKey).(uint64)
knowledge, err := s.repo.GetKnowledgeByID(ctx, tenantID, knowledgeID)
if err != nil {
return err
}
var resolvedTagID string
if tagID != nil && *tagID != "" {
tag, err := s.tagRepo.GetByID(ctx, tenantID, *tagID)
if err != nil {
return err
}
if tag.KnowledgeBaseID != knowledge.KnowledgeBaseID {
return werrors.NewBadRequestError("标签不属于当前知识库")
}
resolvedTagID = tag.ID
}
knowledge.TagID = resolvedTagID
return s.repo.UpdateKnowledge(ctx, knowledge)
}
// UpdateKnowledgeTagBatch updates tags for document knowledge items in batch.
func (s *knowledgeService) UpdateKnowledgeTagBatch(ctx context.Context, updates map[string]*string) error {
if len(updates) == 0 {
return nil
}
tenantIDVal := ctx.Value(types.TenantIDContextKey)
if tenantIDVal == nil {
return werrors.NewUnauthorizedError("tenant ID not found in context")
}
tenantID, ok := tenantIDVal.(uint64)
if !ok {
return werrors.NewUnauthorizedError("invalid tenant ID in context")
}
// Get all knowledge items in batch
knowledgeIDs := make([]string, 0, len(updates))
for knowledgeID := range updates {
knowledgeIDs = append(knowledgeIDs, knowledgeID)
}
knowledgeList, err := s.repo.GetKnowledgeBatch(ctx, tenantID, knowledgeIDs)
if err != nil {
return err
}
// Build tag ID map for validation
tagIDSet := make(map[string]bool)
for _, tagID := range updates {
if tagID != nil && *tagID != "" {
tagIDSet[*tagID] = true
}
}
// Validate all tags in batch
tagMap := make(map[string]*types.KnowledgeTag)
if len(tagIDSet) > 0 {
tagIDs := make([]string, 0, len(tagIDSet))
for tagID := range tagIDSet {
tagIDs = append(tagIDs, tagID)
}
for _, tagID := range tagIDs {
tag, err := s.tagRepo.GetByID(ctx, tenantID, tagID)
if err != nil {
return err
}
tagMap[tagID] = tag
}
}
// Update knowledge items
knowledgeToUpdate := make([]*types.Knowledge, 0)
for _, knowledge := range knowledgeList {
tagID, exists := updates[knowledge.ID]
if !exists {
continue
}
var resolvedTagID string
if tagID != nil && *tagID != "" {
tag, ok := tagMap[*tagID]
if !ok {
return werrors.NewBadRequestError(fmt.Sprintf("标签 %s 不存在", *tagID))
}
if tag.KnowledgeBaseID != knowledge.KnowledgeBaseID {
return werrors.NewBadRequestError(fmt.Sprintf("标签 %s 不属于知识库 %s", *tagID, knowledge.KnowledgeBaseID))
}
resolvedTagID = tag.ID
}
knowledge.TagID = resolvedTagID
knowledgeToUpdate = append(knowledgeToUpdate, knowledge)
}
if len(knowledgeToUpdate) > 0 {
return s.repo.UpdateKnowledgeBatch(ctx, knowledgeToUpdate)
}
return nil
}
// UpdateFAQEntryTag updates the tag assigned to an FAQ entry.
func (s *knowledgeService) UpdateFAQEntryTag(ctx context.Context, kbID string, entryID string, tagID *string) error {
kb, err := s.validateFAQKnowledgeBase(ctx, kbID)
if err != nil {
return err
}
tenantID := ctx.Value(types.TenantIDContextKey).(uint64)
chunk, err := s.chunkRepo.GetChunkByID(ctx, tenantID, entryID)
if err != nil {
return err
}
if chunk.KnowledgeBaseID != kb.ID || chunk.ChunkType != types.ChunkTypeFAQ {
return werrors.NewBadRequestError("仅支持更新 FAQ 条目标签")
}
var resolvedTagID string
if tagID != nil && *tagID != "" {
tag, err := s.tagRepo.GetByID(ctx, tenantID, *tagID)
if err != nil {
return err
}
if tag.KnowledgeBaseID != kb.ID {
return werrors.NewBadRequestError("标签不属于当前知识库")
}
resolvedTagID = tag.ID
}
// Check if tag actually changed
if chunk.TagID == resolvedTagID {
return nil
}
chunk.TagID = resolvedTagID
chunk.UpdatedAt = time.Now()
if err := s.chunkRepo.UpdateChunk(ctx, chunk); err != nil {
return err
}
// Sync tag update to retriever engines
tenantInfo := ctx.Value(types.TenantInfoContextKey).(*types.Tenant)
retrieveEngine, err := retriever.NewCompositeRetrieveEngine(
s.retrieveEngine,
tenantInfo.GetEffectiveEngines(),
)
if err != nil {
return err
}
return retrieveEngine.BatchUpdateChunkTagID(ctx, map[string]string{chunk.ID: resolvedTagID})
}
// UpdateFAQEntryTagBatch updates tags for FAQ entries in batch.
// Key: entry seq_id, Value: tag seq_id (nil to remove tag)
func (s *knowledgeService) UpdateFAQEntryTagBatch(ctx context.Context, kbID string, updates map[int64]*int64) error {
if len(updates) == 0 {
return nil
}
kb, err := s.validateFAQKnowledgeBase(ctx, kbID)
if err != nil {
return err
}
tenantID := ctx.Value(types.TenantIDContextKey).(uint64)
// Get all chunks in batch by seq_id
entrySeqIDs := make([]int64, 0, len(updates))
for entrySeqID := range updates {
entrySeqIDs = append(entrySeqIDs, entrySeqID)
}
chunks, err := s.chunkRepo.ListChunksBySeqID(ctx, tenantID, entrySeqIDs)
if err != nil {
return err
}
// Build chunk seq_id to chunk map
chunkBySeqID := make(map[int64]*types.Chunk)
for _, chunk := range chunks {
chunkBySeqID[chunk.SeqID] = chunk
}
// Build tag seq_id set for validation
tagSeqIDSet := make(map[int64]bool)
for _, tagSeqID := range updates {
if tagSeqID != nil && *tagSeqID > 0 {
tagSeqIDSet[*tagSeqID] = true
}
}
// Validate all tags in batch by seq_id
tagMap := make(map[int64]*types.KnowledgeTag)
if len(tagSeqIDSet) > 0 {
tagSeqIDs := make([]int64, 0, len(tagSeqIDSet))
for tagSeqID := range tagSeqIDSet {
tagSeqIDs = append(tagSeqIDs, tagSeqID)
}
tags, err := s.tagRepo.GetBySeqIDs(ctx, tenantID, tagSeqIDs)
if err != nil {
return err
}
for _, tag := range tags {
if tag.KnowledgeBaseID != kb.ID {
return werrors.NewBadRequestError(fmt.Sprintf("标签 %d 不属于当前知识库", tag.SeqID))
}
tagMap[tag.SeqID] = tag
}
}
// Update chunks
chunksToUpdate := make([]*types.Chunk, 0)
for entrySeqID, tagSeqID := range updates {
chunk, exists := chunkBySeqID[entrySeqID]
if !exists {
continue
}
if chunk.KnowledgeBaseID != kb.ID || chunk.ChunkType != types.ChunkTypeFAQ {
continue
}
var resolvedTagID string
if tagSeqID != nil && *tagSeqID > 0 {
tag, ok := tagMap[*tagSeqID]
if !ok {
return werrors.NewBadRequestError(fmt.Sprintf("标签 %d 不存在", *tagSeqID))
}
resolvedTagID = tag.ID
}
chunk.TagID = resolvedTagID
chunk.UpdatedAt = time.Now()
chunksToUpdate = append(chunksToUpdate, chunk)
}
if len(chunksToUpdate) > 0 {
if err := s.chunkRepo.UpdateChunks(ctx, chunksToUpdate); err != nil {
return err
}
// Sync tag updates to retriever engines
tagUpdates := make(map[string]string)
for _, chunk := range chunksToUpdate {
tagUpdates[chunk.ID] = chunk.TagID
}
tenantInfo := ctx.Value(types.TenantInfoContextKey).(*types.Tenant)
retrieveEngine, err := retriever.NewCompositeRetrieveEngine(
s.retrieveEngine,
tenantInfo.GetEffectiveEngines(),
)
if err != nil {
return err
}
if err := retrieveEngine.BatchUpdateChunkTagID(ctx, tagUpdates); err != nil {
return err
}
}
return nil
}
// SearchFAQEntries searches FAQ entries using hybrid search.
func (s *knowledgeService) SearchFAQEntries(ctx context.Context,
kbID string, req *types.FAQSearchRequest,
) ([]*types.FAQEntry, error) {
// Validate FAQ knowledge base
kb, err := s.validateFAQKnowledgeBase(ctx, kbID)
if err != nil {
return nil, err
}
// Set default values
if req.VectorThreshold <= 0 {
req.VectorThreshold = 0.7
}
if req.MatchCount <= 0 {
req.MatchCount = 10
}
if req.MatchCount > 50 {
req.MatchCount = 50
}
tenantID := ctx.Value(types.TenantIDContextKey).(uint64)
// Convert tag seq_ids to UUIDs
var firstPriorityTagUUIDs, secondPriorityTagUUIDs []string
firstPrioritySeqIDSet := make(map[int64]struct{})
secondPrioritySeqIDSet := make(map[int64]struct{})
if len(req.FirstPriorityTagIDs) > 0 {
tags, err := s.tagRepo.GetBySeqIDs(ctx, tenantID, req.FirstPriorityTagIDs)
if err == nil {
firstPriorityTagUUIDs = make([]string, 0, len(tags))
for _, tag := range tags {
firstPriorityTagUUIDs = append(firstPriorityTagUUIDs, tag.ID)
firstPrioritySeqIDSet[tag.SeqID] = struct{}{}
}
}
}
if len(req.SecondPriorityTagIDs) > 0 {
tags, err := s.tagRepo.GetBySeqIDs(ctx, tenantID, req.SecondPriorityTagIDs)
if err == nil {
secondPriorityTagUUIDs = make([]string, 0, len(tags))
for _, tag := range tags {
secondPriorityTagUUIDs = append(secondPriorityTagUUIDs, tag.ID)
secondPrioritySeqIDSet[tag.SeqID] = struct{}{}
}
}
}
// Build priority tag sets for sorting (using UUID)
hasFirstPriority := len(firstPriorityTagUUIDs) > 0
hasSecondPriority := len(secondPriorityTagUUIDs) > 0
hasPriorityFilter := hasFirstPriority || hasSecondPriority
firstPrioritySet := make(map[string]struct{}, len(firstPriorityTagUUIDs))
for _, tagID := range firstPriorityTagUUIDs {
firstPrioritySet[tagID] = struct{}{}
}
secondPrioritySet := make(map[string]struct{}, len(secondPriorityTagUUIDs))
for _, tagID := range secondPriorityTagUUIDs {
secondPrioritySet[tagID] = struct{}{}
}
// Perform separate searches for each priority level to ensure FirstPriority results
// are not crowded out by higher-scoring SecondPriority results in TopK truncation
var searchResults []*types.SearchResult
if hasPriorityFilter {
// Use goroutines to search both priority levels concurrently
var (
firstResults []*types.SearchResult
secondResults []*types.SearchResult
firstErr error
secondErr error
wg sync.WaitGroup
)
if hasFirstPriority {
wg.Add(1)
go func() {
defer wg.Done()
firstParams := types.SearchParams{
QueryText: secutils.SanitizeForLog(req.QueryText),
VectorThreshold: req.VectorThreshold,
MatchCount: req.MatchCount,
DisableKeywordsMatch: true,
TagIDs: firstPriorityTagUUIDs,
OnlyRecommended: req.OnlyRecommended,
}
firstResults, firstErr = s.kbService.HybridSearch(ctx, kbID, firstParams)
}()
}
if hasSecondPriority {
wg.Add(1)
go func() {
defer wg.Done()
secondParams := types.SearchParams{
QueryText: secutils.SanitizeForLog(req.QueryText),
VectorThreshold: req.VectorThreshold,
MatchCount: req.MatchCount,
DisableKeywordsMatch: true,
TagIDs: secondPriorityTagUUIDs,
OnlyRecommended: req.OnlyRecommended,
}
secondResults, secondErr = s.kbService.HybridSearch(ctx, kbID, secondParams)
}()
}
wg.Wait()
// Check errors
if firstErr != nil {
return nil, firstErr
}
if secondErr != nil {
return nil, secondErr
}
// Merge results: FirstPriority first, then SecondPriority (deduplicated)
seenChunkIDs := make(map[string]struct{})
for _, result := range firstResults {
if _, exists := seenChunkIDs[result.ID]; !exists {
seenChunkIDs[result.ID] = struct{}{}
searchResults = append(searchResults, result)
}
}
for _, result := range secondResults {
if _, exists := seenChunkIDs[result.ID]; !exists {
seenChunkIDs[result.ID] = struct{}{}
searchResults = append(searchResults, result)
}
}
} else {
// No priority filter, search all
searchParams := types.SearchParams{
QueryText: secutils.SanitizeForLog(req.QueryText),
VectorThreshold: req.VectorThreshold,
MatchCount: req.MatchCount,
DisableKeywordsMatch: true,
}
var err error
searchResults, err = s.kbService.HybridSearch(ctx, kbID, searchParams)
if err != nil {
return nil, err
}
}
if len(searchResults) == 0 {
return []*types.FAQEntry{}, nil
}
// Extract chunk IDs and build score/match type/matched content maps
chunkIDs := make([]string, 0, len(searchResults))
chunkScores := make(map[string]float64)
chunkMatchTypes := make(map[string]types.MatchType)
chunkMatchedContents := make(map[string]string)
for _, result := range searchResults {
// SearchResult.ID is the chunk ID
chunkID := result.ID
chunkIDs = append(chunkIDs, chunkID)
chunkScores[chunkID] = result.Score
chunkMatchTypes[chunkID] = result.MatchType
chunkMatchedContents[chunkID] = result.MatchedContent
}
// Batch fetch chunks
chunks, err := s.chunkRepo.ListChunksByID(ctx, tenantID, chunkIDs)
if err != nil {
return nil, err
}
// Build tag UUID to seq_id map for conversion
tagSeqIDMap := make(map[string]int64)
tagIDs := make([]string, 0)
tagIDSet := make(map[string]struct{})
for _, chunk := range chunks {
if chunk.TagID != "" {
if _, exists := tagIDSet[chunk.TagID]; !exists {
tagIDSet[chunk.TagID] = struct{}{}
tagIDs = append(tagIDs, chunk.TagID)
}
}
}
if len(tagIDs) > 0 {
tags, err := s.tagRepo.GetByIDs(ctx, tenantID, tagIDs)
if err == nil {
for _, tag := range tags {
tagSeqIDMap[tag.ID] = tag.SeqID
}
}
}
// Filter FAQ chunks and convert to FAQEntry
kb.EnsureDefaults()
entries := make([]*types.FAQEntry, 0, len(chunks))
for _, chunk := range chunks {
// Only process FAQ type chunks
if chunk.ChunkType != types.ChunkTypeFAQ {
continue
}
if !chunk.IsEnabled {
continue
}
entry, err := s.chunkToFAQEntry(chunk, kb, tagSeqIDMap)
if err != nil {
logger.Warnf(ctx, "Failed to convert chunk to FAQ entry: %v", err)
continue
}
// Preserve score and match type from search results
// Note: Negative question filtering is now handled in HybridSearch
if score, ok := chunkScores[chunk.ID]; ok {
entry.Score = score
}
if matchType, ok := chunkMatchTypes[chunk.ID]; ok {
entry.MatchType = matchType
}
// Set MatchedQuestion from search result's matched content
if matchedContent, ok := chunkMatchedContents[chunk.ID]; ok && matchedContent != "" {
entry.MatchedQuestion = matchedContent
}
entries = append(entries, entry)
}
// Sort entries with two-level priority tag support
if hasPriorityFilter {
// getPriorityLevel returns: 0 = first priority, 1 = second priority, 2 = no priority
// Use chunk.TagID (UUID) for comparison
getPriorityLevel := func(chunk *types.Chunk) int {
if _, ok := firstPrioritySet[chunk.TagID]; ok {
return 0
}
if _, ok := secondPrioritySet[chunk.TagID]; ok {
return 1
}
return 2
}
// Build chunk map for priority lookup
chunkMap := make(map[int64]*types.Chunk)
for _, chunk := range chunks {
chunkMap[chunk.SeqID] = chunk
}
slices.SortFunc(entries, func(a, b *types.FAQEntry) int {
aChunk := chunkMap[a.ID]
bChunk := chunkMap[b.ID]
var aPriority, bPriority int
if aChunk != nil {
aPriority = getPriorityLevel(aChunk)
} else {
aPriority = 2
}
if bChunk != nil {
bPriority = getPriorityLevel(bChunk)
} else {
bPriority = 2
}
// Compare by priority level first
if aPriority != bPriority {
return aPriority - bPriority // Lower level = higher priority
}
// Same priority level, sort by score descending
if b.Score > a.Score {
return 1
} else if b.Score < a.Score {
return -1
}
return 0
})
} else {
// No priority tags, sort by score only
slices.SortFunc(entries, func(a, b *types.FAQEntry) int {
if b.Score > a.Score {
return 1
} else if b.Score < a.Score {
return -1
}
return 0
})
}
// Limit results to requested match count
if len(entries) > req.MatchCount {
entries = entries[:req.MatchCount]
}
// 批量查询TagName并补充到结果中
if len(entries) > 0 {
// 收集所有需要查询的TagID (seq_id)
tagSeqIDs := make([]int64, 0)
tagSeqIDSet := make(map[int64]struct{})
for _, entry := range entries {
if entry.TagID != 0 {
if _, exists := tagSeqIDSet[entry.TagID]; !exists {
tagSeqIDs = append(tagSeqIDs, entry.TagID)
tagSeqIDSet[entry.TagID] = struct{}{}
}
}
}
// 批量查询标签
if len(tagSeqIDs) > 0 {
tags, err := s.tagRepo.GetBySeqIDs(ctx, tenantID, tagSeqIDs)
if err != nil {
logger.Warnf(ctx, "Failed to batch query tags: %v", err)
} else {
// 构建TagSeqID到TagName的映射
tagNameMap := make(map[int64]string)
for _, tag := range tags {
tagNameMap[tag.SeqID] = tag.Name
}
// 补充TagName
for _, entry := range entries {
if entry.TagID != 0 {
if tagName, exists := tagNameMap[entry.TagID]; exists {
entry.TagName = tagName
}
}
}
}
}
}
return entries, nil
}
// DeleteFAQEntries deletes FAQ entries in batch by seq_id.
func (s *knowledgeService) DeleteFAQEntries(ctx context.Context,
kbID string, entrySeqIDs []int64,
) error {
if len(entrySeqIDs) == 0 {
return werrors.NewBadRequestError("请选择需要删除的 FAQ 条目")
}
kb, err := s.validateFAQKnowledgeBase(ctx, kbID)
if err != nil {
return err
}
tenantID := ctx.Value(types.TenantIDContextKey).(uint64)
var faqKnowledge *types.Knowledge
chunksToRemove := make([]*types.Chunk, 0, len(entrySeqIDs))
for _, seqID := range entrySeqIDs {
if seqID <= 0 {
continue
}
chunk, err := s.chunkRepo.GetChunkBySeqID(ctx, tenantID, seqID)
if err != nil {
return werrors.NewNotFoundError("FAQ条目不存在")
}
if chunk.KnowledgeBaseID != kb.ID || chunk.ChunkType != types.ChunkTypeFAQ {
return werrors.NewBadRequestError("包含无效的 FAQ 条目")
}
if err := s.chunkService.DeleteChunk(ctx, chunk.ID); err != nil {
return err
}
if faqKnowledge == nil {
faqKnowledge, err = s.repo.GetKnowledgeByID(ctx, tenantID, chunk.KnowledgeID)
if err != nil {
return err
}
}
chunksToRemove = append(chunksToRemove, chunk)
}
if len(chunksToRemove) > 0 && faqKnowledge != nil {
if err := s.deleteFAQChunkVectors(ctx, kb, faqKnowledge, chunksToRemove); err != nil {
return err
}
}
return nil
}
// ExportFAQEntries exports all FAQ entries for a knowledge base as CSV data.
// The CSV format matches the import example format with 8 columns:
// 分类(必填), 问题(必填), 相似问题(选填-多个用##分隔), 反例问题(选填-多个用##分隔),
// 机器人回答(必填-多个用##分隔), 是否全部回复(选填-默认FALSE), 是否停用(选填-默认FALSE),
// 是否禁止被推荐(选填-默认False 可被推荐)
func (s *knowledgeService) ExportFAQEntries(ctx context.Context, kbID string) ([]byte, error) {
kb, err := s.validateFAQKnowledgeBase(ctx, kbID)
if err != nil {
return nil, err
}
tenantID := ctx.Value(types.TenantIDContextKey).(uint64)
faqKnowledge, err := s.findFAQKnowledge(ctx, tenantID, kb.ID)
if err != nil {
return nil, err
}
if faqKnowledge == nil {
// Return empty CSV with headers only
return s.buildFAQCSV(nil, nil), nil
}
// Get all FAQ chunks
chunks, err := s.chunkRepo.ListAllFAQChunksForExport(ctx, tenantID, faqKnowledge.ID)
if err != nil {
return nil, fmt.Errorf("failed to list FAQ chunks: %w", err)
}
// Build tag map for tag_id -> tag_name conversion
tagMap, err := s.buildTagMap(ctx, tenantID, kbID)
if err != nil {
return nil, fmt.Errorf("failed to build tag map: %w", err)
}
return s.buildFAQCSV(chunks, tagMap), nil
}
// buildTagMap builds a map from tag_id to tag_name for the given knowledge base.
func (s *knowledgeService) buildTagMap(ctx context.Context, tenantID uint64, kbID string) (map[string]string, error) {
// Get all tags for this knowledge base (no pagination limit)
page := &types.Pagination{Page: 1, PageSize: 10000}
tags, _, err := s.tagRepo.ListByKB(ctx, tenantID, kbID, page, "")
if err != nil {
return nil, err
}
tagMap := make(map[string]string, len(tags))
for _, tag := range tags {
if tag != nil {
tagMap[tag.ID] = tag.Name
}
}
return tagMap, nil
}
// buildFAQCSV builds CSV content from FAQ chunks.
func (s *knowledgeService) buildFAQCSV(chunks []*types.Chunk, tagMap map[string]string) []byte {
var buf strings.Builder
// Write CSV header (matching import example format)
headers := []string{
"分类(必填)",
"问题(必填)",
"相似问题(选填-多个用##分隔)",
"反例问题(选填-多个用##分隔)",
"机器人回答(必填-多个用##分隔)",
"是否全部回复(选填-默认FALSE)",
"是否停用(选填-默认FALSE)",
"是否禁止被推荐(选填-默认False 可被推荐)",
}
buf.WriteString(strings.Join(headers, ","))
buf.WriteString("\n")
// Write data rows
for _, chunk := range chunks {
meta, err := chunk.FAQMetadata()
if err != nil || meta == nil {
continue
}
// Get tag name
tagName := ""
if chunk.TagID != "" && tagMap != nil {
if name, ok := tagMap[chunk.TagID]; ok {
tagName = name
}
}
// Build row
row := []string{
escapeCSVField(tagName),
escapeCSVField(meta.StandardQuestion),
escapeCSVField(strings.Join(meta.SimilarQuestions, "##")),
escapeCSVField(strings.Join(meta.NegativeQuestions, "##")),
escapeCSVField(strings.Join(meta.Answers, "##")),
boolToCSV(meta.AnswerStrategy == types.AnswerStrategyAll),
boolToCSV(!chunk.IsEnabled), // 是否停用:取反
boolToCSV(!chunk.Flags.HasFlag(types.ChunkFlagRecommended)), // 是否禁止被推荐:取反
}
buf.WriteString(strings.Join(row, ","))
buf.WriteString("\n")
}
return []byte(buf.String())
}
// escapeCSVField escapes a field for CSV format.
func escapeCSVField(field string) string {
// If field contains comma, newline, or quote, wrap in quotes and escape internal quotes
if strings.ContainsAny(field, ",\"\n\r") {
return "\"" + strings.ReplaceAll(field, "\"", "\"\"") + "\""
}
return field
}
// boolToCSV converts a boolean to CSV TRUE/FALSE string.
func boolToCSV(b bool) string {
if b {
return "TRUE"
}
return "FALSE"
}
func (s *knowledgeService) validateFAQKnowledgeBase(ctx context.Context, kbID string) (*types.KnowledgeBase, error) {
if kbID == "" {
return nil, werrors.NewBadRequestError("知识库 ID 不能为空")
}
kb, err := s.kbService.GetKnowledgeBaseByID(ctx, kbID)
if err != nil {
return nil, err
}
kb.EnsureDefaults()
if kb.Type != types.KnowledgeBaseTypeFAQ {
return nil, werrors.NewBadRequestError("仅 FAQ 知识库支持该操作")
}
return kb, nil
}
func (s *knowledgeService) findFAQKnowledge(
ctx context.Context,
tenantID uint64,
kbID string,
) (*types.Knowledge, error) {
knowledges, err := s.repo.ListKnowledgeByKnowledgeBaseID(ctx, tenantID, kbID)
if err != nil {
return nil, err
}
for _, knowledge := range knowledges {
if knowledge.Type == types.KnowledgeTypeFAQ {
return knowledge, nil
}
}
return nil, nil
}
func (s *knowledgeService) ensureFAQKnowledge(
ctx context.Context,
tenantID uint64,
kb *types.KnowledgeBase,
) (*types.Knowledge, error) {
existing, err := s.findFAQKnowledge(ctx, tenantID, kb.ID)
if err != nil {
return nil, err
}
if existing != nil {
return existing, nil
}
knowledge := &types.Knowledge{
TenantID: tenantID,
KnowledgeBaseID: kb.ID,
Type: types.KnowledgeTypeFAQ,
Title: fmt.Sprintf("%s - FAQ", kb.Name),
Description: "FAQ 条目容器",
Source: types.KnowledgeTypeFAQ,
ParseStatus: "completed",
EnableStatus: "enabled",
EmbeddingModelID: kb.EmbeddingModelID,
CreatedAt: time.Now(),
UpdatedAt: time.Now(),
}
if err := s.repo.CreateKnowledge(ctx, knowledge); err != nil {
return nil, err
}
return knowledge, nil
}
// updateFAQImportProgressStatus updates the FAQ import progress in Redis
func (s *knowledgeService) updateFAQImportProgressStatus(
ctx context.Context,
taskID string,
status types.FAQImportTaskStatus,
progress, total, processed int,
message, errorMsg string,
) error {
// Get existing progress from Redis
existingProgress, err := s.GetFAQImportProgress(ctx, taskID)
if err != nil {
// If not found, create a new progress entry
existingProgress = &types.FAQImportProgress{
TaskID: taskID,
CreatedAt: time.Now().Unix(),
}
}
// Update progress fields
existingProgress.Status = status
existingProgress.Progress = progress
existingProgress.Total = total
existingProgress.Processed = processed
if message != "" {
existingProgress.Message = message
}
existingProgress.Error = errorMsg
if status == types.FAQImportStatusCompleted {
existingProgress.Error = ""
}
// 任务完成或失败时,清除 running key
if status == types.FAQImportStatusCompleted || status == types.FAQImportStatusFailed {
if existingProgress.KBID != "" {
if clearErr := s.clearRunningFAQImportTaskID(ctx, existingProgress.KBID); clearErr != nil {
logger.Errorf(ctx, "Failed to clear running FAQ import task ID: %v", clearErr)
}
}
}
return s.saveFAQImportProgress(ctx, existingProgress)
}
// cleanupFAQEntriesFileOnFinalFailure 在任务最终失败时清理对象存储中的 entries 文件
// 只有当 retryCount >= maxRetry 时才执行清理,否则重试时还需要使用这个文件
func (s *knowledgeService) cleanupFAQEntriesFileOnFinalFailure(ctx context.Context, entriesURL string, retryCount, maxRetry int) {
if entriesURL == "" || retryCount < maxRetry {
return
}
if err := s.fileSvc.DeleteFile(ctx, entriesURL); err != nil {
logger.Warnf(ctx, "Failed to delete FAQ entries file from object storage on final failure: %v", err)
} else {
logger.Infof(ctx, "Deleted FAQ entries file from object storage on final failure: %s", entriesURL)
}
}
// runningFAQImportInfo stores the task ID and enqueued timestamp for uniquely identifying a task instance
type runningFAQImportInfo struct {
TaskID string `json:"task_id"`
EnqueuedAt int64 `json:"enqueued_at"`
}
// getRunningFAQImportInfo checks if there's a running FAQ import task for the given KB
// Returns the task info if found, nil otherwise
func (s *knowledgeService) getRunningFAQImportInfo(ctx context.Context, kbID string) (*runningFAQImportInfo, error) {
key := getFAQImportRunningKey(kbID)
data, err := s.redisClient.Get(ctx, key).Result()
if err != nil {
if errors.Is(err, redis.Nil) {
return nil, nil
}
return nil, fmt.Errorf("failed to get running FAQ import task: %w", err)
}
// Try to parse as JSON first (new format)
var info runningFAQImportInfo
if err := json.Unmarshal([]byte(data), &info); err != nil {
// Fallback: old format was just taskID string
return &runningFAQImportInfo{TaskID: data, EnqueuedAt: 0}, nil
}
return &info, nil
}
// getRunningFAQImportTaskID checks if there's a running FAQ import task for the given KB
// Returns the task ID if found, empty string otherwise (for backward compatibility)
func (s *knowledgeService) getRunningFAQImportTaskID(ctx context.Context, kbID string) (string, error) {
info, err := s.getRunningFAQImportInfo(ctx, kbID)
if err != nil {
return "", err
}
if info == nil {
return "", nil
}
return info.TaskID, nil
}
// setRunningFAQImportInfo sets the running task info for a KB
func (s *knowledgeService) setRunningFAQImportInfo(ctx context.Context, kbID string, info *runningFAQImportInfo) error {
key := getFAQImportRunningKey(kbID)
data, err := json.Marshal(info)
if err != nil {
return fmt.Errorf("failed to marshal running info: %w", err)
}
return s.redisClient.Set(ctx, key, data, faqImportProgressTTL).Err()
}
// clearRunningFAQImportTaskID clears the running task ID for a KB
func (s *knowledgeService) clearRunningFAQImportTaskID(ctx context.Context, kbID string) error {
key := getFAQImportRunningKey(kbID)
return s.redisClient.Del(ctx, key).Err()
}
func (s *knowledgeService) chunkToFAQEntry(chunk *types.Chunk, kb *types.KnowledgeBase, tagSeqIDMap map[string]int64) (*types.FAQEntry, error) {
meta, err := chunk.FAQMetadata()
if err != nil {
return nil, err
}
if meta == nil {
meta = &types.FAQChunkMetadata{StandardQuestion: chunk.Content}
}
// 默认使用 all 策略
answerStrategy := meta.AnswerStrategy
if answerStrategy == "" {
answerStrategy = types.AnswerStrategyAll
}
// Get tag seq_id from map
var tagSeqID int64
if chunk.TagID != "" && tagSeqIDMap != nil {
tagSeqID = tagSeqIDMap[chunk.TagID]
}
entry := &types.FAQEntry{
ID: chunk.SeqID,
ChunkID: chunk.ID,
KnowledgeID: chunk.KnowledgeID,
KnowledgeBaseID: chunk.KnowledgeBaseID,
TagID: tagSeqID,
IsEnabled: chunk.IsEnabled,
IsRecommended: chunk.Flags.HasFlag(types.ChunkFlagRecommended),
StandardQuestion: meta.StandardQuestion,
SimilarQuestions: meta.SimilarQuestions,
NegativeQuestions: meta.NegativeQuestions,
Answers: meta.Answers,
AnswerStrategy: answerStrategy,
IndexMode: kb.FAQConfig.IndexMode,
UpdatedAt: chunk.UpdatedAt,
CreatedAt: chunk.CreatedAt,
ChunkType: chunk.ChunkType,
}
return entry, nil
}
func buildFAQChunkContent(meta *types.FAQChunkMetadata, mode types.FAQIndexMode) string {
var builder strings.Builder
builder.WriteString(fmt.Sprintf("Q: %s\n", meta.StandardQuestion))
if len(meta.SimilarQuestions) > 0 {
builder.WriteString("Similar Questions:\n")
for _, q := range meta.SimilarQuestions {
builder.WriteString(fmt.Sprintf("- %s\n", q))
}
}
// 负例不应该包含在 Content 中,因为它们不应该被索引
// 答案根据索引模式决定是否包含
if mode == types.FAQIndexModeQuestionAnswer && len(meta.Answers) > 0 {
builder.WriteString("Answers:\n")
for _, ans := range meta.Answers {
builder.WriteString(fmt.Sprintf("- %s\n", ans))
}
}
return builder.String()
}
// checkFAQQuestionDuplicate 检查标准问和相似问是否与知识库中其他条目重复
// excludeChunkID 用于排除当前正在编辑的条目(更新时使用)
func (s *knowledgeService) checkFAQQuestionDuplicate(
ctx context.Context,
tenantID uint64,
kbID string,
excludeChunkID string,
meta *types.FAQChunkMetadata,
) error {
// 首先检查当前条目自身的相似问是否与标准问重复
for _, q := range meta.SimilarQuestions {
if q == meta.StandardQuestion {
return werrors.NewBadRequestError(fmt.Sprintf("相似问「%s」不能与标准问相同", q))
}
}
// 检查当前条目自身的相似问之间是否有重复
seen := make(map[string]struct{})
for _, q := range meta.SimilarQuestions {
if _, exists := seen[q]; exists {
return werrors.NewBadRequestError(fmt.Sprintf("相似问「%s」重复", q))
}
seen[q] = struct{}{}
}
// 查询知识库中已有的所有FAQ chunks的metadata
existingChunks, err := s.chunkRepo.ListAllFAQChunksWithMetadataByKnowledgeBaseID(ctx, tenantID, kbID)
if err != nil {
return fmt.Errorf("failed to list existing FAQ chunks: %w", err)
}
// 构建已存在的标准问和相似问集合
for _, chunk := range existingChunks {
// 排除当前正在编辑的条目
if chunk.ID == excludeChunkID {
continue
}
existingMeta, err := chunk.FAQMetadata()
if err != nil || existingMeta == nil {
continue
}
// 检查标准问是否重复
if existingMeta.StandardQuestion == meta.StandardQuestion {
return werrors.NewBadRequestError(fmt.Sprintf("标准问「%s」已存在", meta.StandardQuestion))
}
// 检查当前标准问是否与已有相似问重复
for _, q := range existingMeta.SimilarQuestions {
if q == meta.StandardQuestion {
return werrors.NewBadRequestError(fmt.Sprintf("标准问「%s」与已有相似问重复", meta.StandardQuestion))
}
}
// 检查当前相似问是否与已有标准问重复
for _, q := range meta.SimilarQuestions {
if q == existingMeta.StandardQuestion {
return werrors.NewBadRequestError(fmt.Sprintf("相似问「%s」与已有标准问重复", q))
}
}
// 检查当前相似问是否与已有相似问重复
for _, q := range meta.SimilarQuestions {
for _, existingQ := range existingMeta.SimilarQuestions {
if q == existingQ {
return werrors.NewBadRequestError(fmt.Sprintf("相似问「%s」已存在", q))
}
}
}
}
return nil
}
// resolveTagID resolves tag ID (UUID) from payload, prioritizing tag_id (seq_id) over tag_name
// If no tag is specified, creates or finds the "未分类" tag
// Returns the internal UUID of the tag
func (s *knowledgeService) resolveTagID(ctx context.Context, kbID string, payload *types.FAQEntryPayload) (string, error) {
tenantID := ctx.Value(types.TenantIDContextKey).(uint64)
// 如果提供了 tag_id (seq_id),优先使用 tag_id
if payload.TagID != 0 {
tag, err := s.tagRepo.GetBySeqID(ctx, tenantID, payload.TagID)
if err != nil {
return "", fmt.Errorf("failed to find tag by seq_id %d: %w", payload.TagID, err)
}
return tag.ID, nil
}
// 如果提供了 tag_name查找或创建标签
if payload.TagName != "" {
tag, err := s.tagService.FindOrCreateTagByName(ctx, kbID, payload.TagName)
if err != nil {
return "", fmt.Errorf("failed to resolve tag by name '%s': %w", payload.TagName, err)
}
return tag.ID, nil
}
// 都没有提供,使用"未分类"标签
tag, err := s.tagService.FindOrCreateTagByName(ctx, kbID, types.UntaggedTagName)
if err != nil {
return "", fmt.Errorf("failed to get or create default untagged tag: %w", err)
}
return tag.ID, nil
}
func sanitizeFAQEntryPayload(payload *types.FAQEntryPayload) (*types.FAQChunkMetadata, error) {
// 处理 AnswerStrategy默认为 all
answerStrategy := types.AnswerStrategyAll
if payload.AnswerStrategy != nil && *payload.AnswerStrategy != "" {
switch *payload.AnswerStrategy {
case types.AnswerStrategyAll, types.AnswerStrategyRandom:
answerStrategy = *payload.AnswerStrategy
default:
return nil, werrors.NewBadRequestError("answer_strategy 必须是 'all' 或 'random'")
}
}
meta := &types.FAQChunkMetadata{
StandardQuestion: strings.TrimSpace(payload.StandardQuestion),
SimilarQuestions: payload.SimilarQuestions,
NegativeQuestions: payload.NegativeQuestions,
Answers: payload.Answers,
AnswerStrategy: answerStrategy,
Version: 1,
Source: "faq",
}
meta.Normalize()
if meta.StandardQuestion == "" {
return nil, werrors.NewBadRequestError("标准问不能为空")
}
if len(meta.Answers) == 0 {
return nil, werrors.NewBadRequestError("至少提供一个答案")
}
return meta, nil
}
func buildFAQIndexContent(meta *types.FAQChunkMetadata, mode types.FAQIndexMode) string {
var builder strings.Builder
builder.WriteString(meta.StandardQuestion)
for _, q := range meta.SimilarQuestions {
builder.WriteString("\n")
builder.WriteString(q)
}
if mode == types.FAQIndexModeQuestionAnswer {
for _, ans := range meta.Answers {
builder.WriteString("\n")
builder.WriteString(ans)
}
}
return builder.String()
}
// buildFAQIndexInfoList 构建FAQ索引信息列表支持分别索引模式
func (s *knowledgeService) buildFAQIndexInfoList(
ctx context.Context,
kb *types.KnowledgeBase,
chunk *types.Chunk,
) ([]*types.IndexInfo, error) {
indexMode := types.FAQIndexModeQuestionAnswer
questionIndexMode := types.FAQQuestionIndexModeCombined
if kb.FAQConfig != nil {
if kb.FAQConfig.IndexMode != "" {
indexMode = kb.FAQConfig.IndexMode
}
if kb.FAQConfig.QuestionIndexMode != "" {
questionIndexMode = kb.FAQConfig.QuestionIndexMode
}
}
meta, err := chunk.FAQMetadata()
if err != nil {
return nil, err
}
if meta == nil {
meta = &types.FAQChunkMetadata{StandardQuestion: chunk.Content}
}
// 如果是一起索引模式,使用原有逻辑
if questionIndexMode == types.FAQQuestionIndexModeCombined {
content := buildFAQIndexContent(meta, indexMode)
return []*types.IndexInfo{
{
Content: content,
SourceID: chunk.ID,
SourceType: types.ChunkSourceType,
ChunkID: chunk.ID,
KnowledgeID: chunk.KnowledgeID,
KnowledgeBaseID: chunk.KnowledgeBaseID,
KnowledgeType: types.KnowledgeTypeFAQ,
TagID: chunk.TagID,
IsEnabled: chunk.IsEnabled,
IsRecommended: chunk.Flags.HasFlag(types.ChunkFlagRecommended),
},
}, nil
}
// 分别索引模式:为每个问题创建独立的索引项
indexInfoList := make([]*types.IndexInfo, 0)
// 标准问索引项
standardContent := meta.StandardQuestion
if indexMode == types.FAQIndexModeQuestionAnswer && len(meta.Answers) > 0 {
var builder strings.Builder
builder.WriteString(meta.StandardQuestion)
for _, ans := range meta.Answers {
builder.WriteString("\n")
builder.WriteString(ans)
}
standardContent = builder.String()
}
indexInfoList = append(indexInfoList, &types.IndexInfo{
Content: standardContent,
SourceID: chunk.ID,
SourceType: types.ChunkSourceType,
ChunkID: chunk.ID,
KnowledgeID: chunk.KnowledgeID,
KnowledgeBaseID: chunk.KnowledgeBaseID,
KnowledgeType: types.KnowledgeTypeFAQ,
TagID: chunk.TagID,
IsEnabled: chunk.IsEnabled,
IsRecommended: chunk.Flags.HasFlag(types.ChunkFlagRecommended),
})
// 每个相似问创建一个索引项
for i, similarQ := range meta.SimilarQuestions {
similarContent := similarQ
if indexMode == types.FAQIndexModeQuestionAnswer && len(meta.Answers) > 0 {
var builder strings.Builder
builder.WriteString(similarQ)
for _, ans := range meta.Answers {
builder.WriteString("\n")
builder.WriteString(ans)
}
similarContent = builder.String()
}
sourceID := fmt.Sprintf("%s-%d", chunk.ID, i)
indexInfoList = append(indexInfoList, &types.IndexInfo{
Content: similarContent,
SourceID: sourceID,
SourceType: types.ChunkSourceType,
ChunkID: chunk.ID,
KnowledgeID: chunk.KnowledgeID,
KnowledgeBaseID: chunk.KnowledgeBaseID,
KnowledgeType: types.KnowledgeTypeFAQ,
TagID: chunk.TagID,
IsEnabled: chunk.IsEnabled,
IsRecommended: chunk.Flags.HasFlag(types.ChunkFlagRecommended),
})
}
return indexInfoList, nil
}
// incrementalIndexFAQEntry 增量更新FAQ条目的索引
// 只对内容变化的部分进行embedding计算和索引更新跳过未变化的部分
func (s *knowledgeService) incrementalIndexFAQEntry(
ctx context.Context,
kb *types.KnowledgeBase,
knowledge *types.Knowledge,
chunk *types.Chunk,
embeddingModel embedding.Embedder,
oldStandardQuestion string,
oldSimilarQuestions []string,
oldAnswers []string,
newMeta *types.FAQChunkMetadata,
) error {
indexStartTime := time.Now()
tenantInfo := ctx.Value(types.TenantInfoContextKey).(*types.Tenant)
retrieveEngine, err := retriever.NewCompositeRetrieveEngine(s.retrieveEngine, tenantInfo.GetEffectiveEngines())
if err != nil {
return err
}
indexMode := types.FAQIndexModeQuestionAnswer
if kb.FAQConfig != nil && kb.FAQConfig.IndexMode != "" {
indexMode = kb.FAQConfig.IndexMode
}
// 构建旧的内容(用于比较)
buildOldContent := func(question string) string {
if indexMode == types.FAQIndexModeQuestionAnswer && len(oldAnswers) > 0 {
var builder strings.Builder
builder.WriteString(question)
for _, ans := range oldAnswers {
builder.WriteString("\n")
builder.WriteString(ans)
}
return builder.String()
}
return question
}
// 构建新的内容
buildNewContent := func(question string) string {
if indexMode == types.FAQIndexModeQuestionAnswer && len(newMeta.Answers) > 0 {
var builder strings.Builder
builder.WriteString(question)
for _, ans := range newMeta.Answers {
builder.WriteString("\n")
builder.WriteString(ans)
}
return builder.String()
}
return question
}
// 检查答案是否变化
answersChanged := !slices.Equal(oldAnswers, newMeta.Answers)
// 收集需要更新的索引项
var indexInfoToUpdate []*types.IndexInfo
// 1. 检查标准问是否需要更新
oldStdContent := buildOldContent(oldStandardQuestion)
newStdContent := buildNewContent(newMeta.StandardQuestion)
if oldStdContent != newStdContent {
indexInfoToUpdate = append(indexInfoToUpdate, &types.IndexInfo{
Content: newStdContent,
SourceID: chunk.ID,
SourceType: types.ChunkSourceType,
ChunkID: chunk.ID,
KnowledgeID: chunk.KnowledgeID,
KnowledgeBaseID: chunk.KnowledgeBaseID,
KnowledgeType: types.KnowledgeTypeFAQ,
TagID: chunk.TagID,
IsEnabled: chunk.IsEnabled,
IsRecommended: chunk.Flags.HasFlag(types.ChunkFlagRecommended),
})
}
// 2. 检查每个相似问是否需要更新
oldCount := len(oldSimilarQuestions)
newCount := len(newMeta.SimilarQuestions)
for i, newQ := range newMeta.SimilarQuestions {
needUpdate := false
if i >= oldCount {
// 新增的相似问
needUpdate = true
} else {
// 已存在的相似问,检查内容是否变化
oldQ := oldSimilarQuestions[i]
if oldQ != newQ || answersChanged {
needUpdate = true
}
}
if needUpdate {
sourceID := fmt.Sprintf("%s-%d", chunk.ID, i)
indexInfoToUpdate = append(indexInfoToUpdate, &types.IndexInfo{
Content: buildNewContent(newQ),
SourceID: sourceID,
SourceType: types.ChunkSourceType,
ChunkID: chunk.ID,
KnowledgeID: chunk.KnowledgeID,
KnowledgeBaseID: chunk.KnowledgeBaseID,
KnowledgeType: types.KnowledgeTypeFAQ,
TagID: chunk.TagID,
IsEnabled: chunk.IsEnabled,
IsRecommended: chunk.Flags.HasFlag(types.ChunkFlagRecommended),
})
}
}
// 3. 删除多余的旧相似问索引
if oldCount > newCount {
sourceIDsToDelete := make([]string, 0, oldCount-newCount)
for i := newCount; i < oldCount; i++ {
sourceIDsToDelete = append(sourceIDsToDelete, fmt.Sprintf("%s-%d", chunk.ID, i))
}
logger.Debugf(ctx, "incrementalIndexFAQEntry: deleting %d obsolete source IDs", len(sourceIDsToDelete))
if delErr := retrieveEngine.DeleteBySourceIDList(ctx, sourceIDsToDelete, embeddingModel.GetDimensions(), types.KnowledgeTypeFAQ); delErr != nil {
logger.Warnf(ctx, "incrementalIndexFAQEntry: failed to delete obsolete source IDs: %v", delErr)
}
}
// 4. 批量索引需要更新的内容
if len(indexInfoToUpdate) > 0 {
logger.Debugf(ctx, "incrementalIndexFAQEntry: updating %d index entries (skipped %d unchanged)",
len(indexInfoToUpdate), 1+newCount-len(indexInfoToUpdate))
if err := retrieveEngine.BatchIndex(ctx, embeddingModel, indexInfoToUpdate); err != nil {
return err
}
} else {
logger.Debugf(ctx, "incrementalIndexFAQEntry: all %d entries unchanged, skipping index update", 1+newCount)
}
// 5. 更新 knowledge 记录
now := time.Now()
knowledge.UpdatedAt = now
knowledge.ProcessedAt = &now
if err := s.repo.UpdateKnowledge(ctx, knowledge); err != nil {
return err
}
totalDuration := time.Since(indexStartTime)
logger.Debugf(ctx, "incrementalIndexFAQEntry: completed in %v, updated %d/%d entries",
totalDuration, len(indexInfoToUpdate), 1+newCount)
return nil
}
func (s *knowledgeService) indexFAQChunks(ctx context.Context,
kb *types.KnowledgeBase, knowledge *types.Knowledge,
chunks []*types.Chunk, embeddingModel embedding.Embedder,
adjustStorage bool, needDelete bool,
) error {
if len(chunks) == 0 {
return nil
}
indexStartTime := time.Now()
logger.Debugf(ctx, "indexFAQChunks: starting to index %d chunks", len(chunks))
tenantInfo := ctx.Value(types.TenantInfoContextKey).(*types.Tenant)
retrieveEngine, err := retriever.NewCompositeRetrieveEngine(s.retrieveEngine, tenantInfo.GetEffectiveEngines())
if err != nil {
return err
}
// 构建索引信息
buildIndexInfoStartTime := time.Now()
indexInfo := make([]*types.IndexInfo, 0)
chunkIDs := make([]string, 0, len(chunks))
for _, chunk := range chunks {
infoList, err := s.buildFAQIndexInfoList(ctx, kb, chunk)
if err != nil {
return err
}
indexInfo = append(indexInfo, infoList...)
chunkIDs = append(chunkIDs, chunk.ID)
}
buildIndexInfoDuration := time.Since(buildIndexInfoStartTime)
logger.Debugf(
ctx,
"indexFAQChunks: built %d index info entries for %d chunks in %v",
len(indexInfo),
len(chunks),
buildIndexInfoDuration,
)
var size int64
if adjustStorage {
estimateStartTime := time.Now()
size = retrieveEngine.EstimateStorageSize(ctx, embeddingModel, indexInfo)
estimateDuration := time.Since(estimateStartTime)
logger.Debugf(ctx, "indexFAQChunks: estimated storage size %d bytes in %v", size, estimateDuration)
if tenantInfo.StorageQuota > 0 && tenantInfo.StorageUsed+size > tenantInfo.StorageQuota {
return types.NewStorageQuotaExceededError()
}
}
// 删除旧向量
var deleteDuration time.Duration
if needDelete {
deleteStartTime := time.Now()
if err := retrieveEngine.DeleteByChunkIDList(ctx, chunkIDs, embeddingModel.GetDimensions(), types.KnowledgeTypeFAQ); err != nil {
logger.Warnf(ctx, "Delete FAQ vectors failed: %v", err)
}
deleteDuration = time.Since(deleteStartTime)
if deleteDuration > 100*time.Millisecond {
logger.Debugf(ctx, "indexFAQChunks: deleted old vectors for %d chunks in %v", len(chunkIDs), deleteDuration)
}
}
// 批量索引(这里可能是性能瓶颈)
batchIndexStartTime := time.Now()
if err := retrieveEngine.BatchIndex(ctx, embeddingModel, indexInfo); err != nil {
return err
}
batchIndexDuration := time.Since(batchIndexStartTime)
logger.Debugf(ctx, "indexFAQChunks: batch indexed %d index info entries in %v (avg: %v per entry)",
len(indexInfo), batchIndexDuration, batchIndexDuration/time.Duration(len(indexInfo)))
if adjustStorage && size > 0 {
adjustStartTime := time.Now()
if err := s.tenantRepo.AdjustStorageUsed(ctx, tenantInfo.ID, size); err == nil {
tenantInfo.StorageUsed += size
}
knowledge.StorageSize += size
adjustDuration := time.Since(adjustStartTime)
if adjustDuration > 50*time.Millisecond {
logger.Debugf(ctx, "indexFAQChunks: adjusted storage in %v", adjustDuration)
}
}
updateStartTime := time.Now()
now := time.Now()
knowledge.UpdatedAt = now
knowledge.ProcessedAt = &now
err = s.repo.UpdateKnowledge(ctx, knowledge)
updateDuration := time.Since(updateStartTime)
if updateDuration > 50*time.Millisecond {
logger.Debugf(ctx, "indexFAQChunks: updated knowledge in %v", updateDuration)
}
totalDuration := time.Since(indexStartTime)
logger.Debugf(
ctx,
"indexFAQChunks: completed indexing %d chunks in %v (build: %v, delete: %v, batchIndex: %v, update: %v)",
len(chunks),
totalDuration,
buildIndexInfoDuration,
deleteDuration,
batchIndexDuration,
updateDuration,
)
return err
}
func (s *knowledgeService) deleteFAQChunkVectors(ctx context.Context,
kb *types.KnowledgeBase, knowledge *types.Knowledge, chunks []*types.Chunk,
) error {
if len(chunks) == 0 {
return nil
}
embeddingModel, err := s.modelService.GetEmbeddingModel(ctx, kb.EmbeddingModelID)
if err != nil {
return err
}
tenantInfo := ctx.Value(types.TenantInfoContextKey).(*types.Tenant)
retrieveEngine, err := retriever.NewCompositeRetrieveEngine(s.retrieveEngine, tenantInfo.GetEffectiveEngines())
if err != nil {
return err
}
indexInfo := make([]*types.IndexInfo, 0)
chunkIDs := make([]string, 0, len(chunks))
for _, chunk := range chunks {
infoList, err := s.buildFAQIndexInfoList(ctx, kb, chunk)
if err != nil {
return err
}
indexInfo = append(indexInfo, infoList...)
chunkIDs = append(chunkIDs, chunk.ID)
}
size := retrieveEngine.EstimateStorageSize(ctx, embeddingModel, indexInfo)
if err := retrieveEngine.DeleteByChunkIDList(ctx, chunkIDs, embeddingModel.GetDimensions(), types.KnowledgeTypeFAQ); err != nil {
return err
}
if size > 0 {
if err := s.tenantRepo.AdjustStorageUsed(ctx, tenantInfo.ID, -size); err == nil {
tenantInfo.StorageUsed -= size
if tenantInfo.StorageUsed < 0 {
tenantInfo.StorageUsed = 0
}
}
if knowledge.StorageSize >= size {
knowledge.StorageSize -= size
} else {
knowledge.StorageSize = 0
}
}
knowledge.UpdatedAt = time.Now()
return s.repo.UpdateKnowledge(ctx, knowledge)
}
func ensureManualFileName(title string) string {
if title == "" {
return fmt.Sprintf("manual-%s%s", time.Now().Format("20060102-150405"), manualFileExtension)
}
trimmed := strings.TrimSpace(title)
if strings.HasSuffix(strings.ToLower(trimmed), manualFileExtension) {
return trimmed
}
return trimmed + manualFileExtension
}
// sanitizeManualDownloadFilename converts a knowledge title into a safe .md
// download filename. Characters that are illegal or dangerous in HTTP header
// values and file-system paths are removed or replaced; a blank result falls
// back to "untitled".
func sanitizeManualDownloadFilename(title string) string {
safeName := strings.NewReplacer(
"\n", "", "\r", "", "\t", "", "/", "-", "\\", "-", "\"", "'",
).Replace(title)
if strings.TrimSpace(safeName) == "" {
safeName = "untitled"
}
if !strings.HasSuffix(strings.ToLower(safeName), manualFileExtension) {
safeName += manualFileExtension
}
return safeName
}
func (s *knowledgeService) triggerManualProcessing(ctx context.Context,
kb *types.KnowledgeBase, knowledge *types.Knowledge, content string, doSync bool,
) {
clean := strings.TrimSpace(content)
if clean == "" {
return
}
// Resolve remote images: download http(s) images, upload to storage, replace URLs.
// This runs before chunking so that chunks contain stable provider:// URLs.
var resolvedImages []docparser.StoredImage
if s.imageResolver != nil {
fileSvc := s.resolveFileService(ctx, kb)
updatedContent, storedImages, resolveErr := s.imageResolver.ResolveRemoteImages(ctx, clean, fileSvc, knowledge.TenantID)
if resolveErr != nil {
logger.Warnf(ctx, "Remote image resolution partially failed: %v", resolveErr)
}
if len(storedImages) > 0 {
logger.Infof(ctx, "Resolved %d remote images for manual knowledge %s", len(storedImages), knowledge.ID)
clean = updatedContent
resolvedImages = storedImages
}
}
// Manual content is markdown - chunk directly with Go chunker
chunkCfg := chunker.SplitterConfig{
ChunkSize: kb.ChunkingConfig.ChunkSize,
ChunkOverlap: kb.ChunkingConfig.ChunkOverlap,
Separators: kb.ChunkingConfig.Separators,
}
if chunkCfg.ChunkSize <= 0 {
chunkCfg.ChunkSize = 512
}
if chunkCfg.ChunkOverlap <= 0 {
chunkCfg.ChunkOverlap = 50
}
if len(chunkCfg.Separators) == 0 {
chunkCfg.Separators = []string{"\n\n", "\n", "。"}
}
var parsed []types.ParsedChunk
opts := ProcessChunksOptions{
// When the KB has VLM enabled and we resolved remote images, pass them
// through so processChunks will enqueue image:multimodal tasks (OCR + caption).
EnableMultimodel: kb.IsMultimodalEnabled() && len(resolvedImages) > 0,
StoredImages: resolvedImages,
}
if kb.ChunkingConfig.EnableParentChild {
parentCfg, childCfg := buildParentChildConfigs(kb.ChunkingConfig, chunkCfg)
pcResult := chunker.SplitTextParentChild(clean, parentCfg, childCfg)
parsed = make([]types.ParsedChunk, len(pcResult.Children))
for i, c := range pcResult.Children {
parsed[i] = types.ParsedChunk{
Content: c.Content,
Seq: c.Seq,
Start: c.Start,
End: c.End,
ParentIndex: c.ParentIndex,
}
}
parentChunks := make([]types.ParsedParentChunk, len(pcResult.Parents))
for i, p := range pcResult.Parents {
parentChunks[i] = types.ParsedParentChunk{Content: p.Content, Seq: p.Seq, Start: p.Start, End: p.End}
}
opts.ParentChunks = parentChunks
} else {
splitChunks := chunker.SplitText(clean, chunkCfg)
parsed = make([]types.ParsedChunk, len(splitChunks))
for i, c := range splitChunks {
parsed[i] = types.ParsedChunk{
Content: c.Content,
Seq: c.Seq,
Start: c.Start,
End: c.End,
}
}
}
if doSync {
s.processChunks(ctx, kb, knowledge, parsed, opts)
return
}
newCtx := logger.CloneContext(ctx)
go s.processChunks(newCtx, kb, knowledge, parsed, opts)
}
func (s *knowledgeService) cleanupKnowledgeResources(ctx context.Context, knowledge *types.Knowledge) error {
logger.GetLogger(ctx).Infof("Cleaning knowledge resources before manual update, knowledge ID: %s", knowledge.ID)
var cleanupErr error
if knowledge.ParseStatus == types.ManualKnowledgeStatusDraft && knowledge.StorageSize == 0 {
// Draft without indexed data, skip cleanup.
return nil
}
tenantInfo := ctx.Value(types.TenantInfoContextKey).(*types.Tenant)
if knowledge.EmbeddingModelID != "" {
retrieveEngine, err := retriever.NewCompositeRetrieveEngine(
s.retrieveEngine,
tenantInfo.GetEffectiveEngines(),
)
if err != nil {
logger.GetLogger(ctx).WithField("error", err).Error("Failed to init retrieve engine during cleanup")
cleanupErr = errors.Join(cleanupErr, err)
} else {
embeddingModel, modelErr := s.modelService.GetEmbeddingModel(ctx, knowledge.EmbeddingModelID)
if modelErr != nil {
logger.GetLogger(ctx).WithField("error", modelErr).Error("Failed to get embedding model during cleanup")
cleanupErr = errors.Join(cleanupErr, modelErr)
} else {
if err := retrieveEngine.DeleteByKnowledgeIDList(ctx, []string{knowledge.ID}, embeddingModel.GetDimensions(), knowledge.Type); err != nil {
logger.GetLogger(ctx).WithField("error", err).Error("Failed to delete manual knowledge index")
cleanupErr = errors.Join(cleanupErr, err)
}
}
}
}
if err := s.chunkService.DeleteChunksByKnowledgeID(ctx, knowledge.ID); err != nil {
logger.GetLogger(ctx).WithField("error", err).Error("Failed to delete manual knowledge chunks")
cleanupErr = errors.Join(cleanupErr, err)
}
namespace := types.NameSpace{KnowledgeBase: knowledge.KnowledgeBaseID, Knowledge: knowledge.ID}
if err := s.graphEngine.DelGraph(ctx, []types.NameSpace{namespace}); err != nil {
logger.GetLogger(ctx).WithField("error", err).Error("Failed to delete manual knowledge graph data")
cleanupErr = errors.Join(cleanupErr, err)
}
if knowledge.StorageSize > 0 {
tenantInfo.StorageUsed -= knowledge.StorageSize
if tenantInfo.StorageUsed < 0 {
tenantInfo.StorageUsed = 0
}
if err := s.tenantRepo.AdjustStorageUsed(ctx, tenantInfo.ID, -knowledge.StorageSize); err != nil {
logger.GetLogger(ctx).WithField("error", err).Error("Failed to adjust storage usage during manual cleanup")
cleanupErr = errors.Join(cleanupErr, err)
}
knowledge.StorageSize = 0
}
return cleanupErr
}
func (s *knowledgeService) getVLMConfig(ctx context.Context, kb *types.KnowledgeBase) (*types.DocParserVLMConfig, error) {
if kb == nil {
return nil, nil
}
// 兼容老版本:直接使用 ModelName 和 BaseURL
if kb.VLMConfig.ModelName != "" && kb.VLMConfig.BaseURL != "" {
return &types.DocParserVLMConfig{
ModelName: kb.VLMConfig.ModelName,
BaseURL: kb.VLMConfig.BaseURL,
APIKey: kb.VLMConfig.APIKey,
InterfaceType: kb.VLMConfig.InterfaceType,
}, nil
}
// 新版本未启用或无模型ID时返回nil
if !kb.VLMConfig.Enabled || kb.VLMConfig.ModelID == "" {
return nil, nil
}
model, err := s.modelService.GetModelByID(ctx, kb.VLMConfig.ModelID)
if err != nil {
return nil, err
}
interfaceType := model.Parameters.InterfaceType
if interfaceType == "" {
interfaceType = "openai"
}
return &types.DocParserVLMConfig{
ModelName: model.Name,
BaseURL: model.Parameters.BaseURL,
APIKey: model.Parameters.APIKey,
InterfaceType: interfaceType,
}, nil
}
func (s *knowledgeService) buildStorageConfig(ctx context.Context, kb *types.KnowledgeBase) *types.DocParserStorageConfig {
provider := kb.GetStorageProvider()
if provider == "" {
provider = "local"
}
// Backward compatibility: if legacy cos_config has full params for the chosen provider, use them.
sc := &kb.StorageConfig
hasKBFull := false
switch provider {
case "cos":
hasKBFull = sc.SecretID != "" && sc.BucketName != ""
case "minio":
hasKBFull = sc.BucketName != ""
case "local":
hasKBFull = false
}
if hasKBFull {
logger.Infof(ctx, "[storage] buildStorageConfig use legacy kb config: kb=%s provider=%s bucket=%s path_prefix=%s",
kb.ID, provider, sc.BucketName, sc.PathPrefix)
return &types.DocParserStorageConfig{
Provider: strings.ToUpper(provider),
Region: sc.Region,
BucketName: sc.BucketName,
AccessKeyID: sc.SecretID,
SecretAccessKey: sc.SecretKey,
AppID: sc.AppID,
PathPrefix: sc.PathPrefix,
}
}
// Merge from tenant's StorageEngineConfig.
var out types.DocParserStorageConfig
out.Provider = strings.ToUpper(provider)
tenant, _ := ctx.Value(types.TenantInfoContextKey).(*types.Tenant)
if tenant != nil && tenant.StorageEngineConfig != nil {
sec := tenant.StorageEngineConfig
if sec.DefaultProvider != "" && provider == "" {
provider = strings.ToLower(strings.TrimSpace(sec.DefaultProvider))
out.Provider = strings.ToUpper(provider)
}
switch provider {
case "local":
if sec.Local != nil {
out.PathPrefix = sec.Local.PathPrefix
}
case "minio":
if sec.MinIO != nil {
out.BucketName = sec.MinIO.BucketName
out.PathPrefix = sec.MinIO.PathPrefix
if sec.MinIO.Mode == "remote" {
out.Endpoint = sec.MinIO.Endpoint
out.AccessKeyID = sec.MinIO.AccessKeyID
out.SecretAccessKey = sec.MinIO.SecretAccessKey
} else {
out.Endpoint = os.Getenv("MINIO_ENDPOINT")
out.AccessKeyID = os.Getenv("MINIO_ACCESS_KEY_ID")
out.SecretAccessKey = os.Getenv("MINIO_SECRET_ACCESS_KEY")
}
}
case "cos":
if sec.COS != nil {
out.Region = sec.COS.Region
out.BucketName = sec.COS.BucketName
out.AccessKeyID = sec.COS.SecretID
out.SecretAccessKey = sec.COS.SecretKey
out.AppID = sec.COS.AppID
out.PathPrefix = sec.COS.PathPrefix
}
}
}
logger.Infof(ctx, "[storage] buildStorageConfig use merged tenant/global config: kb=%s provider=%s bucket=%s path_prefix=%s endpoint=%s",
kb.ID, strings.ToLower(out.Provider), out.BucketName, out.PathPrefix, out.Endpoint)
return &out
}
// resolveFileService returns the FileService for the given knowledge base,
// based on the KB's StorageProviderConfig (or legacy StorageConfig.Provider) and the tenant's StorageEngineConfig.
// Falls back to the global fileSvc when no tenant-level storage config is found.
func (s *knowledgeService) resolveFileService(ctx context.Context, kb *types.KnowledgeBase) interfaces.FileService {
if kb == nil {
logger.Infof(ctx, "[storage] resolveFileService fallback default: kb=nil")
return s.fileSvc
}
provider := kb.GetStorageProvider()
tenant, _ := ctx.Value(types.TenantInfoContextKey).(*types.Tenant)
if provider == "" && tenant != nil && tenant.StorageEngineConfig != nil {
provider = strings.ToLower(strings.TrimSpace(tenant.StorageEngineConfig.DefaultProvider))
}
if provider == "" || tenant == nil || tenant.StorageEngineConfig == nil {
logger.Infof(ctx, "[storage] resolveFileService fallback default: kb=%s provider=%q tenant_cfg=%v",
kb.ID, provider, tenant != nil && tenant.StorageEngineConfig != nil)
return s.fileSvc
}
sec := tenant.StorageEngineConfig
baseDir := strings.TrimSpace(os.Getenv("LOCAL_STORAGE_BASE_DIR"))
svc, resolvedProvider, err := filesvc.NewFileServiceFromStorageConfig(provider, sec, baseDir)
if err != nil {
logger.Errorf(ctx, "Failed to create %s file service from tenant config: %v, falling back to default", provider, err)
return s.fileSvc
}
logger.Infof(ctx, "[storage] resolveFileService selected: kb=%s provider=%s", kb.ID, resolvedProvider)
return svc
}
// resolveFileServiceForPath is like resolveFileService but adds a safety check:
// if the resolved provider doesn't match what the filePath implies, fall back to
// the provider inferred from the file path. This protects historical data when
// tenant/KB config changes but files were stored under the old provider.
func (s *knowledgeService) resolveFileServiceForPath(ctx context.Context, kb *types.KnowledgeBase, filePath string) interfaces.FileService {
svc := s.resolveFileService(ctx, kb)
if filePath == "" {
return svc
}
inferred := types.InferStorageFromFilePath(filePath)
if inferred == "" {
return svc
}
configured := kb.GetStorageProvider()
if configured == "" {
tenant, _ := ctx.Value(types.TenantInfoContextKey).(*types.Tenant)
if tenant != nil && tenant.StorageEngineConfig != nil {
configured = strings.ToLower(strings.TrimSpace(tenant.StorageEngineConfig.DefaultProvider))
}
}
if configured == "" {
configured = strings.ToLower(strings.TrimSpace(os.Getenv("STORAGE_TYPE")))
}
if configured != "" && configured != inferred {
logger.Warnf(ctx, "[storage] FilePath format mismatch: configured=%s inferred=%s filePath=%s, using global fallback",
configured, inferred, filePath)
return s.fileSvc
}
return svc
}
func IsImageType(fileType string) bool {
switch fileType {
case "jpg", "jpeg", "png", "gif", "webp", "bmp", "svg", "tiff":
return true
default:
return false
}
}
// downloadFileFromURL downloads a remote file to a temp file and returns its binary content.
// payloadFileName and payloadFileType are in/out pointers: if they point to an empty string,
// the function resolves the value from Content-Disposition / URL path and writes it back.
// It does NOT perform SSRF validation — callers are responsible for that.
func downloadFileFromURL(ctx context.Context, fileURL string, payloadFileName, payloadFileType *string) ([]byte, error) {
httpClient := &http.Client{Timeout: 60 * time.Second}
req, err := http.NewRequestWithContext(ctx, http.MethodGet, fileURL, nil)
if err != nil {
return nil, fmt.Errorf("failed to create request for file URL: %w", err)
}
resp, err := httpClient.Do(req)
if err != nil {
return nil, fmt.Errorf("failed to download file from URL: %w", err)
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
return nil, fmt.Errorf("remote server returned status %d", resp.StatusCode)
}
// Reject oversized files early via Content-Length
if contentLength := resp.ContentLength; contentLength > maxFileURLSize {
return nil, fmt.Errorf("file size %d bytes exceeds limit of %d bytes (10MB)", contentLength, maxFileURLSize)
}
// Resolve fileName: payload > Content-Disposition > URL path
if *payloadFileName == "" {
if cd := resp.Header.Get("Content-Disposition"); cd != "" {
*payloadFileName = extractFileNameFromContentDisposition(cd)
}
}
if *payloadFileName == "" {
*payloadFileName = extractFileNameFromURL(fileURL)
}
if *payloadFileType == "" && *payloadFileName != "" {
*payloadFileType = getFileType(*payloadFileName)
}
// Stream response body into a temp file, capped at maxFileURLSize
tmpFile, err := os.CreateTemp("", "weknora-fileurl-*")
if err != nil {
return nil, fmt.Errorf("failed to create temp file: %w", err)
}
tmpPath := tmpFile.Name()
defer os.Remove(tmpPath)
limiter := &io.LimitedReader{R: resp.Body, N: maxFileURLSize + 1}
written, err := io.Copy(tmpFile, limiter)
tmpFile.Close()
if err != nil {
return nil, fmt.Errorf("failed to write temp file: %w", err)
}
if written > maxFileURLSize {
return nil, fmt.Errorf("file size exceeds limit of 10MB")
}
contentBytes, err := os.ReadFile(tmpPath)
if err != nil {
return nil, fmt.Errorf("failed to read temp file: %w", err)
}
return contentBytes, nil
}
// ProcessManualUpdate handles Asynq manual knowledge update tasks.
// It performs cleanup of old indexes/chunks (when NeedCleanup is true) and re-indexes the content.
func (s *knowledgeService) ProcessManualUpdate(ctx context.Context, t *asynq.Task) error {
var payload types.ManualProcessPayload
if err := json.Unmarshal(t.Payload(), &payload); err != nil {
logger.Errorf(ctx, "failed to unmarshal manual process task payload: %v", err)
return nil
}
ctx = logger.WithRequestID(ctx, payload.RequestId)
ctx = logger.WithField(ctx, "manual_process", payload.KnowledgeID)
ctx = context.WithValue(ctx, types.TenantIDContextKey, payload.TenantID)
tenantInfo, err := s.tenantRepo.GetTenantByID(ctx, payload.TenantID)
if err != nil {
logger.Errorf(ctx, "ProcessManualUpdate: failed to get tenant: %v", err)
return nil
}
ctx = context.WithValue(ctx, types.TenantInfoContextKey, tenantInfo)
knowledge, err := s.repo.GetKnowledgeByID(ctx, payload.TenantID, payload.KnowledgeID)
if err != nil {
logger.Errorf(ctx, "ProcessManualUpdate: failed to get knowledge: %v", err)
return nil
}
if knowledge == nil {
logger.Warnf(ctx, "ProcessManualUpdate: knowledge not found: %s", payload.KnowledgeID)
return nil
}
// Skip if already completed or being deleted
if knowledge.ParseStatus == types.ParseStatusCompleted {
logger.Infof(ctx, "ProcessManualUpdate: already completed, skipping: %s", payload.KnowledgeID)
return nil
}
if knowledge.ParseStatus == types.ParseStatusDeleting {
logger.Infof(ctx, "ProcessManualUpdate: being deleted, skipping: %s", payload.KnowledgeID)
return nil
}
kb, err := s.kbService.GetKnowledgeBaseByID(ctx, payload.KnowledgeBaseID)
if err != nil {
logger.Errorf(ctx, "ProcessManualUpdate: failed to get knowledge base: %v", err)
knowledge.ParseStatus = "failed"
knowledge.ErrorMessage = fmt.Sprintf("failed to get knowledge base: %v", err)
knowledge.UpdatedAt = time.Now()
s.repo.UpdateKnowledge(ctx, knowledge)
return nil
}
// Update status to processing
knowledge.ParseStatus = "processing"
knowledge.UpdatedAt = time.Now()
if err := s.repo.UpdateKnowledge(ctx, knowledge); err != nil {
logger.Errorf(ctx, "ProcessManualUpdate: failed to update status to processing: %v", err)
return nil
}
// Cleanup old resources (indexes, chunks, graph) for update operations
if payload.NeedCleanup {
if err := s.cleanupKnowledgeResources(ctx, knowledge); err != nil {
logger.ErrorWithFields(ctx, err, map[string]interface{}{
"knowledge_id": payload.KnowledgeID,
})
knowledge.ParseStatus = "failed"
knowledge.ErrorMessage = fmt.Sprintf("failed to cleanup old resources: %v", err)
knowledge.UpdatedAt = time.Now()
s.repo.UpdateKnowledge(ctx, knowledge)
return nil
}
}
// Run manual processing (image resolution + chunking + embedding) synchronously within the worker
s.triggerManualProcessing(ctx, kb, knowledge, payload.Content, true)
return nil
}
// ProcessDocument handles Asynq document processing tasks
func (s *knowledgeService) ProcessDocument(ctx context.Context, t *asynq.Task) error {
var payload types.DocumentProcessPayload
if err := json.Unmarshal(t.Payload(), &payload); err != nil {
logger.Errorf(ctx, "failed to unmarshal document process task payload: %v", err)
return nil
}
ctx = logger.WithRequestID(ctx, payload.RequestId)
ctx = logger.WithField(ctx, "document_process", payload.KnowledgeID)
ctx = context.WithValue(ctx, types.TenantIDContextKey, payload.TenantID)
// 获取任务重试信息,用于判断是否是最后一次重试
retryCount, _ := asynq.GetRetryCount(ctx)
maxRetry, _ := asynq.GetMaxRetry(ctx)
isLastRetry := retryCount >= maxRetry
tenantInfo, err := s.tenantRepo.GetTenantByID(ctx, payload.TenantID)
if err != nil {
logger.Errorf(ctx, "failed to get tenant: %v", err)
return nil
}
ctx = context.WithValue(ctx, types.TenantInfoContextKey, tenantInfo)
logger.Infof(ctx, "Processing document task: knowledge_id=%s, file_path=%s, retry=%d/%d",
payload.KnowledgeID, payload.FilePath, retryCount, maxRetry)
// 幂等性检查获取knowledge记录
knowledge, err := s.repo.GetKnowledgeByID(ctx, payload.TenantID, payload.KnowledgeID)
if err != nil {
logger.Errorf(ctx, "failed to get knowledge: %v", err)
return nil
}
if knowledge == nil {
return nil
}
// 检查是否正在删除 - 如果是则直接退出,避免与删除操作冲突
if knowledge.ParseStatus == types.ParseStatusDeleting {
logger.Infof(ctx, "Knowledge is being deleted, aborting processing: %s", payload.KnowledgeID)
return nil
}
// 检查任务状态 - 幂等性处理
if knowledge.ParseStatus == types.ParseStatusCompleted {
logger.Infof(ctx, "Document already completed, skipping: %s", payload.KnowledgeID)
return nil // 幂等:已完成的任务直接返回
}
if knowledge.ParseStatus == types.ParseStatusFailed {
// 检查是否可恢复(例如:超时、临时错误等)
// 对于不可恢复的错误,直接返回
logger.Warnf(
ctx,
"Document processing previously failed: %s, error: %s",
payload.KnowledgeID,
knowledge.ErrorMessage,
)
// 这里可以根据错误类型判断是否可恢复,暂时允许重试
}
// 检查是否有部分处理有chunks但状态不是completed
if knowledge.ParseStatus != "completed" && knowledge.ParseStatus != "pending" &&
knowledge.ParseStatus != "processing" {
// 状态异常,记录日志但继续处理
logger.Warnf(ctx, "Unexpected parse status: %s for knowledge: %s", knowledge.ParseStatus, payload.KnowledgeID)
}
// 获取知识库信息
kb, err := s.kbService.GetKnowledgeBaseByID(ctx, payload.KnowledgeBaseID)
if err != nil {
logger.Errorf(ctx, "failed to get knowledge base: %v", err)
knowledge.ParseStatus = "failed"
knowledge.ErrorMessage = fmt.Sprintf("failed to get knowledge base: %v", err)
knowledge.UpdatedAt = time.Now()
s.repo.UpdateKnowledge(ctx, knowledge)
return nil
}
knowledge.ParseStatus = "processing"
knowledge.UpdatedAt = time.Now()
if err := s.repo.UpdateKnowledge(ctx, knowledge); err != nil {
logger.Errorf(ctx, "failed to update knowledge status to processing: %v", err)
return nil
}
// 检查多模态配置(仅对文件导入)
if payload.FilePath != "" && !payload.EnableMultimodel && IsImageType(payload.FileType) {
logger.GetLogger(ctx).WithField("knowledge_id", knowledge.ID).
WithField("error", ErrImageNotParse).Errorf("processDocument image without enable multimodel")
knowledge.ParseStatus = "failed"
knowledge.ErrorMessage = ErrImageNotParse.Error()
knowledge.UpdatedAt = time.Now()
s.repo.UpdateKnowledge(ctx, knowledge)
return nil
}
// New pipeline: convert -> store images -> chunk -> vectorize -> multimodal tasks
var convertResult *types.ReadResult
var chunks []types.ParsedChunk
if payload.FileURL != "" {
// file_url import: SSRF re-check (防 DNS 重绑定), download, persist, then delegate to convert()
if safe, reason := secutils.IsSSRFSafeURL(payload.FileURL); !safe {
logger.Errorf(ctx, "File URL rejected for SSRF protection in ProcessDocument: %s, reason: %s", payload.FileURL, reason)
knowledge.ParseStatus = "failed"
knowledge.ErrorMessage = "File URL is not allowed for security reasons"
knowledge.UpdatedAt = time.Now()
s.repo.UpdateKnowledge(ctx, knowledge)
return nil
}
resolvedFileName := payload.FileName
resolvedFileType := payload.FileType
contentBytes, err := downloadFileFromURL(ctx, payload.FileURL, &resolvedFileName, &resolvedFileType)
if err != nil {
logger.Errorf(ctx, "Failed to download file from URL: %s, error: %v", payload.FileURL, err)
if isLastRetry {
knowledge.ParseStatus = "failed"
knowledge.ErrorMessage = err.Error()
knowledge.UpdatedAt = time.Now()
s.repo.UpdateKnowledge(ctx, knowledge)
}
return fmt.Errorf("failed to download file from URL: %w", err)
}
if resolvedFileType != "" && !allowedFileURLExtensions[strings.ToLower(resolvedFileType)] {
logger.Errorf(ctx, "Unsupported file type resolved from file URL: %s", resolvedFileType)
knowledge.ParseStatus = "failed"
knowledge.ErrorMessage = fmt.Sprintf("unsupported file type: %s", resolvedFileType)
knowledge.UpdatedAt = time.Now()
s.repo.UpdateKnowledge(ctx, knowledge)
return nil
}
if resolvedFileName != "" && knowledge.FileName == "" {
knowledge.FileName = resolvedFileName
}
if resolvedFileType != "" && knowledge.FileType == "" {
knowledge.FileType = resolvedFileType
s.repo.UpdateKnowledge(ctx, knowledge)
}
fileSvc := s.resolveFileService(ctx, kb)
filePath, err := fileSvc.SaveBytes(ctx, contentBytes, payload.TenantID, resolvedFileName, true)
if err != nil {
if isLastRetry {
knowledge.ParseStatus = "failed"
knowledge.ErrorMessage = err.Error()
knowledge.UpdatedAt = time.Now()
s.repo.UpdateKnowledge(ctx, knowledge)
}
return fmt.Errorf("failed to save downloaded file: %w", err)
}
payload.FilePath = filePath
payload.FileName = resolvedFileName
payload.FileType = resolvedFileType
convertResult, err = s.convert(ctx, payload, kb, knowledge, isLastRetry)
if err != nil {
return err
}
if convertResult == nil {
return nil
}
} else if payload.URL != "" {
// URL import
convertResult, err = s.convert(ctx, payload, kb, knowledge, isLastRetry)
if err != nil {
return err
}
if convertResult == nil {
return nil
}
} else if len(payload.Passages) > 0 {
// Text passage import - direct chunking, no conversion needed
passageChunks := make([]types.ParsedChunk, 0, len(payload.Passages))
start, end := 0, 0
for i, p := range payload.Passages {
if p == "" {
continue
}
end += len([]rune(p))
passageChunks = append(passageChunks, types.ParsedChunk{
Content: p,
Seq: i,
Start: start,
End: end,
})
start = end
}
s.processChunks(ctx, kb, knowledge, passageChunks)
return nil
} else {
// File import
convertResult, err = s.convert(ctx, payload, kb, knowledge, isLastRetry)
if err != nil {
return err
}
if convertResult == nil {
return nil
}
}
// Step 2: Store images and update markdown references
var storedImages []docparser.StoredImage
if s.imageResolver != nil && convertResult != nil {
fileSvc := s.resolveFileService(ctx, kb)
tenantID, _ := ctx.Value(types.TenantIDContextKey).(uint64)
updatedMarkdown, images, resolveErr := s.imageResolver.ResolveAndStore(ctx, convertResult, fileSvc, tenantID)
if resolveErr != nil {
logger.Warnf(ctx, "Image resolution partially failed: %v", resolveErr)
}
if updatedMarkdown != "" {
convertResult.MarkdownContent = updatedMarkdown
}
storedImages = images
logger.Infof(ctx, "Resolved %d images for knowledge %s", len(storedImages), knowledge.ID)
}
// Step 3: Split into chunks using Go chunker
chunkCfg := chunker.SplitterConfig{
ChunkSize: kb.ChunkingConfig.ChunkSize,
ChunkOverlap: kb.ChunkingConfig.ChunkOverlap,
Separators: kb.ChunkingConfig.Separators,
}
if chunkCfg.ChunkSize <= 0 {
chunkCfg.ChunkSize = 512
}
if chunkCfg.ChunkOverlap <= 0 {
chunkCfg.ChunkOverlap = 50
}
if len(chunkCfg.Separators) == 0 {
chunkCfg.Separators = []string{"\n\n", "\n", "。"}
}
processOpts := ProcessChunksOptions{
EnableQuestionGeneration: payload.EnableQuestionGeneration,
QuestionCount: payload.QuestionCount,
EnableMultimodel: payload.EnableMultimodel,
StoredImages: storedImages,
}
if kb.ChunkingConfig.EnableParentChild {
parentCfg, childCfg := buildParentChildConfigs(kb.ChunkingConfig, chunkCfg)
pcResult := chunker.SplitTextParentChild(convertResult.MarkdownContent, parentCfg, childCfg)
chunks = make([]types.ParsedChunk, len(pcResult.Children))
for i, c := range pcResult.Children {
chunks[i] = types.ParsedChunk{
Content: c.Content,
Seq: c.Seq,
Start: c.Start,
End: c.End,
ParentIndex: c.ParentIndex,
}
}
parentChunks := make([]types.ParsedParentChunk, len(pcResult.Parents))
for i, p := range pcResult.Parents {
parentChunks[i] = types.ParsedParentChunk{Content: p.Content, Seq: p.Seq, Start: p.Start, End: p.End}
}
processOpts.ParentChunks = parentChunks
logger.Infof(ctx, "Split document into %d parent + %d child chunks for knowledge %s",
len(pcResult.Parents), len(pcResult.Children), knowledge.ID)
} else {
splitChunks := chunker.SplitText(convertResult.MarkdownContent, chunkCfg)
chunks = make([]types.ParsedChunk, len(splitChunks))
for i, c := range splitChunks {
chunks[i] = types.ParsedChunk{
Content: c.Content,
Seq: c.Seq,
Start: c.Start,
End: c.End,
}
}
logger.Infof(ctx, "Split document into %d chunks for knowledge %s", len(chunks), knowledge.ID)
}
// Step 4: Process chunks (vectorize + index + enqueue async tasks)
s.processChunks(ctx, kb, knowledge, chunks, processOpts)
return nil
}
// convert handles both file and URL reading using a unified ReadRequest.
func (s *knowledgeService) convert(
ctx context.Context,
payload types.DocumentProcessPayload,
kb *types.KnowledgeBase,
knowledge *types.Knowledge,
isLastRetry bool,
) (*types.ReadResult, error) {
isURL := payload.URL != ""
fileType := payload.FileType
overrides := s.getParserEngineOverridesFromContext(ctx)
if isURL {
if safe, reason := secutils.IsSSRFSafeURL(payload.URL); !safe {
logger.Errorf(ctx, "URL rejected for SSRF protection: %s, reason: %s", payload.URL, reason)
knowledge.ParseStatus = "failed"
knowledge.ErrorMessage = "URL is not allowed for security reasons"
knowledge.UpdatedAt = time.Now()
s.repo.UpdateKnowledge(ctx, knowledge)
return nil, nil
}
}
parserEngine := kb.ChunkingConfig.ResolveParserEngine(fileType)
if isURL {
parserEngine = kb.ChunkingConfig.ResolveParserEngine("url")
}
logger.Infof(ctx, "[convert] kb=%s fileType=%s isURL=%v engine=%q rules=%+v",
kb.ID, fileType, isURL, parserEngine, kb.ChunkingConfig.ParserEngineRules)
var reader interfaces.DocReader = s.resolveDocReader(parserEngine, fileType, isURL, overrides)
if reader == nil {
knowledge.ParseStatus = "failed"
knowledge.ErrorMessage = "Document parsing service is not configured. Please use text/paragraph import or set DOCREADER_ADDR."
knowledge.UpdatedAt = time.Now()
s.repo.UpdateKnowledge(ctx, knowledge)
return nil, nil
}
req := &types.ReadRequest{
URL: payload.URL,
Title: knowledge.Title,
ParserEngine: parserEngine,
RequestID: payload.RequestId,
ParserEngineOverrides: overrides,
}
if !isURL {
fileReader, err := s.resolveFileServiceForPath(ctx, kb, payload.FilePath).GetFile(ctx, payload.FilePath)
if err != nil {
return s.failKnowledge(ctx, knowledge, isLastRetry, "failed to get file: %v", err)
}
defer fileReader.Close()
contentBytes, err := io.ReadAll(fileReader)
if err != nil {
return s.failKnowledge(ctx, knowledge, isLastRetry, "failed to read file: %v", err)
}
req.FileContent = contentBytes
req.FileName = payload.FileName
req.FileType = fileType
}
result, err := reader.Read(ctx, req)
if err != nil {
return s.failKnowledge(ctx, knowledge, isLastRetry, "document read failed: %v", err)
}
if result.Error != "" {
knowledge.ParseStatus = "failed"
knowledge.ErrorMessage = result.Error
knowledge.UpdatedAt = time.Now()
s.repo.UpdateKnowledge(ctx, knowledge)
return nil, nil
}
return result, nil
}
// resolveDocReader returns the appropriate DocReader for the given engine.
// Returns nil when the required service is unavailable.
func (s *knowledgeService) resolveDocReader(engine, fileType string, isURL bool, overrides map[string]string) interfaces.DocReader {
switch engine {
case docparser.SimpleEngineName:
return &docparser.SimpleFormatReader{}
case "mineru":
return docparser.NewMinerUReader(overrides)
case "mineru_cloud":
return docparser.NewMinerUCloudReader(overrides)
case "builtin":
// 明确指定使用 builtin 引擎docreader不使用 simple format 兜底
return s.documentReader
default:
// 未指定引擎时的兜底逻辑simple format 使用 Go 原生处理,其他使用 docreader
if !isURL && docparser.IsSimpleFormat(fileType) {
return &docparser.SimpleFormatReader{}
}
return s.documentReader
}
}
// failKnowledge marks knowledge as failed (only on last retry) and returns an error.
func (s *knowledgeService) failKnowledge(
ctx context.Context,
knowledge *types.Knowledge,
isLastRetry bool,
format string,
args ...interface{},
) (*types.ReadResult, error) {
errMsg := fmt.Sprintf(format, args...)
if isLastRetry {
knowledge.ParseStatus = "failed"
knowledge.ErrorMessage = errMsg
knowledge.UpdatedAt = time.Now()
s.repo.UpdateKnowledge(ctx, knowledge)
}
return nil, fmt.Errorf(format, args...)
}
// enqueueImageMultimodalTasks enqueues asynq tasks for multimodal image processing.
func (s *knowledgeService) enqueueImageMultimodalTasks(
ctx context.Context,
knowledge *types.Knowledge,
kb *types.KnowledgeBase,
images []docparser.StoredImage,
chunks []types.ParsedChunk,
) {
if s.task == nil || len(images) == 0 {
return
}
for _, img := range images {
// Match image to the ParsedChunk whose content contains the image URL.
// ChunkID was populated by processChunks with the real DB UUID.
chunkID := ""
for _, c := range chunks {
if strings.Contains(c.Content, img.ServingURL) {
chunkID = c.ChunkID
break
}
}
if chunkID == "" && len(chunks) > 0 {
chunkID = chunks[0].ChunkID
}
payload := types.ImageMultimodalPayload{
TenantID: knowledge.TenantID,
KnowledgeID: knowledge.ID,
KnowledgeBaseID: kb.ID,
ChunkID: chunkID,
ImageURL: img.ServingURL,
EnableOCR: true,
EnableCaption: true,
}
payloadBytes, err := json.Marshal(payload)
if err != nil {
logger.Warnf(ctx, "Failed to marshal image multimodal payload: %v", err)
continue
}
task := asynq.NewTask(types.TypeImageMultimodal, payloadBytes)
if _, err := s.task.Enqueue(task); err != nil {
logger.Warnf(ctx, "Failed to enqueue image multimodal task for %s: %v", img.ServingURL, err)
} else {
logger.Infof(ctx, "Enqueued image:multimodal task for %s", img.ServingURL)
}
}
}
// ProcessFAQImport handles Asynq FAQ import tasks (including dry run mode)
func (s *knowledgeService) ProcessFAQImport(ctx context.Context, t *asynq.Task) error {
var payload types.FAQImportPayload
if err := json.Unmarshal(t.Payload(), &payload); err != nil {
logger.Errorf(ctx, "failed to unmarshal FAQ import task payload: %v", err)
return fmt.Errorf("failed to unmarshal task payload: %w", err)
}
ctx = logger.WithRequestID(ctx, uuid.New().String())
ctx = logger.WithField(ctx, "faq_import", payload.TaskID)
ctx = context.WithValue(ctx, types.TenantIDContextKey, payload.TenantID)
// 获取任务重试信息,用于判断是否是最后一次重试
retryCount, _ := asynq.GetRetryCount(ctx)
maxRetry, _ := asynq.GetMaxRetry(ctx)
isLastRetry := retryCount >= maxRetry
tenantInfo, err := s.tenantRepo.GetTenantByID(ctx, payload.TenantID)
if err != nil {
logger.Errorf(ctx, "failed to get tenant: %v", err)
return nil
}
ctx = context.WithValue(ctx, types.TenantInfoContextKey, tenantInfo)
// 如果 entries 存储在对象存储中,先下载
if payload.EntriesURL != "" && len(payload.Entries) == 0 {
logger.Infof(ctx, "Downloading FAQ entries from object storage: %s", payload.EntriesURL)
reader, err := s.fileSvc.GetFile(ctx, payload.EntriesURL)
if err != nil {
logger.Errorf(ctx, "Failed to download FAQ entries from object storage: %v", err)
return fmt.Errorf("failed to download entries: %w", err)
}
defer reader.Close()
entriesData, err := io.ReadAll(reader)
if err != nil {
logger.Errorf(ctx, "Failed to read FAQ entries data: %v", err)
return fmt.Errorf("failed to read entries data: %w", err)
}
var entries []types.FAQEntryPayload
if err := json.Unmarshal(entriesData, &entries); err != nil {
logger.Errorf(ctx, "Failed to unmarshal FAQ entries: %v", err)
return fmt.Errorf("failed to unmarshal entries: %w", err)
}
payload.Entries = entries
logger.Infof(ctx, "Downloaded %d FAQ entries from object storage", len(entries))
}
logger.Infof(ctx, "Processing FAQ import task: task_id=%s, kb_id=%s, total_entries=%d, dry_run=%v, retry=%d/%d",
payload.TaskID, payload.KBID, len(payload.Entries), payload.DryRun, retryCount, maxRetry)
// 保存原始总数量
originalTotalEntries := len(payload.Entries)
// 初始化进度
// 检查是否已有验证结果(用于重试时跳过验证)
// 注意:必须在保存新 progress 之前查询,否则会被覆盖
existingProgress, _ := s.GetFAQImportProgress(ctx, payload.TaskID)
progress := &types.FAQImportProgress{
TaskID: payload.TaskID,
KBID: payload.KBID,
KnowledgeID: payload.KnowledgeID,
Status: types.FAQImportStatusProcessing,
Progress: 0,
Total: originalTotalEntries,
Processed: 0,
SuccessCount: 0,
FailedCount: 0,
FailedEntries: make([]types.FAQFailedEntry, 0),
SuccessEntries: make([]types.FAQSuccessEntry, 0),
Message: "正在验证条目...",
CreatedAt: time.Now().Unix(),
UpdatedAt: time.Now().Unix(),
DryRun: payload.DryRun,
}
if err := s.saveFAQImportProgress(ctx, progress); err != nil {
logger.Warnf(ctx, "Failed to save initial FAQ import progress: %v", err)
}
var validEntryIndices []int
if existingProgress != nil && len(existingProgress.ValidEntryIndices) > 0 {
// 重试时直接使用之前的验证结果
validEntryIndices = existingProgress.ValidEntryIndices
progress.FailedCount = existingProgress.FailedCount
progress.FailedEntries = existingProgress.FailedEntries
logger.Infof(ctx, "Reusing previous validation result: valid=%d, failed=%d",
len(validEntryIndices), progress.FailedCount)
} else {
// 第一步:执行验证(无论是 dry run 还是 import 模式都需要验证)
validEntryIndices = s.executeFAQDryRunValidation(ctx, &payload, progress)
// 保存验证通过的索引,用于重试时跳过验证
progress.ValidEntryIndices = validEntryIndices
if err := s.saveFAQImportProgress(ctx, progress); err != nil {
logger.Warnf(ctx, "Failed to save validation result: %v", err)
}
logger.Infof(ctx, "FAQ validation completed: total=%d, valid=%d, failed=%d",
originalTotalEntries, len(validEntryIndices), progress.FailedCount)
}
// Dry run 模式:验证完成后直接返回结果
if payload.DryRun {
return s.finalizeFAQValidation(ctx, &payload, progress, originalTotalEntries)
}
// Import 模式:检查是否有有效条目需要导入
if len(validEntryIndices) == 0 {
// 没有有效条目,直接完成
return s.finalizeFAQValidation(ctx, &payload, progress, originalTotalEntries)
}
// 提取有效的条目
validEntries := make([]types.FAQEntryPayload, 0, len(validEntryIndices))
for _, idx := range validEntryIndices {
validEntries = append(validEntries, payload.Entries[idx])
}
// 更新进度消息
progress.Message = fmt.Sprintf("验证完成,开始导入 %d 条有效数据...", len(validEntries))
progress.UpdatedAt = time.Now().Unix()
if err := s.saveFAQImportProgress(ctx, progress); err != nil {
logger.Warnf(ctx, "Failed to update FAQ import progress: %v", err)
}
// 幂等性检查获取knowledge记录FAQ任务使用knowledge ID作为taskID
knowledge, err := s.repo.GetKnowledgeByID(ctx, payload.TenantID, payload.KnowledgeID)
if err != nil {
logger.Errorf(ctx, "failed to get FAQ knowledge: %v", err)
return nil
}
if knowledge == nil {
return nil
}
kb, err := s.kbService.GetKnowledgeBaseByID(ctx, payload.KBID)
if err != nil {
logger.Errorf(ctx, "Failed to get knowledge base: %v", err)
// 如果是最后一次重试,更新状态为失败
if isLastRetry {
if updateErr := s.updateFAQImportProgressStatus(ctx, payload.TaskID, types.FAQImportStatusFailed, 0, originalTotalEntries, 0, "获取知识库失败", err.Error()); updateErr != nil {
logger.Errorf(ctx, "Failed to update task status to failed: %v", updateErr)
}
}
s.cleanupFAQEntriesFileOnFinalFailure(ctx, payload.EntriesURL, retryCount, maxRetry)
return fmt.Errorf("failed to get knowledge base: %w", err)
}
// 检查任务状态 - 幂等性处理(复用之前获取的 existingProgress
var processedCount int
if existingProgress != nil {
if existingProgress.Status == types.FAQImportStatusCompleted {
logger.Infof(ctx, "FAQ import already completed, skipping: %s", payload.TaskID)
return nil // 幂等:已完成的任务直接返回
}
// 获取已处理的数量(注意:这是相对于 validEntries 的索引)
processedCount = existingProgress.Processed - progress.FailedCount // 已处理数 - 验证失败数 = 已导入的有效条目数
if processedCount < 0 {
processedCount = 0
}
logger.Infof(ctx, "Resuming FAQ import from progress: %d/%d (valid entries)", processedCount, len(validEntries))
}
// 幂等性处理清理可能已部分处理的chunks和索引数据
chunksDeleted, err := s.chunkRepo.DeleteUnindexedChunks(ctx, payload.TenantID, payload.KnowledgeID)
if err != nil {
logger.Errorf(ctx, "Failed to delete unindexed chunks: %v", err)
// 如果是最后一次重试,更新状态为失败
if isLastRetry {
if updateErr := s.updateFAQImportProgressStatus(ctx, payload.TaskID, types.FAQImportStatusFailed, 0, originalTotalEntries, 0, "清理未索引数据失败", err.Error()); updateErr != nil {
logger.Errorf(ctx, "Failed to update task status to failed: %v", updateErr)
}
}
s.cleanupFAQEntriesFileOnFinalFailure(ctx, payload.EntriesURL, retryCount, maxRetry)
return fmt.Errorf("failed to delete unindexed chunks: %w", err)
}
if len(chunksDeleted) > 0 {
logger.Infof(ctx, "Deleted unindexed chunks: %d", len(chunksDeleted))
// 删除索引数据
embeddingModel, err := s.modelService.GetEmbeddingModel(ctx, kb.EmbeddingModelID)
if err == nil {
retrieveEngine, err := retriever.NewCompositeRetrieveEngine(
s.retrieveEngine,
tenantInfo.GetEffectiveEngines(),
)
if err == nil {
chunkIDs := make([]string, 0, len(chunksDeleted))
for _, chunk := range chunksDeleted {
chunkIDs = append(chunkIDs, chunk.ID)
}
if err := retrieveEngine.DeleteByChunkIDList(ctx, chunkIDs, embeddingModel.GetDimensions(), types.KnowledgeTypeFAQ); err != nil {
logger.Warnf(ctx, "Failed to delete index data for chunks (may not exist): %v", err)
} else {
logger.Infof(ctx, "Successfully deleted index data for %d chunks", len(chunksDeleted))
}
}
}
}
// 如果已经处理了一部分有效条目,从该位置继续
entriesToImport := validEntries
importMode := payload.Mode
if processedCount > 0 && processedCount < len(validEntries) {
entriesToImport = validEntries[processedCount:]
// 重试场景下,如果之前已经处理了一部分数据,需要切换到 Append 模式
// 因为 Replace 模式的删除操作在第一次运行时已经执行过了
// 如果继续使用 Replace 模式calculateReplaceOperations 会将之前成功导入的数据标记为删除
// 导致数据丢失
if payload.Mode == types.FAQBatchModeReplace {
importMode = types.FAQBatchModeAppend
logger.Infof(ctx, "Switching to Append mode for retry, original mode was Replace")
}
logger.Infof(ctx, "Continuing FAQ import from entry %d, remaining: %d entries", processedCount, len(entriesToImport))
}
// 构建FAQBatchUpsertPayload使用验证通过的有效条目
faqPayload := &types.FAQBatchUpsertPayload{
Entries: entriesToImport,
Mode: importMode,
}
// 执行FAQ导入传入已处理的偏移量用于进度计算
if err := s.executeFAQImport(ctx, payload.TaskID, payload.KBID, faqPayload, payload.TenantID, progress.FailedCount+processedCount, progress); err != nil {
logger.Errorf(ctx, "FAQ import task failed: %s, error: %v", payload.TaskID, err)
// 如果是最后一次重试,更新状态为失败
if isLastRetry {
if updateErr := s.updateFAQImportProgressStatus(ctx, payload.TaskID, types.FAQImportStatusFailed, 0, originalTotalEntries, len(validEntries), "导入失败", err.Error()); updateErr != nil {
logger.Errorf(ctx, "Failed to update task status to failed: %v", updateErr)
}
}
s.cleanupFAQEntriesFileOnFinalFailure(ctx, payload.EntriesURL, retryCount, maxRetry)
return fmt.Errorf("FAQ import failed: %w", err)
}
// 任务成功完成
logger.Infof(ctx, "FAQ import task completed: %s, imported: %d, failed: %d",
payload.TaskID, len(validEntries), progress.FailedCount)
// 最终完成处理(生成失败条目 CSV 等)
return s.finalizeFAQValidation(ctx, &payload, progress, originalTotalEntries)
}
// finalizeFAQValidation 完成 FAQ 验证/导入任务,生成失败条目 CSV如果有
func (s *knowledgeService) finalizeFAQValidation(ctx context.Context, payload *types.FAQImportPayload,
progress *types.FAQImportProgress, originalTotalEntries int,
) error {
// 清理对象存储中的 entries 文件(如果有)
if payload.EntriesURL != "" {
if err := s.fileSvc.DeleteFile(ctx, payload.EntriesURL); err != nil {
logger.Warnf(ctx, "Failed to delete FAQ entries file from object storage: %v", err)
} else {
logger.Infof(ctx, "Deleted FAQ entries file from object storage: %s", payload.EntriesURL)
}
}
progress.UpdatedAt = time.Now().Unix()
// 如果有失败条目,生成 CSV 文件
if len(progress.FailedEntries) > 0 {
csvURL, err := s.generateFailedEntriesCSV(ctx, payload.TenantID, payload.TaskID, progress.FailedEntries)
if err != nil {
logger.Warnf(ctx, "Failed to generate failed entries CSV: %v", err)
} else {
progress.FailedEntriesURL = csvURL
progress.FailedEntries = nil // 清空内联数据,使用 URL
progress.Message += " (失败记录已导出为CSV)"
}
}
// 如果不是 dry run 模式,保存导入结果统计到数据库
if !payload.DryRun {
if err := s.saveFAQImportResultToDatabase(ctx, payload, progress, originalTotalEntries); err != nil {
logger.Warnf(ctx, "Failed to save FAQ import result to database: %v", err)
}
// 只有 replace 模式才清理未使用的 Tag
// append 模式不应删除用户预先创建的空标签
if payload.Mode == types.FAQBatchModeReplace {
deletedTags, err := s.tagRepo.DeleteUnusedTags(ctx, payload.TenantID, payload.KBID)
if err != nil {
logger.Warnf(ctx, "FAQ import task %s: failed to cleanup unused tags: %v", payload.TaskID, err)
} else if deletedTags > 0 {
logger.Infof(ctx, "FAQ import task %s: cleaned up %d unused tags after replace import", payload.TaskID, deletedTags)
}
}
}
// 使用 updateFAQImportProgressStatus 来确保正确清理 running key
// 但是需要先保存其他字段,因为 updateFAQImportProgressStatus 不会保存所有字段
if err := s.saveFAQImportProgress(ctx, progress); err != nil {
logger.Warnf(ctx, "Failed to save final FAQ import progress: %v", err)
}
// 然后调用状态更新来清理 running key
if err := s.updateFAQImportProgressStatus(ctx, payload.TaskID, types.FAQImportStatusCompleted,
100, originalTotalEntries, originalTotalEntries, progress.Message, ""); err != nil {
logger.Warnf(ctx, "Failed to update final FAQ import status: %v", err)
}
logger.Infof(ctx, "FAQ task completed: %s, dry_run=%v, success: %d, failed: %d",
payload.TaskID, payload.DryRun, progress.SuccessCount, progress.FailedCount)
return nil
}
const (
kbCloneProgressKeyPrefix = "kb_clone_progress:"
kbCloneProgressTTL = 24 * time.Hour
)
// getKBCloneProgressKey returns the Redis key for storing KB clone progress
func getKBCloneProgressKey(taskID string) string {
return kbCloneProgressKeyPrefix + taskID
}
const (
faqImportProgressKeyPrefix = "faq_import_progress:"
faqImportRunningKeyPrefix = "faq_import_running:"
faqImportProgressTTL = 3 * time.Hour
)
// getFAQImportProgressKey returns the Redis key for storing FAQ import progress
func getFAQImportProgressKey(taskID string) string {
return faqImportProgressKeyPrefix + taskID
}
// getFAQImportRunningKey returns the Redis key for storing running task ID by KB ID
func getFAQImportRunningKey(kbID string) string {
return faqImportRunningKeyPrefix + kbID
}
// saveFAQImportProgress saves the FAQ import progress to Redis
func (s *knowledgeService) saveFAQImportProgress(ctx context.Context, progress *types.FAQImportProgress) error {
key := getFAQImportProgressKey(progress.TaskID)
progress.UpdatedAt = time.Now().Unix()
data, err := json.Marshal(progress)
if err != nil {
return fmt.Errorf("failed to marshal FAQ import progress: %w", err)
}
return s.redisClient.Set(ctx, key, data, faqImportProgressTTL).Err()
}
// GetFAQImportProgress retrieves the progress of an FAQ import task
func (s *knowledgeService) GetFAQImportProgress(ctx context.Context, taskID string) (*types.FAQImportProgress, error) {
key := getFAQImportProgressKey(taskID)
data, err := s.redisClient.Get(ctx, key).Bytes()
if err != nil {
if errors.Is(err, redis.Nil) {
return nil, werrors.NewNotFoundError("FAQ import task not found")
}
return nil, fmt.Errorf("failed to get FAQ import progress from Redis: %w", err)
}
var progress types.FAQImportProgress
if err := json.Unmarshal(data, &progress); err != nil {
return nil, fmt.Errorf("failed to unmarshal FAQ import progress: %w", err)
}
// If task is completed, enrich with persisted result fields from database
if progress.Status == types.FAQImportStatusCompleted && progress.KnowledgeID != "" {
tenantID := ctx.Value(types.TenantIDContextKey).(uint64)
knowledge, err := s.repo.GetKnowledgeByID(ctx, tenantID, progress.KnowledgeID)
if err == nil && knowledge != nil {
if result, err := knowledge.GetLastFAQImportResult(); err == nil && result != nil {
progress.SkippedCount = result.SkippedCount
progress.ImportMode = result.ImportMode
progress.ImportedAt = result.ImportedAt
progress.DisplayStatus = result.DisplayStatus
progress.ProcessingTime = result.ProcessingTime
}
}
}
return &progress, nil
}
// UpdateLastFAQImportResultDisplayStatus updates the display status of FAQ import result
func (s *knowledgeService) UpdateLastFAQImportResultDisplayStatus(ctx context.Context, kbID string, displayStatus string) error {
// 验证displayStatus参数
if displayStatus != "open" && displayStatus != "close" {
return werrors.NewBadRequestError("invalid display status, must be 'open' or 'close'")
}
// 获取当前租户ID
tenantID := ctx.Value(types.TenantIDContextKey).(uint64)
// 查找FAQ类型的knowledge
knowledgeList, err := s.repo.ListKnowledgeByKnowledgeBaseID(ctx, tenantID, kbID)
if err != nil {
return fmt.Errorf("failed to list knowledge: %w", err)
}
// 查找FAQ类型的knowledge
var faqKnowledge *types.Knowledge
for _, k := range knowledgeList {
if k.Type == types.KnowledgeTypeFAQ {
faqKnowledge = k
break
}
}
if faqKnowledge == nil {
return werrors.NewNotFoundError("FAQ knowledge not found in this knowledge base")
}
// 解析当前的导入结果
result, err := faqKnowledge.GetLastFAQImportResult()
if err != nil {
return fmt.Errorf("failed to parse FAQ import result: %w", err)
}
if result == nil {
return werrors.NewNotFoundError("no FAQ import result found")
}
// 更新显示状态
result.DisplayStatus = displayStatus
// 保存更新后的结果
if err := faqKnowledge.SetLastFAQImportResult(result); err != nil {
return fmt.Errorf("failed to set FAQ import result: %w", err)
}
// 更新数据库
if err := s.repo.UpdateKnowledge(ctx, faqKnowledge); err != nil {
return fmt.Errorf("failed to update knowledge: %w", err)
}
return nil
}
// ProcessKBClone handles Asynq knowledge base clone tasks
func (s *knowledgeService) ProcessKBClone(ctx context.Context, t *asynq.Task) error {
var payload types.KBClonePayload
if err := json.Unmarshal(t.Payload(), &payload); err != nil {
return fmt.Errorf("failed to unmarshal KB clone payload: %w", err)
}
// Add tenant ID to context
ctx = context.WithValue(ctx, types.TenantIDContextKey, payload.TenantID)
// Get tenant info and add to context
tenantInfo, err := s.tenantRepo.GetTenantByID(ctx, payload.TenantID)
if err != nil {
logger.Errorf(ctx, "Failed to get tenant info: %v", err)
return fmt.Errorf("failed to get tenant info: %w", err)
}
ctx = context.WithValue(ctx, types.TenantInfoContextKey, tenantInfo)
// Check if this is the last retry
retryCount, _ := asynq.GetRetryCount(ctx)
maxRetry, _ := asynq.GetMaxRetry(ctx)
isLastRetry := retryCount >= maxRetry
logger.Infof(ctx, "Processing KB clone task: %s, source: %s, target: %s, retry: %d/%d",
payload.TaskID, payload.SourceID, payload.TargetID, retryCount, maxRetry)
// Helper function to handle errors - only mark as failed on last retry
handleError := func(progress *types.KBCloneProgress, err error, message string) {
if isLastRetry {
progress.Status = types.KBCloneStatusFailed
progress.Error = err.Error()
progress.Message = message
progress.UpdatedAt = time.Now().Unix()
_ = s.saveKBCloneProgress(ctx, progress)
}
}
// Update progress to processing
progress := &types.KBCloneProgress{
TaskID: payload.TaskID,
SourceID: payload.SourceID,
TargetID: payload.TargetID,
Status: types.KBCloneStatusProcessing,
Progress: 0,
Message: "Starting knowledge base clone...",
UpdatedAt: time.Now().Unix(),
}
if err := s.saveKBCloneProgress(ctx, progress); err != nil {
logger.Errorf(ctx, "Failed to update KB clone progress: %v", err)
}
// Get source and target knowledge bases
srcKB, dstKB, err := s.kbService.CopyKnowledgeBase(ctx, payload.SourceID, payload.TargetID)
if err != nil {
logger.Errorf(ctx, "Failed to copy knowledge base: %v", err)
handleError(progress, err, "Failed to copy knowledge base configuration")
return err
}
// Use different sync strategies based on knowledge base type
if srcKB.Type == types.KnowledgeBaseTypeFAQ {
return s.cloneFAQKnowledgeBase(ctx, srcKB, dstKB, progress, handleError)
}
// Document type: use Knowledge-level diff based on file_hash
addKnowledge, err := s.repo.AminusB(ctx, srcKB.TenantID, srcKB.ID, dstKB.TenantID, dstKB.ID)
if err != nil {
logger.Errorf(ctx, "Failed to get knowledge to add: %v", err)
handleError(progress, err, "Failed to calculate knowledge difference")
return err
}
delKnowledge, err := s.repo.AminusB(ctx, dstKB.TenantID, dstKB.ID, srcKB.TenantID, srcKB.ID)
if err != nil {
logger.Errorf(ctx, "Failed to get knowledge to delete: %v", err)
handleError(progress, err, "Failed to calculate knowledge difference")
return err
}
totalOperations := len(addKnowledge) + len(delKnowledge)
progress.Total = totalOperations
progress.Message = fmt.Sprintf("Found %d knowledge to add, %d to delete", len(addKnowledge), len(delKnowledge))
progress.UpdatedAt = time.Now().Unix()
_ = s.saveKBCloneProgress(ctx, progress)
logger.Infof(ctx, "Knowledge after update to add: %d, delete: %d", len(addKnowledge), len(delKnowledge))
processedCount := 0
batch := 10
// Delete knowledge in target that doesn't exist in source
g, gctx := errgroup.WithContext(ctx)
for ids := range slices.Chunk(delKnowledge, batch) {
g.Go(func() error {
err := s.DeleteKnowledgeList(gctx, ids)
if err != nil {
logger.Errorf(gctx, "delete partial knowledge %v: %v", ids, err)
return err
}
return nil
})
}
if err := g.Wait(); err != nil {
logger.Errorf(ctx, "delete total knowledge %d: %v", len(delKnowledge), err)
handleError(progress, err, "Failed to delete knowledge")
return err
}
processedCount += len(delKnowledge)
if totalOperations > 0 {
progress.Progress = processedCount * 100 / totalOperations
}
progress.Processed = processedCount
progress.Message = fmt.Sprintf("Deleted %d knowledge, cloning %d...", len(delKnowledge), len(addKnowledge))
progress.UpdatedAt = time.Now().Unix()
_ = s.saveKBCloneProgress(ctx, progress)
// Clone knowledge from source to target
g, gctx = errgroup.WithContext(ctx)
g.SetLimit(batch)
for _, knowledge := range addKnowledge {
g.Go(func() error {
srcKn, err := s.repo.GetKnowledgeByID(gctx, srcKB.TenantID, knowledge)
if err != nil {
logger.Errorf(gctx, "get knowledge %s: %v", knowledge, err)
return err
}
err = s.cloneKnowledge(gctx, srcKn, dstKB)
if err != nil {
logger.Errorf(gctx, "clone knowledge %s: %v", knowledge, err)
return err
}
// Update progress
processedCount++
if totalOperations > 0 {
progress.Progress = processedCount * 100 / totalOperations
}
progress.Processed = processedCount
progress.Message = fmt.Sprintf("Cloned %d/%d knowledge", processedCount-len(delKnowledge), len(addKnowledge))
progress.UpdatedAt = time.Now().Unix()
_ = s.saveKBCloneProgress(ctx, progress)
return nil
})
}
if err := g.Wait(); err != nil {
logger.Errorf(ctx, "add total knowledge %d: %v", len(addKnowledge), err)
handleError(progress, err, "Failed to clone knowledge")
return err
}
// Mark as completed
progress.Status = types.KBCloneStatusCompleted
progress.Progress = 100
progress.Processed = totalOperations
progress.Message = "Knowledge base clone completed successfully"
progress.UpdatedAt = time.Now().Unix()
if err := s.saveKBCloneProgress(ctx, progress); err != nil {
logger.Errorf(ctx, "Failed to update KB clone progress to completed: %v", err)
}
logger.Infof(ctx, "KB clone task completed: %s", payload.TaskID)
return nil
}
// cloneFAQKnowledgeBase handles FAQ knowledge base cloning with chunk-level incremental sync
func (s *knowledgeService) cloneFAQKnowledgeBase(
ctx context.Context,
srcKB, dstKB *types.KnowledgeBase,
progress *types.KBCloneProgress,
handleError func(*types.KBCloneProgress, error, string),
) error {
// Get source FAQ knowledge first (FAQ KB has exactly one Knowledge entry)
srcKnowledgeList, err := s.repo.ListKnowledgeByKnowledgeBaseID(ctx, srcKB.TenantID, srcKB.ID)
if err != nil {
logger.Errorf(ctx, "Failed to get source FAQ knowledge: %v", err)
handleError(progress, err, "Failed to get source FAQ knowledge")
return err
}
if len(srcKnowledgeList) == 0 {
// Source has no FAQ knowledge, nothing to clone
progress.Status = types.KBCloneStatusCompleted
progress.Progress = 100
progress.Message = "Source FAQ knowledge base is empty"
progress.UpdatedAt = time.Now().Unix()
_ = s.saveKBCloneProgress(ctx, progress)
return nil
}
srcKnowledge := srcKnowledgeList[0]
// Get chunk-level differences based on content_hash
chunksToAdd, chunksToDelete, err := s.chunkRepo.FAQChunkDiff(ctx, srcKB.TenantID, srcKB.ID, dstKB.TenantID, dstKB.ID)
if err != nil {
logger.Errorf(ctx, "Failed to calculate FAQ chunk difference: %v", err)
handleError(progress, err, "Failed to calculate FAQ chunk difference")
return err
}
totalOperations := len(chunksToAdd) + len(chunksToDelete)
progress.Total = totalOperations
progress.Message = fmt.Sprintf("Found %d FAQ entries to add, %d to delete", len(chunksToAdd), len(chunksToDelete))
progress.UpdatedAt = time.Now().Unix()
_ = s.saveKBCloneProgress(ctx, progress)
logger.Infof(ctx, "FAQ chunks to add: %d, delete: %d", len(chunksToAdd), len(chunksToDelete))
// If nothing to do, mark as completed
if totalOperations == 0 {
progress.Status = types.KBCloneStatusCompleted
progress.Progress = 100
progress.Message = "FAQ knowledge base is already in sync"
progress.UpdatedAt = time.Now().Unix()
_ = s.saveKBCloneProgress(ctx, progress)
return nil
}
// Get tenant info and initialize retrieve engine
tenantInfo := ctx.Value(types.TenantInfoContextKey).(*types.Tenant)
retrieveEngine, err := retriever.NewCompositeRetrieveEngine(s.retrieveEngine, tenantInfo.GetEffectiveEngines())
if err != nil {
logger.Errorf(ctx, "Failed to init retrieve engine: %v", err)
handleError(progress, err, "Failed to initialize retrieve engine")
return err
}
// Get embedding model
embeddingModel, err := s.modelService.GetEmbeddingModel(ctx, dstKB.EmbeddingModelID)
if err != nil {
logger.Errorf(ctx, "Failed to get embedding model: %v", err)
handleError(progress, err, "Failed to get embedding model")
return err
}
processedCount := 0
// Delete FAQ chunks that don't exist in source
if len(chunksToDelete) > 0 {
// Delete from vector store
if err := retrieveEngine.DeleteByChunkIDList(ctx, chunksToDelete, embeddingModel.GetDimensions(), types.KnowledgeTypeFAQ); err != nil {
logger.Errorf(ctx, "Failed to delete FAQ chunks from vector store: %v", err)
handleError(progress, err, "Failed to delete FAQ entries from vector store")
return err
}
// Delete from database
if err := s.chunkRepo.DeleteChunks(ctx, dstKB.TenantID, chunksToDelete); err != nil {
logger.Errorf(ctx, "Failed to delete FAQ chunks from database: %v", err)
handleError(progress, err, "Failed to delete FAQ entries from database")
return err
}
processedCount += len(chunksToDelete)
if totalOperations > 0 {
progress.Progress = processedCount * 100 / totalOperations
}
progress.Processed = processedCount
progress.Message = fmt.Sprintf("Deleted %d FAQ entries, adding %d...", len(chunksToDelete), len(chunksToAdd))
progress.UpdatedAt = time.Now().Unix()
_ = s.saveKBCloneProgress(ctx, progress)
}
// Get or create the FAQ knowledge entry in destination
dstKnowledge, err := s.getOrCreateFAQKnowledge(ctx, dstKB, srcKnowledge)
if err != nil {
logger.Errorf(ctx, "Failed to get or create FAQ knowledge: %v", err)
handleError(progress, err, "Failed to prepare FAQ knowledge entry")
return err
}
// Clone FAQ chunks from source to destination
batch := 50
tagIDMapping := map[string]string{} // srcTagID -> dstTagID
for i := 0; i < len(chunksToAdd); i += batch {
end := i + batch
if end > len(chunksToAdd) {
end = len(chunksToAdd)
}
batchIDs := chunksToAdd[i:end]
// Get source chunks
srcChunks, err := s.chunkRepo.ListChunksByID(ctx, srcKB.TenantID, batchIDs)
if err != nil {
logger.Errorf(ctx, "Failed to get source FAQ chunks: %v", err)
handleError(progress, err, "Failed to get source FAQ entries")
return err
}
// Create new chunks for destination
newChunks := make([]*types.Chunk, 0, len(srcChunks))
for _, srcChunk := range srcChunks {
// Map TagID to target knowledge base
targetTagID := ""
if srcChunk.TagID != "" {
if mappedTagID, ok := tagIDMapping[srcChunk.TagID]; ok {
targetTagID = mappedTagID
} else {
// Try to find or create the tag in target knowledge base
targetTagID = s.getOrCreateTagInTarget(ctx, srcKB.TenantID, dstKB.TenantID, dstKB.ID, srcChunk.TagID, tagIDMapping)
}
}
newChunk := &types.Chunk{
ID: uuid.New().String(),
TenantID: dstKB.TenantID,
KnowledgeID: dstKnowledge.ID,
KnowledgeBaseID: dstKB.ID,
TagID: targetTagID,
Content: srcChunk.Content,
ChunkIndex: srcChunk.ChunkIndex,
IsEnabled: srcChunk.IsEnabled,
Flags: srcChunk.Flags,
ChunkType: types.ChunkTypeFAQ,
Metadata: srcChunk.Metadata,
ContentHash: srcChunk.ContentHash,
ImageInfo: srcChunk.ImageInfo,
Status: int(types.ChunkStatusStored), // Initially stored, will be indexed
CreatedAt: time.Now(),
UpdatedAt: time.Now(),
}
newChunks = append(newChunks, newChunk)
}
// Save to database
if err := s.chunkRepo.CreateChunks(ctx, newChunks); err != nil {
logger.Errorf(ctx, "Failed to create FAQ chunks: %v", err)
handleError(progress, err, "Failed to create FAQ entries")
return err
}
// Index in vector store using existing method
// This will index standard question + similar questions based on FAQConfig
if err := s.indexFAQChunks(ctx, dstKB, dstKnowledge, newChunks, embeddingModel, false, false); err != nil {
logger.Errorf(ctx, "Failed to index FAQ chunks: %v", err)
handleError(progress, err, "Failed to index FAQ entries")
return err
}
// Update chunk status to indexed
for _, chunk := range newChunks {
chunk.Status = int(types.ChunkStatusIndexed)
}
if err := s.chunkService.UpdateChunks(ctx, newChunks); err != nil {
logger.Warnf(ctx, "Failed to update FAQ chunks status: %v", err)
// Don't fail the whole operation for status update failure
}
processedCount += len(batchIDs)
if totalOperations > 0 {
progress.Progress = processedCount * 100 / totalOperations
}
progress.Processed = processedCount
progress.Message = fmt.Sprintf("Added %d/%d FAQ entries", processedCount-len(chunksToDelete), len(chunksToAdd))
progress.UpdatedAt = time.Now().Unix()
_ = s.saveKBCloneProgress(ctx, progress)
}
// Mark as completed
progress.Status = types.KBCloneStatusCompleted
progress.Progress = 100
progress.Processed = totalOperations
progress.Message = "FAQ knowledge base clone completed successfully"
progress.UpdatedAt = time.Now().Unix()
if err := s.saveKBCloneProgress(ctx, progress); err != nil {
logger.Errorf(ctx, "Failed to update KB clone progress to completed: %v", err)
}
return nil
}
// getOrCreateFAQKnowledge gets or creates the FAQ knowledge entry for a knowledge base
// If srcKnowledge is provided, it will copy relevant fields from source when creating new knowledge
func (s *knowledgeService) getOrCreateFAQKnowledge(ctx context.Context, kb *types.KnowledgeBase, srcKnowledge *types.Knowledge) (*types.Knowledge, error) {
// FAQ knowledge base should have exactly one Knowledge entry
knowledgeList, err := s.repo.ListKnowledgeByKnowledgeBaseID(ctx, kb.TenantID, kb.ID)
if err != nil {
return nil, err
}
if len(knowledgeList) > 0 {
return knowledgeList[0], nil
}
// Create a new FAQ knowledge entry, copying from source if available
knowledge := &types.Knowledge{
ID: uuid.New().String(),
TenantID: kb.TenantID,
KnowledgeBaseID: kb.ID,
Type: types.KnowledgeTypeFAQ,
Title: "FAQ",
ParseStatus: "completed",
EnableStatus: "enabled",
EmbeddingModelID: kb.EmbeddingModelID,
}
// Copy additional fields from source knowledge if available
if srcKnowledge != nil {
knowledge.Title = srcKnowledge.Title
knowledge.Description = srcKnowledge.Description
knowledge.Source = srcKnowledge.Source
knowledge.Metadata = srcKnowledge.Metadata
}
if err := s.repo.CreateKnowledge(ctx, knowledge); err != nil {
return nil, err
}
return knowledge, nil
}
// saveKBCloneProgress saves the KB clone progress to Redis
func (s *knowledgeService) saveKBCloneProgress(ctx context.Context, progress *types.KBCloneProgress) error {
key := getKBCloneProgressKey(progress.TaskID)
data, err := json.Marshal(progress)
if err != nil {
return fmt.Errorf("failed to marshal progress: %w", err)
}
return s.redisClient.Set(ctx, key, data, kbCloneProgressTTL).Err()
}
// SaveKBCloneProgress saves the KB clone progress to Redis (public method for handler use)
func (s *knowledgeService) SaveKBCloneProgress(ctx context.Context, progress *types.KBCloneProgress) error {
return s.saveKBCloneProgress(ctx, progress)
}
// GetKBCloneProgress retrieves the progress of a knowledge base clone task
func (s *knowledgeService) GetKBCloneProgress(ctx context.Context, taskID string) (*types.KBCloneProgress, error) {
key := getKBCloneProgressKey(taskID)
data, err := s.redisClient.Get(ctx, key).Bytes()
if err != nil {
if errors.Is(err, redis.Nil) {
return nil, werrors.NewNotFoundError("KB clone task not found")
}
return nil, fmt.Errorf("failed to get progress from Redis: %w", err)
}
var progress types.KBCloneProgress
if err := json.Unmarshal(data, &progress); err != nil {
return nil, fmt.Errorf("failed to unmarshal progress: %w", err)
}
return &progress, nil
}
// ─── Knowledge Move ─────────────────────────────────────────────────────────
const (
knowledgeMoveProgressKeyPrefix = "knowledge_move_progress:"
knowledgeMoveProgressTTL = 24 * time.Hour
)
func getKnowledgeMoveProgressKey(taskID string) string {
return knowledgeMoveProgressKeyPrefix + taskID
}
func (s *knowledgeService) saveKnowledgeMoveProgress(ctx context.Context, progress *types.KnowledgeMoveProgress) error {
key := getKnowledgeMoveProgressKey(progress.TaskID)
data, err := json.Marshal(progress)
if err != nil {
return fmt.Errorf("failed to marshal move progress: %w", err)
}
return s.redisClient.Set(ctx, key, data, knowledgeMoveProgressTTL).Err()
}
// SaveKnowledgeMoveProgress saves the knowledge move progress to Redis (public method for handler use)
func (s *knowledgeService) SaveKnowledgeMoveProgress(ctx context.Context, progress *types.KnowledgeMoveProgress) error {
return s.saveKnowledgeMoveProgress(ctx, progress)
}
// GetKnowledgeMoveProgress retrieves the progress of a knowledge move task
func (s *knowledgeService) GetKnowledgeMoveProgress(ctx context.Context, taskID string) (*types.KnowledgeMoveProgress, error) {
key := getKnowledgeMoveProgressKey(taskID)
data, err := s.redisClient.Get(ctx, key).Bytes()
if err != nil {
if errors.Is(err, redis.Nil) {
return nil, werrors.NewNotFoundError("Knowledge move task not found")
}
return nil, fmt.Errorf("failed to get move progress from Redis: %w", err)
}
var progress types.KnowledgeMoveProgress
if err := json.Unmarshal(data, &progress); err != nil {
return nil, fmt.Errorf("failed to unmarshal move progress: %w", err)
}
return &progress, nil
}
// ProcessKnowledgeMove handles Asynq knowledge move tasks
func (s *knowledgeService) ProcessKnowledgeMove(ctx context.Context, t *asynq.Task) error {
var payload types.KnowledgeMovePayload
if err := json.Unmarshal(t.Payload(), &payload); err != nil {
return fmt.Errorf("failed to unmarshal knowledge move payload: %w", err)
}
// Add tenant ID to context
ctx = context.WithValue(ctx, types.TenantIDContextKey, payload.TenantID)
// Get tenant info and add to context
tenantInfo, err := s.tenantRepo.GetTenantByID(ctx, payload.TenantID)
if err != nil {
logger.Errorf(ctx, "ProcessKnowledgeMove: failed to get tenant info: %v", err)
return fmt.Errorf("failed to get tenant info: %w", err)
}
ctx = context.WithValue(ctx, types.TenantInfoContextKey, tenantInfo)
// Check if this is the last retry
retryCount, _ := asynq.GetRetryCount(ctx)
maxRetry, _ := asynq.GetMaxRetry(ctx)
isLastRetry := retryCount >= maxRetry
logger.Infof(ctx, "ProcessKnowledgeMove: task=%s, source=%s, target=%s, mode=%s, count=%d, retry=%d/%d",
payload.TaskID, payload.SourceKBID, payload.TargetKBID, payload.Mode, len(payload.KnowledgeIDs), retryCount, maxRetry)
// Helper function to handle errors - only mark as failed on last retry
handleError := func(progress *types.KnowledgeMoveProgress, err error, message string) {
if isLastRetry {
progress.Status = types.KBCloneStatusFailed
progress.Error = err.Error()
progress.Message = message
progress.UpdatedAt = time.Now().Unix()
_ = s.saveKnowledgeMoveProgress(ctx, progress)
}
}
// Update progress to processing
progress := &types.KnowledgeMoveProgress{
TaskID: payload.TaskID,
SourceKBID: payload.SourceKBID,
TargetKBID: payload.TargetKBID,
Status: types.KBCloneStatusProcessing,
Total: len(payload.KnowledgeIDs),
Progress: 0,
Message: "Starting knowledge move...",
UpdatedAt: time.Now().Unix(),
}
_ = s.saveKnowledgeMoveProgress(ctx, progress)
// Get source and target knowledge bases
sourceKB, err := s.kbService.GetKnowledgeBaseByID(ctx, payload.SourceKBID)
if err != nil {
handleError(progress, err, "Failed to get source knowledge base")
return err
}
targetKB, err := s.kbService.GetKnowledgeBaseByID(ctx, payload.TargetKBID)
if err != nil {
handleError(progress, err, "Failed to get target knowledge base")
return err
}
// Validate compatibility
if sourceKB.Type != targetKB.Type {
err := fmt.Errorf("type mismatch: source=%s, target=%s", sourceKB.Type, targetKB.Type)
handleError(progress, err, "Source and target knowledge bases must be the same type")
return err
}
if sourceKB.EmbeddingModelID != targetKB.EmbeddingModelID {
err := fmt.Errorf("embedding model mismatch: source=%s, target=%s", sourceKB.EmbeddingModelID, targetKB.EmbeddingModelID)
handleError(progress, err, "Source and target must use the same embedding model")
return err
}
// Process each knowledge item
for i, knowledgeID := range payload.KnowledgeIDs {
err := s.moveOneKnowledge(ctx, knowledgeID, sourceKB, targetKB, payload.Mode)
if err != nil {
logger.Errorf(ctx, "ProcessKnowledgeMove: failed to move knowledge %s: %v", knowledgeID, err)
progress.Failed++
}
progress.Processed = i + 1
if progress.Total > 0 {
progress.Progress = progress.Processed * 100 / progress.Total
}
progress.Message = fmt.Sprintf("Moved %d/%d knowledge items", progress.Processed, progress.Total)
progress.UpdatedAt = time.Now().Unix()
_ = s.saveKnowledgeMoveProgress(ctx, progress)
}
// Mark as completed
if progress.Failed > 0 && progress.Failed == progress.Total {
progress.Status = types.KBCloneStatusFailed
progress.Message = fmt.Sprintf("Knowledge move failed: all %d items failed", progress.Total)
} else {
progress.Status = types.KBCloneStatusCompleted
progress.Message = fmt.Sprintf("Knowledge move completed: %d/%d succeeded", progress.Processed-progress.Failed, progress.Total)
}
progress.Progress = 100
progress.UpdatedAt = time.Now().Unix()
_ = s.saveKnowledgeMoveProgress(ctx, progress)
logger.Infof(ctx, "ProcessKnowledgeMove: task=%s completed, processed=%d, failed=%d", payload.TaskID, progress.Processed, progress.Failed)
return nil
}
// moveOneKnowledge moves a single knowledge item from source KB to target KB.
func (s *knowledgeService) moveOneKnowledge(
ctx context.Context,
knowledgeID string,
sourceKB, targetKB *types.KnowledgeBase,
mode string,
) error {
tenantID := ctx.Value(types.TenantIDContextKey).(uint64)
// Get the knowledge item
knowledge, err := s.repo.GetKnowledgeByID(ctx, tenantID, knowledgeID)
if err != nil {
return fmt.Errorf("failed to get knowledge %s: %w", knowledgeID, err)
}
// Only move completed items
if knowledge.ParseStatus != types.ParseStatusCompleted {
return fmt.Errorf("knowledge %s is not in completed status (current: %s)", knowledgeID, knowledge.ParseStatus)
}
// Mark as processing during move
knowledge.ParseStatus = types.ParseStatusProcessing
if err := s.repo.UpdateKnowledge(ctx, knowledge); err != nil {
return fmt.Errorf("failed to mark knowledge as processing: %w", err)
}
switch mode {
case "reuse_vectors":
return s.moveKnowledgeReuseVectors(ctx, knowledge, sourceKB, targetKB)
case "reparse":
return s.moveKnowledgeReparse(ctx, knowledge, sourceKB, targetKB)
default:
return fmt.Errorf("unknown move mode: %s", mode)
}
}
// moveKnowledgeReuseVectors moves knowledge by copying vector indices and updating DB references.
func (s *knowledgeService) moveKnowledgeReuseVectors(
ctx context.Context,
knowledge *types.Knowledge,
sourceKB, targetKB *types.KnowledgeBase,
) error {
tenantID := ctx.Value(types.TenantIDContextKey).(uint64)
tenantInfo := ctx.Value(types.TenantInfoContextKey).(*types.Tenant)
// 1. Get old chunk IDs for vector index copy mapping
oldChunks, err := s.chunkRepo.ListChunksByKnowledgeID(ctx, tenantID, knowledge.ID)
if err != nil {
return fmt.Errorf("failed to list chunks: %w", err)
}
// Build identity mapping (same chunk IDs, just moving between KBs)
chunkIDMapping := make(map[string]string, len(oldChunks))
for _, c := range oldChunks {
chunkIDMapping[c.ID] = c.ID
}
// 2. Copy vector indices from source KB to target KB
if len(chunkIDMapping) > 0 && knowledge.EmbeddingModelID != "" {
retrieveEngine, err := retriever.NewCompositeRetrieveEngine(s.retrieveEngine, tenantInfo.GetEffectiveEngines())
if err != nil {
return fmt.Errorf("failed to init retrieve engine: %w", err)
}
embeddingModel, err := s.modelService.GetEmbeddingModel(ctx, knowledge.EmbeddingModelID)
if err != nil {
return fmt.Errorf("failed to get embedding model: %w", err)
}
// Copy indices from source KB to target KB
knowledgeIDMapping := map[string]string{knowledge.ID: knowledge.ID}
if err := retrieveEngine.CopyIndices(ctx, sourceKB.ID, targetKB.ID,
knowledgeIDMapping, chunkIDMapping,
embeddingModel.GetDimensions(), sourceKB.Type,
); err != nil {
return fmt.Errorf("failed to copy indices: %w", err)
}
// Delete indices from source KB
if err := retrieveEngine.DeleteByKnowledgeIDList(ctx, []string{knowledge.ID},
embeddingModel.GetDimensions(), sourceKB.Type,
); err != nil {
logger.Warnf(ctx, "moveKnowledgeReuseVectors: failed to delete old indices for knowledge %s: %v", knowledge.ID, err)
// Non-fatal: indices will be orphaned but won't affect correctness
}
}
// 3. Update chunks' knowledge_base_id in DB
if err := s.chunkRepo.MoveChunksByKnowledgeID(ctx, tenantID, knowledge.ID, targetKB.ID); err != nil {
return fmt.Errorf("failed to move chunks: %w", err)
}
// 4. Update knowledge record
knowledge.KnowledgeBaseID = targetKB.ID
knowledge.TagID = "" // Clear tag since tags are KB-scoped
knowledge.ParseStatus = types.ParseStatusCompleted
knowledge.UpdatedAt = time.Now()
if err := s.repo.UpdateKnowledge(ctx, knowledge); err != nil {
return fmt.Errorf("failed to update knowledge: %w", err)
}
return nil
}
// moveKnowledgeReparse moves knowledge to target KB and re-parses it with target KB's configuration.
func (s *knowledgeService) moveKnowledgeReparse(
ctx context.Context,
knowledge *types.Knowledge,
_, targetKB *types.KnowledgeBase,
) error {
tenantID := ctx.Value(types.TenantIDContextKey).(uint64)
// 1. Clean up existing chunks and vector indices
if err := s.cleanupKnowledgeResources(ctx, knowledge); err != nil {
logger.Warnf(ctx, "moveKnowledgeReparse: cleanup partial error for knowledge %s: %v", knowledge.ID, err)
// Continue - partial cleanup is acceptable
}
// 2. Update knowledge to belong to target KB
knowledge.KnowledgeBaseID = targetKB.ID
knowledge.EmbeddingModelID = targetKB.EmbeddingModelID
knowledge.TagID = "" // Clear tag since tags are KB-scoped
knowledge.ParseStatus = types.ParseStatusPending
knowledge.EnableStatus = "disabled"
knowledge.Description = ""
knowledge.ProcessedAt = nil
knowledge.UpdatedAt = time.Now()
if err := s.repo.UpdateKnowledge(ctx, knowledge); err != nil {
return fmt.Errorf("failed to update knowledge: %w", err)
}
// 3. Enqueue document processing task with target KB's configuration
if knowledge.IsManual() {
meta, err := knowledge.ManualMetadata()
if err != nil || meta == nil {
return fmt.Errorf("failed to get manual metadata for reparse: %w", err)
}
s.triggerManualProcessing(ctx, targetKB, knowledge, meta.Content, false)
return nil
}
if knowledge.FilePath != "" {
enableMultimodel := targetKB.IsMultimodalEnabled()
enableQuestionGeneration := false
questionCount := 3
if targetKB.QuestionGenerationConfig != nil && targetKB.QuestionGenerationConfig.Enabled {
enableQuestionGeneration = true
if targetKB.QuestionGenerationConfig.QuestionCount > 0 {
questionCount = targetKB.QuestionGenerationConfig.QuestionCount
}
}
taskPayload := types.DocumentProcessPayload{
TenantID: tenantID,
KnowledgeID: knowledge.ID,
KnowledgeBaseID: targetKB.ID,
FilePath: knowledge.FilePath,
FileName: knowledge.FileName,
FileType: getFileType(knowledge.FileName),
EnableMultimodel: enableMultimodel,
EnableQuestionGeneration: enableQuestionGeneration,
QuestionCount: questionCount,
}
payloadBytes, err := json.Marshal(taskPayload)
if err != nil {
return fmt.Errorf("failed to marshal document process payload: %w", err)
}
task := asynq.NewTask(types.TypeDocumentProcess, payloadBytes, asynq.Queue("default"), asynq.MaxRetry(3))
info, err := s.task.Enqueue(task)
if err != nil {
return fmt.Errorf("failed to enqueue document process task: %w", err)
}
logger.Infof(ctx, "moveKnowledgeReparse: enqueued reparse task id=%s for knowledge=%s", info.ID, knowledge.ID)
}
return nil
}
// getOrCreateTagInTarget finds or creates a tag in the target knowledge base based on the source tag.
// It looks up the source tag by ID, then tries to find a tag with the same name in the target KB.
// If not found, it creates a new tag with the same properties.
// The mapping is cached in tagIDMapping for subsequent lookups.
func (s *knowledgeService) getOrCreateTagInTarget(
ctx context.Context,
srcTenantID, dstTenantID uint64,
dstKnowledgeBaseID string,
srcTagID string,
tagIDMapping map[string]string,
) string {
// Get source tag
srcTag, err := s.tagRepo.GetByID(ctx, srcTenantID, srcTagID)
if err != nil || srcTag == nil {
logger.Warnf(ctx, "Failed to get source tag %s: %v", srcTagID, err)
tagIDMapping[srcTagID] = "" // Cache empty result to avoid repeated lookups
return ""
}
// Try to find existing tag with same name in target KB
dstTag, err := s.tagRepo.GetByName(ctx, dstTenantID, dstKnowledgeBaseID, srcTag.Name)
if err == nil && dstTag != nil {
tagIDMapping[srcTagID] = dstTag.ID
return dstTag.ID
}
// Create new tag in target KB
// "未分类" tag should have the lowest sort order to appear first
sortOrder := srcTag.SortOrder
if srcTag.Name == types.UntaggedTagName {
sortOrder = -1
}
newTag := &types.KnowledgeTag{
ID: uuid.New().String(),
TenantID: dstTenantID,
KnowledgeBaseID: dstKnowledgeBaseID,
Name: srcTag.Name,
Color: srcTag.Color,
SortOrder: sortOrder,
CreatedAt: time.Now(),
UpdatedAt: time.Now(),
}
if err := s.tagRepo.Create(ctx, newTag); err != nil {
logger.Warnf(ctx, "Failed to create tag %s in target KB: %v", srcTag.Name, err)
tagIDMapping[srcTagID] = "" // Cache empty result
return ""
}
tagIDMapping[srcTagID] = newTag.ID
logger.Infof(ctx, "Created tag %s (ID: %s) in target KB %s", newTag.Name, newTag.ID, dstKnowledgeBaseID)
return newTag.ID
}
// SearchKnowledge searches knowledge items by keyword across the tenant and shared knowledge bases.
// fileTypes: optional list of file extensions to filter by (e.g., ["csv", "xlsx"])
func (s *knowledgeService) SearchKnowledge(ctx context.Context, keyword string, offset, limit int, fileTypes []string) ([]*types.Knowledge, bool, error) {
tenantID, ok := ctx.Value(types.TenantIDContextKey).(uint64)
if !ok {
return nil, false, werrors.NewUnauthorizedError("Tenant ID not found in context")
}
scopes := make([]types.KnowledgeSearchScope, 0)
// Own tenant: document-type knowledge bases
ownKBs, err := s.kbService.ListKnowledgeBases(ctx)
if err == nil {
for _, kb := range ownKBs {
if kb != nil && kb.Type == types.KnowledgeBaseTypeDocument {
scopes = append(scopes, types.KnowledgeSearchScope{TenantID: tenantID, KBID: kb.ID})
}
}
}
// Shared knowledge bases (document type only)
if userIDVal := ctx.Value(types.UserIDContextKey); userIDVal != nil {
if userID, ok := userIDVal.(string); ok && userID != "" {
sharedList, err := s.kbShareService.ListSharedKnowledgeBases(ctx, userID, tenantID)
if err == nil {
for _, info := range sharedList {
if info != nil && info.KnowledgeBase != nil && info.KnowledgeBase.Type == types.KnowledgeBaseTypeDocument {
scopes = append(scopes, types.KnowledgeSearchScope{
TenantID: info.SourceTenantID,
KBID: info.KnowledgeBase.ID,
})
}
}
}
}
}
if len(scopes) == 0 {
return nil, false, nil
}
return s.repo.SearchKnowledgeInScopes(ctx, scopes, keyword, offset, limit, fileTypes)
}
// SearchKnowledgeForScopes searches knowledge within the given scopes (e.g. for shared agent context).
func (s *knowledgeService) SearchKnowledgeForScopes(ctx context.Context, scopes []types.KnowledgeSearchScope, keyword string, offset, limit int, fileTypes []string) ([]*types.Knowledge, bool, error) {
if len(scopes) == 0 {
return nil, false, nil
}
return s.repo.SearchKnowledgeInScopes(ctx, scopes, keyword, offset, limit, fileTypes)
}
// ProcessKnowledgeListDelete handles Asynq knowledge list delete tasks
func (s *knowledgeService) ProcessKnowledgeListDelete(ctx context.Context, t *asynq.Task) error {
var payload types.KnowledgeListDeletePayload
if err := json.Unmarshal(t.Payload(), &payload); err != nil {
logger.Errorf(ctx, "Failed to unmarshal knowledge list delete payload: %v", err)
return err
}
logger.Infof(ctx, "Processing knowledge list delete task for %d knowledge items", len(payload.KnowledgeIDs))
// Get tenant info
tenant, err := s.tenantRepo.GetTenantByID(ctx, payload.TenantID)
if err != nil {
logger.Errorf(ctx, "Failed to get tenant %d: %v", payload.TenantID, err)
return err
}
// Set context values
ctx = context.WithValue(ctx, types.TenantIDContextKey, payload.TenantID)
ctx = context.WithValue(ctx, types.TenantInfoContextKey, tenant)
// Delete knowledge list
if err := s.DeleteKnowledgeList(ctx, payload.KnowledgeIDs); err != nil {
logger.Errorf(ctx, "Failed to delete knowledge list: %v", err)
return err
}
logger.Infof(ctx, "Successfully deleted %d knowledge items", len(payload.KnowledgeIDs))
return nil
}