mirror of
https://github.com/Tencent/WeKnora.git
synced 2026-06-04 13:30:32 +08:00
Cloning a knowledge base previously copied only the storage path strings (knowledge.FilePath and chunk.ImageInfo.URL), so the source and the clone shared the same physical objects in the storage backend. Once the original file and extracted images are deleted on source removal, the clone is left with dangling references and its document and images become unreadable — data loss that occurs even for same-store clones. Add a CopyFile primitive to the FileService interface and implement it in every backend: server-side CopyObject on the object stores (s3/obs/cos/oss/tos/ks3/minio), io.Copy on local, and a no-op on dummy. Destinations use the knowledge-owned layout and reuse the existing path/object-key guards; a sentinel ErrCrossBackendCopy is returned when the source scheme does not match the backend. Use CopyFile to deep-copy the document file in cloneKnowledge and the extracted images in CloneChunk and cloneFAQKnowledgeBase via a shared cloneChunkImageInfo helper that deduplicates identical image URLs per clone and rewrites them to the new objects. Copied objects are cleaned up best-effort if a clone fails partway through. A clone-time preflight rejects cloning into a target bound to a different storage backend when the tenant pins providers via StorageEngineConfig. Adds unit tests for local CopyFile (independent copy survives source deletion, traversal rejection, cross-backend rejection), cloneChunkImageInfo (empty/multi/dedup/parse-failure/OriginalURL handling), and the storage provider preflight.
316 lines
9.1 KiB
Go
316 lines
9.1 KiB
Go
package file
|
|
|
|
import (
|
|
"context"
|
|
"fmt"
|
|
"io"
|
|
"mime/multipart"
|
|
"os"
|
|
"path/filepath"
|
|
"strings"
|
|
"time"
|
|
|
|
"github.com/Tencent/WeKnora/internal/logger"
|
|
"github.com/Tencent/WeKnora/internal/types/interfaces"
|
|
"github.com/aws/aws-sdk-go-v2/aws"
|
|
"github.com/aws/aws-sdk-go-v2/credentials"
|
|
"github.com/aws/aws-sdk-go-v2/service/s3"
|
|
"github.com/google/uuid"
|
|
)
|
|
|
|
type obsFileService struct {
|
|
client *s3.Client
|
|
bucketName string
|
|
endpoint string
|
|
region string
|
|
pathPrefix string
|
|
proxyDomain string
|
|
}
|
|
|
|
type obsEndpointResolver struct {
|
|
url string
|
|
}
|
|
|
|
func (r *obsEndpointResolver) ResolveEndpoint(region string, options s3.EndpointResolverOptions) (aws.Endpoint, error) {
|
|
return aws.Endpoint{
|
|
URL: r.url,
|
|
SigningRegion: region,
|
|
HostnameImmutable: true,
|
|
}, nil
|
|
}
|
|
|
|
func NewObsFileService(
|
|
endpoint, region, accessKeyID, secretAccessKey, bucketName string,
|
|
pathPrefix string,
|
|
) (interfaces.FileService, error) {
|
|
|
|
client := s3.New(s3.Options{
|
|
Region: region,
|
|
EndpointResolver: &obsEndpointResolver{url: endpoint},
|
|
Credentials: credentials.NewStaticCredentialsProvider(accessKeyID, secretAccessKey, ""),
|
|
UsePathStyle: true,
|
|
})
|
|
|
|
_, err := client.HeadBucket(context.Background(), &s3.HeadBucketInput{
|
|
Bucket: aws.String(bucketName),
|
|
})
|
|
if err != nil {
|
|
_, createErr := client.CreateBucket(context.Background(), &s3.CreateBucketInput{
|
|
Bucket: aws.String(bucketName),
|
|
})
|
|
if createErr != nil {
|
|
fmt.Printf("Warning: bucket %s may not exist or cannot be created: %v\n", bucketName, createErr)
|
|
}
|
|
}
|
|
|
|
proxyDomain := strings.TrimSuffix(os.Getenv("OBS_PROXY_DOMAIN"), "/")
|
|
|
|
return &obsFileService{
|
|
client: client,
|
|
bucketName: bucketName,
|
|
endpoint: endpoint,
|
|
region: region,
|
|
pathPrefix: strings.Trim(pathPrefix, "/"),
|
|
proxyDomain: proxyDomain,
|
|
}, nil
|
|
}
|
|
|
|
func CheckObsConnectivity(ctx context.Context, endpoint, region, accessKey, secretKey, bucketName string) error {
|
|
checkCtx, cancel := context.WithTimeout(ctx, 10*time.Second)
|
|
defer cancel()
|
|
|
|
client := s3.New(s3.Options{
|
|
Region: region,
|
|
EndpointResolver: &obsEndpointResolver{url: endpoint},
|
|
Credentials: credentials.NewStaticCredentialsProvider(accessKey, secretKey, ""),
|
|
UsePathStyle: true,
|
|
})
|
|
|
|
_, err := client.HeadBucket(checkCtx, &s3.HeadBucketInput{
|
|
Bucket: aws.String(bucketName),
|
|
})
|
|
if err != nil {
|
|
return fmt.Errorf("OBS connectivity check failed: %w", err)
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func (s *obsFileService) CheckConnectivity(ctx context.Context) error {
|
|
checkCtx, cancel := context.WithTimeout(ctx, 10*time.Second)
|
|
defer cancel()
|
|
|
|
_, err := s.client.HeadBucket(checkCtx, &s3.HeadBucketInput{
|
|
Bucket: aws.String(s.bucketName),
|
|
})
|
|
if err != nil {
|
|
return fmt.Errorf("OBS connectivity check failed: %w", err)
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func (s *obsFileService) parseObsFilePath(filePath string) (string, error) {
|
|
prefix := s.getPrifix()
|
|
|
|
if strings.HasPrefix(filePath, prefix) {
|
|
rest := strings.TrimPrefix(filePath, prefix)
|
|
// With proxy domain: path is {prefix}/{objectKey} (no bucket name)
|
|
if s.proxyDomain != "" {
|
|
rest = strings.TrimPrefix(rest, "/")
|
|
if rest != "" {
|
|
return rest, nil
|
|
}
|
|
return "", fmt.Errorf("invalid OBS file path: %s", filePath)
|
|
}
|
|
// Without proxy domain: path is {prefix}/{bucketName}/{objectKey}
|
|
parts := strings.SplitN(rest, "/", 2)
|
|
if len(parts) == 2 && parts[0] == s.bucketName && parts[1] != "" {
|
|
return parts[1], nil
|
|
}
|
|
return "", fmt.Errorf("invalid OBS file path: %s", filePath)
|
|
}
|
|
return filePath, nil
|
|
}
|
|
|
|
func (s *obsFileService) getPrifix() string {
|
|
if s.proxyDomain != "" {
|
|
return s.proxyDomain + "/"
|
|
}
|
|
return "obs://"
|
|
}
|
|
|
|
func (s *obsFileService) SaveFile(ctx context.Context,
|
|
file *multipart.FileHeader, tenantID uint64, knowledgeID string,
|
|
) (string, error) {
|
|
ext := filepath.Ext(file.Filename)
|
|
|
|
var objectKey string
|
|
if s.pathPrefix != "" {
|
|
objectKey = fmt.Sprintf("%s/%d/%s/%s%s", s.pathPrefix, tenantID, knowledgeID, uuid.New().String(), ext)
|
|
} else {
|
|
objectKey = fmt.Sprintf("%d/%s/%s%s", tenantID, knowledgeID, uuid.New().String(), ext)
|
|
}
|
|
|
|
src, err := file.Open()
|
|
if err != nil {
|
|
return "", fmt.Errorf("failed to open file: %w", err)
|
|
}
|
|
defer src.Close()
|
|
|
|
contentType := file.Header.Get("Content-Type")
|
|
if contentType == "" {
|
|
contentType = "application/octet-stream"
|
|
}
|
|
|
|
_, err = s.client.PutObject(ctx, &s3.PutObjectInput{
|
|
Bucket: aws.String(s.bucketName),
|
|
Key: aws.String(objectKey),
|
|
Body: src,
|
|
ContentLength: aws.Int64(file.Size),
|
|
ContentType: aws.String(contentType),
|
|
// ACL: "private",
|
|
})
|
|
if err != nil {
|
|
return "", fmt.Errorf("failed to upload file to OBS: %w", err)
|
|
}
|
|
prefix := s.getPrifix()
|
|
if s.proxyDomain != "" {
|
|
return fmt.Sprintf("%s%s", prefix, objectKey), nil
|
|
}
|
|
return fmt.Sprintf("%s%s/%s", prefix, s.bucketName, objectKey), nil
|
|
}
|
|
|
|
func (s *obsFileService) GetFile(ctx context.Context, filePath string) (io.ReadCloser, error) {
|
|
objectKey, err := s.parseObsFilePath(filePath)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
output, err := s.client.GetObject(ctx, &s3.GetObjectInput{
|
|
Bucket: aws.String(s.bucketName),
|
|
Key: aws.String(objectKey),
|
|
})
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to get file from OBS: %w", err)
|
|
}
|
|
|
|
return output.Body, nil
|
|
}
|
|
|
|
func (s *obsFileService) DeleteFile(ctx context.Context, filePath string) error {
|
|
objectKey, err := s.parseObsFilePath(filePath)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
_, err = s.client.DeleteObject(ctx, &s3.DeleteObjectInput{
|
|
Bucket: aws.String(s.bucketName),
|
|
Key: aws.String(objectKey),
|
|
})
|
|
if err != nil {
|
|
return fmt.Errorf("failed to delete file from OBS: %w", err)
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func (s *obsFileService) GetFileURL(ctx context.Context, filePath string) (string, error) {
|
|
if strings.HasPrefix(filePath, "http://") || strings.HasPrefix(filePath, "https://") {
|
|
return filePath, nil
|
|
}
|
|
|
|
objectKey, err := s.parseObsFilePath(filePath)
|
|
if err != nil {
|
|
return "", err
|
|
}
|
|
|
|
if s.proxyDomain != "" {
|
|
return s.proxyDomain + "/" + strings.TrimPrefix(objectKey, "/"), nil
|
|
}
|
|
|
|
return fmt.Sprintf("%s/%s/%s", s.endpoint, s.bucketName, strings.TrimPrefix(objectKey, "/")), nil
|
|
}
|
|
|
|
// CopyFile copies an existing OBS object to a new knowledge-owned object using a
|
|
// server-side CopyObject (OBS is S3-compatible). The destination uses the same
|
|
// layout as SaveFile. Returns ErrCrossBackendCopy when srcPath does not belong
|
|
// to this OBS service.
|
|
func (s *obsFileService) CopyFile(ctx context.Context,
|
|
srcPath string, tenantID uint64, knowledgeID string,
|
|
) (string, error) {
|
|
// Reject paths that do not use this service's prefix (proxy domain or obs://).
|
|
// parseObsFilePath falls back to returning the raw input for unknown prefixes,
|
|
// so guard explicitly here to detect cross-backend sources.
|
|
if !strings.HasPrefix(srcPath, s.getPrifix()) {
|
|
return "", fmt.Errorf("obs copy rejected source %q: %w", srcPath, ErrCrossBackendCopy)
|
|
}
|
|
srcKey, err := s.parseObsFilePath(srcPath)
|
|
if err != nil {
|
|
return "", fmt.Errorf("obs copy rejected source %q: %w", srcPath, ErrCrossBackendCopy)
|
|
}
|
|
|
|
ext := filepath.Ext(srcPath)
|
|
var destKey string
|
|
if s.pathPrefix != "" {
|
|
destKey = fmt.Sprintf("%s/%d/%s/%s%s", s.pathPrefix, tenantID, knowledgeID, uuid.New().String(), ext)
|
|
} else {
|
|
destKey = fmt.Sprintf("%d/%s/%s%s", tenantID, knowledgeID, uuid.New().String(), ext)
|
|
}
|
|
|
|
// CopySource is "bucket/key"; the '/' separators must NOT be percent-encoded
|
|
// (url.PathEscape would turn them into %2F and break the bucket/key split).
|
|
_, err = s.client.CopyObject(ctx, &s3.CopyObjectInput{
|
|
Bucket: aws.String(s.bucketName),
|
|
CopySource: aws.String(s.bucketName + "/" + srcKey),
|
|
Key: aws.String(destKey),
|
|
})
|
|
if err != nil {
|
|
return "", fmt.Errorf("failed to copy file in OBS: %w", err)
|
|
}
|
|
|
|
prefix := s.getPrifix()
|
|
var newPath string
|
|
if s.proxyDomain != "" {
|
|
newPath = fmt.Sprintf("%s%s", prefix, destKey)
|
|
} else {
|
|
newPath = fmt.Sprintf("%s%s/%s", prefix, s.bucketName, destKey)
|
|
}
|
|
logger.Infof(ctx, "Copied OBS object %s to %s", srcPath, newPath)
|
|
return newPath, nil
|
|
}
|
|
|
|
func (s *obsFileService) SaveBytes(ctx context.Context, data []byte, tenantID uint64, fileName string, temp bool) (string, error) {
|
|
ext := filepath.Ext(fileName)
|
|
|
|
var objectKey string
|
|
if temp {
|
|
if s.pathPrefix != "" {
|
|
objectKey = fmt.Sprintf("%s/temp/%d/%s%s", s.pathPrefix, tenantID, uuid.New().String(), ext)
|
|
} else {
|
|
objectKey = fmt.Sprintf("temp/%d/%s%s", tenantID, uuid.New().String(), ext)
|
|
}
|
|
} else {
|
|
if s.pathPrefix != "" {
|
|
objectKey = fmt.Sprintf("%s/%d/%s%s", s.pathPrefix, tenantID, uuid.New().String(), ext)
|
|
} else {
|
|
objectKey = fmt.Sprintf("%d/%s%s", tenantID, uuid.New().String(), ext)
|
|
}
|
|
}
|
|
|
|
_, err := s.client.PutObject(ctx, &s3.PutObjectInput{
|
|
Bucket: aws.String(s.bucketName),
|
|
Key: aws.String(objectKey),
|
|
Body: strings.NewReader(string(data)),
|
|
ContentType: aws.String("application/octet-stream"),
|
|
ACL: "public-read",
|
|
})
|
|
if err != nil {
|
|
return "", fmt.Errorf("failed to upload bytes to OBS: %w", err)
|
|
}
|
|
|
|
prefix := s.getPrifix()
|
|
if s.proxyDomain != "" {
|
|
return fmt.Sprintf("%s%s", prefix, objectKey), nil
|
|
}
|
|
return fmt.Sprintf("%s%s/%s", prefix, s.bucketName, objectKey), nil
|
|
}
|