refactor: web search provider registry pattern

This commit is contained in:
wizardchen
2026-01-19 20:28:31 +08:00
committed by lyingbug
parent b74545b007
commit fad72fc272
8 changed files with 159 additions and 103 deletions

View File

@@ -577,27 +577,6 @@ extract:
with_no_tag: |
请随机生成一段文本,内容请自由发挥,字数在 [50-200] 之间。
# WebSearch 配置
web_search:
# 可用搜索引擎列表
providers:
- id: "duckduckgo"
name: "DuckDuckGo"
free: true
requires_api_key: false
description: "DuckDuckGo API"
# 默认配置
default:
provider: "duckduckgo"
max_results: 5
include_date: true
compression_method: "none"
blacklist: []
# 全局超时设置
timeout: 10
# 租户配置
tenant:

View File

@@ -18,7 +18,7 @@ import (
// WebSearchService provides web search functionality
type WebSearchService struct {
providers map[string]interfaces.WebSearchProvider
config *config.WebSearchConfig
timeout int
}
// CompressWithRAG performs RAG-based compression using a temporary, hidden knowledge base.
@@ -259,7 +259,7 @@ func (s *WebSearchService) Search(
}
// Set timeout
timeout := time.Duration(s.config.Timeout) * time.Second
timeout := time.Duration(s.timeout) * time.Second
if timeout == 0 {
timeout = 10 * time.Second
}
@@ -286,37 +286,26 @@ func (s *WebSearchService) Search(
}
// NewWebSearchService creates a new web search service
func NewWebSearchService(cfg *config.Config) (interfaces.WebSearchService, error) {
if cfg.WebSearch == nil {
return nil, fmt.Errorf("web search config is not available")
func NewWebSearchService(cfg *config.Config, registry *web_search.Registry) (interfaces.WebSearchService, error) {
timeout := 10 // default timeout
if cfg.WebSearch != nil && cfg.WebSearch.Timeout > 0 {
timeout = cfg.WebSearch.Timeout
}
service := &WebSearchService{
providers: make(map[string]interfaces.WebSearchProvider),
config: cfg.WebSearch,
// Create all registered providers
providers, err := registry.CreateAllProviders()
if err != nil {
return nil, err
}
// Initialize providers based on config
for _, providerConfig := range cfg.WebSearch.Providers {
var provider interfaces.WebSearchProvider
var err error
switch providerConfig.ID {
case "duckduckgo":
provider, err = web_search.NewDuckDuckGoProvider(providerConfig)
case "google":
provider, err = web_search.NewGoogleProvider(providerConfig)
default:
return nil, fmt.Errorf("unknown web search provider: %s", providerConfig.ID)
}
if err != nil {
return nil, fmt.Errorf("failed to initialize provider %s: %v", providerConfig.ID, err)
}
service.providers[providerConfig.ID] = provider
logger.Infof(context.Background(), "Initialized web search provider: %s", providerConfig.ID)
for id := range providers {
logger.Infof(context.Background(), "Initialized web search provider: %s", id)
}
return service, nil
return &WebSearchService{
providers: providers,
timeout: timeout,
}, nil
}
// filterBlacklist filters results based on blacklist rules

View File

@@ -11,7 +11,6 @@ import (
"time"
"github.com/PuerkitoBio/goquery"
"github.com/Tencent/WeKnora/internal/config"
"github.com/Tencent/WeKnora/internal/logger"
"github.com/Tencent/WeKnora/internal/types"
"github.com/Tencent/WeKnora/internal/types/interfaces"
@@ -24,7 +23,7 @@ type DuckDuckGoProvider struct {
}
// NewDuckDuckGoProvider creates a new DuckDuckGo provider
func NewDuckDuckGoProvider(_ config.WebSearchProviderConfig) (interfaces.WebSearchProvider, error) {
func NewDuckDuckGoProvider() (interfaces.WebSearchProvider, error) {
return &DuckDuckGoProvider{
client: &http.Client{
Timeout: 30 * time.Second,
@@ -32,6 +31,17 @@ func NewDuckDuckGoProvider(_ config.WebSearchProviderConfig) (interfaces.WebSear
}, nil
}
// DuckDuckGoProviderInfo returns the provider info for registration
func DuckDuckGoProviderInfo() types.WebSearchProviderInfo {
return types.WebSearchProviderInfo{
ID: "duckduckgo",
Name: "DuckDuckGo",
Free: true,
RequiresAPIKey: false,
Description: "DuckDuckGo Search API",
}
}
// Name returns the provider name
func (p *DuckDuckGoProvider) Name() string {
return "duckduckgo"

View File

@@ -4,11 +4,11 @@ import (
"context"
"fmt"
"net/url"
"os"
"google.golang.org/api/customsearch/v1"
"google.golang.org/api/option"
"github.com/Tencent/WeKnora/internal/config"
"github.com/Tencent/WeKnora/internal/types"
"github.com/Tencent/WeKnora/internal/types/interfaces"
)
@@ -22,8 +22,13 @@ type GoogleProvider struct {
}
// NewGoogleProvider creates a new Google provider
func NewGoogleProvider(cfg config.WebSearchProviderConfig) (interfaces.WebSearchProvider, error) {
u, err := url.Parse(cfg.APIURL)
func NewGoogleProvider() (interfaces.WebSearchProvider, error) {
apiURL := os.Getenv("GOOGLE_SEARCH_API_URL")
if apiURL == "" {
return nil, fmt.Errorf("GOOGLE_SEARCH_API_URL environment variable is not set")
}
u, err := url.Parse(apiURL)
if err != nil {
return nil, err
}
@@ -46,10 +51,21 @@ func NewGoogleProvider(cfg config.WebSearchProviderConfig) (interfaces.WebSearch
srv: srv,
apiKey: apiKey,
engineID: engineID,
baseURL: cfg.APIURL,
baseURL: apiURL,
}, nil
}
// GoogleProviderInfo returns the provider info for registration
func GoogleProviderInfo() types.WebSearchProviderInfo {
return types.WebSearchProviderInfo{
ID: "google",
Name: "Google",
Free: false,
RequiresAPIKey: true,
Description: "Google Custom Search API",
}
}
// Name returns the provider name
func (p *GoogleProvider) Name() string {
return "google"

View File

@@ -0,0 +1,88 @@
package web_search
import (
"fmt"
"sync"
"github.com/Tencent/WeKnora/internal/types"
"github.com/Tencent/WeKnora/internal/types/interfaces"
)
// ProviderFactory creates a new web search provider instance
type ProviderFactory func() (interfaces.WebSearchProvider, error)
// ProviderRegistration holds provider metadata and factory
type ProviderRegistration struct {
Info types.WebSearchProviderInfo
Factory ProviderFactory
}
// Registry manages web search provider registrations
type Registry struct {
providers map[string]*ProviderRegistration
mu sync.RWMutex
}
// NewRegistry creates a new web search provider registry
func NewRegistry() *Registry {
return &Registry{
providers: make(map[string]*ProviderRegistration),
}
}
// Register registers a web search provider
func (r *Registry) Register(info types.WebSearchProviderInfo, factory ProviderFactory) {
r.mu.Lock()
defer r.mu.Unlock()
r.providers[info.ID] = &ProviderRegistration{
Info: info,
Factory: factory,
}
}
// GetRegistration returns the registration for a provider
func (r *Registry) GetRegistration(id string) (*ProviderRegistration, bool) {
r.mu.RLock()
defer r.mu.RUnlock()
reg, ok := r.providers[id]
return reg, ok
}
// GetAllProviderInfos returns info for all registered providers
func (r *Registry) GetAllProviderInfos() []types.WebSearchProviderInfo {
r.mu.RLock()
defer r.mu.RUnlock()
infos := make([]types.WebSearchProviderInfo, 0, len(r.providers))
for _, reg := range r.providers {
infos = append(infos, reg.Info)
}
return infos
}
// CreateProvider creates a provider instance by ID
func (r *Registry) CreateProvider(id string) (interfaces.WebSearchProvider, error) {
r.mu.RLock()
reg, ok := r.providers[id]
r.mu.RUnlock()
if !ok {
return nil, fmt.Errorf("web search provider %s not registered", id)
}
return reg.Factory()
}
// CreateAllProviders creates instances of all registered providers
func (r *Registry) CreateAllProviders() (map[string]interfaces.WebSearchProvider, error) {
r.mu.RLock()
defer r.mu.RUnlock()
providers := make(map[string]interfaces.WebSearchProvider)
for id, reg := range r.providers {
provider, err := reg.Factory()
if err != nil {
// Skip providers that fail to initialize (e.g., missing API keys)
continue
}
providers[id] = provider
}
return providers, nil
}

View File

@@ -279,26 +279,5 @@ func loadPromptTemplates(configDir string) (*PromptTemplatesConfig, error) {
// WebSearchConfig represents the web search configuration
type WebSearchConfig struct {
Providers []WebSearchProviderConfig `yaml:"providers" json:"providers"`
Default WebSearchDefaultConfig `yaml:"default" json:"default"`
Timeout int `yaml:"timeout" json:"timeout"` // 超时时间(秒)
}
// WebSearchProviderConfig represents configuration for a web search provider
type WebSearchProviderConfig struct {
ID string `yaml:"id" json:"id"`
Name string `yaml:"name" json:"name"`
Free bool `yaml:"free" json:"free"`
RequiresAPIKey bool `yaml:"requires_api_key" json:"requires_api_key"`
Description string `yaml:"description,omitempty" json:"description,omitempty"`
APIURL string `yaml:"api_url,omitempty" json:"api_url,omitempty"`
}
// WebSearchDefaultConfig represents the default web search configuration
type WebSearchDefaultConfig struct {
Provider string `yaml:"provider" json:"provider"`
MaxResults int `yaml:"max_results" json:"max_results"`
IncludeDate bool `yaml:"include_date" json:"include_date"`
CompressionMethod string `yaml:"compression_method" json:"compression_method"`
Blacklist []string `yaml:"blacklist" json:"blacklist"`
Timeout int `yaml:"timeout" json:"timeout"` // 超时时间(秒)
}

View File

@@ -37,6 +37,7 @@ import (
"github.com/Tencent/WeKnora/internal/application/service/file"
"github.com/Tencent/WeKnora/internal/application/service/llmcontext"
"github.com/Tencent/WeKnora/internal/application/service/retriever"
"github.com/Tencent/WeKnora/internal/application/service/web_search"
"github.com/Tencent/WeKnora/internal/config"
"github.com/Tencent/WeKnora/internal/database"
"github.com/Tencent/WeKnora/internal/event"
@@ -138,7 +139,9 @@ func BuildContainer(container *dig.Container) *dig.Container {
must(container.Provide(service.NewCustomAgentService))
// Web search service (needed by AgentService)
logger.Debugf(ctx, "[Container] Registering web search service...")
logger.Debugf(ctx, "[Container] Registering web search registry and providers...")
must(container.Provide(web_search.NewRegistry))
must(container.Invoke(registerWebSearchProviders))
must(container.Provide(service.NewWebSearchService))
// Agent service layer (requires event bus, web search service)
@@ -647,3 +650,16 @@ func NewDuckDB() (*sql.DB, error) {
return sqlDB, nil
}
// registerWebSearchProviders registers all web search providers to the registry
func registerWebSearchProviders(registry *web_search.Registry) {
// Register DuckDuckGo provider
registry.Register(web_search.DuckDuckGoProviderInfo(), func() (interfaces.WebSearchProvider, error) {
return web_search.NewDuckDuckGoProvider()
})
// Register Google provider
registry.Register(web_search.GoogleProviderInfo(), func() (interfaces.WebSearchProvider, error) {
return web_search.NewGoogleProvider()
})
}

View File

@@ -3,21 +3,20 @@ package handler
import (
"net/http"
"github.com/Tencent/WeKnora/internal/config"
"github.com/Tencent/WeKnora/internal/application/service/web_search"
"github.com/Tencent/WeKnora/internal/logger"
"github.com/Tencent/WeKnora/internal/types"
"github.com/gin-gonic/gin"
)
// WebSearchHandler handles web search related requests
type WebSearchHandler struct {
cfg *config.Config
registry *web_search.Registry
}
// NewWebSearchHandler creates a new web search handler
func NewWebSearchHandler(cfg *config.Config) *WebSearchHandler {
func NewWebSearchHandler(registry *web_search.Registry) *WebSearchHandler {
return &WebSearchHandler{
cfg: cfg,
registry: registry,
}
}
@@ -35,27 +34,7 @@ func (h *WebSearchHandler) GetProviders(c *gin.Context) {
ctx := c.Request.Context()
logger.Info(ctx, "Getting web search providers")
if h.cfg.WebSearch == nil || len(h.cfg.WebSearch.Providers) == 0 {
logger.Warn(ctx, "No web search providers configured")
c.JSON(http.StatusOK, gin.H{
"success": true,
"data": []types.WebSearchProviderInfo{},
})
return
}
// Convert config providers to API response format
providers := make([]types.WebSearchProviderInfo, 0, len(h.cfg.WebSearch.Providers))
for _, provider := range h.cfg.WebSearch.Providers {
providers = append(providers, types.WebSearchProviderInfo{
ID: provider.ID,
Name: provider.Name,
Free: provider.Free,
RequiresAPIKey: provider.RequiresAPIKey,
Description: provider.Description,
APIURL: provider.APIURL,
})
}
providers := h.registry.GetAllProviderInfos()
logger.Infof(ctx, "Returning %d web search providers", len(providers))
c.JSON(http.StatusOK, gin.H{