diff --git a/config/config.yaml b/config/config.yaml index 7cdb6e31..dc74d7dc 100644 --- a/config/config.yaml +++ b/config/config.yaml @@ -577,27 +577,6 @@ extract: with_no_tag: | 请随机生成一段文本,内容请自由发挥,字数在 [50-200] 之间。 -# WebSearch 配置 -web_search: - # 可用搜索引擎列表 - providers: - - id: "duckduckgo" - name: "DuckDuckGo" - free: true - requires_api_key: false - description: "DuckDuckGo API" - - - # 默认配置 - default: - provider: "duckduckgo" - max_results: 5 - include_date: true - compression_method: "none" - blacklist: [] - - # 全局超时设置 - timeout: 10 # 租户配置 tenant: diff --git a/internal/application/service/web_search.go b/internal/application/service/web_search.go index f00f262a..78751cfc 100644 --- a/internal/application/service/web_search.go +++ b/internal/application/service/web_search.go @@ -18,7 +18,7 @@ import ( // WebSearchService provides web search functionality type WebSearchService struct { providers map[string]interfaces.WebSearchProvider - config *config.WebSearchConfig + timeout int } // CompressWithRAG performs RAG-based compression using a temporary, hidden knowledge base. @@ -259,7 +259,7 @@ func (s *WebSearchService) Search( } // Set timeout - timeout := time.Duration(s.config.Timeout) * time.Second + timeout := time.Duration(s.timeout) * time.Second if timeout == 0 { timeout = 10 * time.Second } @@ -286,37 +286,26 @@ func (s *WebSearchService) Search( } // NewWebSearchService creates a new web search service -func NewWebSearchService(cfg *config.Config) (interfaces.WebSearchService, error) { - if cfg.WebSearch == nil { - return nil, fmt.Errorf("web search config is not available") +func NewWebSearchService(cfg *config.Config, registry *web_search.Registry) (interfaces.WebSearchService, error) { + timeout := 10 // default timeout + if cfg.WebSearch != nil && cfg.WebSearch.Timeout > 0 { + timeout = cfg.WebSearch.Timeout } - service := &WebSearchService{ - providers: make(map[string]interfaces.WebSearchProvider), - config: cfg.WebSearch, + // Create all registered providers + providers, err := registry.CreateAllProviders() + if err != nil { + return nil, err } - // Initialize providers based on config - for _, providerConfig := range cfg.WebSearch.Providers { - var provider interfaces.WebSearchProvider - var err error - - switch providerConfig.ID { - case "duckduckgo": - provider, err = web_search.NewDuckDuckGoProvider(providerConfig) - case "google": - provider, err = web_search.NewGoogleProvider(providerConfig) - default: - return nil, fmt.Errorf("unknown web search provider: %s", providerConfig.ID) - } - if err != nil { - return nil, fmt.Errorf("failed to initialize provider %s: %v", providerConfig.ID, err) - } - service.providers[providerConfig.ID] = provider - logger.Infof(context.Background(), "Initialized web search provider: %s", providerConfig.ID) + for id := range providers { + logger.Infof(context.Background(), "Initialized web search provider: %s", id) } - return service, nil + return &WebSearchService{ + providers: providers, + timeout: timeout, + }, nil } // filterBlacklist filters results based on blacklist rules diff --git a/internal/application/service/web_search/duckduckgo.go b/internal/application/service/web_search/duckduckgo.go index d9b59a6e..dc040f74 100644 --- a/internal/application/service/web_search/duckduckgo.go +++ b/internal/application/service/web_search/duckduckgo.go @@ -11,7 +11,6 @@ import ( "time" "github.com/PuerkitoBio/goquery" - "github.com/Tencent/WeKnora/internal/config" "github.com/Tencent/WeKnora/internal/logger" "github.com/Tencent/WeKnora/internal/types" "github.com/Tencent/WeKnora/internal/types/interfaces" @@ -24,7 +23,7 @@ type DuckDuckGoProvider struct { } // NewDuckDuckGoProvider creates a new DuckDuckGo provider -func NewDuckDuckGoProvider(_ config.WebSearchProviderConfig) (interfaces.WebSearchProvider, error) { +func NewDuckDuckGoProvider() (interfaces.WebSearchProvider, error) { return &DuckDuckGoProvider{ client: &http.Client{ Timeout: 30 * time.Second, @@ -32,6 +31,17 @@ func NewDuckDuckGoProvider(_ config.WebSearchProviderConfig) (interfaces.WebSear }, nil } +// DuckDuckGoProviderInfo returns the provider info for registration +func DuckDuckGoProviderInfo() types.WebSearchProviderInfo { + return types.WebSearchProviderInfo{ + ID: "duckduckgo", + Name: "DuckDuckGo", + Free: true, + RequiresAPIKey: false, + Description: "DuckDuckGo Search API", + } +} + // Name returns the provider name func (p *DuckDuckGoProvider) Name() string { return "duckduckgo" diff --git a/internal/application/service/web_search/google.go b/internal/application/service/web_search/google.go index a241c737..66b6c474 100644 --- a/internal/application/service/web_search/google.go +++ b/internal/application/service/web_search/google.go @@ -4,11 +4,11 @@ import ( "context" "fmt" "net/url" + "os" "google.golang.org/api/customsearch/v1" "google.golang.org/api/option" - "github.com/Tencent/WeKnora/internal/config" "github.com/Tencent/WeKnora/internal/types" "github.com/Tencent/WeKnora/internal/types/interfaces" ) @@ -22,8 +22,13 @@ type GoogleProvider struct { } // NewGoogleProvider creates a new Google provider -func NewGoogleProvider(cfg config.WebSearchProviderConfig) (interfaces.WebSearchProvider, error) { - u, err := url.Parse(cfg.APIURL) +func NewGoogleProvider() (interfaces.WebSearchProvider, error) { + apiURL := os.Getenv("GOOGLE_SEARCH_API_URL") + if apiURL == "" { + return nil, fmt.Errorf("GOOGLE_SEARCH_API_URL environment variable is not set") + } + + u, err := url.Parse(apiURL) if err != nil { return nil, err } @@ -46,10 +51,21 @@ func NewGoogleProvider(cfg config.WebSearchProviderConfig) (interfaces.WebSearch srv: srv, apiKey: apiKey, engineID: engineID, - baseURL: cfg.APIURL, + baseURL: apiURL, }, nil } +// GoogleProviderInfo returns the provider info for registration +func GoogleProviderInfo() types.WebSearchProviderInfo { + return types.WebSearchProviderInfo{ + ID: "google", + Name: "Google", + Free: false, + RequiresAPIKey: true, + Description: "Google Custom Search API", + } +} + // Name returns the provider name func (p *GoogleProvider) Name() string { return "google" diff --git a/internal/application/service/web_search/registry.go b/internal/application/service/web_search/registry.go new file mode 100644 index 00000000..48bf9c8e --- /dev/null +++ b/internal/application/service/web_search/registry.go @@ -0,0 +1,88 @@ +package web_search + +import ( + "fmt" + "sync" + + "github.com/Tencent/WeKnora/internal/types" + "github.com/Tencent/WeKnora/internal/types/interfaces" +) + +// ProviderFactory creates a new web search provider instance +type ProviderFactory func() (interfaces.WebSearchProvider, error) + +// ProviderRegistration holds provider metadata and factory +type ProviderRegistration struct { + Info types.WebSearchProviderInfo + Factory ProviderFactory +} + +// Registry manages web search provider registrations +type Registry struct { + providers map[string]*ProviderRegistration + mu sync.RWMutex +} + +// NewRegistry creates a new web search provider registry +func NewRegistry() *Registry { + return &Registry{ + providers: make(map[string]*ProviderRegistration), + } +} + +// Register registers a web search provider +func (r *Registry) Register(info types.WebSearchProviderInfo, factory ProviderFactory) { + r.mu.Lock() + defer r.mu.Unlock() + r.providers[info.ID] = &ProviderRegistration{ + Info: info, + Factory: factory, + } +} + +// GetRegistration returns the registration for a provider +func (r *Registry) GetRegistration(id string) (*ProviderRegistration, bool) { + r.mu.RLock() + defer r.mu.RUnlock() + reg, ok := r.providers[id] + return reg, ok +} + +// GetAllProviderInfos returns info for all registered providers +func (r *Registry) GetAllProviderInfos() []types.WebSearchProviderInfo { + r.mu.RLock() + defer r.mu.RUnlock() + infos := make([]types.WebSearchProviderInfo, 0, len(r.providers)) + for _, reg := range r.providers { + infos = append(infos, reg.Info) + } + return infos +} + +// CreateProvider creates a provider instance by ID +func (r *Registry) CreateProvider(id string) (interfaces.WebSearchProvider, error) { + r.mu.RLock() + reg, ok := r.providers[id] + r.mu.RUnlock() + if !ok { + return nil, fmt.Errorf("web search provider %s not registered", id) + } + return reg.Factory() +} + +// CreateAllProviders creates instances of all registered providers +func (r *Registry) CreateAllProviders() (map[string]interfaces.WebSearchProvider, error) { + r.mu.RLock() + defer r.mu.RUnlock() + + providers := make(map[string]interfaces.WebSearchProvider) + for id, reg := range r.providers { + provider, err := reg.Factory() + if err != nil { + // Skip providers that fail to initialize (e.g., missing API keys) + continue + } + providers[id] = provider + } + return providers, nil +} diff --git a/internal/config/config.go b/internal/config/config.go index c06af39d..759f6a03 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -279,26 +279,5 @@ func loadPromptTemplates(configDir string) (*PromptTemplatesConfig, error) { // WebSearchConfig represents the web search configuration type WebSearchConfig struct { - Providers []WebSearchProviderConfig `yaml:"providers" json:"providers"` - Default WebSearchDefaultConfig `yaml:"default" json:"default"` - Timeout int `yaml:"timeout" json:"timeout"` // 超时时间(秒) -} - -// WebSearchProviderConfig represents configuration for a web search provider -type WebSearchProviderConfig struct { - ID string `yaml:"id" json:"id"` - Name string `yaml:"name" json:"name"` - Free bool `yaml:"free" json:"free"` - RequiresAPIKey bool `yaml:"requires_api_key" json:"requires_api_key"` - Description string `yaml:"description,omitempty" json:"description,omitempty"` - APIURL string `yaml:"api_url,omitempty" json:"api_url,omitempty"` -} - -// WebSearchDefaultConfig represents the default web search configuration -type WebSearchDefaultConfig struct { - Provider string `yaml:"provider" json:"provider"` - MaxResults int `yaml:"max_results" json:"max_results"` - IncludeDate bool `yaml:"include_date" json:"include_date"` - CompressionMethod string `yaml:"compression_method" json:"compression_method"` - Blacklist []string `yaml:"blacklist" json:"blacklist"` + Timeout int `yaml:"timeout" json:"timeout"` // 超时时间(秒) } diff --git a/internal/container/container.go b/internal/container/container.go index 6dc18a8d..3ef01674 100644 --- a/internal/container/container.go +++ b/internal/container/container.go @@ -37,6 +37,7 @@ import ( "github.com/Tencent/WeKnora/internal/application/service/file" "github.com/Tencent/WeKnora/internal/application/service/llmcontext" "github.com/Tencent/WeKnora/internal/application/service/retriever" + "github.com/Tencent/WeKnora/internal/application/service/web_search" "github.com/Tencent/WeKnora/internal/config" "github.com/Tencent/WeKnora/internal/database" "github.com/Tencent/WeKnora/internal/event" @@ -138,7 +139,9 @@ func BuildContainer(container *dig.Container) *dig.Container { must(container.Provide(service.NewCustomAgentService)) // Web search service (needed by AgentService) - logger.Debugf(ctx, "[Container] Registering web search service...") + logger.Debugf(ctx, "[Container] Registering web search registry and providers...") + must(container.Provide(web_search.NewRegistry)) + must(container.Invoke(registerWebSearchProviders)) must(container.Provide(service.NewWebSearchService)) // Agent service layer (requires event bus, web search service) @@ -647,3 +650,16 @@ func NewDuckDB() (*sql.DB, error) { return sqlDB, nil } + +// registerWebSearchProviders registers all web search providers to the registry +func registerWebSearchProviders(registry *web_search.Registry) { + // Register DuckDuckGo provider + registry.Register(web_search.DuckDuckGoProviderInfo(), func() (interfaces.WebSearchProvider, error) { + return web_search.NewDuckDuckGoProvider() + }) + + // Register Google provider + registry.Register(web_search.GoogleProviderInfo(), func() (interfaces.WebSearchProvider, error) { + return web_search.NewGoogleProvider() + }) +} diff --git a/internal/handler/web_search.go b/internal/handler/web_search.go index 7df81ca0..0e0d9bcf 100644 --- a/internal/handler/web_search.go +++ b/internal/handler/web_search.go @@ -3,21 +3,20 @@ package handler import ( "net/http" - "github.com/Tencent/WeKnora/internal/config" + "github.com/Tencent/WeKnora/internal/application/service/web_search" "github.com/Tencent/WeKnora/internal/logger" - "github.com/Tencent/WeKnora/internal/types" "github.com/gin-gonic/gin" ) // WebSearchHandler handles web search related requests type WebSearchHandler struct { - cfg *config.Config + registry *web_search.Registry } // NewWebSearchHandler creates a new web search handler -func NewWebSearchHandler(cfg *config.Config) *WebSearchHandler { +func NewWebSearchHandler(registry *web_search.Registry) *WebSearchHandler { return &WebSearchHandler{ - cfg: cfg, + registry: registry, } } @@ -35,27 +34,7 @@ func (h *WebSearchHandler) GetProviders(c *gin.Context) { ctx := c.Request.Context() logger.Info(ctx, "Getting web search providers") - if h.cfg.WebSearch == nil || len(h.cfg.WebSearch.Providers) == 0 { - logger.Warn(ctx, "No web search providers configured") - c.JSON(http.StatusOK, gin.H{ - "success": true, - "data": []types.WebSearchProviderInfo{}, - }) - return - } - - // Convert config providers to API response format - providers := make([]types.WebSearchProviderInfo, 0, len(h.cfg.WebSearch.Providers)) - for _, provider := range h.cfg.WebSearch.Providers { - providers = append(providers, types.WebSearchProviderInfo{ - ID: provider.ID, - Name: provider.Name, - Free: provider.Free, - RequiresAPIKey: provider.RequiresAPIKey, - Description: provider.Description, - APIURL: provider.APIURL, - }) - } + providers := h.registry.GetAllProviderInfos() logger.Infof(ctx, "Returning %d web search providers", len(providers)) c.JSON(http.StatusOK, gin.H{