From 5241dbc39ed98942cd1bd4ea05f2cb01b5fbc3ef Mon Sep 17 00:00:00 2001 From: wizardchen Date: Tue, 24 Feb 2026 10:51:12 +0800 Subject: [PATCH] feat(security): implement path and filename validation utilities - Added `SafePathUnderBase` to prevent path traversal by ensuring file paths remain within a specified base directory. - Introduced `SafeFileName` to validate and sanitize file names, disallowing path traversal and empty names. - Implemented `SafeObjectKey` to validate object keys for storage, ensuring they do not contain path traversal sequences. - Updated file handling methods in `cos.go`, `local.go`, and `minio.go` to utilize these new validation utilities, enhancing security against invalid file paths and names. This update improves the robustness of file operations by enforcing strict validation rules, thereby mitigating potential security risks. --- internal/application/service/file/cos.go | 21 ++++++++- internal/application/service/file/local.go | 36 ++++++++++++--- internal/application/service/file/minio.go | 24 +++++++--- internal/utils/security.go | 52 ++++++++++++++++++++++ 4 files changed, 120 insertions(+), 13 deletions(-) diff --git a/internal/application/service/file/cos.go b/internal/application/service/file/cos.go index 99bc4221..94163174 100644 --- a/internal/application/service/file/cos.go +++ b/internal/application/service/file/cos.go @@ -13,6 +13,7 @@ import ( "time" "github.com/Tencent/WeKnora/internal/types/interfaces" + "github.com/Tencent/WeKnora/internal/utils" "github.com/google/uuid" "github.com/tencentyun/cos-go-sdk-v5" ) @@ -99,6 +100,9 @@ func (s *cosFileService) SaveFile(ctx context.Context, // GetFile retrieves a file from COS storage by its path URL func (s *cosFileService) GetFile(ctx context.Context, filePathUrl string) (io.ReadCloser, error) { objectName := strings.TrimPrefix(filePathUrl, s.bucketURL) + if err := utils.SafeObjectKey(objectName); err != nil { + return nil, fmt.Errorf("invalid file path: %w", err) + } resp, err := s.client.Object.Get(ctx, objectName, nil) if err != nil { return nil, fmt.Errorf("failed to get file from COS: %w", err) @@ -109,6 +113,9 @@ func (s *cosFileService) GetFile(ctx context.Context, filePathUrl string) (io.Re // DeleteFile removes a file from COS storage func (s *cosFileService) DeleteFile(ctx context.Context, filePath string) error { objectName := strings.TrimPrefix(filePath, s.bucketURL) + if err := utils.SafeObjectKey(objectName); err != nil { + return fmt.Errorf("invalid file path: %w", err) + } _, err := s.client.Object.Delete(ctx, objectName) if err != nil { return fmt.Errorf("failed to delete file: %w", err) @@ -120,7 +127,11 @@ func (s *cosFileService) DeleteFile(ctx context.Context, filePath string) error // If temp is true and temp bucket is configured, saves to temp bucket (with lifecycle auto-expiration) // Otherwise saves to main bucket func (s *cosFileService) SaveBytes(ctx context.Context, data []byte, tenantID uint64, fileName string, temp bool) (string, error) { - ext := filepath.Ext(fileName) + safeName, err := utils.SafeFileName(fileName) + if err != nil { + return "", fmt.Errorf("invalid file name: %w", err) + } + ext := filepath.Ext(safeName) reader := bytes.NewReader(data) // 如果请求写入临时桶且临时桶已配置 @@ -135,7 +146,7 @@ func (s *cosFileService) SaveBytes(ctx context.Context, data []byte, tenantID ui // 写入主桶 objectName := fmt.Sprintf("%s/%d/exports/%s%s", s.cosPathPrefix, tenantID, uuid.New().String(), ext) - _, err := s.client.Object.Put(ctx, objectName, reader, nil) + _, err = s.client.Object.Put(ctx, objectName, reader, nil) if err != nil { return "", fmt.Errorf("failed to upload bytes to COS: %w", err) } @@ -148,6 +159,9 @@ func (s *cosFileService) GetFileURL(ctx context.Context, filePath string) (strin // 判断文件属于哪个桶 if s.tempClient != nil && strings.HasPrefix(filePath, s.tempBucketURL) { objectName := strings.TrimPrefix(filePath, s.tempBucketURL) + if err := utils.SafeObjectKey(objectName); err != nil { + return "", fmt.Errorf("invalid file path: %w", err) + } // Generate presigned URL (valid for 24 hours) presignedURL, err := s.tempClient.Object.GetPresignedURL(ctx, http.MethodGet, objectName, s.tempClient.GetCredential().SecretID, s.tempClient.GetCredential().SecretKey, 24*time.Hour, nil) if err != nil { @@ -157,6 +171,9 @@ func (s *cosFileService) GetFileURL(ctx context.Context, filePath string) (strin } objectName := strings.TrimPrefix(filePath, s.bucketURL) + if err := utils.SafeObjectKey(objectName); err != nil { + return "", fmt.Errorf("invalid file path: %w", err) + } // Generate presigned URL (valid for 24 hours) presignedURL, err := s.client.Object.GetPresignedURL(ctx, http.MethodGet, objectName, s.client.GetCredential().SecretID, s.client.GetCredential().SecretKey, 24*time.Hour, nil) if err != nil { diff --git a/internal/application/service/file/local.go b/internal/application/service/file/local.go index 07e5b781..c40533b8 100644 --- a/internal/application/service/file/local.go +++ b/internal/application/service/file/local.go @@ -11,6 +11,7 @@ import ( "github.com/Tencent/WeKnora/internal/logger" "github.com/Tencent/WeKnora/internal/types/interfaces" + secutils "github.com/Tencent/WeKnora/internal/utils" ) // localFileService implements the FileService interface for local file system storage @@ -37,6 +38,10 @@ func (s *localFileService) SaveFile(ctx context.Context, // Create storage directory with tenant and knowledge ID dir := filepath.Join(s.baseDir, fmt.Sprintf("%d", tenantID), knowledgeID) + if _, err := secutils.SafePathUnderBase(s.baseDir, dir); err != nil { + logger.Errorf(ctx, "Path traversal denied for SaveFile dir: %v", err) + return "", fmt.Errorf("invalid path: %w", err) + } logger.Infof(ctx, "Creating directory: %s", dir) if err := os.MkdirAll(dir, 0o755); err != nil { logger.Errorf(ctx, "Failed to create directory: %v", err) @@ -80,11 +85,17 @@ func (s *localFileService) SaveFile(ctx context.Context, // GetFile retrieves a file from the local file system by its path // Returns a ReadCloser for reading the file content +// 路径必须在 baseDir 下,防止路径遍历(如 ../../) func (s *localFileService) GetFile(ctx context.Context, filePath string) (io.ReadCloser, error) { logger.Infof(ctx, "Getting file: %s", filePath) - // Open the file for reading - file, err := os.Open(filePath) + resolved, err := secutils.SafePathUnderBase(s.baseDir, filePath) + if err != nil { + logger.Errorf(ctx, "Path traversal denied for GetFile: %v", err) + return nil, fmt.Errorf("invalid file path: %w", err) + } + + file, err := os.Open(resolved) if err != nil { logger.Errorf(ctx, "Failed to open file: %v", err) return nil, fmt.Errorf("failed to open file: %w", err) @@ -96,11 +107,17 @@ func (s *localFileService) GetFile(ctx context.Context, filePath string) (io.Rea // DeleteFile removes a file from the local file system // Returns an error if deletion fails +// 路径必须在 baseDir 下,防止路径遍历(如 ../../) func (s *localFileService) DeleteFile(ctx context.Context, filePath string) error { logger.Infof(ctx, "Deleting file: %s", filePath) - // Remove the file - err := os.Remove(filePath) + resolved, err := secutils.SafePathUnderBase(s.baseDir, filePath) + if err != nil { + logger.Errorf(ctx, "Path traversal denied for DeleteFile: %v", err) + return fmt.Errorf("invalid file path: %w", err) + } + + err = os.Remove(resolved) if err != nil { logger.Errorf(ctx, "Failed to delete file: %v", err) return fmt.Errorf("failed to delete file: %w", err) @@ -112,9 +129,16 @@ func (s *localFileService) DeleteFile(ctx context.Context, filePath string) erro // SaveBytes saves bytes data to a file and returns the file path // temp parameter is ignored for local storage (no auto-expiration support) +// fileName 仅允许安全文件名,禁止路径遍历(如 ../../) func (s *localFileService) SaveBytes(ctx context.Context, data []byte, tenantID uint64, fileName string, temp bool) (string, error) { logger.Infof(ctx, "Saving bytes data: fileName=%s, size=%d, tenantID=%d, temp=%v", fileName, len(data), tenantID, temp) + safeName, err := secutils.SafeFileName(fileName) + if err != nil { + logger.Errorf(ctx, "Invalid fileName for SaveBytes: %v", err) + return "", fmt.Errorf("invalid file name: %w", err) + } + // Create storage directory with tenant ID dir := filepath.Join(s.baseDir, fmt.Sprintf("%d", tenantID), "exports") if err := os.MkdirAll(dir, 0o755); err != nil { @@ -123,8 +147,8 @@ func (s *localFileService) SaveBytes(ctx context.Context, data []byte, tenantID } // Generate unique filename using timestamp - ext := filepath.Ext(fileName) - baseName := fileName[:len(fileName)-len(ext)] + ext := filepath.Ext(safeName) + baseName := safeName[:len(safeName)-len(ext)] uniqueFileName := fmt.Sprintf("%s_%d%s", baseName, time.Now().UnixNano(), ext) filePath := filepath.Join(dir, uniqueFileName) diff --git a/internal/application/service/file/minio.go b/internal/application/service/file/minio.go index 0f0753af..0a3e93ff 100644 --- a/internal/application/service/file/minio.go +++ b/internal/application/service/file/minio.go @@ -10,6 +10,7 @@ import ( "time" "github.com/Tencent/WeKnora/internal/types/interfaces" + "github.com/Tencent/WeKnora/internal/utils" "github.com/google/uuid" "github.com/minio/minio-go/v7" "github.com/minio/minio-go/v7/pkg/credentials" @@ -90,9 +91,12 @@ func (s *minioFileService) GetFile(ctx context.Context, filePath string) (io.Rea // Extract object name objectName := filePath[9+len(s.bucketName):] - if objectName[0] == '/' { + if len(objectName) > 0 && objectName[0] == '/' { objectName = objectName[1:] } + if err := utils.SafeObjectKey(objectName); err != nil { + return nil, fmt.Errorf("invalid file path: %w", err) + } // Get object obj, err := s.client.GetObject(ctx, s.bucketName, objectName, minio.GetObjectOptions{}) @@ -113,9 +117,12 @@ func (s *minioFileService) DeleteFile(ctx context.Context, filePath string) erro // Extract object name objectName := filePath[9+len(s.bucketName):] - if objectName[0] == '/' { + if len(objectName) > 0 && objectName[0] == '/' { objectName = objectName[1:] } + if err := utils.SafeObjectKey(objectName); err != nil { + return fmt.Errorf("invalid file path: %w", err) + } // Delete object err := s.client.RemoveObject(ctx, s.bucketName, objectName, minio.RemoveObjectOptions{ @@ -131,12 +138,16 @@ func (s *minioFileService) DeleteFile(ctx context.Context, filePath string) erro // SaveBytes saves bytes data to MinIO and returns the file path // temp parameter is ignored for MinIO (no auto-expiration support in this implementation) func (s *minioFileService) SaveBytes(ctx context.Context, data []byte, tenantID uint64, fileName string, temp bool) (string, error) { - ext := filepath.Ext(fileName) + safeName, err := utils.SafeFileName(fileName) + if err != nil { + return "", fmt.Errorf("invalid file name: %w", err) + } + ext := filepath.Ext(safeName) objectName := fmt.Sprintf("%d/exports/%s%s", tenantID, uuid.New().String(), ext) // Upload bytes to MinIO reader := bytes.NewReader(data) - _, err := s.client.PutObject(ctx, s.bucketName, objectName, reader, int64(len(data)), minio.PutObjectOptions{ + _, err = s.client.PutObject(ctx, s.bucketName, objectName, reader, int64(len(data)), minio.PutObjectOptions{ ContentType: "text/csv; charset=utf-8", }) if err != nil { @@ -155,9 +166,12 @@ func (s *minioFileService) GetFileURL(ctx context.Context, filePath string) (str // Extract object name objectName := filePath[9+len(s.bucketName):] - if objectName[0] == '/' { + if len(objectName) > 0 && objectName[0] == '/' { objectName = objectName[1:] } + if err := utils.SafeObjectKey(objectName); err != nil { + return "", fmt.Errorf("invalid file path: %w", err) + } // Generate presigned URL (valid for 24 hours) presignedURL, err := s.client.PresignedGetObject(ctx, s.bucketName, objectName, 24*time.Hour, nil) diff --git a/internal/utils/security.go b/internal/utils/security.go index 15924c8c..4a61ae66 100644 --- a/internal/utils/security.go +++ b/internal/utils/security.go @@ -7,6 +7,7 @@ import ( "net" "net/http" "net/url" + "path/filepath" "regexp" "strings" "time" @@ -95,6 +96,57 @@ func ValidateInput(input string) (string, bool) { return strings.TrimSpace(input), true } +// SafePathUnderBase 校验 filePath 是否落在 baseDir 下,防止路径遍历(如 ../../)。 +// 返回规范化的绝对路径;若路径逃逸出 baseDir 则返回错误。 +func SafePathUnderBase(baseDir, filePath string) (string, error) { + if baseDir == "" || filePath == "" { + return "", fmt.Errorf("baseDir and filePath cannot be empty") + } + absBase, err := filepath.Abs(filepath.Clean(baseDir)) + if err != nil { + return "", fmt.Errorf("invalid base dir: %w", err) + } + absPath, err := filepath.Abs(filepath.Clean(filePath)) + if err != nil { + return "", fmt.Errorf("invalid file path: %w", err) + } + sep := string(filepath.Separator) + if absPath != absBase && !strings.HasPrefix(absPath, absBase+sep) { + return "", fmt.Errorf("path traversal denied: path is outside base directory") + } + return absPath, nil +} + +// SafeFileName 校验并返回安全的“仅文件名”部分,防止路径遍历。 +// 仅保留最后一个路径成分,禁止 ".."、空名或仅含点,用于 SaveBytes 等场景。 +func SafeFileName(fileName string) (string, error) { + if fileName == "" { + return "", fmt.Errorf("fileName cannot be empty") + } + base := filepath.Base(filepath.Clean(fileName)) + if base == "" || base == "." || base == ".." { + return "", fmt.Errorf("invalid fileName: path traversal or empty name") + } + if strings.Contains(base, "..") { + return "", fmt.Errorf("invalid fileName: contains path traversal") + } + if len(base) > 255 { + return "", fmt.Errorf("fileName too long") + } + return base, nil +} + +// SafeObjectKey 校验对象存储的 key(如 COS/MinIO objectName),禁止包含 ".." 等路径遍历 +func SafeObjectKey(objectKey string) error { + if objectKey == "" { + return fmt.Errorf("object key cannot be empty") + } + if strings.Contains(objectKey, "..") { + return fmt.Errorf("object key contains path traversal") + } + return nil +} + // IsValidURL 验证 URL 是否安全 func IsValidURL(url string) bool { if url == "" {