From e309e0bed8d7f6c8b07df59c05718cd0e7b5a84a Mon Sep 17 00:00:00 2001 From: DaWesen <3880255095@qq.com> Date: Sat, 7 Mar 2026 14:01:22 +0800 Subject: [PATCH] =?UTF-8?q?feat(storage):=20=E9=9B=86=E6=88=90S3=E5=AD=98?= =?UTF-8?q?=E5=82=A8=E9=80=82=E9=85=8D=E5=99=A8?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit 添加对AWS S3及兼容存储服务的支持: - 实现完整的S3FileService接口 - 支持文件上传、下载、删除功能 - 添加配置支持和环境变量检查 - 实现连接测试功能 - 遵循与其他存储适配器相同的代码风格 --- .env.example | 21 +- go.mod | 18 + go.sum | 36 ++ helm/values.yaml | 2 +- internal/application/service/file/factory.go | 10 + internal/application/service/file/s3.go | 332 +++++++++++++++++++ internal/container/container.go | 20 ++ internal/handler/system.go | 39 ++- internal/types/tenant.go | 15 +- scripts/check-env.sh | 8 + 10 files changed, 496 insertions(+), 5 deletions(-) create mode 100644 internal/application/service/file/s3.go diff --git a/.env.example b/.env.example index 48112f4d..1c2c04c6 100644 --- a/.env.example +++ b/.env.example @@ -27,7 +27,7 @@ DB_DRIVER=postgres # 向量存储类型(postgres/elasticsearch_v7/elasticsearch_v8/qdrant/milvus/weaviate) RETRIEVE_DRIVER=postgres -# 文件存储类型(local/minio/cos/tos) +# 文件存储类型(local/minio/cos/tos/s3) STORAGE_TYPE=local # 流处理后端(memory/redis) @@ -192,6 +192,25 @@ COS_ENABLE_OLD_DOMAIN=true # 火山引擎TOS临时桶区域(可选,默认与主桶相同) # TOS_TEMP_REGION=your_tos_temp_region +# 如果使用AWS S3作为文件存储,需要配置以下参数 +# AWS S3的访问端点,例如 https://s3.amazonaws.com +# S3_ENDPOINT=https://s3.amazonaws.com + +# AWS S3的区域,例如 us-east-1 +# S3_REGION=us-east-1 + +# AWS S3访问密钥 Access Key +# S3_ACCESS_KEY=your_s3_access_key + +# AWS S3访问密钥 Secret Key +# S3_SECRET_KEY=your_s3_secret_key + +# AWS S3桶名称 +# S3_BUCKET_NAME=your_s3_bucket_name + +# AWS S3可选路径前缀(可选) +# S3_PATH_PREFIX=your_s3_path_prefix + # 如果解析网络连接使用Web代理,需要配置以下参数 # WEB_PROXY=your_web_proxy diff --git a/go.mod b/go.mod index 04db7c2d..0810720a 100644 --- a/go.mod +++ b/go.mod @@ -6,6 +6,10 @@ require ( github.com/JohannesKaufmann/html-to-markdown/v2 v2.5.0 github.com/PuerkitoBio/goquery v1.10.3 github.com/asg017/sqlite-vec-go-bindings v0.1.6 + github.com/aws/aws-sdk-go-v2 v1.41.3 + github.com/aws/aws-sdk-go-v2/config v1.29.0 + github.com/aws/aws-sdk-go-v2/credentials v1.19.11 + github.com/aws/aws-sdk-go-v2/service/s3 v1.78.0 github.com/chromedp/chromedp v0.14.2 github.com/duckdb/duckdb-go/v2 v2.5.4 github.com/elastic/go-elasticsearch/v7 v7.17.10 @@ -69,6 +73,20 @@ require ( github.com/andybalholm/brotli v1.2.0 // indirect github.com/andybalholm/cascadia v1.3.3 // indirect github.com/apache/arrow-go/v18 v18.4.1 // indirect + github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.7.6 // indirect + github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.18.19 // indirect + github.com/aws/aws-sdk-go-v2/internal/configsources v1.4.19 // indirect + github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.7.19 // indirect + github.com/aws/aws-sdk-go-v2/internal/ini v1.8.5 // indirect + github.com/aws/aws-sdk-go-v2/internal/v4a v1.4.20 // indirect + github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.13.6 // indirect + github.com/aws/aws-sdk-go-v2/service/internal/checksum v1.9.11 // indirect + github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.13.19 // indirect + github.com/aws/aws-sdk-go-v2/service/internal/s3shared v1.19.19 // indirect + github.com/aws/aws-sdk-go-v2/service/sso v1.30.12 // indirect + github.com/aws/aws-sdk-go-v2/service/ssooidc v1.35.16 // indirect + github.com/aws/aws-sdk-go-v2/service/sts v1.41.8 // indirect + github.com/aws/smithy-go v1.24.2 // indirect github.com/bahlo/generic-list-go v0.2.0 // indirect github.com/beorn7/perks v1.0.1 // indirect github.com/blang/semver/v4 v4.0.0 // indirect diff --git a/go.sum b/go.sum index 95a311aa..c0a70887 100644 --- a/go.sum +++ b/go.sum @@ -1397,6 +1397,42 @@ github.com/armon/go-metrics v0.0.0-20180917152333-f0300d1749da/go.mod h1:Q73ZrmV github.com/armon/go-radix v0.0.0-20180808171621-7fddfc383310/go.mod h1:ufUuZ+zHj4x4TnLV4JWEpy2hxWSpsRywHrMgIH9cCH8= github.com/asg017/sqlite-vec-go-bindings v0.1.6 h1:Nx0jAzyS38XpkKznJ9xQjFXz2X9tI7KqjwVxV8RNoww= github.com/asg017/sqlite-vec-go-bindings v0.1.6/go.mod h1:A8+cTt/nKFsYCQF6OgzSNpKZrzNo5gQsXBTfsXHXY0Q= +github.com/aws/aws-sdk-go-v2 v1.41.3 h1:4kQ/fa22KjDt13QCy1+bYADvdgcxpfH18f0zP542kZA= +github.com/aws/aws-sdk-go-v2 v1.41.3/go.mod h1:mwsPRE8ceUUpiTgF7QmQIJ7lgsKUPQOUl3o72QBrE1o= +github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.7.6 h1:N4lRUXZpZ1KVEUn6hxtco/1d2lgYhNn1fHkkl8WhlyQ= +github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.7.6/go.mod h1:lyw7GFp3qENLh7kwzf7iMzAxDn+NzjXEAGjKS2UOKqI= +github.com/aws/aws-sdk-go-v2/config v1.29.0 h1:Vk/u4jof33or1qAQLdofpjKV7mQQT7DcUpnYx8kdmxY= +github.com/aws/aws-sdk-go-v2/config v1.29.0/go.mod h1:iXAZK3Gxvpq3tA+B9WaDYpZis7M8KFgdrDPMmHrgbJM= +github.com/aws/aws-sdk-go-v2/credentials v1.19.11 h1:NdV8cwCcAXrCWyxArt58BrvZJ9pZ9Fhf9w6Uh5W3Uyc= +github.com/aws/aws-sdk-go-v2/credentials v1.19.11/go.mod h1:30yY2zqkMPdrvxBqzI9xQCM+WrlrZKSOpSJEsylVU+8= +github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.18.19 h1:INUvJxmhdEbVulJYHI061k4TVuS3jzzthNvjqvVvTKM= +github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.18.19/go.mod h1:FpZN2QISLdEBWkayloda+sZjVJL+e9Gl0k1SyTgcswU= +github.com/aws/aws-sdk-go-v2/internal/configsources v1.4.19 h1:/sECfyq2JTifMI2JPyZ4bdRN77zJmr6SrS1eL3augIA= +github.com/aws/aws-sdk-go-v2/internal/configsources v1.4.19/go.mod h1:dMf8A5oAqr9/oxOfLkC/c2LU/uMcALP0Rgn2BD5LWn0= +github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.7.19 h1:AWeJMk33GTBf6J20XJe6qZoRSJo0WfUhsMdUKhoODXE= +github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.7.19/go.mod h1:+GWrYoaAsV7/4pNHpwh1kiNLXkKaSoppxQq9lbH8Ejw= +github.com/aws/aws-sdk-go-v2/internal/ini v1.8.5 h1:clHU5fm//kWS1C2HgtgWxfQbFbx4b6rx+5jzhgX9HrI= +github.com/aws/aws-sdk-go-v2/internal/ini v1.8.5/go.mod h1:O3h0IK87yXci+kg6flUKzJnWeziQUKciKrLjcatSNcY= +github.com/aws/aws-sdk-go-v2/internal/v4a v1.4.20 h1:qi3e/dmpdONhj1RyIZdi6DKKpDXS5Lb8ftr3p7cyHJc= +github.com/aws/aws-sdk-go-v2/internal/v4a v1.4.20/go.mod h1:V1K+TeJVD5JOk3D9e5tsX2KUdL7BlB+FV6cBhdobN8c= +github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.13.6 h1:XAq62tBTJP/85lFD5oqOOe7YYgWxY9LvWq8plyDvDVg= +github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.13.6/go.mod h1:x0nZssQ3qZSnIcePWLvcoFisRXJzcTVvYpAAdYX8+GI= +github.com/aws/aws-sdk-go-v2/service/internal/checksum v1.9.11 h1:BYf7XNsJMzl4mObARUBUib+j2tf0U//JAAtTnYqvqCw= +github.com/aws/aws-sdk-go-v2/service/internal/checksum v1.9.11/go.mod h1:aEUS4WrNk/+FxkBZZa7tVgp4pGH+kFGW40Y8rCPqt5g= +github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.13.19 h1:X1Tow7suZk9UCJHE1Iw9GMZJJl0dAnKXXP1NaSDHwmw= +github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.13.19/go.mod h1:/rARO8psX+4sfjUQXp5LLifjUt8DuATZ31WptNJTyQA= +github.com/aws/aws-sdk-go-v2/service/internal/s3shared v1.19.19 h1:JnQeStZvPHFHeyky/7LbMlyQjUa+jIBj36OlWm0pzIk= +github.com/aws/aws-sdk-go-v2/service/internal/s3shared v1.19.19/go.mod h1:HGyasyHvYdFQeJhvDHfH7HXkHh57htcJGKDZ+7z+I24= +github.com/aws/aws-sdk-go-v2/service/s3 v1.78.0 h1:EBm8lXevBWe+kK9VOU/IBeOI189WPRwPUc3LvJK9GOs= +github.com/aws/aws-sdk-go-v2/service/s3 v1.78.0/go.mod h1:4qzsZSzB/KiX2EzDjs9D7A8rI/WGJxZceVJIHqtJjIU= +github.com/aws/aws-sdk-go-v2/service/sso v1.30.12 h1:iSsvB9EtQ09YrsmIc44Heqlx5ByGErqhPK1ZQLppias= +github.com/aws/aws-sdk-go-v2/service/sso v1.30.12/go.mod h1:fEWYKTRGoZNl8tZ77i61/ccwOMJdGxwOhWCkp6TXAr0= +github.com/aws/aws-sdk-go-v2/service/ssooidc v1.35.16 h1:EnUdUqRP1CNzt2DkV67tJx6XDN4xlfBFm+bzeNOQVb0= +github.com/aws/aws-sdk-go-v2/service/ssooidc v1.35.16/go.mod h1:Jic/xv0Rq/pFNCh3WwpH4BEqdbSAl+IyHro8LbibHD8= +github.com/aws/aws-sdk-go-v2/service/sts v1.41.8 h1:XQTQTF75vnug2TXS8m7CVJfC2nniYPZnO1D4Np761Oo= +github.com/aws/aws-sdk-go-v2/service/sts v1.41.8/go.mod h1:Xgx+PR1NUOjNmQY+tRMnouRp83JRM8pRMw/vCaVhPkI= +github.com/aws/smithy-go v1.24.2 h1:FzA3bu/nt/vDvmnkg+R8Xl46gmzEDam6mZ1hzmwXFng= +github.com/aws/smithy-go v1.24.2/go.mod h1:YE2RhdIuDbA5E5bTdciG9KrW3+TiEONeUWCqxX9i1Fc= github.com/aymerick/raymond v2.0.3-0.20180322193309-b565731e1464+incompatible/go.mod h1:osfaiScAUVup+UC9Nfq76eWqDhXlp+4UYaA8uhTBO6g= github.com/bahlo/generic-list-go v0.2.0 h1:5sz/EEAK+ls5wF+NeqDpk5+iNdMDXrh3z3nPnH1Wvgk= github.com/bahlo/generic-list-go v0.2.0/go.mod h1:2KvAjgMlE5NNynlg/5iLrrCCZ2+5xWbdbCW3pNTGyYg= diff --git a/helm/values.yaml b/helm/values.yaml index adfc7aec..70aea7b8 100644 --- a/helm/values.yaml +++ b/helm/values.yaml @@ -90,7 +90,7 @@ app: GIN_MODE: release # -- Retrieval driver: postgres, elasticsearch_v7, elasticsearch_v8, qdrant RETRIEVE_DRIVER: postgres - # -- Storage type: local, minio, cos, tos + # -- Storage type: local, minio, cos, tos, s3 STORAGE_TYPE: local LOCAL_STORAGE_BASE_DIR: /data/files AUTO_RECOVER_DIRTY: "true" diff --git a/internal/application/service/file/factory.go b/internal/application/service/file/factory.go index 93b07050..5cc8d008 100644 --- a/internal/application/service/file/factory.go +++ b/internal/application/service/file/factory.go @@ -90,6 +90,16 @@ func NewFileServiceFromStorageConfig( } svc, err := NewTosFileService(sec.TOS.Endpoint, sec.TOS.Region, sec.TOS.AccessKey, sec.TOS.SecretKey, sec.TOS.BucketName, sec.TOS.PathPrefix) return svc, p, err + case "s3": + if sec == nil || sec.S3 == nil || sec.S3.Endpoint == "" || sec.S3.Region == "" || sec.S3.AccessKey == "" || sec.S3.SecretKey == "" || sec.S3.BucketName == "" { + return nil, p, fmt.Errorf("incomplete s3 config") + } + pathPrefix := strings.TrimSpace(sec.S3.PathPrefix) + if pathPrefix == "" { + pathPrefix = "weknora/" + } + svc, err := NewS3FileService(sec.S3.Endpoint, sec.S3.AccessKey, sec.S3.SecretKey, sec.S3.BucketName, sec.S3.Region, pathPrefix) + return svc, p, err default: return nil, p, fmt.Errorf("unsupported provider %q", p) diff --git a/internal/application/service/file/s3.go b/internal/application/service/file/s3.go new file mode 100644 index 00000000..b0cd5fd4 --- /dev/null +++ b/internal/application/service/file/s3.go @@ -0,0 +1,332 @@ +package file + +import ( + "bytes" + "context" + "errors" + "fmt" + "io" + "mime/multipart" + "path/filepath" + "strings" + "time" + + "github.com/Tencent/WeKnora/internal/types/interfaces" + "github.com/Tencent/WeKnora/internal/utils" + "github.com/aws/aws-sdk-go-v2/aws" + "github.com/aws/aws-sdk-go-v2/config" + "github.com/aws/aws-sdk-go-v2/credentials" + "github.com/aws/aws-sdk-go-v2/service/s3" + "github.com/aws/aws-sdk-go-v2/service/s3/types" + "github.com/google/uuid" +) + +// s3FileService AWS S3 file service implementation +type s3FileService struct { + client *s3.Client + bucketName string + pathPrefix string +} + +// newS3Client creates a bare s3FileService with just the SDK client initialised. +func newS3Client(endpoint, accessKey, secretKey, bucketName, region, pathPrefix string) (*s3FileService, error) { + var cfg aws.Config + var err error + + // Configure AWS SDK + cfg, err = config.LoadDefaultConfig(context.Background(), + config.WithRegion(region), + config.WithCredentialsProvider(credentials.NewStaticCredentialsProvider(accessKey, secretKey, "")), + ) + + if err != nil { + return nil, fmt.Errorf("failed to load AWS config: %w", err) + } + + // Create S3 client with custom endpoint if provided + var client *s3.Client + if endpoint != "" { + // Use S3-specific endpoint resolver for custom endpoints + client = s3.NewFromConfig(cfg, s3.WithEndpointResolver(s3.EndpointResolverFromURL(endpoint))) + } else { + // Standard AWS S3 + client = s3.NewFromConfig(cfg) + } + + // Normalize pathPrefix: ensure it ends with '/' if not empty + if pathPrefix != "" && !strings.HasSuffix(pathPrefix, "/") { + pathPrefix += "/" + } + + return &s3FileService{ + client: client, + bucketName: bucketName, + pathPrefix: pathPrefix, + }, nil +} + +// NewS3FileService creates an AWS S3 file service. +// It verifies that the bucket exists and creates it if missing. +func NewS3FileService(endpoint, + accessKey, secretKey, bucketName, region, pathPrefix string, +) (interfaces.FileService, error) { + svc, err := newS3Client(endpoint, accessKey, secretKey, bucketName, region, pathPrefix) + if err != nil { + return nil, err + } + + // Check if bucket exists + exists, err := svc.bucketExists(context.Background()) + if err != nil { + return nil, fmt.Errorf("failed to check bucket: %w", err) + } + + if !exists { + if err = svc.createBucket(context.Background()); err != nil { + return nil, fmt.Errorf("failed to create bucket: %w", err) + } + } + + return svc, nil +} + +// bucketExists checks if the bucket exists +func (s *s3FileService) bucketExists(ctx context.Context) (bool, error) { + _, err := s.client.HeadBucket(ctx, &s3.HeadBucketInput{ + Bucket: aws.String(s.bucketName), + }) + if err != nil { + // Check if the error is a NotFound error + var notFound *types.NotFound + if errors.As(err, ¬Found) { + return false, nil + } + return false, err + } + return true, nil +} + +// createBucket creates a new bucket +func (s *s3FileService) createBucket(ctx context.Context) error { + _, err := s.client.CreateBucket(ctx, &s3.CreateBucketInput{ + Bucket: aws.String(s.bucketName), + }) + return err +} + +// CheckConnectivity verifies S3 is reachable and, if a bucket is configured, +// that the bucket exists. This is a read-only probe — it never creates a bucket. +func (s *s3FileService) CheckConnectivity(ctx context.Context) error { + checkCtx, cancel := context.WithTimeout(ctx, 10*time.Second) + defer cancel() + + if s.bucketName != "" { + exists, err := s.bucketExists(checkCtx) + if err != nil { + return err + } + if !exists { + return fmt.Errorf("bucket %q does not exist", s.bucketName) + } + return nil + } + + // List buckets to verify connectivity + _, err := s.client.ListBuckets(checkCtx, &s3.ListBucketsInput{}) + return err +} + +// CheckS3Connectivity tests S3 connectivity using the provided credentials. +// It creates a temporary service instance internally and delegates to CheckConnectivity. +func CheckS3Connectivity(ctx context.Context, endpoint, accessKey, secretKey, bucketName, region string) error { + svc, err := newS3Client(endpoint, accessKey, secretKey, bucketName, region, "") + if err != nil { + return err + } + return svc.CheckConnectivity(ctx) +} + +// parseS3FilePath extracts the object name from a provider scheme: s3://{bucket}/{objectKey} +func (s *s3FileService) parseS3FilePath(filePath string) (string, error) { + // Provider scheme format: s3://{bucket}/{objectKey} + const prefix = "s3://" + if !strings.HasPrefix(filePath, prefix) { + return "", fmt.Errorf("invalid S3 file path: %s", filePath) + } + rest := strings.TrimPrefix(filePath, prefix) + parts := strings.SplitN(rest, "/", 2) + if len(parts) != 2 || parts[0] == "" || parts[1] == "" { + return "", fmt.Errorf("invalid S3 file path: %s", filePath) + } + if parts[0] != s.bucketName { + return "", fmt.Errorf("bucket mismatch in path: got %s, want %s", parts[0], s.bucketName) + } + if err := utils.SafeObjectKey(parts[1]); err != nil { + return "", fmt.Errorf("invalid file path: %w", err) + } + return parts[1], nil +} + +// getContentTypeByExt returns the content type based on file extension +func getContentTypeByExt(ext string) string { + switch strings.ToLower(ext) { + case ".csv": + return "text/csv; charset=utf-8" + case ".json": + return "application/json" + case ".pdf": + return "application/pdf" + case ".doc": + return "application/msword" + case ".docx": + return "application/vnd.openxmlformats-officedocument.wordprocessingml.document" + case ".xls": + return "application/vnd.ms-excel" + case ".xlsx": + return "application/vnd.openxmlformats-officedocument.spreadsheetml.sheet" + case ".ppt": + return "application/vnd.ms-powerpoint" + case ".pptx": + return "application/vnd.openxmlformats-officedocument.presentationml.presentation" + case ".txt": + return "text/plain; charset=utf-8" + case ".md": + return "text/markdown" + case ".html": + return "text/html; charset=utf-8" + case ".jpg", ".jpeg": + return "image/jpeg" + case ".png": + return "image/png" + case ".gif": + return "image/gif" + case ".svg": + return "image/svg+xml" + case ".mp3": + return "audio/mpeg" + case ".mp4": + return "video/mp4" + default: + return "application/octet-stream" + } +} + +// SaveFile saves a file to S3 +func (s *s3FileService) SaveFile(ctx context.Context, + file *multipart.FileHeader, tenantID uint64, knowledgeID string, +) (string, error) { + // Generate object name + ext := filepath.Ext(file.Filename) + objectName := fmt.Sprintf("%s%d/%s/%s%s", s.pathPrefix, tenantID, knowledgeID, uuid.New().String(), ext) + + // Open file + src, err := file.Open() + if err != nil { + return "", fmt.Errorf("failed to open file: %w", err) + } + defer src.Close() + + // Determine content type + contentType := file.Header.Get("Content-Type") + if contentType == "" { + contentType = getContentTypeByExt(ext) + } + + // Upload file to S3 + _, err = s.client.PutObject(ctx, &s3.PutObjectInput{ + Bucket: aws.String(s.bucketName), + Key: aws.String(objectName), + Body: src, + ContentLength: aws.Int64(file.Size), + ContentType: aws.String(contentType), + }) + if err != nil { + return "", fmt.Errorf("failed to upload file to S3: %w", err) + } + + return fmt.Sprintf("s3://%s/%s", s.bucketName, objectName), nil +} + +// GetFile gets a file from S3 +func (s *s3FileService) GetFile(ctx context.Context, filePath string) (io.ReadCloser, error) { + objectName, err := s.parseS3FilePath(filePath) + if err != nil { + return nil, err + } + + resp, err := s.client.GetObject(ctx, &s3.GetObjectInput{ + Bucket: aws.String(s.bucketName), + Key: aws.String(objectName), + }) + if err != nil { + return nil, fmt.Errorf("failed to get file from S3: %w", err) + } + + return resp.Body, nil +} + +// DeleteFile deletes a file +func (s *s3FileService) DeleteFile(ctx context.Context, filePath string) error { + objectName, err := s.parseS3FilePath(filePath) + if err != nil { + return err + } + + _, err = s.client.DeleteObject(ctx, &s3.DeleteObjectInput{ + Bucket: aws.String(s.bucketName), + Key: aws.String(objectName), + }) + if err != nil { + return fmt.Errorf("failed to delete file: %w", err) + } + + return nil +} + +// SaveBytes saves bytes data to S3 and returns the file path +// temp parameter is ignored for S3 (no auto-expiration support in this implementation) +func (s *s3FileService) SaveBytes(ctx context.Context, data []byte, tenantID uint64, fileName string, temp bool) (string, error) { + safeName, err := utils.SafeFileName(fileName) + if err != nil { + return "", fmt.Errorf("invalid file name: %w", err) + } + ext := filepath.Ext(safeName) + objectName := fmt.Sprintf("%s%d/exports/%s%s", s.pathPrefix, tenantID, uuid.New().String(), ext) + + // Upload bytes to S3 + reader := bytes.NewReader(data) + _, err = s.client.PutObject(ctx, &s3.PutObjectInput{ + Bucket: aws.String(s.bucketName), + Key: aws.String(objectName), + Body: reader, + ContentLength: aws.Int64(int64(len(data))), + ContentType: aws.String("text/csv; charset=utf-8"), + }) + if err != nil { + return "", fmt.Errorf("failed to upload bytes to S3: %w", err) + } + + return fmt.Sprintf("s3://%s/%s", s.bucketName, objectName), nil +} + +// GetFileURL returns a presigned download URL for the file +func (s *s3FileService) GetFileURL(ctx context.Context, filePath string) (string, error) { + objectName, err := s.parseS3FilePath(filePath) + if err != nil { + return "", err + } + + // Create presign client + presignClient := s3.NewPresignClient(s.client) + + // Generate presigned URL + presignedReq, err := presignClient.PresignGetObject(ctx, &s3.GetObjectInput{ + Bucket: aws.String(s.bucketName), + Key: aws.String(objectName), + }, s3.WithPresignExpires(24*time.Hour)) + if err != nil { + return "", fmt.Errorf("failed to generate presigned URL: %w", err) + } + + return presignedReq.URL, nil +} diff --git a/internal/container/container.go b/internal/container/container.go index 6d64a50d..4475f66e 100644 --- a/internal/container/container.go +++ b/internal/container/container.go @@ -534,6 +534,26 @@ func initFileService(cfg *config.Config) (interfaces.FileService, error) { os.Getenv("TOS_TEMP_BUCKET_NAME"), // 可选:临时桶名称(桶需配置生命周期规则自动过期) os.Getenv("TOS_TEMP_REGION"), // 可选:临时桶 region,默认与主桶相同 ) + case "s3": + if os.Getenv("S3_ENDPOINT") == "" || + os.Getenv("S3_REGION") == "" || + os.Getenv("S3_ACCESS_KEY") == "" || + os.Getenv("S3_SECRET_KEY") == "" || + os.Getenv("S3_BUCKET_NAME") == "" { + return nil, fmt.Errorf("missing S3 configuration") + } + pathPrefix := os.Getenv("S3_PATH_PREFIX") + if pathPrefix == "" { + pathPrefix = "weknora/" + } + return file.NewS3FileService( + os.Getenv("S3_ENDPOINT"), + os.Getenv("S3_ACCESS_KEY"), + os.Getenv("S3_SECRET_KEY"), + os.Getenv("S3_BUCKET_NAME"), + os.Getenv("S3_REGION"), + pathPrefix, + ) case "local": baseDir := os.Getenv("LOCAL_STORAGE_BASE_DIR") if baseDir == "" { diff --git a/internal/handler/system.go b/internal/handler/system.go index 96520808..fdd0f666 100644 --- a/internal/handler/system.go +++ b/internal/handler/system.go @@ -686,10 +686,11 @@ func isBlockedStorageEndpoint(endpoint string) (bool, string) { // StorageCheckRequest is the body for POST /system/storage-engine-check. type StorageCheckRequest struct { - Provider string `json:"provider"` // "minio", "cos", or "tos" + Provider string `json:"provider"` // "minio", "cos", "tos", or "s3" MinIO *types.MinIOEngineConfig `json:"minio,omitempty"` COS *types.COSEngineConfig `json:"cos,omitempty"` TOS *types.TOSEngineConfig `json:"tos,omitempty"` + S3 *types.S3EngineConfig `json:"s3,omitempty"` } // StorageCheckResponse is the response for a single-engine connectivity check. @@ -724,6 +725,8 @@ func (h *SystemHandler) CheckStorageEngine(c *gin.Context) { h.checkCOS(c, ctx, req.COS) case "tos": h.checkTOS(c, ctx, req.TOS) + case "s3": + h.checkS3(c, ctx, req.S3) default: c.JSON(200, gin.H{"code": 0, "data": StorageCheckResponse{OK: true, Message: "本地存储无需检测"}}) } @@ -881,3 +884,37 @@ func (h *SystemHandler) checkTOS(c *gin.Context, ctx context.Context, cfg *types } c.JSON(200, gin.H{"code": 0, "data": StorageCheckResponse{OK: true, Message: fmt.Sprintf("连接成功,Bucket「%s」已确认存在", cfg.BucketName)}}) } + +func (h *SystemHandler) checkS3(c *gin.Context, ctx context.Context, cfg *types.S3EngineConfig) { + if cfg == nil { + c.JSON(200, gin.H{"code": 0, "data": StorageCheckResponse{OK: false, Message: "未提供 S3 配置"}}) + return + } + if cfg.Endpoint == "" || cfg.Region == "" || cfg.AccessKey == "" || cfg.SecretKey == "" || cfg.BucketName == "" { + c.JSON(200, gin.H{"code": 0, "data": StorageCheckResponse{OK: false, Message: "Endpoint、Region、Access Key、Secret Key、Bucket 名称不能为空"}}) + return + } + + if blocked, reason := isBlockedStorageEndpoint(cfg.Endpoint); blocked { + logger.Warnf(ctx, "Storage check: S3 endpoint blocked by SSRF protection, endpoint: %s", cfg.Endpoint) + c.JSON(200, gin.H{"code": 0, "data": StorageCheckResponse{OK: false, Message: reason}}) + return + } + + err := file.CheckS3Connectivity(ctx, cfg.Endpoint, cfg.AccessKey, cfg.SecretKey, cfg.BucketName, cfg.Region) + if err != nil { + logger.Errorf(ctx, "Storage check: S3 connectivity failed, bucket: %s, error: %v", cfg.BucketName, err) + errMsg := err.Error() + if strings.Contains(errMsg, "403") { + c.JSON(200, gin.H{"code": 0, "data": StorageCheckResponse{OK: false, Message: "认证失败,请检查 Access Key / Secret Key 是否正确"}}) + return + } + if strings.Contains(errMsg, "404") || strings.Contains(errMsg, "NotFound") { + c.JSON(200, gin.H{"code": 0, "data": StorageCheckResponse{OK: false, Message: fmt.Sprintf("Bucket「%s」不存在,请检查名称和 Region", cfg.BucketName)}}) + return + } + c.JSON(200, gin.H{"code": 0, "data": StorageCheckResponse{OK: false, Message: sanitizeStorageCheckError(err)}}) + return + } + c.JSON(200, gin.H{"code": 0, "data": StorageCheckResponse{OK: true, Message: fmt.Sprintf("连接成功,Bucket「%s」已确认存在", cfg.BucketName)}}) +} diff --git a/internal/types/tenant.go b/internal/types/tenant.go index 59e01db9..5edd49d7 100644 --- a/internal/types/tenant.go +++ b/internal/types/tenant.go @@ -317,14 +317,15 @@ func (c *ParserEngineConfig) Scan(value interface{}) error { return json.Unmarshal(b, c) } -// StorageEngineConfig holds tenant-level storage engine parameters for Local, MinIO, COS, and TOS. +// StorageEngineConfig holds tenant-level storage engine parameters for Local, MinIO, COS, TOS, and S3. // Knowledge bases select which provider to use; parameters are read from here. type StorageEngineConfig struct { - DefaultProvider string `json:"default_provider"` // "local", "minio", "cos", "tos" + DefaultProvider string `json:"default_provider"` // "local", "minio", "cos", "tos", "s3" Local *LocalEngineConfig `json:"local,omitempty"` MinIO *MinIOEngineConfig `json:"minio,omitempty"` COS *COSEngineConfig `json:"cos,omitempty"` TOS *TOSEngineConfig `json:"tos,omitempty"` + S3 *S3EngineConfig `json:"s3,omitempty"` } // LocalEngineConfig is for local file system storage (single-machine deployment only). @@ -364,6 +365,16 @@ type TOSEngineConfig struct { PathPrefix string `json:"path_prefix"` } +// S3EngineConfig is for AWS S3 and S3-compatible object storage. +type S3EngineConfig struct { + Endpoint string `json:"endpoint"` + Region string `json:"region"` + AccessKey string `json:"access_key"` + SecretKey string `json:"secret_key"` + BucketName string `json:"bucket_name"` + PathPrefix string `json:"path_prefix"` +} + // Value implements the driver.Valuer interface for StorageEngineConfig func (c *StorageEngineConfig) Value() (driver.Value, error) { if c == nil { diff --git a/scripts/check-env.sh b/scripts/check-env.sh index 592f1687..1daa73e4 100755 --- a/scripts/check-env.sh +++ b/scripts/check-env.sh @@ -97,6 +97,14 @@ if [ "$STORAGE_TYPE" = "tos" ]; then check_var "TOS_BUCKET_NAME" fi +if [ "$STORAGE_TYPE" = "s3" ]; then + check_var "S3_ENDPOINT" + check_var "S3_REGION" + check_var "S3_ACCESS_KEY" + check_var "S3_SECRET_KEY" + check_var "S3_BUCKET_NAME" +fi + echo "" log_info "Redis 配置:" check_var "REDIS_ADDR"