mirror of
https://github.com/Tencent/WeKnora.git
synced 2026-06-04 13:30:32 +08:00
refactor: Refactor retriever engine checks and expose mapping
This commit is contained in:
@@ -8,6 +8,7 @@ import (
|
||||
|
||||
"github.com/Tencent/WeKnora/internal/config"
|
||||
"github.com/Tencent/WeKnora/internal/logger"
|
||||
"github.com/Tencent/WeKnora/internal/types"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/minio/minio-go/v7"
|
||||
"github.com/minio/minio-go/v7/pkg/credentials"
|
||||
@@ -102,7 +103,7 @@ func (h *SystemHandler) getKeywordIndexEngine() string {
|
||||
keywordEngines := []string{}
|
||||
for _, driver := range drivers {
|
||||
driver = strings.TrimSpace(driver)
|
||||
if driver == "postgres" || driver == "elasticsearch_v7" || driver == "elasticsearch_v8" {
|
||||
if h.supportsRetrieverType(driver, types.KeywordsRetrieverType) {
|
||||
keywordEngines = append(keywordEngines, driver)
|
||||
}
|
||||
}
|
||||
@@ -131,7 +132,7 @@ func (h *SystemHandler) getVectorStoreEngine() string {
|
||||
vectorEngines := []string{}
|
||||
for _, driver := range drivers {
|
||||
driver = strings.TrimSpace(driver)
|
||||
if driver == "postgres" || driver == "elasticsearch_v8" {
|
||||
if h.supportsRetrieverType(driver, types.VectorRetrieverType) {
|
||||
vectorEngines = append(vectorEngines, driver)
|
||||
}
|
||||
}
|
||||
@@ -150,6 +151,27 @@ func (h *SystemHandler) getGraphDatabaseEngine() string {
|
||||
return "Neo4j"
|
||||
}
|
||||
|
||||
// supportsRetrieverType checks if a driver supports a specific retriever type
|
||||
// by looking up the retrieverEngineMapping from types package
|
||||
func (h *SystemHandler) supportsRetrieverType(driver string, retrieverType types.RetrieverType) bool {
|
||||
// Get the mapping of all supported drivers and their capabilities
|
||||
mapping := types.GetRetrieverEngineMapping()
|
||||
|
||||
// Check if the driver exists in the mapping
|
||||
engines, exists := mapping[driver]
|
||||
if !exists {
|
||||
return false
|
||||
}
|
||||
|
||||
// Check if any of the engine configurations support the requested retriever type
|
||||
for _, engine := range engines {
|
||||
if engine.RetrieverType == retrieverType {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// isMinioEnabled checks if MinIO is enabled
|
||||
func (h *SystemHandler) isMinioEnabled() bool {
|
||||
// Check if all required MinIO environment variables are set
|
||||
|
||||
@@ -4,10 +4,7 @@ import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/Tencent/WeKnora/internal/logger"
|
||||
@@ -61,31 +58,6 @@ type mcpGoClient struct {
|
||||
initialized bool
|
||||
}
|
||||
|
||||
// validateURLScheme validates that authentication endpoints use HTTPS
|
||||
func validateURLScheme(serviceURL string, hasBasicAuth bool, hasTokenAuth bool) error {
|
||||
if !hasBasicAuth && !hasTokenAuth {
|
||||
// No authentication, no HTTPS requirement
|
||||
return nil
|
||||
}
|
||||
|
||||
if serviceURL == "" {
|
||||
return fmt.Errorf("URL is required")
|
||||
}
|
||||
|
||||
// Parse the URL
|
||||
parsedURL, err := url.Parse(serviceURL)
|
||||
if err != nil {
|
||||
return fmt.Errorf("invalid URL: %w", err)
|
||||
}
|
||||
|
||||
// Enforce HTTPS for authenticated endpoints
|
||||
if parsedURL.Scheme != "https" {
|
||||
return fmt.Errorf("HTTPS is required for authenticated connections, got %s scheme", parsedURL.Scheme)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// NewMCPClient creates a new MCP client based on the transport type
|
||||
func NewMCPClient(config *ClientConfig) (MCPClient, error) {
|
||||
// Create HTTP client with timeout
|
||||
|
||||
@@ -29,6 +29,12 @@ var retrieverEngineMapping = map[string][]RetrieverEngineParams{
|
||||
},
|
||||
}
|
||||
|
||||
// GetRetrieverEngineMapping returns the retriever engine mapping
|
||||
// This allows other packages to access the driver capabilities
|
||||
func GetRetrieverEngineMapping() map[string][]RetrieverEngineParams {
|
||||
return retrieverEngineMapping
|
||||
}
|
||||
|
||||
// GetDefaultRetrieverEngines returns the default retriever engines based on RETRIEVE_DRIVER env
|
||||
func GetDefaultRetrieverEngines() []RetrieverEngineParams {
|
||||
result := []RetrieverEngineParams{}
|
||||
|
||||
Reference in New Issue
Block a user