From 0e19c978aeed192d095fa8aadd98b1978150a374 Mon Sep 17 00:00:00 2001 From: sqkstwj Date: Thu, 30 Apr 2026 10:01:13 +0800 Subject: [PATCH] fix(files): fall back to global file service when tenant storage config is unavailable --- internal/router/router.go | 37 ++++++++-- internal/router/router_files_test.go | 104 +++++++++++++++++++++++++++ 2 files changed, 135 insertions(+), 6 deletions(-) create mode 100644 internal/router/router_files_test.go diff --git a/internal/router/router.go b/internal/router/router.go index ae9a0f86..7d321839 100644 --- a/internal/router/router.go +++ b/internal/router/router.go @@ -35,6 +35,7 @@ type RouterParams struct { dig.In Config *config.Config + FileService interfaces.FileService UserService interfaces.UserService KBService interfaces.KnowledgeBaseService KnowledgeService interfaces.KnowledgeService @@ -120,7 +121,7 @@ func NewRouter(params RouterParams) *gin.Engine { r.Use(middleware.Auth(params.TenantService, params.UserService, params.Config)) // 文件服务:统一代理本地/MinIO/COS/TOS存储后端(需要认证) - serveFiles(r) + serveFiles(r, params.FileService) // Presigned file access: no auth required, signature-verified. servePresignedFiles(r, params.TenantService) @@ -741,7 +742,11 @@ func serveFrontendStatic(r *gin.Engine) { // // Route: // - /files?file_path= -func serveFiles(r *gin.Engine) { +type getRouteRegistrar interface { + GET(string, ...gin.HandlerFunc) gin.IRoutes +} + +func serveFiles(r getRouteRegistrar, globalFileService interfaces.FileService) { baseDir := os.Getenv("LOCAL_STORAGE_BASE_DIR") if baseDir == "" { baseDir = "/data/files" @@ -774,11 +779,31 @@ func serveFiles(r *gin.Engine) { return } - fileSvc, resolvedProvider, err := filesvc.NewFileServiceFromStorageConfig(provider, tenant.StorageEngineConfig, absDir) + var ( + fileSvc interfaces.FileService + resolvedProvider string + err error + ) + + if tenant.StorageEngineConfig != nil { + fileSvc, resolvedProvider, err = filesvc.NewFileServiceFromStorageConfig(provider, tenant.StorageEngineConfig, absDir) + } else { + err = http.ErrMissingFile + } if err != nil { - logger.Warnf(context.Background(), "[Router] /files resolve file service failed: tenant_id=%d provider=%s err=%v", tenant.ID, provider, err) - c.Status(http.StatusBadRequest) - return + globalStorageType := strings.ToLower(strings.TrimSpace(os.Getenv("STORAGE_TYPE"))) + if globalStorageType == "" { + globalStorageType = "local" + } + if provider == globalStorageType && globalFileService != nil { + logger.Warnf(context.Background(), "[Router] /files tenant storage config missing or invalid, fallback to global file service: tenant_id=%d provider=%s err=%v", tenant.ID, provider, err) + fileSvc = globalFileService + resolvedProvider = globalStorageType + } else { + logger.Warnf(context.Background(), "[Router] /files resolve file service failed without fallback: tenant_id=%d provider=%s global_storage_type=%s err=%v", tenant.ID, provider, globalStorageType, err) + c.Status(http.StatusBadRequest) + return + } } reader, err := fileSvc.GetFile(c.Request.Context(), filePath) diff --git a/internal/router/router_files_test.go b/internal/router/router_files_test.go new file mode 100644 index 00000000..19a4647d --- /dev/null +++ b/internal/router/router_files_test.go @@ -0,0 +1,104 @@ +package router + +import ( + "context" + "io" + "mime/multipart" + "net/http" + "net/http/httptest" + "net/url" + "strings" + "testing" + + "github.com/gin-gonic/gin" + + "github.com/Tencent/WeKnora/internal/types" + "github.com/Tencent/WeKnora/internal/types/interfaces" +) + +var _ interfaces.FileService = (*stubFileService)(nil) + +type stubFileService struct { + getFile func(ctx context.Context, filePath string) (io.ReadCloser, error) +} + +func (s *stubFileService) CheckConnectivity(ctx context.Context) error { + return nil +} + +func (s *stubFileService) SaveFile(ctx context.Context, file *multipart.FileHeader, tenantID uint64, knowledgeID string) (string, error) { + panic("unexpected call to SaveFile") +} + +func (s *stubFileService) SaveBytes(ctx context.Context, data []byte, tenantID uint64, fileName string, temp bool) (string, error) { + panic("unexpected call to SaveBytes") +} + +func (s *stubFileService) GetFile(ctx context.Context, filePath string) (io.ReadCloser, error) { + if s.getFile == nil { + panic("unexpected call to GetFile") + } + return s.getFile(ctx, filePath) +} + +func (s *stubFileService) GetFileURL(ctx context.Context, filePath string) (string, error) { + panic("unexpected call to GetFileURL") +} + +func (s *stubFileService) DeleteFile(ctx context.Context, filePath string) error { + panic("unexpected call to DeleteFile") +} + +func TestServeFilesFallsBackToGlobalFileService(t *testing.T) { + gin.SetMode(gin.TestMode) + t.Setenv("STORAGE_TYPE", "local") + + engine := gin.New() + var requestedPath string + serveFiles(engine, &stubFileService{ + getFile: func(ctx context.Context, filePath string) (io.ReadCloser, error) { + requestedPath = filePath + return io.NopCloser(strings.NewReader("fallback-body")), nil + }, + }) + + filePath := "local://docs/example.txt" + req := httptest.NewRequest(http.MethodGet, "/files?file_path="+url.QueryEscape(filePath), nil) + req = req.WithContext(context.WithValue(req.Context(), types.TenantInfoContextKey, &types.Tenant{ID: 42})) + + recorder := httptest.NewRecorder() + engine.ServeHTTP(recorder, req) + + if got, want := recorder.Code, http.StatusOK; got != want { + t.Fatalf("status = %d, want %d", got, want) + } + if requestedPath != filePath { + t.Fatalf("requested path = %q, want %q", requestedPath, filePath) + } + if body := recorder.Body.String(); body != "fallback-body" { + t.Fatalf("body = %q, want %q", body, "fallback-body") + } +} + +func TestServeFilesDoesNotFallbackWhenProviderDoesNotMatchGlobalStorage(t *testing.T) { + gin.SetMode(gin.TestMode) + t.Setenv("STORAGE_TYPE", "minio") + + engine := gin.New() + serveFiles(engine, &stubFileService{ + getFile: func(ctx context.Context, filePath string) (io.ReadCloser, error) { + t.Fatalf("GetFile should not be called for mismatched provider, got %q", filePath) + return nil, nil + }, + }) + + req := httptest.NewRequest(http.MethodGet, "/files?file_path="+url.QueryEscape("local://docs/example.txt"), nil) + req = req.WithContext(context.WithValue(req.Context(), types.TenantInfoContextKey, &types.Tenant{ID: 42})) + + recorder := httptest.NewRecorder() + engine.ServeHTTP(recorder, req) + + if got, want := recorder.Code, http.StatusBadRequest; got != want { + t.Fatalf("status = %d, want %d", got, want) + } +}