mirror of
https://github.com/Tencent/WeKnora.git
synced 2026-06-04 13:30:32 +08:00
Cloning a knowledge base previously copied only the storage path strings (knowledge.FilePath and chunk.ImageInfo.URL), so the source and the clone shared the same physical objects in the storage backend. Once the original file and extracted images are deleted on source removal, the clone is left with dangling references and its document and images become unreadable — data loss that occurs even for same-store clones. Add a CopyFile primitive to the FileService interface and implement it in every backend: server-side CopyObject on the object stores (s3/obs/cos/oss/tos/ks3/minio), io.Copy on local, and a no-op on dummy. Destinations use the knowledge-owned layout and reuse the existing path/object-key guards; a sentinel ErrCrossBackendCopy is returned when the source scheme does not match the backend. Use CopyFile to deep-copy the document file in cloneKnowledge and the extracted images in CloneChunk and cloneFAQKnowledgeBase via a shared cloneChunkImageInfo helper that deduplicates identical image URLs per clone and rewrites them to the new objects. Copied objects are cleaned up best-effort if a clone fails partway through. A clone-time preflight rejects cloning into a target bound to a different storage backend when the tenant pins providers via StorageEngineConfig. Adds unit tests for local CopyFile (independent copy survives source deletion, traversal rejection, cross-backend rejection), cloneChunkImageInfo (empty/multi/dedup/parse-failure/OriginalURL handling), and the storage provider preflight.
326 lines
9.8 KiB
Go
326 lines
9.8 KiB
Go
package file
|
|
|
|
import (
|
|
"bytes"
|
|
"context"
|
|
"errors"
|
|
"fmt"
|
|
"io"
|
|
"mime/multipart"
|
|
"path/filepath"
|
|
"strings"
|
|
"time"
|
|
|
|
"github.com/Tencent/WeKnora/internal/logger"
|
|
"github.com/Tencent/WeKnora/internal/types/interfaces"
|
|
"github.com/Tencent/WeKnora/internal/utils"
|
|
"github.com/aws/aws-sdk-go-v2/aws"
|
|
"github.com/aws/aws-sdk-go-v2/config"
|
|
"github.com/aws/aws-sdk-go-v2/credentials"
|
|
"github.com/aws/aws-sdk-go-v2/service/s3"
|
|
"github.com/aws/aws-sdk-go-v2/service/s3/types"
|
|
"github.com/google/uuid"
|
|
)
|
|
|
|
// s3FileService AWS S3 file service implementation
|
|
type s3FileService struct {
|
|
client *s3.Client
|
|
bucketName string
|
|
pathPrefix string
|
|
}
|
|
|
|
// newS3Client creates a bare s3FileService with just the SDK client initialised.
|
|
func newS3Client(endpoint, accessKey, secretKey, bucketName, region, pathPrefix string) (*s3FileService, error) {
|
|
var cfg aws.Config
|
|
var err error
|
|
|
|
// Configure AWS SDK
|
|
cfg, err = config.LoadDefaultConfig(context.Background(),
|
|
config.WithRegion(region),
|
|
config.WithCredentialsProvider(credentials.NewStaticCredentialsProvider(accessKey, secretKey, "")),
|
|
)
|
|
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to load AWS config: %w", err)
|
|
}
|
|
|
|
// Create S3 client with custom endpoint if provided.
|
|
// For S3-compatible services (non-AWS), use path-style addressing
|
|
// (endpoint/bucket/key) instead of virtual-hosted style (bucket.endpoint/key).
|
|
var client *s3.Client
|
|
if endpoint != "" {
|
|
usePathStyle := !strings.Contains(endpoint, "amazonaws.com")
|
|
client = s3.NewFromConfig(cfg, func(o *s3.Options) {
|
|
o.BaseEndpoint = aws.String(endpoint)
|
|
o.UsePathStyle = usePathStyle
|
|
})
|
|
} else {
|
|
// Standard AWS S3
|
|
client = s3.NewFromConfig(cfg)
|
|
}
|
|
|
|
// Normalize pathPrefix: ensure it ends with '/' if not empty
|
|
if pathPrefix != "" && !strings.HasSuffix(pathPrefix, "/") {
|
|
pathPrefix += "/"
|
|
}
|
|
|
|
return &s3FileService{
|
|
client: client,
|
|
bucketName: bucketName,
|
|
pathPrefix: pathPrefix,
|
|
}, nil
|
|
}
|
|
|
|
// NewS3FileService creates an AWS S3 file service.
|
|
// It verifies that the bucket exists and creates it if missing.
|
|
func NewS3FileService(endpoint,
|
|
accessKey, secretKey, bucketName, region, pathPrefix string,
|
|
) (interfaces.FileService, error) {
|
|
svc, err := newS3Client(endpoint, accessKey, secretKey, bucketName, region, pathPrefix)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
// Check if bucket exists
|
|
exists, err := svc.bucketExists(context.Background())
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to check bucket: %w", err)
|
|
}
|
|
|
|
if !exists {
|
|
if err = svc.createBucket(context.Background()); err != nil {
|
|
return nil, fmt.Errorf("failed to create bucket: %w", err)
|
|
}
|
|
}
|
|
|
|
return svc, nil
|
|
}
|
|
|
|
// bucketExists checks if the bucket exists
|
|
func (s *s3FileService) bucketExists(ctx context.Context) (bool, error) {
|
|
_, err := s.client.HeadBucket(ctx, &s3.HeadBucketInput{
|
|
Bucket: aws.String(s.bucketName),
|
|
})
|
|
if err != nil {
|
|
// Check if the error is a NotFound error
|
|
var notFound *types.NotFound
|
|
if errors.As(err, ¬Found) {
|
|
return false, nil
|
|
}
|
|
return false, err
|
|
}
|
|
return true, nil
|
|
}
|
|
|
|
// createBucket creates a new bucket
|
|
func (s *s3FileService) createBucket(ctx context.Context) error {
|
|
_, err := s.client.CreateBucket(ctx, &s3.CreateBucketInput{
|
|
Bucket: aws.String(s.bucketName),
|
|
})
|
|
return err
|
|
}
|
|
|
|
// CheckConnectivity verifies S3 is reachable and, if a bucket is configured,
|
|
// that the bucket exists. This is a read-only probe — it never creates a bucket.
|
|
func (s *s3FileService) CheckConnectivity(ctx context.Context) error {
|
|
checkCtx, cancel := context.WithTimeout(ctx, 10*time.Second)
|
|
defer cancel()
|
|
|
|
if s.bucketName != "" {
|
|
exists, err := s.bucketExists(checkCtx)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
if !exists {
|
|
return fmt.Errorf("bucket %q does not exist", s.bucketName)
|
|
}
|
|
return nil
|
|
}
|
|
|
|
// List buckets to verify connectivity
|
|
_, err := s.client.ListBuckets(checkCtx, &s3.ListBucketsInput{})
|
|
return err
|
|
}
|
|
|
|
// CheckS3Connectivity tests S3 connectivity using the provided credentials.
|
|
// It creates a temporary service instance internally and delegates to CheckConnectivity.
|
|
func CheckS3Connectivity(ctx context.Context, endpoint, accessKey, secretKey, bucketName, region string) error {
|
|
svc, err := newS3Client(endpoint, accessKey, secretKey, bucketName, region, "")
|
|
if err != nil {
|
|
return err
|
|
}
|
|
return svc.CheckConnectivity(ctx)
|
|
}
|
|
|
|
// parseS3FilePath extracts the object name from a provider scheme: s3://{bucket}/{objectKey}
|
|
func (s *s3FileService) parseS3FilePath(filePath string) (string, error) {
|
|
// Provider scheme format: s3://{bucket}/{objectKey}
|
|
const prefix = "s3://"
|
|
if !strings.HasPrefix(filePath, prefix) {
|
|
return "", fmt.Errorf("invalid S3 file path: %s", filePath)
|
|
}
|
|
rest := strings.TrimPrefix(filePath, prefix)
|
|
parts := strings.SplitN(rest, "/", 2)
|
|
if len(parts) != 2 || parts[0] == "" || parts[1] == "" {
|
|
return "", fmt.Errorf("invalid S3 file path: %s", filePath)
|
|
}
|
|
if parts[0] != s.bucketName {
|
|
return "", fmt.Errorf("bucket mismatch in path: got %s, want %s", parts[0], s.bucketName)
|
|
}
|
|
if err := utils.SafeObjectKey(parts[1]); err != nil {
|
|
return "", fmt.Errorf("invalid file path: %w", err)
|
|
}
|
|
return parts[1], nil
|
|
}
|
|
|
|
// SaveFile saves a file to S3
|
|
func (s *s3FileService) SaveFile(ctx context.Context,
|
|
file *multipart.FileHeader, tenantID uint64, knowledgeID string,
|
|
) (string, error) {
|
|
// Generate object name
|
|
ext := filepath.Ext(file.Filename)
|
|
objectName := fmt.Sprintf("%s%d/%s/%s%s", s.pathPrefix, tenantID, knowledgeID, uuid.New().String(), ext)
|
|
|
|
// Open file
|
|
src, err := file.Open()
|
|
if err != nil {
|
|
return "", fmt.Errorf("failed to open file: %w", err)
|
|
}
|
|
defer src.Close()
|
|
|
|
// Determine content type
|
|
contentType := file.Header.Get("Content-Type")
|
|
if contentType == "" {
|
|
contentType = utils.GetContentTypeByExt(ext)
|
|
}
|
|
|
|
// Upload file to S3
|
|
_, err = s.client.PutObject(ctx, &s3.PutObjectInput{
|
|
Bucket: aws.String(s.bucketName),
|
|
Key: aws.String(objectName),
|
|
Body: src,
|
|
ContentLength: aws.Int64(file.Size),
|
|
ContentType: aws.String(contentType),
|
|
})
|
|
if err != nil {
|
|
return "", fmt.Errorf("failed to upload file to S3: %w", err)
|
|
}
|
|
|
|
return fmt.Sprintf("s3://%s/%s", s.bucketName, objectName), nil
|
|
}
|
|
|
|
// GetFile gets a file from S3
|
|
func (s *s3FileService) GetFile(ctx context.Context, filePath string) (io.ReadCloser, error) {
|
|
objectName, err := s.parseS3FilePath(filePath)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
resp, err := s.client.GetObject(ctx, &s3.GetObjectInput{
|
|
Bucket: aws.String(s.bucketName),
|
|
Key: aws.String(objectName),
|
|
})
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to get file from S3: %w", err)
|
|
}
|
|
|
|
return resp.Body, nil
|
|
}
|
|
|
|
// DeleteFile deletes a file
|
|
func (s *s3FileService) DeleteFile(ctx context.Context, filePath string) error {
|
|
objectName, err := s.parseS3FilePath(filePath)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
_, err = s.client.DeleteObject(ctx, &s3.DeleteObjectInput{
|
|
Bucket: aws.String(s.bucketName),
|
|
Key: aws.String(objectName),
|
|
})
|
|
if err != nil {
|
|
return fmt.Errorf("failed to delete file: %w", err)
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// CopyFile copies an existing S3 object to a new knowledge-owned object using a
|
|
// server-side CopyObject (no data leaves S3). The destination uses the same
|
|
// layout as SaveFile. Returns ErrCrossBackendCopy when srcPath is not an s3:// path.
|
|
func (s *s3FileService) CopyFile(ctx context.Context,
|
|
srcPath string, tenantID uint64, knowledgeID string,
|
|
) (string, error) {
|
|
srcKey, err := s.parseS3FilePath(srcPath)
|
|
if err != nil {
|
|
return "", fmt.Errorf("s3 copy rejected source %q: %w", srcPath, ErrCrossBackendCopy)
|
|
}
|
|
|
|
ext := filepath.Ext(srcPath)
|
|
destKey := fmt.Sprintf("%s%d/%s/%s%s", s.pathPrefix, tenantID, knowledgeID, uuid.New().String(), ext)
|
|
|
|
// CopySource is "bucket/key"; the '/' separators must NOT be percent-encoded
|
|
// (url.PathEscape would turn them into %2F and break the bucket/key split).
|
|
// srcKey is already validated by parseS3FilePath -> SafeObjectKey.
|
|
_, err = s.client.CopyObject(ctx, &s3.CopyObjectInput{
|
|
Bucket: aws.String(s.bucketName),
|
|
CopySource: aws.String(s.bucketName + "/" + srcKey),
|
|
Key: aws.String(destKey),
|
|
})
|
|
if err != nil {
|
|
return "", fmt.Errorf("failed to copy file in S3: %w", err)
|
|
}
|
|
|
|
newPath := fmt.Sprintf("s3://%s/%s", s.bucketName, destKey)
|
|
logger.Infof(ctx, "Copied S3 object %s to %s", srcPath, newPath)
|
|
return newPath, nil
|
|
}
|
|
|
|
// SaveBytes saves bytes data to S3 and returns the file path
|
|
// temp parameter is ignored for S3 (no auto-expiration support in this implementation)
|
|
func (s *s3FileService) SaveBytes(ctx context.Context, data []byte, tenantID uint64, fileName string, temp bool) (string, error) {
|
|
safeName, err := utils.SafeFileName(fileName)
|
|
if err != nil {
|
|
return "", fmt.Errorf("invalid file name: %w", err)
|
|
}
|
|
ext := filepath.Ext(safeName)
|
|
objectName := fmt.Sprintf("%s%d/exports/%s%s", s.pathPrefix, tenantID, uuid.New().String(), ext)
|
|
|
|
// Upload bytes to S3
|
|
reader := bytes.NewReader(data)
|
|
_, err = s.client.PutObject(ctx, &s3.PutObjectInput{
|
|
Bucket: aws.String(s.bucketName),
|
|
Key: aws.String(objectName),
|
|
Body: reader,
|
|
ContentLength: aws.Int64(int64(len(data))),
|
|
ContentType: aws.String(utils.GetContentTypeByExt(ext)),
|
|
})
|
|
if err != nil {
|
|
return "", fmt.Errorf("failed to upload bytes to S3: %w", err)
|
|
}
|
|
|
|
return fmt.Sprintf("s3://%s/%s", s.bucketName, objectName), nil
|
|
}
|
|
|
|
// GetFileURL returns a presigned download URL for the file
|
|
func (s *s3FileService) GetFileURL(ctx context.Context, filePath string) (string, error) {
|
|
objectName, err := s.parseS3FilePath(filePath)
|
|
if err != nil {
|
|
return "", err
|
|
}
|
|
|
|
// Create presign client
|
|
presignClient := s3.NewPresignClient(s.client)
|
|
|
|
// Generate presigned URL
|
|
presignedReq, err := presignClient.PresignGetObject(ctx, &s3.GetObjectInput{
|
|
Bucket: aws.String(s.bucketName),
|
|
Key: aws.String(objectName),
|
|
}, s3.WithPresignExpires(24*time.Hour))
|
|
if err != nil {
|
|
return "", fmt.Errorf("failed to generate presigned URL: %w", err)
|
|
}
|
|
|
|
return presignedReq.URL, nil
|
|
}
|