refactor: Refactor retriever engine checks and expose mapping

This commit is contained in:
begoniezhao
2026-01-12 17:34:01 +08:00
parent e023c26af8
commit 389900af7e
3 changed files with 30 additions and 30 deletions

View File

@@ -8,6 +8,7 @@ import (
"github.com/Tencent/WeKnora/internal/config"
"github.com/Tencent/WeKnora/internal/logger"
"github.com/Tencent/WeKnora/internal/types"
"github.com/gin-gonic/gin"
"github.com/minio/minio-go/v7"
"github.com/minio/minio-go/v7/pkg/credentials"
@@ -102,7 +103,7 @@ func (h *SystemHandler) getKeywordIndexEngine() string {
keywordEngines := []string{}
for _, driver := range drivers {
driver = strings.TrimSpace(driver)
if driver == "postgres" || driver == "elasticsearch_v7" || driver == "elasticsearch_v8" {
if h.supportsRetrieverType(driver, types.KeywordsRetrieverType) {
keywordEngines = append(keywordEngines, driver)
}
}
@@ -131,7 +132,7 @@ func (h *SystemHandler) getVectorStoreEngine() string {
vectorEngines := []string{}
for _, driver := range drivers {
driver = strings.TrimSpace(driver)
if driver == "postgres" || driver == "elasticsearch_v8" {
if h.supportsRetrieverType(driver, types.VectorRetrieverType) {
vectorEngines = append(vectorEngines, driver)
}
}
@@ -150,6 +151,27 @@ func (h *SystemHandler) getGraphDatabaseEngine() string {
return "Neo4j"
}
// supportsRetrieverType checks if a driver supports a specific retriever type
// by looking up the retrieverEngineMapping from types package
func (h *SystemHandler) supportsRetrieverType(driver string, retrieverType types.RetrieverType) bool {
// Get the mapping of all supported drivers and their capabilities
mapping := types.GetRetrieverEngineMapping()
// Check if the driver exists in the mapping
engines, exists := mapping[driver]
if !exists {
return false
}
// Check if any of the engine configurations support the requested retriever type
for _, engine := range engines {
if engine.RetrieverType == retrieverType {
return true
}
}
return false
}
// isMinioEnabled checks if MinIO is enabled
func (h *SystemHandler) isMinioEnabled() bool {
// Check if all required MinIO environment variables are set

View File

@@ -4,10 +4,7 @@ import (
"context"
"encoding/json"
"fmt"
"net"
"net/http"
"net/url"
"strings"
"time"
"github.com/Tencent/WeKnora/internal/logger"
@@ -61,31 +58,6 @@ type mcpGoClient struct {
initialized bool
}
// validateURLScheme validates that authentication endpoints use HTTPS
func validateURLScheme(serviceURL string, hasBasicAuth bool, hasTokenAuth bool) error {
if !hasBasicAuth && !hasTokenAuth {
// No authentication, no HTTPS requirement
return nil
}
if serviceURL == "" {
return fmt.Errorf("URL is required")
}
// Parse the URL
parsedURL, err := url.Parse(serviceURL)
if err != nil {
return fmt.Errorf("invalid URL: %w", err)
}
// Enforce HTTPS for authenticated endpoints
if parsedURL.Scheme != "https" {
return fmt.Errorf("HTTPS is required for authenticated connections, got %s scheme", parsedURL.Scheme)
}
return nil
}
// NewMCPClient creates a new MCP client based on the transport type
func NewMCPClient(config *ClientConfig) (MCPClient, error) {
// Create HTTP client with timeout

View File

@@ -29,6 +29,12 @@ var retrieverEngineMapping = map[string][]RetrieverEngineParams{
},
}
// GetRetrieverEngineMapping returns the retriever engine mapping
// This allows other packages to access the driver capabilities
func GetRetrieverEngineMapping() map[string][]RetrieverEngineParams {
return retrieverEngineMapping
}
// GetDefaultRetrieverEngines returns the default retriever engines based on RETRIEVE_DRIVER env
func GetDefaultRetrieverEngines() []RetrieverEngineParams {
result := []RetrieverEngineParams{}