feat(security): implement path and filename validation utilities

- Added `SafePathUnderBase` to prevent path traversal by ensuring file paths remain within a specified base directory.
- Introduced `SafeFileName` to validate and sanitize file names, disallowing path traversal and empty names.
- Implemented `SafeObjectKey` to validate object keys for storage, ensuring they do not contain path traversal sequences.
- Updated file handling methods in `cos.go`, `local.go`, and `minio.go` to utilize these new validation utilities, enhancing security against invalid file paths and names.

This update improves the robustness of file operations by enforcing strict validation rules, thereby mitigating potential security risks.
This commit is contained in:
wizardchen
2026-02-24 10:51:12 +08:00
committed by lyingbug
parent d28a370931
commit 5241dbc39e
4 changed files with 120 additions and 13 deletions

View File

@@ -13,6 +13,7 @@ import (
"time"
"github.com/Tencent/WeKnora/internal/types/interfaces"
"github.com/Tencent/WeKnora/internal/utils"
"github.com/google/uuid"
"github.com/tencentyun/cos-go-sdk-v5"
)
@@ -99,6 +100,9 @@ func (s *cosFileService) SaveFile(ctx context.Context,
// GetFile retrieves a file from COS storage by its path URL
func (s *cosFileService) GetFile(ctx context.Context, filePathUrl string) (io.ReadCloser, error) {
objectName := strings.TrimPrefix(filePathUrl, s.bucketURL)
if err := utils.SafeObjectKey(objectName); err != nil {
return nil, fmt.Errorf("invalid file path: %w", err)
}
resp, err := s.client.Object.Get(ctx, objectName, nil)
if err != nil {
return nil, fmt.Errorf("failed to get file from COS: %w", err)
@@ -109,6 +113,9 @@ func (s *cosFileService) GetFile(ctx context.Context, filePathUrl string) (io.Re
// DeleteFile removes a file from COS storage
func (s *cosFileService) DeleteFile(ctx context.Context, filePath string) error {
objectName := strings.TrimPrefix(filePath, s.bucketURL)
if err := utils.SafeObjectKey(objectName); err != nil {
return fmt.Errorf("invalid file path: %w", err)
}
_, err := s.client.Object.Delete(ctx, objectName)
if err != nil {
return fmt.Errorf("failed to delete file: %w", err)
@@ -120,7 +127,11 @@ func (s *cosFileService) DeleteFile(ctx context.Context, filePath string) error
// If temp is true and temp bucket is configured, saves to temp bucket (with lifecycle auto-expiration)
// Otherwise saves to main bucket
func (s *cosFileService) SaveBytes(ctx context.Context, data []byte, tenantID uint64, fileName string, temp bool) (string, error) {
ext := filepath.Ext(fileName)
safeName, err := utils.SafeFileName(fileName)
if err != nil {
return "", fmt.Errorf("invalid file name: %w", err)
}
ext := filepath.Ext(safeName)
reader := bytes.NewReader(data)
// 如果请求写入临时桶且临时桶已配置
@@ -135,7 +146,7 @@ func (s *cosFileService) SaveBytes(ctx context.Context, data []byte, tenantID ui
// 写入主桶
objectName := fmt.Sprintf("%s/%d/exports/%s%s", s.cosPathPrefix, tenantID, uuid.New().String(), ext)
_, err := s.client.Object.Put(ctx, objectName, reader, nil)
_, err = s.client.Object.Put(ctx, objectName, reader, nil)
if err != nil {
return "", fmt.Errorf("failed to upload bytes to COS: %w", err)
}
@@ -148,6 +159,9 @@ func (s *cosFileService) GetFileURL(ctx context.Context, filePath string) (strin
// 判断文件属于哪个桶
if s.tempClient != nil && strings.HasPrefix(filePath, s.tempBucketURL) {
objectName := strings.TrimPrefix(filePath, s.tempBucketURL)
if err := utils.SafeObjectKey(objectName); err != nil {
return "", fmt.Errorf("invalid file path: %w", err)
}
// Generate presigned URL (valid for 24 hours)
presignedURL, err := s.tempClient.Object.GetPresignedURL(ctx, http.MethodGet, objectName, s.tempClient.GetCredential().SecretID, s.tempClient.GetCredential().SecretKey, 24*time.Hour, nil)
if err != nil {
@@ -157,6 +171,9 @@ func (s *cosFileService) GetFileURL(ctx context.Context, filePath string) (strin
}
objectName := strings.TrimPrefix(filePath, s.bucketURL)
if err := utils.SafeObjectKey(objectName); err != nil {
return "", fmt.Errorf("invalid file path: %w", err)
}
// Generate presigned URL (valid for 24 hours)
presignedURL, err := s.client.Object.GetPresignedURL(ctx, http.MethodGet, objectName, s.client.GetCredential().SecretID, s.client.GetCredential().SecretKey, 24*time.Hour, nil)
if err != nil {

View File

@@ -11,6 +11,7 @@ import (
"github.com/Tencent/WeKnora/internal/logger"
"github.com/Tencent/WeKnora/internal/types/interfaces"
secutils "github.com/Tencent/WeKnora/internal/utils"
)
// localFileService implements the FileService interface for local file system storage
@@ -37,6 +38,10 @@ func (s *localFileService) SaveFile(ctx context.Context,
// Create storage directory with tenant and knowledge ID
dir := filepath.Join(s.baseDir, fmt.Sprintf("%d", tenantID), knowledgeID)
if _, err := secutils.SafePathUnderBase(s.baseDir, dir); err != nil {
logger.Errorf(ctx, "Path traversal denied for SaveFile dir: %v", err)
return "", fmt.Errorf("invalid path: %w", err)
}
logger.Infof(ctx, "Creating directory: %s", dir)
if err := os.MkdirAll(dir, 0o755); err != nil {
logger.Errorf(ctx, "Failed to create directory: %v", err)
@@ -80,11 +85,17 @@ func (s *localFileService) SaveFile(ctx context.Context,
// GetFile retrieves a file from the local file system by its path
// Returns a ReadCloser for reading the file content
// 路径必须在 baseDir 下,防止路径遍历(如 ../../
func (s *localFileService) GetFile(ctx context.Context, filePath string) (io.ReadCloser, error) {
logger.Infof(ctx, "Getting file: %s", filePath)
// Open the file for reading
file, err := os.Open(filePath)
resolved, err := secutils.SafePathUnderBase(s.baseDir, filePath)
if err != nil {
logger.Errorf(ctx, "Path traversal denied for GetFile: %v", err)
return nil, fmt.Errorf("invalid file path: %w", err)
}
file, err := os.Open(resolved)
if err != nil {
logger.Errorf(ctx, "Failed to open file: %v", err)
return nil, fmt.Errorf("failed to open file: %w", err)
@@ -96,11 +107,17 @@ func (s *localFileService) GetFile(ctx context.Context, filePath string) (io.Rea
// DeleteFile removes a file from the local file system
// Returns an error if deletion fails
// 路径必须在 baseDir 下,防止路径遍历(如 ../../
func (s *localFileService) DeleteFile(ctx context.Context, filePath string) error {
logger.Infof(ctx, "Deleting file: %s", filePath)
// Remove the file
err := os.Remove(filePath)
resolved, err := secutils.SafePathUnderBase(s.baseDir, filePath)
if err != nil {
logger.Errorf(ctx, "Path traversal denied for DeleteFile: %v", err)
return fmt.Errorf("invalid file path: %w", err)
}
err = os.Remove(resolved)
if err != nil {
logger.Errorf(ctx, "Failed to delete file: %v", err)
return fmt.Errorf("failed to delete file: %w", err)
@@ -112,9 +129,16 @@ func (s *localFileService) DeleteFile(ctx context.Context, filePath string) erro
// SaveBytes saves bytes data to a file and returns the file path
// temp parameter is ignored for local storage (no auto-expiration support)
// fileName 仅允许安全文件名,禁止路径遍历(如 ../../
func (s *localFileService) SaveBytes(ctx context.Context, data []byte, tenantID uint64, fileName string, temp bool) (string, error) {
logger.Infof(ctx, "Saving bytes data: fileName=%s, size=%d, tenantID=%d, temp=%v", fileName, len(data), tenantID, temp)
safeName, err := secutils.SafeFileName(fileName)
if err != nil {
logger.Errorf(ctx, "Invalid fileName for SaveBytes: %v", err)
return "", fmt.Errorf("invalid file name: %w", err)
}
// Create storage directory with tenant ID
dir := filepath.Join(s.baseDir, fmt.Sprintf("%d", tenantID), "exports")
if err := os.MkdirAll(dir, 0o755); err != nil {
@@ -123,8 +147,8 @@ func (s *localFileService) SaveBytes(ctx context.Context, data []byte, tenantID
}
// Generate unique filename using timestamp
ext := filepath.Ext(fileName)
baseName := fileName[:len(fileName)-len(ext)]
ext := filepath.Ext(safeName)
baseName := safeName[:len(safeName)-len(ext)]
uniqueFileName := fmt.Sprintf("%s_%d%s", baseName, time.Now().UnixNano(), ext)
filePath := filepath.Join(dir, uniqueFileName)

View File

@@ -10,6 +10,7 @@ import (
"time"
"github.com/Tencent/WeKnora/internal/types/interfaces"
"github.com/Tencent/WeKnora/internal/utils"
"github.com/google/uuid"
"github.com/minio/minio-go/v7"
"github.com/minio/minio-go/v7/pkg/credentials"
@@ -90,9 +91,12 @@ func (s *minioFileService) GetFile(ctx context.Context, filePath string) (io.Rea
// Extract object name
objectName := filePath[9+len(s.bucketName):]
if objectName[0] == '/' {
if len(objectName) > 0 && objectName[0] == '/' {
objectName = objectName[1:]
}
if err := utils.SafeObjectKey(objectName); err != nil {
return nil, fmt.Errorf("invalid file path: %w", err)
}
// Get object
obj, err := s.client.GetObject(ctx, s.bucketName, objectName, minio.GetObjectOptions{})
@@ -113,9 +117,12 @@ func (s *minioFileService) DeleteFile(ctx context.Context, filePath string) erro
// Extract object name
objectName := filePath[9+len(s.bucketName):]
if objectName[0] == '/' {
if len(objectName) > 0 && objectName[0] == '/' {
objectName = objectName[1:]
}
if err := utils.SafeObjectKey(objectName); err != nil {
return fmt.Errorf("invalid file path: %w", err)
}
// Delete object
err := s.client.RemoveObject(ctx, s.bucketName, objectName, minio.RemoveObjectOptions{
@@ -131,12 +138,16 @@ func (s *minioFileService) DeleteFile(ctx context.Context, filePath string) erro
// SaveBytes saves bytes data to MinIO and returns the file path
// temp parameter is ignored for MinIO (no auto-expiration support in this implementation)
func (s *minioFileService) SaveBytes(ctx context.Context, data []byte, tenantID uint64, fileName string, temp bool) (string, error) {
ext := filepath.Ext(fileName)
safeName, err := utils.SafeFileName(fileName)
if err != nil {
return "", fmt.Errorf("invalid file name: %w", err)
}
ext := filepath.Ext(safeName)
objectName := fmt.Sprintf("%d/exports/%s%s", tenantID, uuid.New().String(), ext)
// Upload bytes to MinIO
reader := bytes.NewReader(data)
_, err := s.client.PutObject(ctx, s.bucketName, objectName, reader, int64(len(data)), minio.PutObjectOptions{
_, err = s.client.PutObject(ctx, s.bucketName, objectName, reader, int64(len(data)), minio.PutObjectOptions{
ContentType: "text/csv; charset=utf-8",
})
if err != nil {
@@ -155,9 +166,12 @@ func (s *minioFileService) GetFileURL(ctx context.Context, filePath string) (str
// Extract object name
objectName := filePath[9+len(s.bucketName):]
if objectName[0] == '/' {
if len(objectName) > 0 && objectName[0] == '/' {
objectName = objectName[1:]
}
if err := utils.SafeObjectKey(objectName); err != nil {
return "", fmt.Errorf("invalid file path: %w", err)
}
// Generate presigned URL (valid for 24 hours)
presignedURL, err := s.client.PresignedGetObject(ctx, s.bucketName, objectName, 24*time.Hour, nil)

View File

@@ -7,6 +7,7 @@ import (
"net"
"net/http"
"net/url"
"path/filepath"
"regexp"
"strings"
"time"
@@ -95,6 +96,57 @@ func ValidateInput(input string) (string, bool) {
return strings.TrimSpace(input), true
}
// SafePathUnderBase 校验 filePath 是否落在 baseDir 下,防止路径遍历(如 ../../)。
// 返回规范化的绝对路径;若路径逃逸出 baseDir 则返回错误。
func SafePathUnderBase(baseDir, filePath string) (string, error) {
if baseDir == "" || filePath == "" {
return "", fmt.Errorf("baseDir and filePath cannot be empty")
}
absBase, err := filepath.Abs(filepath.Clean(baseDir))
if err != nil {
return "", fmt.Errorf("invalid base dir: %w", err)
}
absPath, err := filepath.Abs(filepath.Clean(filePath))
if err != nil {
return "", fmt.Errorf("invalid file path: %w", err)
}
sep := string(filepath.Separator)
if absPath != absBase && !strings.HasPrefix(absPath, absBase+sep) {
return "", fmt.Errorf("path traversal denied: path is outside base directory")
}
return absPath, nil
}
// SafeFileName 校验并返回安全的“仅文件名”部分,防止路径遍历。
// 仅保留最后一个路径成分,禁止 ".."、空名或仅含点,用于 SaveBytes 等场景。
func SafeFileName(fileName string) (string, error) {
if fileName == "" {
return "", fmt.Errorf("fileName cannot be empty")
}
base := filepath.Base(filepath.Clean(fileName))
if base == "" || base == "." || base == ".." {
return "", fmt.Errorf("invalid fileName: path traversal or empty name")
}
if strings.Contains(base, "..") {
return "", fmt.Errorf("invalid fileName: contains path traversal")
}
if len(base) > 255 {
return "", fmt.Errorf("fileName too long")
}
return base, nil
}
// SafeObjectKey 校验对象存储的 key如 COS/MinIO objectName禁止包含 ".." 等路径遍历
func SafeObjectKey(objectKey string) error {
if objectKey == "" {
return fmt.Errorf("object key cannot be empty")
}
if strings.Contains(objectKey, "..") {
return fmt.Errorf("object key contains path traversal")
}
return nil
}
// IsValidURL 验证 URL 是否安全
func IsValidURL(url string) bool {
if url == "" {