From 389900af7e1b7013f9abe4b9d1d0462d07eadfdc Mon Sep 17 00:00:00 2001 From: begoniezhao Date: Mon, 12 Jan 2026 17:34:01 +0800 Subject: [PATCH] refactor: Refactor retriever engine checks and expose mapping --- internal/handler/system.go | 26 ++++++++++++++++++++++++-- internal/mcp/client.go | 28 ---------------------------- internal/types/tenant.go | 6 ++++++ 3 files changed, 30 insertions(+), 30 deletions(-) diff --git a/internal/handler/system.go b/internal/handler/system.go index 3ab2b243..0f42be8e 100644 --- a/internal/handler/system.go +++ b/internal/handler/system.go @@ -8,6 +8,7 @@ import ( "github.com/Tencent/WeKnora/internal/config" "github.com/Tencent/WeKnora/internal/logger" + "github.com/Tencent/WeKnora/internal/types" "github.com/gin-gonic/gin" "github.com/minio/minio-go/v7" "github.com/minio/minio-go/v7/pkg/credentials" @@ -102,7 +103,7 @@ func (h *SystemHandler) getKeywordIndexEngine() string { keywordEngines := []string{} for _, driver := range drivers { driver = strings.TrimSpace(driver) - if driver == "postgres" || driver == "elasticsearch_v7" || driver == "elasticsearch_v8" { + if h.supportsRetrieverType(driver, types.KeywordsRetrieverType) { keywordEngines = append(keywordEngines, driver) } } @@ -131,7 +132,7 @@ func (h *SystemHandler) getVectorStoreEngine() string { vectorEngines := []string{} for _, driver := range drivers { driver = strings.TrimSpace(driver) - if driver == "postgres" || driver == "elasticsearch_v8" { + if h.supportsRetrieverType(driver, types.VectorRetrieverType) { vectorEngines = append(vectorEngines, driver) } } @@ -150,6 +151,27 @@ func (h *SystemHandler) getGraphDatabaseEngine() string { return "Neo4j" } +// supportsRetrieverType checks if a driver supports a specific retriever type +// by looking up the retrieverEngineMapping from types package +func (h *SystemHandler) supportsRetrieverType(driver string, retrieverType types.RetrieverType) bool { + // Get the mapping of all supported drivers and their capabilities + mapping := types.GetRetrieverEngineMapping() + + // Check if the driver exists in the mapping + engines, exists := mapping[driver] + if !exists { + return false + } + + // Check if any of the engine configurations support the requested retriever type + for _, engine := range engines { + if engine.RetrieverType == retrieverType { + return true + } + } + return false +} + // isMinioEnabled checks if MinIO is enabled func (h *SystemHandler) isMinioEnabled() bool { // Check if all required MinIO environment variables are set diff --git a/internal/mcp/client.go b/internal/mcp/client.go index f971412c..70315c7f 100644 --- a/internal/mcp/client.go +++ b/internal/mcp/client.go @@ -4,10 +4,7 @@ import ( "context" "encoding/json" "fmt" - "net" "net/http" - "net/url" - "strings" "time" "github.com/Tencent/WeKnora/internal/logger" @@ -61,31 +58,6 @@ type mcpGoClient struct { initialized bool } -// validateURLScheme validates that authentication endpoints use HTTPS -func validateURLScheme(serviceURL string, hasBasicAuth bool, hasTokenAuth bool) error { - if !hasBasicAuth && !hasTokenAuth { - // No authentication, no HTTPS requirement - return nil - } - - if serviceURL == "" { - return fmt.Errorf("URL is required") - } - - // Parse the URL - parsedURL, err := url.Parse(serviceURL) - if err != nil { - return fmt.Errorf("invalid URL: %w", err) - } - - // Enforce HTTPS for authenticated endpoints - if parsedURL.Scheme != "https" { - return fmt.Errorf("HTTPS is required for authenticated connections, got %s scheme", parsedURL.Scheme) - } - - return nil -} - // NewMCPClient creates a new MCP client based on the transport type func NewMCPClient(config *ClientConfig) (MCPClient, error) { // Create HTTP client with timeout diff --git a/internal/types/tenant.go b/internal/types/tenant.go index 67ce9502..088298a9 100644 --- a/internal/types/tenant.go +++ b/internal/types/tenant.go @@ -29,6 +29,12 @@ var retrieverEngineMapping = map[string][]RetrieverEngineParams{ }, } +// GetRetrieverEngineMapping returns the retriever engine mapping +// This allows other packages to access the driver capabilities +func GetRetrieverEngineMapping() map[string][]RetrieverEngineParams { + return retrieverEngineMapping +} + // GetDefaultRetrieverEngines returns the default retriever engines based on RETRIEVE_DRIVER env func GetDefaultRetrieverEngines() []RetrieverEngineParams { result := []RetrieverEngineParams{}