mirror of
https://github.com/Tencent/WeKnora.git
synced 2026-06-04 13:30:32 +08:00
feat(web-search): refactor web search provider
- Updated relevant files to include provider registration, implementation, and metadata. - Enhanced frontend components to support provider management and configuration. - Added localization for new provider settings and messages. - Implemented backend repository methods for CRUD operations on web search providers.
This commit is contained in:
267
docs/添加新的网络搜索引擎.md
Normal file
267
docs/添加新的网络搜索引擎.md
Normal file
@@ -0,0 +1,267 @@
|
||||
# 添加新的网络搜索引擎
|
||||
|
||||
本文档说明如何在 WeKnora 中添加一个新的网络搜索引擎类型(如 Brave Search、Searx 等)。
|
||||
|
||||
## 架构概述
|
||||
|
||||
```
|
||||
internal/
|
||||
├── types/
|
||||
│ └── web_search_provider.go # 实体定义 + Provider 类型元数据
|
||||
├── infrastructure/
|
||||
│ └── web_search/
|
||||
│ ├── registry.go # Provider 工厂注册表
|
||||
│ ├── bing.go # Bing 实现
|
||||
│ ├── google.go # Google 实现
|
||||
│ ├── duckduckgo.go # DuckDuckGo 实现
|
||||
│ └── tavily.go # Tavily 实现
|
||||
├── container/
|
||||
│ └── container.go # DI 注册(registerWebSearchProviders)
|
||||
└── types/interfaces/
|
||||
└── web_search.go # WebSearchProvider 接口
|
||||
```
|
||||
|
||||
搜索引擎的 API 端点**硬编码**在代码中,不向用户暴露 BaseURL,从源头消除 SSRF 风险。
|
||||
|
||||
## 步骤
|
||||
|
||||
以添加 **Brave Search** 为例。
|
||||
|
||||
### 1. 在 `types/web_search_provider.go` 中注册类型常量
|
||||
|
||||
```go
|
||||
const (
|
||||
WebSearchProviderTypeBing WebSearchProviderType = "bing"
|
||||
WebSearchProviderTypeGoogle WebSearchProviderType = "google"
|
||||
WebSearchProviderTypeDuckDuckGo WebSearchProviderType = "duckduckgo"
|
||||
WebSearchProviderTypeTavily WebSearchProviderType = "tavily"
|
||||
WebSearchProviderTypeBrave WebSearchProviderType = "brave" // ← 新增
|
||||
)
|
||||
```
|
||||
|
||||
### 2. 在 `GetWebSearchProviderTypes()` 中添加类型元数据
|
||||
|
||||
```go
|
||||
{
|
||||
ID: "brave",
|
||||
Name: "Brave Search",
|
||||
Free: false,
|
||||
RequiresAPIKey: true,
|
||||
Description: "Brave Search API",
|
||||
DocsURL: "https://brave.com/search/api/",
|
||||
},
|
||||
```
|
||||
|
||||
字段说明:
|
||||
|
||||
| 字段 | 说明 |
|
||||
| ---------------- | ---------------------------------------------- |
|
||||
| `ID` | 唯一标识,存入数据库,不可更改 |
|
||||
| `Name` | 前端展示名称 |
|
||||
| `Free` | 是否免费 |
|
||||
| `RequiresAPIKey` | 是否需要 API Key |
|
||||
| `RequiresEngineID` | 是否需要额外 ID(如 Google CSE) |
|
||||
| `Description` | 简短描述 |
|
||||
| `DocsURL` | 官方文档链接,前端在添加对话框中显示 |
|
||||
|
||||
### 3. 创建 Provider 实现
|
||||
|
||||
新建 `internal/infrastructure/web_search/brave.go`:
|
||||
|
||||
```go
|
||||
package web_search
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
"github.com/Tencent/WeKnora/internal/types"
|
||||
"github.com/Tencent/WeKnora/internal/types/interfaces"
|
||||
)
|
||||
|
||||
const defaultBraveSearchURL = "https://api.search.brave.com/res/v1/web/search"
|
||||
|
||||
type BraveProvider struct {
|
||||
client *http.Client
|
||||
apiKey string
|
||||
}
|
||||
|
||||
// NewBraveProvider 从参数创建实例(不读环境变量)
|
||||
func NewBraveProvider(params types.WebSearchProviderParameters) (interfaces.WebSearchProvider, error) {
|
||||
if params.APIKey == "" {
|
||||
return nil, fmt.Errorf("API key is required for Brave provider")
|
||||
}
|
||||
return &BraveProvider{
|
||||
client: &http.Client{Timeout: 10 * time.Second},
|
||||
apiKey: params.APIKey,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func BraveProviderTypeInfo() types.WebSearchProviderTypeInfo {
|
||||
return types.WebSearchProviderTypeInfo{
|
||||
ID: "brave",
|
||||
Name: "Brave Search",
|
||||
Free: false,
|
||||
RequiresAPIKey: true,
|
||||
Description: "Brave Search API",
|
||||
DocsURL: "https://brave.com/search/api/",
|
||||
}
|
||||
}
|
||||
|
||||
func (p *BraveProvider) Name() string { return "brave" }
|
||||
|
||||
func (p *BraveProvider) Search(
|
||||
ctx context.Context, query string, maxResults int, includeDate bool,
|
||||
) ([]*types.WebSearchResult, error) {
|
||||
// 构造请求 — BaseURL 硬编码
|
||||
req, err := http.NewRequestWithContext(ctx, "GET",
|
||||
fmt.Sprintf("%s?q=%s&count=%d", defaultBraveSearchURL, query, maxResults), nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
req.Header.Set("X-Subscription-Token", p.apiKey)
|
||||
req.Header.Set("Accept", "application/json")
|
||||
|
||||
resp, err := p.client.Do(req)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
// 解析响应
|
||||
body, _ := io.ReadAll(resp.Body)
|
||||
var data braveResponse
|
||||
if err := json.Unmarshal(body, &data); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
results := make([]*types.WebSearchResult, 0, len(data.Web.Results))
|
||||
for _, r := range data.Web.Results {
|
||||
results = append(results, &types.WebSearchResult{
|
||||
Title: r.Title,
|
||||
URL: r.URL,
|
||||
Snippet: r.Description,
|
||||
Source: "brave",
|
||||
})
|
||||
}
|
||||
return results, nil
|
||||
}
|
||||
|
||||
type braveResponse struct {
|
||||
Web struct {
|
||||
Results []struct {
|
||||
Title string `json:"title"`
|
||||
URL string `json:"url"`
|
||||
Description string `json:"description"`
|
||||
} `json:"results"`
|
||||
} `json:"web"`
|
||||
}
|
||||
```
|
||||
|
||||
**关键要求**:
|
||||
|
||||
1. **构造函数签名**必须为 `func(types.WebSearchProviderParameters) (interfaces.WebSearchProvider, error)`
|
||||
2. **API 端点硬编码**为常量,不从参数中读取
|
||||
3. **实现 `interfaces.WebSearchProvider` 接口**:`Name()` 和 `Search()`
|
||||
|
||||
### 4. 在 Service 的参数校验中添加新类型
|
||||
|
||||
编辑 `internal/application/service/web_search_provider.go`:
|
||||
|
||||
```go
|
||||
func isValidProviderType(provider types.WebSearchProviderType) bool {
|
||||
switch provider {
|
||||
case types.WebSearchProviderTypeBing,
|
||||
types.WebSearchProviderTypeGoogle,
|
||||
types.WebSearchProviderTypeDuckDuckGo,
|
||||
types.WebSearchProviderTypeTavily,
|
||||
types.WebSearchProviderTypeBrave: // ← 新增
|
||||
return true
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
### 5. 在 DI 容器中注册
|
||||
|
||||
编辑 `internal/container/container.go` 的 `registerWebSearchProviders` 函数:
|
||||
|
||||
```go
|
||||
func registerWebSearchProviders(registry *infra_web_search.Registry) {
|
||||
// ... 已有注册 ...
|
||||
|
||||
// Register Brave provider type
|
||||
registry.Register(infra_web_search.BraveProviderTypeInfo(), infra_web_search.NewBraveProvider)
|
||||
}
|
||||
```
|
||||
|
||||
### 6. 验证
|
||||
|
||||
```bash
|
||||
# 编译
|
||||
go build ./...
|
||||
|
||||
# 启动后调用 API 验证类型列表
|
||||
curl http://localhost:8080/api/v1/web-search-providers/types \
|
||||
-H 'X-API-Key: your_key'
|
||||
|
||||
# 创建 Brave 搜索引擎实例
|
||||
curl -X POST http://localhost:8080/api/v1/web-search-providers \
|
||||
-H 'X-API-Key: your_key' \
|
||||
-H 'Content-Type: application/json' \
|
||||
-d '{
|
||||
"name": "Brave Search",
|
||||
"provider": "brave",
|
||||
"parameters": { "api_key": "BSA..." },
|
||||
"is_default": true
|
||||
}'
|
||||
```
|
||||
|
||||
## 需要额外参数的情况
|
||||
|
||||
如果新引擎需要 API Key 以外的参数(类似 Google 的 `engine_id`),有两种方式:
|
||||
|
||||
### 方式一:使用 `ExtraConfig`
|
||||
|
||||
利用 `WebSearchProviderParameters.ExtraConfig` 字段,不需要改类型定义:
|
||||
|
||||
```go
|
||||
func NewFooProvider(params types.WebSearchProviderParameters) (interfaces.WebSearchProvider, error) {
|
||||
region := params.ExtraConfig["region"]
|
||||
if region == "" {
|
||||
region = "us"
|
||||
}
|
||||
// ...
|
||||
}
|
||||
```
|
||||
|
||||
前端在 `GetWebSearchProviderTypes()` 中可以标注需要哪些 extra 字段(后续支持动态表单渲染)。
|
||||
|
||||
### 方式二:添加专用字段
|
||||
|
||||
如果参数非常通用(比如多个引擎都需要),可以在 `WebSearchProviderParameters` 中添加新字段:
|
||||
|
||||
```go
|
||||
type WebSearchProviderParameters struct {
|
||||
APIKey string `json:"api_key,omitempty"`
|
||||
EngineID string `json:"engine_id,omitempty"`
|
||||
Region string `json:"region,omitempty"` // ← 新增
|
||||
ExtraConfig map[string]string `json:"extra_config,omitempty"`
|
||||
}
|
||||
```
|
||||
|
||||
同时在 `WebSearchProviderTypeInfo` 中添加 `RequiresRegion bool` 等字段,前端根据此动态显示输入框。
|
||||
|
||||
## 文件变更清单
|
||||
|
||||
| 文件 | 操作 |
|
||||
| ---- | ---- |
|
||||
| `internal/types/web_search_provider.go` | 添加常量 + 类型元数据 |
|
||||
| `internal/infrastructure/web_search/brave.go` | **新建** Provider 实现 |
|
||||
| `internal/application/service/web_search_provider.go` | `isValidProviderType` 加新类型 |
|
||||
| `internal/container/container.go` | `registerWebSearchProviders` 注册 |
|
||||
73
frontend/src/api/web-search-provider.ts
Normal file
73
frontend/src/api/web-search-provider.ts
Normal file
@@ -0,0 +1,73 @@
|
||||
import { get, post, put, del } from '@/utils/request'
|
||||
|
||||
// WebSearchProviderEntity represents a configured web search provider instance
|
||||
export interface WebSearchProviderEntity {
|
||||
id?: string
|
||||
tenant_id?: number
|
||||
name: string
|
||||
provider: 'bing' | 'google' | 'duckduckgo' | 'tavily'
|
||||
description?: string
|
||||
parameters: {
|
||||
api_key?: string
|
||||
engine_id?: string
|
||||
extra_config?: Record<string, string>
|
||||
}
|
||||
is_default?: boolean
|
||||
created_at?: string
|
||||
updated_at?: string
|
||||
}
|
||||
|
||||
// WebSearchProviderTypeInfo describes metadata for a provider type
|
||||
export interface WebSearchProviderTypeInfo {
|
||||
id: string
|
||||
name: string
|
||||
requires_api_key: boolean
|
||||
requires_engine_id?: boolean
|
||||
description?: string
|
||||
docs_url?: string
|
||||
}
|
||||
|
||||
// Create a new web search provider
|
||||
export function createWebSearchProvider(data: Partial<WebSearchProviderEntity>) {
|
||||
return post('/api/v1/web-search-providers', data)
|
||||
}
|
||||
|
||||
// List all web search providers for the current tenant
|
||||
export function listWebSearchProviders() {
|
||||
return get('/api/v1/web-search-providers')
|
||||
}
|
||||
|
||||
// Get a single web search provider by ID
|
||||
export function getWebSearchProvider(id: string) {
|
||||
return get(`/api/v1/web-search-providers/${id}`)
|
||||
}
|
||||
|
||||
// Update an existing web search provider
|
||||
export function updateWebSearchProvider(id: string, data: Partial<WebSearchProviderEntity>) {
|
||||
return put(`/api/v1/web-search-providers/${id}`, data)
|
||||
}
|
||||
|
||||
// Delete a web search provider
|
||||
export function deleteWebSearchProvider(id: string) {
|
||||
return del(`/api/v1/web-search-providers/${id}`)
|
||||
}
|
||||
|
||||
// Get available provider types (for dynamic form rendering)
|
||||
export function listWebSearchProviderTypes(): Promise<WebSearchProviderTypeInfo[]> {
|
||||
return get('/api/v1/web-search-providers/types').then((res: any) => {
|
||||
if (res.success && res.data) {
|
||||
return res.data
|
||||
}
|
||||
return []
|
||||
})
|
||||
}
|
||||
|
||||
// Test a web search provider connection.
|
||||
// If id is provided, tests the existing saved provider.
|
||||
// If data is provided, tests with raw credentials (no persistence).
|
||||
export function testWebSearchProvider(id?: string, data?: { provider: string; parameters: any }): Promise<any> {
|
||||
if (id) {
|
||||
return post(`/api/v1/web-search-providers/${id}/test`, {})
|
||||
}
|
||||
return post('/api/v1/web-search-providers/test', data || {})
|
||||
}
|
||||
@@ -1,6 +1,7 @@
|
||||
import { get, put } from '@/utils/request'
|
||||
|
||||
// WebSearchProviderConfig represents information about a web search provider
|
||||
// WebSearchProviderConfig represents information about a web search provider type
|
||||
// Deprecated: Use WebSearchProviderTypeInfo from web-search-provider.ts instead
|
||||
export interface WebSearchProviderConfig {
|
||||
id: string
|
||||
name: string
|
||||
@@ -12,7 +13,10 @@ export interface WebSearchProviderConfig {
|
||||
|
||||
// WebSearchConfig represents the web search configuration for a tenant
|
||||
export interface WebSearchConfig {
|
||||
provider: string
|
||||
// New: references a WebSearchProviderEntity by ID
|
||||
default_provider_id?: string
|
||||
// Deprecated: kept for backward compatibility
|
||||
provider?: string
|
||||
api_key?: string
|
||||
max_results: number
|
||||
include_date: boolean
|
||||
@@ -24,7 +28,7 @@ export interface WebSearchConfig {
|
||||
document_fragments?: number
|
||||
}
|
||||
|
||||
// Get web search providers
|
||||
// Get web search provider types (available engines)
|
||||
export function getWebSearchProviders() {
|
||||
return get('/api/v1/web-search/providers')
|
||||
}
|
||||
@@ -38,4 +42,3 @@ export function getTenantWebSearchConfig() {
|
||||
export function updateTenantWebSearchConfig(config: WebSearchConfig) {
|
||||
return put('/api/v1/tenants/kv/web-search-config', config)
|
||||
}
|
||||
|
||||
|
||||
@@ -450,7 +450,11 @@ export default {
|
||||
suggestedPrompts: 'Suggested Prompts',
|
||||
mode: 'Running Mode',
|
||||
webSearch: 'Web Search',
|
||||
webSearchProvider: 'Search Engine',
|
||||
webSearchProviderPlaceholder: 'Use default search engine',
|
||||
webSearchMaxResults: 'Max Search Results',
|
||||
webFetchEnabled: 'Auto-Fetch Page Content',
|
||||
webFetchTopN: 'Pages to Fetch',
|
||||
knowledgeBases: 'Knowledge Bases',
|
||||
allKnowledgeBases: 'All Knowledge Bases',
|
||||
allKnowledgeBasesDesc: 'Agent can access all knowledge bases',
|
||||
@@ -699,6 +703,34 @@ export default {
|
||||
webSearchSettings: {
|
||||
title: 'Web Search Configuration',
|
||||
description: 'Configure web search so answers can include up-to-date information from the internet.',
|
||||
// Provider entity management
|
||||
providersTitle: 'Search Engine Providers',
|
||||
addProvider: 'Add Provider',
|
||||
editProvider: 'Edit Provider',
|
||||
noProviders: 'No search engine providers configured. Click "Add Provider" to get started.',
|
||||
deleteConfirm: 'Are you sure you want to delete this provider?',
|
||||
default: 'Default',
|
||||
providerNameLabel: 'Name',
|
||||
providerNamePlaceholder: 'e.g., Production Bing Search',
|
||||
providerTypeLabel: 'Provider Type',
|
||||
providerDescLabel: 'Notes',
|
||||
providerDescPlaceholder: 'Optional, e.g., for testing',
|
||||
engineIdLabel: 'Engine ID',
|
||||
setAsDefault: 'Set as default',
|
||||
testConnection: 'Test Connection',
|
||||
testing: 'Testing...',
|
||||
free: 'Free',
|
||||
viewDocs: 'View docs for API key',
|
||||
apiKeyUnchanged: 'Leave empty to keep current key',
|
||||
noDescription: "No description provided",
|
||||
noProvidersDesc: "Add a web search provider to enable your agents to retrieve real-time information from the internet.",
|
||||
basicInfo: "Basic Information",
|
||||
credentials: "Credentials",
|
||||
setAsDefaultDesc: "This provider will be used by default when an agent doesn't specify one",
|
||||
// Search behavior
|
||||
searchBehaviorTitle: 'Search Behavior',
|
||||
defaultProviderLabel: 'Default Provider',
|
||||
defaultProviderDescription: 'Select the default search provider for agents that do not specify their own.',
|
||||
providerLabel: 'Search Provider',
|
||||
providerDescription: 'Choose the search engine service used for web search',
|
||||
providerPlaceholder: 'Select a search engine...',
|
||||
@@ -722,7 +754,12 @@ export default {
|
||||
toasts: {
|
||||
loadProvidersFailed: 'Failed to load search providers: {message}',
|
||||
saveSuccess: 'Web search configuration saved',
|
||||
saveFailed: 'Failed to save configuration: {message}'
|
||||
saveFailed: 'Failed to save configuration: {message}',
|
||||
providerCreated: 'Search provider created',
|
||||
providerUpdated: 'Search provider updated',
|
||||
providerDeleted: 'Search provider deleted',
|
||||
testSuccess: 'Connection test succeeded',
|
||||
testFailed: 'Connection test failed',
|
||||
}
|
||||
},
|
||||
chatHistorySettings: {
|
||||
@@ -3178,7 +3215,10 @@ export default {
|
||||
maxIterations: 'Maximum reasoning steps when the Agent executes tasks',
|
||||
kbScope: 'Select the scope of knowledge bases accessible to the agent',
|
||||
webSearch: 'When enabled, the agent can search the internet for information',
|
||||
webSearchProvider: 'Specify a search engine for this agent. Leave empty to use the default.',
|
||||
webSearchMaxResults: 'Maximum number of results returned per search',
|
||||
webFetchEnabled: 'After reranking, auto-fetch full page content from top web results for better answers',
|
||||
webFetchTopN: 'Maximum number of web pages to fetch after reranking',
|
||||
retrievalSection: 'Configure knowledge base retrieval and ranking parameters',
|
||||
queryExpansion: 'Automatically expand query terms to improve recall',
|
||||
embeddingTopK: 'Maximum number of results from vector retrieval',
|
||||
|
||||
@@ -550,6 +550,31 @@ export default {
|
||||
title: "웹 검색 설정",
|
||||
description:
|
||||
"웹 검색 기능을 구성하여 질문에 답변할 때 인터넷에서 실시간 정보를 가져와 지식베이스 내용을 보완합니다",
|
||||
providersTitle: "검색 엔진 프로바이더 설정",
|
||||
addProvider: "프로바이더 추가",
|
||||
editProvider: "프로바이더 편집",
|
||||
noProviders: "검색 엔진 프로바이더가 설정되지 않았습니다. \"프로바이더 추가\"를 클릭하여 시작하세요.",
|
||||
deleteConfirm: "이 프로바이더를 삭제하시겠습니까?",
|
||||
default: "기본",
|
||||
providerNameLabel: "이름",
|
||||
providerNamePlaceholder: "예: 프로덕션 Bing 검색",
|
||||
providerTypeLabel: "프로바이더 유형",
|
||||
providerDescLabel: "설명",
|
||||
engineIdLabel: "엔진 ID",
|
||||
setAsDefault: "기본으로 설정",
|
||||
free: "무료",
|
||||
viewDocs: "문서에서 API 키 확인",
|
||||
apiKeyUnchanged: "현재 키를 유지하려면 비워두세요",
|
||||
testConnection: "연결 테스트",
|
||||
testing: "테스트 중...",
|
||||
noDescription: "설명 없음",
|
||||
noProvidersDesc: "웹 검색 프로바이더를 추가하여 에이전트가 인터넷에서 실시간 정보를 가져올 수 있도록 합니다.",
|
||||
basicInfo: "기본 정보",
|
||||
credentials: "자격 증명",
|
||||
setAsDefaultDesc: "에이전트가 검색 엔진을 지정하지 않은 경우 이 프로바이더가 기본적으로 사용됩니다",
|
||||
searchBehaviorTitle: "검색 동작 설정",
|
||||
defaultProviderLabel: "기본 프로바이더",
|
||||
defaultProviderDescription: "자체 프로바이더를 지정하지 않은 에이전트의 기본 검색 프로바이더를 선택합니다",
|
||||
providerLabel: "검색 엔진 프로바이더",
|
||||
providerDescription: "웹 검색에 사용할 검색 엔진 서비스 선택",
|
||||
providerPlaceholder: "검색 엔진 선택...",
|
||||
@@ -575,6 +600,9 @@ export default {
|
||||
loadProvidersFailed: "검색 엔진 목록 로드 실패: {message}",
|
||||
saveSuccess: "웹 검색 설정이 저장되었습니다",
|
||||
saveFailed: "설정 저장 실패: {message}",
|
||||
providerCreated: "검색 엔진 프로바이더가 생성되었습니다",
|
||||
providerUpdated: "검색 엔진 프로바이더가 업데이트되었습니다",
|
||||
providerDeleted: "검색 엔진 프로바이더가 삭제되었습니다",
|
||||
},
|
||||
},
|
||||
chatHistorySettings: {
|
||||
|
||||
@@ -672,6 +672,31 @@ export default {
|
||||
webSearchSettings: {
|
||||
title: 'Настройки веб-поиска',
|
||||
description: 'Настройте веб-поиск, чтобы ответы могли включать актуальную информацию из интернета.',
|
||||
providersTitle: 'Поисковые провайдеры',
|
||||
addProvider: 'Добавить провайдер',
|
||||
editProvider: 'Редактировать провайдер',
|
||||
noProviders: 'Поисковые провайдеры не настроены. Нажмите «Добавить провайдер», чтобы начать.',
|
||||
deleteConfirm: 'Вы уверены, что хотите удалить этот провайдер?',
|
||||
default: 'По умолчанию',
|
||||
providerNameLabel: 'Название',
|
||||
providerNamePlaceholder: 'Напр., Продакшн Bing Поиск',
|
||||
providerTypeLabel: 'Тип провайдера',
|
||||
providerDescLabel: 'Описание',
|
||||
engineIdLabel: 'ID движка',
|
||||
setAsDefault: 'Установить по умолчанию',
|
||||
free: 'Бесплатно',
|
||||
viewDocs: 'Документация для получения ключа',
|
||||
apiKeyUnchanged: 'Оставьте пустым, чтобы сохранить текущий ключ',
|
||||
testConnection: 'Проверить соединение',
|
||||
testing: 'Тестирование...',
|
||||
noDescription: "Нет описания",
|
||||
noProvidersDesc: "Добавьте провайдера веб-поиска, чтобы позволить вашим агентам получать информацию из Интернета в реальном времени.",
|
||||
basicInfo: "Основная информация",
|
||||
credentials: "Учетные данные",
|
||||
setAsDefaultDesc: "Этот провайдер будет использоваться по умолчанию, если агент не укажет свой",
|
||||
searchBehaviorTitle: 'Поведение поиска',
|
||||
defaultProviderLabel: 'Провайдер по умолчанию',
|
||||
defaultProviderDescription: 'Выберите провайдер поиска по умолчанию для агентов, не указавших собственный.',
|
||||
providerLabel: 'Провайдер поиска',
|
||||
providerDescription: 'Выберите поисковый сервис, используемый для веб-поиска',
|
||||
providerPlaceholder: 'Выберите поисковую систему...',
|
||||
@@ -697,7 +722,10 @@ export default {
|
||||
toasts: {
|
||||
loadProvidersFailed: 'Не удалось загрузить список поисковых провайдеров: {message}',
|
||||
saveSuccess: 'Настройки веб-поиска сохранены',
|
||||
saveFailed: 'Не удалось сохранить настройки: {message}'
|
||||
saveFailed: 'Не удалось сохранить настройки: {message}',
|
||||
providerCreated: 'Поисковый провайдер создан',
|
||||
providerUpdated: 'Поисковый провайдер обновлён',
|
||||
providerDeleted: 'Поисковый провайдер удалён',
|
||||
}
|
||||
},
|
||||
chatHistorySettings: {
|
||||
|
||||
@@ -558,6 +558,34 @@ export default {
|
||||
title: "网络搜索配置",
|
||||
description:
|
||||
"配置网络搜索功能,在回答问题时可以从互联网获取实时信息补充知识库内容",
|
||||
// Provider entity management
|
||||
providersTitle: "搜索引擎配置",
|
||||
addProvider: "添加搜索引擎",
|
||||
editProvider: "编辑搜索引擎",
|
||||
noProviders: "暂无搜索引擎配置,点击「添加搜索引擎」开始配置。",
|
||||
deleteConfirm: "确定要删除此搜索引擎配置吗?",
|
||||
default: "默认",
|
||||
providerNameLabel: "名称",
|
||||
providerNamePlaceholder: "例如:生产环境 Bing 搜索",
|
||||
providerTypeLabel: "引擎类型",
|
||||
providerDescLabel: "备注",
|
||||
providerDescPlaceholder: "可选,如:测试环境用",
|
||||
engineIdLabel: "搜索引擎 ID",
|
||||
setAsDefault: "设为默认",
|
||||
testConnection: "测试连接",
|
||||
testing: "测试中...",
|
||||
free: "免费",
|
||||
viewDocs: "查看文档获取密钥",
|
||||
apiKeyUnchanged: "留空保持当前密钥不变",
|
||||
noDescription: "暂无描述",
|
||||
noProvidersDesc: "添加一个网络搜索引擎,为您的智能体提供实时的互联网信息检索能力。",
|
||||
basicInfo: "基础信息",
|
||||
credentials: "凭证信息",
|
||||
setAsDefaultDesc: "当智能体没有指定特定的搜索引擎时,将默认使用此配置",
|
||||
// Search behavior
|
||||
searchBehaviorTitle: "搜索行为配置",
|
||||
defaultProviderLabel: "默认搜索引擎",
|
||||
defaultProviderDescription: "为未指定搜索引擎的智能体选择默认使用的搜索引擎",
|
||||
providerLabel: "搜索引擎提供商",
|
||||
providerDescription: "选择用于网络搜索的搜索引擎服务",
|
||||
providerPlaceholder: "选择搜索引擎...",
|
||||
@@ -583,6 +611,11 @@ export default {
|
||||
loadProvidersFailed: "加载搜索引擎列表失败: {message}",
|
||||
saveSuccess: "网络搜索配置已保存",
|
||||
saveFailed: "保存配置失败: {message}",
|
||||
providerCreated: "搜索引擎配置已创建",
|
||||
providerUpdated: "搜索引擎配置已更新",
|
||||
providerDeleted: "搜索引擎配置已删除",
|
||||
testSuccess: "连接测试成功",
|
||||
testFailed: "连接测试失败",
|
||||
},
|
||||
},
|
||||
chatHistorySettings: {
|
||||
@@ -1155,7 +1188,11 @@ export default {
|
||||
suggestedPrompts: "推荐问题",
|
||||
mode: "运行模式",
|
||||
webSearch: "网络搜索",
|
||||
webSearchProvider: "搜索引擎",
|
||||
webSearchProviderPlaceholder: "使用默认搜索引擎",
|
||||
webSearchMaxResults: "最大搜索结果数",
|
||||
webFetchEnabled: "自动抓取页面内容",
|
||||
webFetchTopN: "抓取页面数",
|
||||
knowledgeBases: "关联知识库",
|
||||
allKnowledgeBases: "全部知识库",
|
||||
allKnowledgeBasesDesc: "智能体可访问所有知识库",
|
||||
@@ -3176,7 +3213,10 @@ export default {
|
||||
maxIterations: "Agent 执行任务时的最大推理步骤数",
|
||||
kbScope: "选择智能体可访问的知识库范围",
|
||||
webSearch: "启用后智能体可以搜索互联网获取信息",
|
||||
webSearchProvider: "为此智能体指定搜索引擎,留空则使用默认搜索引擎",
|
||||
webSearchMaxResults: "每次搜索返回的最大结果数量",
|
||||
webFetchEnabled: "Rerank 后自动抓取排名靠前的网页完整内容,提升回答质量",
|
||||
webFetchTopN: "Rerank 后最多抓取几个网页的完整内容",
|
||||
retrievalSection: "配置知识库检索和排序的参数",
|
||||
queryExpansion: "自动扩展查询词以提高召回率",
|
||||
embeddingTopK: "向量检索返回的最大结果数量",
|
||||
|
||||
@@ -914,6 +914,32 @@
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<!-- 网络搜索最大结果数 -->
|
||||
<div v-if="formData.config.web_search_enabled" class="setting-row">
|
||||
<div class="setting-info">
|
||||
<label>{{ $t('agent.editor.webSearchProvider') }}</label>
|
||||
<p class="desc">{{ $t('agentEditor.desc.webSearchProvider') }}</p>
|
||||
</div>
|
||||
<div class="setting-control">
|
||||
<t-select
|
||||
v-model="formData.config.web_search_provider_id"
|
||||
clearable
|
||||
:placeholder="$t('agent.editor.webSearchProviderPlaceholder')"
|
||||
style="width: 240px;"
|
||||
>
|
||||
<t-option
|
||||
v-for="p in webSearchProviderList"
|
||||
:key="p.id"
|
||||
:value="p.id"
|
||||
:label="p.name"
|
||||
>
|
||||
<span>{{ p.name }}</span>
|
||||
<t-tag v-if="p.is_default" theme="primary" size="small" style="margin-left: 6px;">{{ $t('common.default') }}</t-tag>
|
||||
</t-option>
|
||||
</t-select>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<!-- 网络搜索最大结果数 -->
|
||||
<div v-if="formData.config.web_search_enabled" class="setting-row">
|
||||
<div class="setting-info">
|
||||
@@ -927,6 +953,31 @@
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<!-- 自动抓取页面内容 -->
|
||||
<div v-if="formData.config.web_search_enabled" class="setting-row">
|
||||
<div class="setting-info">
|
||||
<label>{{ $t('agent.editor.webFetchEnabled') }}</label>
|
||||
<p class="desc">{{ $t('agentEditor.desc.webFetchEnabled') }}</p>
|
||||
</div>
|
||||
<div class="setting-control">
|
||||
<t-switch v-model="formData.config.web_fetch_enabled" />
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<!-- 抓取页面数 -->
|
||||
<div v-if="formData.config.web_search_enabled && formData.config.web_fetch_enabled" class="setting-row">
|
||||
<div class="setting-info">
|
||||
<label>{{ $t('agent.editor.webFetchTopN') }}</label>
|
||||
<p class="desc">{{ $t('agentEditor.desc.webFetchTopN') }}</p>
|
||||
</div>
|
||||
<div class="setting-control">
|
||||
<div class="slider-wrapper">
|
||||
<t-slider v-model="formData.config.web_fetch_top_n" :min="1" :max="10" />
|
||||
<span class="slider-value">{{ formData.config.web_fetch_top_n }}</span>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
@@ -1164,6 +1215,7 @@ import { listModels, type ModelConfig } from '@/api/model';
|
||||
import { listKnowledgeBases } from '@/api/knowledge-base';
|
||||
import { listMCPServices, type MCPService } from '@/api/mcp-service';
|
||||
import { listSkills, type SkillInfo } from '@/api/skill';
|
||||
import { listWebSearchProviders, type WebSearchProviderEntity } from '@/api/web-search-provider';
|
||||
import { getAgentConfig, getConversationConfig, getStorageEngineStatus, type StorageEngineStatusItem, type PromptTemplate } from '@/api/system';
|
||||
import { useUIStore } from '@/stores/ui';
|
||||
import { useOrganizationStore } from '@/stores/organization';
|
||||
@@ -1195,6 +1247,7 @@ const saving = ref(false);
|
||||
const allModels = ref<ModelConfig[]>([]);
|
||||
const kbOptions = ref<{ label: string; value: string; type?: 'document' | 'faq'; count?: number; shared?: boolean; orgName?: string }[]>([]);
|
||||
const mcpOptions = ref<{ label: string; value: string }[]>([]);
|
||||
const webSearchProviderList = ref<WebSearchProviderEntity[]>([]);
|
||||
const skillOptions = ref<{ name: string; description: string }[]>([]);
|
||||
// 是否允许启用 Skills(取决于后端沙箱是否启用,disabled 时为 false;未请求前为 false 避免闪显)
|
||||
const skillsAvailable = ref(false);
|
||||
@@ -1919,6 +1972,16 @@ const loadDependencies = async () => {
|
||||
console.warn('Failed to load storage engine status', e);
|
||||
}
|
||||
|
||||
// 加载网络搜索引擎配置列表
|
||||
try {
|
||||
const wsRes: any = await listWebSearchProviders();
|
||||
if (wsRes?.data && Array.isArray(wsRes.data)) {
|
||||
webSearchProviderList.value = wsRes.data;
|
||||
}
|
||||
} catch (e) {
|
||||
console.warn('Failed to load web search providers', e);
|
||||
}
|
||||
|
||||
// 加载占位符定义(从统一 API)
|
||||
try {
|
||||
const placeholdersRes = await getPlaceholders();
|
||||
|
||||
@@ -6,380 +6,325 @@
|
||||
</div>
|
||||
|
||||
<div class="settings-group">
|
||||
<!-- 搜索引擎提供商 -->
|
||||
<div class="setting-row">
|
||||
<div class="setting-info">
|
||||
<label>{{ t('webSearchSettings.providerLabel') }}</label>
|
||||
<p class="desc">{{ t('webSearchSettings.providerDescription') }}</p>
|
||||
</div>
|
||||
<div class="setting-control">
|
||||
<t-select
|
||||
v-model="localProvider"
|
||||
:loading="loadingProviders"
|
||||
filterable
|
||||
:placeholder="t('webSearchSettings.providerPlaceholder')"
|
||||
@change="handleProviderChange"
|
||||
@focus="loadProviders"
|
||||
style="width: 280px;"
|
||||
>
|
||||
<t-option
|
||||
v-for="provider in providers"
|
||||
:key="provider.id"
|
||||
:value="provider.id"
|
||||
:label="provider.name"
|
||||
>
|
||||
<div class="provider-option-wrapper">
|
||||
<div class="provider-option">
|
||||
<span class="provider-name">{{ provider.name }}</span>
|
||||
</div>
|
||||
</div>
|
||||
</t-option>
|
||||
</t-select>
|
||||
</div>
|
||||
<div class="section-subheader">
|
||||
<h3>{{ t('webSearchSettings.providersTitle') }}</h3>
|
||||
<t-button theme="primary" size="small" @click="openAddDialog">
|
||||
<template #icon><add-icon /></template>
|
||||
{{ t('webSearchSettings.addProvider') }}
|
||||
</t-button>
|
||||
</div>
|
||||
|
||||
<!-- API 密钥 -->
|
||||
<div v-if="selectedProvider && selectedProvider.requires_api_key" class="setting-row">
|
||||
<div class="setting-info">
|
||||
<label>{{ t('webSearchSettings.apiKeyLabel') }}</label>
|
||||
<p class="desc">{{ t('webSearchSettings.apiKeyDescription') }}</p>
|
||||
</div>
|
||||
<div class="setting-control">
|
||||
<t-input
|
||||
v-model="localAPIKey"
|
||||
type="password"
|
||||
:placeholder="t('webSearchSettings.apiKeyPlaceholder')"
|
||||
@change="handleAPIKeyChange"
|
||||
style="width: 400px;"
|
||||
:show-password="true"
|
||||
/>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<!-- 最大结果数 -->
|
||||
<div class="setting-row">
|
||||
<div class="setting-info">
|
||||
<label>{{ t('webSearchSettings.maxResultsLabel') }}</label>
|
||||
<p class="desc">{{ t('webSearchSettings.maxResultsDescription') }}</p>
|
||||
</div>
|
||||
<div class="setting-control">
|
||||
<div class="slider-with-value">
|
||||
<t-slider
|
||||
v-model="localMaxResults"
|
||||
:min="1"
|
||||
:max="50"
|
||||
:step="1"
|
||||
:marks="{ 1: '1', 10: '10', 20: '20', 30: '30', 40: '40', 50: '50' }"
|
||||
@change="handleMaxResultsChange"
|
||||
style="width: 200px;"
|
||||
/>
|
||||
<span class="value-display">{{ localMaxResults }}</span>
|
||||
<!-- Provider List -->
|
||||
<div v-if="providerEntities.length > 0" class="provider-list">
|
||||
<div v-for="entity in providerEntities" :key="entity.id" class="provider-item">
|
||||
<div class="item-info">
|
||||
<div class="item-header">
|
||||
<span class="item-name">{{ entity.name }}</span>
|
||||
<t-tag v-if="entity.is_default" theme="primary" size="small" variant="light">
|
||||
{{ t('webSearchSettings.default') }}
|
||||
</t-tag>
|
||||
<t-tag size="small" variant="outline">{{ entity.provider }}</t-tag>
|
||||
</div>
|
||||
<div class="item-desc">{{ entity.description || t('webSearchSettings.noDescription') }}</div>
|
||||
</div>
|
||||
<div class="item-actions">
|
||||
<t-button theme="default" variant="text" size="small" @click="testExistingConnection(entity)" :loading="testingId === entity.id">
|
||||
{{ t('webSearchSettings.testConnection') }}
|
||||
</t-button>
|
||||
<t-button theme="primary" variant="text" size="small" @click="editProvider(entity)">
|
||||
{{ t('common.edit') }}
|
||||
</t-button>
|
||||
<t-popconfirm :content="t('webSearchSettings.deleteConfirm')" @confirm="deleteProvider(entity.id!)">
|
||||
<t-button theme="danger" variant="text" size="small">
|
||||
{{ t('common.delete') }}
|
||||
</t-button>
|
||||
</t-popconfirm>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<!-- 包含日期 -->
|
||||
<div class="setting-row">
|
||||
<div class="setting-info">
|
||||
<label>{{ t('webSearchSettings.includeDateLabel') }}</label>
|
||||
<p class="desc">{{ t('webSearchSettings.includeDateDescription') }}</p>
|
||||
</div>
|
||||
<div class="setting-control">
|
||||
<t-switch
|
||||
v-model="localIncludeDate"
|
||||
@change="handleIncludeDateChange"
|
||||
/>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<!-- 压缩方法 -->
|
||||
<div class="setting-row">
|
||||
<div class="setting-info">
|
||||
<label>{{ t('webSearchSettings.compressionLabel') }}</label>
|
||||
<p class="desc">{{ t('webSearchSettings.compressionDescription') }}</p>
|
||||
</div>
|
||||
<div class="setting-control">
|
||||
<t-select
|
||||
v-model="localCompressionMethod"
|
||||
@change="handleCompressionMethodChange"
|
||||
style="width: 280px;"
|
||||
:placeholder="t('webSearchSettings.compressionLabel')"
|
||||
>
|
||||
<t-option value="none" :label="t('webSearchSettings.compressionNone')">
|
||||
{{ t('webSearchSettings.compressionNone') }}
|
||||
</t-option>
|
||||
<t-option value="llm_summary" :label="t('webSearchSettings.compressionSummary')">
|
||||
{{ t('webSearchSettings.compressionSummary') }}
|
||||
</t-option>
|
||||
</t-select>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<!-- 黑名单 -->
|
||||
<div class="setting-row vertical">
|
||||
<div class="setting-info">
|
||||
<label>{{ t('webSearchSettings.blacklistLabel') }}</label>
|
||||
<p class="desc">{{ t('webSearchSettings.blacklistDescription') }}</p>
|
||||
</div>
|
||||
<div class="setting-control">
|
||||
<t-textarea
|
||||
v-model="localBlacklistText"
|
||||
:placeholder="t('webSearchSettings.blacklistPlaceholder')"
|
||||
:autosize="{ minRows: 4, maxRows: 8 }"
|
||||
@change="handleBlacklistChange"
|
||||
style="width: 500px;"
|
||||
/>
|
||||
</div>
|
||||
<!-- Empty State -->
|
||||
<div v-else class="empty-providers">
|
||||
<p>{{ t('webSearchSettings.noProvidersDesc') }}</p>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<!-- Add/Edit Dialog -->
|
||||
<t-dialog
|
||||
v-model:visible="showAddProviderDialog"
|
||||
:header="editingProvider ? t('webSearchSettings.editProvider') : t('webSearchSettings.addProvider')"
|
||||
width="520px"
|
||||
:footer="false"
|
||||
destroy-on-close
|
||||
>
|
||||
<div class="dialog-form-container">
|
||||
<t-form :data="providerForm" label-align="top" @submit="saveProvider" class="provider-form">
|
||||
<t-form-item :label="t('webSearchSettings.providerTypeLabel')" name="provider">
|
||||
<t-select v-model="providerForm.provider" :disabled="!!editingProvider" @change="onProviderTypeChange">
|
||||
<t-option v-for="pt in providerTypes" :key="pt.id" :value="pt.id" :label="pt.name">
|
||||
<div class="provider-option">
|
||||
<span>{{ pt.name }}</span>
|
||||
<t-tag v-if="pt.free" theme="success" size="small" variant="light">{{ t('webSearchSettings.free') }}</t-tag>
|
||||
</div>
|
||||
</t-option>
|
||||
</t-select>
|
||||
</t-form-item>
|
||||
|
||||
<t-form-item :label="t('webSearchSettings.providerNameLabel')" name="name">
|
||||
<t-input v-model="providerForm.name" :placeholder="selectedProviderType?.name || t('webSearchSettings.providerNamePlaceholder')" />
|
||||
</t-form-item>
|
||||
|
||||
<t-form-item :label="t('webSearchSettings.providerDescLabel')" name="description">
|
||||
<t-input v-model="providerForm.description" :placeholder="t('webSearchSettings.providerDescPlaceholder')" />
|
||||
</t-form-item>
|
||||
|
||||
<template v-if="selectedProviderType?.requires_api_key || selectedProviderType?.requires_engine_id">
|
||||
<div class="form-divider"></div>
|
||||
|
||||
<div class="credentials-hint" v-if="selectedProviderType?.docs_url">
|
||||
<a :href="selectedProviderType.docs_url" target="_blank" rel="noopener noreferrer">
|
||||
{{ t('webSearchSettings.viewDocs') }} ↗
|
||||
</a>
|
||||
</div>
|
||||
|
||||
<t-form-item v-if="selectedProviderType?.requires_api_key" :label="t('webSearchSettings.apiKeyLabel')" name="parameters.api_key">
|
||||
<t-input
|
||||
v-model="providerForm.parameters.api_key"
|
||||
type="password"
|
||||
:placeholder="editingProvider ? t('webSearchSettings.apiKeyUnchanged') : t('webSearchSettings.apiKeyPlaceholder')"
|
||||
/>
|
||||
</t-form-item>
|
||||
<t-form-item v-if="selectedProviderType?.requires_engine_id" :label="t('webSearchSettings.engineIdLabel')" name="parameters.engine_id">
|
||||
<t-input v-model="providerForm.parameters.engine_id" :placeholder="t('webSearchSettings.engineIdLabel')" />
|
||||
</t-form-item>
|
||||
</template>
|
||||
|
||||
<div class="form-divider"></div>
|
||||
|
||||
<t-form-item :label="t('webSearchSettings.setAsDefault')" name="is_default">
|
||||
<template #help>
|
||||
<div class="switch-help">
|
||||
{{ t('webSearchSettings.setAsDefaultDesc') }}
|
||||
</div>
|
||||
</template>
|
||||
<t-switch v-model="providerForm.is_default" />
|
||||
</t-form-item>
|
||||
|
||||
<div class="dialog-footer">
|
||||
<div class="footer-left">
|
||||
<t-button
|
||||
v-if="selectedProviderType && !selectedProviderType.free"
|
||||
theme="default"
|
||||
variant="outline"
|
||||
:loading="testing"
|
||||
@click="testConnection"
|
||||
>
|
||||
{{ testing ? t('webSearchSettings.testing') : t('webSearchSettings.testConnection') }}
|
||||
</t-button>
|
||||
</div>
|
||||
<div class="footer-right">
|
||||
<t-button theme="default" variant="base" @click="showAddProviderDialog = false">{{ t('common.cancel') }}</t-button>
|
||||
<t-button theme="primary" type="submit" :loading="saving">{{ t('common.save') }}</t-button>
|
||||
</div>
|
||||
</div>
|
||||
</t-form>
|
||||
</div>
|
||||
</t-dialog>
|
||||
</div>
|
||||
</template>
|
||||
|
||||
<script setup lang="ts">
|
||||
import { ref, computed, onMounted, nextTick } from 'vue'
|
||||
import { ref, computed, onMounted } from 'vue'
|
||||
import { MessagePlugin } from 'tdesign-vue-next'
|
||||
import { useI18n } from 'vue-i18n'
|
||||
import { getWebSearchProviders, getTenantWebSearchConfig, updateTenantWebSearchConfig, type WebSearchProviderConfig, type WebSearchConfig } from '@/api/web-search'
|
||||
import { AddIcon } from 'tdesign-icons-vue-next'
|
||||
import {
|
||||
listWebSearchProviders,
|
||||
listWebSearchProviderTypes,
|
||||
createWebSearchProvider,
|
||||
updateWebSearchProvider,
|
||||
deleteWebSearchProvider as deleteWebSearchProviderAPI,
|
||||
testWebSearchProvider,
|
||||
type WebSearchProviderEntity,
|
||||
type WebSearchProviderTypeInfo,
|
||||
} from '@/api/web-search-provider'
|
||||
|
||||
const { t } = useI18n()
|
||||
|
||||
// 本地状态
|
||||
const loadingProviders = ref(false)
|
||||
const providers = ref<WebSearchProviderConfig[]>([])
|
||||
const localProvider = ref<string>('')
|
||||
const localAPIKey = ref<string>('')
|
||||
const localMaxResults = ref<number>(5)
|
||||
const localIncludeDate = ref<boolean>(true)
|
||||
const localCompressionMethod = ref<string>('none')
|
||||
const localBlacklistText = ref<string>('')
|
||||
const isInitializing = ref(true) // 标记是否正在初始化,初始化期间不触发自动保存
|
||||
const initialConfig = ref<WebSearchConfig | null>(null) // 保存初始配置,用于比较是否有变化
|
||||
// ===== State =====
|
||||
const providerEntities = ref<WebSearchProviderEntity[]>([])
|
||||
const providerTypes = ref<WebSearchProviderTypeInfo[]>([])
|
||||
const showAddProviderDialog = ref(false)
|
||||
const editingProvider = ref<WebSearchProviderEntity | null>(null)
|
||||
const testing = ref(false)
|
||||
const testingId = ref<string | null>(null)
|
||||
const saving = ref(false)
|
||||
|
||||
// 计算属性:当前选中的提供商
|
||||
const selectedProvider = computed(() => {
|
||||
return providers.value.find(p => p.id === localProvider.value)
|
||||
const providerForm = ref<{
|
||||
name: string
|
||||
provider: string
|
||||
description: string
|
||||
parameters: { api_key?: string; engine_id?: string }
|
||||
is_default: boolean
|
||||
}>({
|
||||
name: '',
|
||||
provider: 'duckduckgo',
|
||||
description: '',
|
||||
parameters: {},
|
||||
is_default: false,
|
||||
})
|
||||
|
||||
// 加载提供商列表
|
||||
const loadProviders = async () => {
|
||||
if (providers.value.length > 0) {
|
||||
return // 已加载过
|
||||
}
|
||||
|
||||
loadingProviders.value = true
|
||||
// ===== Computed =====
|
||||
const selectedProviderType = computed(() => {
|
||||
return providerTypes.value.find(pt => pt.id === providerForm.value.provider)
|
||||
})
|
||||
|
||||
// ===== Methods =====
|
||||
const onProviderTypeChange = () => {
|
||||
providerForm.value.parameters = {}
|
||||
}
|
||||
|
||||
const loadProviderEntities = async () => {
|
||||
try {
|
||||
const response = await getWebSearchProviders()
|
||||
// request拦截器已经处理了响应,直接使用data字段
|
||||
const response = await listWebSearchProviders()
|
||||
if (response.data && Array.isArray(response.data)) {
|
||||
providers.value = response.data
|
||||
providerEntities.value = response.data
|
||||
}
|
||||
} catch (error: any) {
|
||||
console.error('Failed to load web search providers:', error)
|
||||
const errorMessage = error?.message || t('webSearchSettings.errors.unknown')
|
||||
MessagePlugin.error(t('webSearchSettings.toasts.loadProvidersFailed', { message: errorMessage }))
|
||||
} finally {
|
||||
loadingProviders.value = false
|
||||
} catch (error) {
|
||||
console.error('Failed to load provider entities:', error)
|
||||
}
|
||||
}
|
||||
|
||||
// 加载租户配置
|
||||
const loadTenantConfig = async () => {
|
||||
const loadProviderTypes = async () => {
|
||||
try {
|
||||
const response = await getTenantWebSearchConfig()
|
||||
// request拦截器已经处理了响应,直接使用data字段
|
||||
if (response.data) {
|
||||
const config = response.data
|
||||
// 在设置初始值时,禁用自动保存
|
||||
isInitializing.value = true
|
||||
|
||||
// 保存初始配置的副本(用于后续比较)
|
||||
const blacklist = (config.blacklist || []).join('\n')
|
||||
initialConfig.value = {
|
||||
provider: config.provider || '',
|
||||
api_key: config.api_key === '***' ? '***' : config.api_key || '',
|
||||
max_results: config.max_results || 5,
|
||||
include_date: config.include_date !== undefined ? config.include_date : true,
|
||||
compression_method: config.compression_method || 'none',
|
||||
blacklist: config.blacklist || []
|
||||
}
|
||||
|
||||
// 设置本地状态值
|
||||
localProvider.value = config.provider || ''
|
||||
// API key 在响应中被隐藏,如果是 "***",说明已配置但未返回实际值
|
||||
localAPIKey.value = config.api_key === '***' ? '***' : config.api_key || ''
|
||||
localMaxResults.value = config.max_results || 5
|
||||
localIncludeDate.value = config.include_date !== undefined ? config.include_date : true
|
||||
localCompressionMethod.value = config.compression_method || 'none'
|
||||
localBlacklistText.value = blacklist
|
||||
|
||||
// 等待所有响应式更新完成后再启用自动保存
|
||||
await nextTick()
|
||||
await nextTick()
|
||||
// 使用 setTimeout 确保所有事件都已处理完毕
|
||||
setTimeout(() => {
|
||||
isInitializing.value = false
|
||||
}, 100)
|
||||
} else {
|
||||
// 如果没有配置数据,保存默认配置
|
||||
initialConfig.value = {
|
||||
provider: '',
|
||||
api_key: '',
|
||||
max_results: 5,
|
||||
include_date: true,
|
||||
compression_method: 'none',
|
||||
blacklist: []
|
||||
}
|
||||
await nextTick()
|
||||
setTimeout(() => {
|
||||
isInitializing.value = false
|
||||
}, 100)
|
||||
}
|
||||
} catch (error: any) {
|
||||
console.error('Failed to load tenant web search config:', error)
|
||||
// 如果配置不存在,使用默认值(不显示错误)
|
||||
initialConfig.value = {
|
||||
provider: '',
|
||||
providerTypes.value = await listWebSearchProviderTypes()
|
||||
} catch (error) {
|
||||
console.error('Failed to load provider types:', error)
|
||||
}
|
||||
}
|
||||
|
||||
const openAddDialog = () => {
|
||||
editingProvider.value = null
|
||||
providerForm.value = {
|
||||
name: '',
|
||||
provider: providerTypes.value[0]?.id || 'duckduckgo',
|
||||
description: '',
|
||||
parameters: {},
|
||||
is_default: providerEntities.value.length === 0
|
||||
}
|
||||
showAddProviderDialog.value = true
|
||||
}
|
||||
|
||||
const editProvider = (entity: WebSearchProviderEntity) => {
|
||||
editingProvider.value = entity
|
||||
providerForm.value = {
|
||||
name: entity.name,
|
||||
provider: entity.provider,
|
||||
description: entity.description || '',
|
||||
parameters: {
|
||||
api_key: '',
|
||||
max_results: 5,
|
||||
include_date: true,
|
||||
compression_method: 'none',
|
||||
blacklist: []
|
||||
}
|
||||
await nextTick()
|
||||
setTimeout(() => {
|
||||
isInitializing.value = false
|
||||
}, 100)
|
||||
engine_id: entity.parameters?.engine_id || '',
|
||||
},
|
||||
is_default: entity.is_default || false,
|
||||
}
|
||||
showAddProviderDialog.value = true
|
||||
}
|
||||
|
||||
// 检查配置是否有变化
|
||||
const hasConfigChanged = (): boolean => {
|
||||
if (!initialConfig.value) {
|
||||
return true // 如果没有初始配置,认为有变化
|
||||
}
|
||||
|
||||
const blacklist = localBlacklistText.value
|
||||
.split('\n')
|
||||
.map(line => line.trim())
|
||||
.filter(line => line.length > 0)
|
||||
|
||||
const currentConfig: WebSearchConfig = {
|
||||
provider: localProvider.value,
|
||||
api_key: localAPIKey.value,
|
||||
max_results: localMaxResults.value,
|
||||
include_date: localIncludeDate.value,
|
||||
compression_method: localCompressionMethod.value,
|
||||
blacklist: blacklist
|
||||
}
|
||||
|
||||
// 比较配置是否有变化(忽略 API key 的 '***' 占位符)
|
||||
const initial = initialConfig.value
|
||||
if (currentConfig.provider !== initial.provider) return true
|
||||
if (currentConfig.api_key !== initial.api_key &&
|
||||
!(currentConfig.api_key === '***' && initial.api_key === '***')) return true
|
||||
if (currentConfig.max_results !== initial.max_results) return true
|
||||
if (currentConfig.include_date !== initial.include_date) return true
|
||||
if (currentConfig.compression_method !== initial.compression_method) return true
|
||||
|
||||
// 比较黑名单数组
|
||||
const currentBlacklist = blacklist.sort().join(',')
|
||||
const initialBlacklist = (initial.blacklist || []).sort().join(',')
|
||||
if (currentBlacklist !== initialBlacklist) return true
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
// 保存配置
|
||||
const saveConfig = async () => {
|
||||
// 如果配置没有变化,不保存
|
||||
if (!hasConfigChanged()) {
|
||||
const saveProvider = async ({ validateResult, firstError }: any) => {
|
||||
if (validateResult !== true && validateResult !== undefined) {
|
||||
MessagePlugin.warning(firstError || 'Please check the form fields')
|
||||
return
|
||||
}
|
||||
|
||||
saving.value = true
|
||||
try {
|
||||
const blacklist = localBlacklistText.value
|
||||
.split('\n')
|
||||
.map(line => line.trim())
|
||||
.filter(line => line.length > 0)
|
||||
|
||||
const config: WebSearchConfig = {
|
||||
provider: localProvider.value,
|
||||
api_key: localAPIKey.value,
|
||||
max_results: localMaxResults.value,
|
||||
include_date: localIncludeDate.value,
|
||||
compression_method: localCompressionMethod.value,
|
||||
blacklist: blacklist
|
||||
const data: Partial<WebSearchProviderEntity> = {
|
||||
name: providerForm.value.name.trim() || selectedProviderType.value?.name || providerForm.value.provider,
|
||||
provider: providerForm.value.provider as any,
|
||||
description: providerForm.value.description,
|
||||
parameters: { ...providerForm.value.parameters },
|
||||
is_default: providerForm.value.is_default,
|
||||
}
|
||||
|
||||
await updateTenantWebSearchConfig(config)
|
||||
|
||||
// 更新初始配置,避免重复保存
|
||||
initialConfig.value = {
|
||||
provider: config.provider,
|
||||
api_key: config.api_key,
|
||||
max_results: config.max_results,
|
||||
include_date: config.include_date,
|
||||
compression_method: config.compression_method,
|
||||
blacklist: [...config.blacklist]
|
||||
if (editingProvider.value && !data.parameters!.api_key) {
|
||||
delete data.parameters!.api_key
|
||||
}
|
||||
|
||||
MessagePlugin.success(t('webSearchSettings.toasts.saveSuccess'))
|
||||
|
||||
if (editingProvider.value) {
|
||||
await updateWebSearchProvider(editingProvider.value.id!, data)
|
||||
MessagePlugin.success(t('webSearchSettings.toasts.providerUpdated'))
|
||||
} else {
|
||||
await createWebSearchProvider(data)
|
||||
MessagePlugin.success(t('webSearchSettings.toasts.providerCreated'))
|
||||
}
|
||||
showAddProviderDialog.value = false
|
||||
await loadProviderEntities()
|
||||
} catch (error: any) {
|
||||
console.error('Failed to save web search config:', error)
|
||||
const errorMessage = error?.message || t('webSearchSettings.errors.unknown')
|
||||
MessagePlugin.error(t('webSearchSettings.toasts.saveFailed', { message: errorMessage }))
|
||||
throw error
|
||||
MessagePlugin.error(error?.message || 'Failed to save provider')
|
||||
} finally {
|
||||
saving.value = false
|
||||
}
|
||||
}
|
||||
|
||||
// 防抖保存
|
||||
let saveTimer: number | null = null
|
||||
const debouncedSave = () => {
|
||||
// 初始化期间不触发自动保存
|
||||
if (isInitializing.value) {
|
||||
return
|
||||
const deleteProvider = async (id: string) => {
|
||||
try {
|
||||
await deleteWebSearchProviderAPI(id)
|
||||
MessagePlugin.success(t('webSearchSettings.toasts.providerDeleted'))
|
||||
await loadProviderEntities()
|
||||
} catch (error: any) {
|
||||
MessagePlugin.error(error?.message || 'Failed to delete provider')
|
||||
}
|
||||
if (saveTimer) {
|
||||
clearTimeout(saveTimer)
|
||||
}
|
||||
|
||||
const testConnection = async () => {
|
||||
testing.value = true
|
||||
try {
|
||||
const data = {
|
||||
provider: providerForm.value.provider,
|
||||
parameters: { ...providerForm.value.parameters },
|
||||
}
|
||||
|
||||
if (editingProvider.value && !data.parameters.api_key) {
|
||||
const res = await testWebSearchProvider(editingProvider.value.id!)
|
||||
if (res.success) {
|
||||
MessagePlugin.success(t('webSearchSettings.toasts.testSuccess'))
|
||||
} else {
|
||||
MessagePlugin.error(res.error || t('webSearchSettings.toasts.testFailed'))
|
||||
}
|
||||
} else {
|
||||
const res = await testWebSearchProvider(undefined, data)
|
||||
if (res.success) {
|
||||
MessagePlugin.success(t('webSearchSettings.toasts.testSuccess'))
|
||||
} else {
|
||||
MessagePlugin.error(res.error || t('webSearchSettings.toasts.testFailed'))
|
||||
}
|
||||
}
|
||||
} catch (error: any) {
|
||||
MessagePlugin.error(error?.message || t('webSearchSettings.toasts.testFailed'))
|
||||
} finally {
|
||||
testing.value = false
|
||||
}
|
||||
saveTimer = window.setTimeout(() => {
|
||||
saveConfig().catch(() => {
|
||||
// 错误已在 saveConfig 中处理
|
||||
})
|
||||
}, 500)
|
||||
}
|
||||
|
||||
// 处理变化
|
||||
const handleProviderChange = () => {
|
||||
debouncedSave()
|
||||
const testExistingConnection = async (entity: WebSearchProviderEntity) => {
|
||||
testingId.value = entity.id!
|
||||
try {
|
||||
const res = await testWebSearchProvider(entity.id!)
|
||||
if (res.success) {
|
||||
MessagePlugin.success(t('webSearchSettings.toasts.testSuccess'))
|
||||
} else {
|
||||
MessagePlugin.error(res.error || t('webSearchSettings.toasts.testFailed'))
|
||||
}
|
||||
} catch (error: any) {
|
||||
MessagePlugin.error(error?.message || t('webSearchSettings.toasts.testFailed'))
|
||||
} finally {
|
||||
testingId.value = null
|
||||
}
|
||||
}
|
||||
|
||||
const handleAPIKeyChange = () => {
|
||||
debouncedSave()
|
||||
}
|
||||
|
||||
const handleMaxResultsChange = () => {
|
||||
debouncedSave()
|
||||
}
|
||||
|
||||
const handleIncludeDateChange = () => {
|
||||
debouncedSave()
|
||||
}
|
||||
|
||||
const handleCompressionMethodChange = () => {
|
||||
debouncedSave()
|
||||
}
|
||||
|
||||
const handleBlacklistChange = () => {
|
||||
debouncedSave()
|
||||
}
|
||||
|
||||
// 初始化
|
||||
// ===== Init =====
|
||||
onMounted(async () => {
|
||||
isInitializing.value = true
|
||||
await loadProviders()
|
||||
await loadTenantConfig()
|
||||
// loadTenantConfig 内部已经处理了 isInitializing,这里不需要再设置
|
||||
await Promise.all([loadProviderTypes(), loadProviderEntities()])
|
||||
})
|
||||
</script>
|
||||
|
||||
@@ -409,139 +354,130 @@ onMounted(async () => {
|
||||
.settings-group {
|
||||
display: flex;
|
||||
flex-direction: column;
|
||||
gap: 0;
|
||||
}
|
||||
|
||||
.setting-row {
|
||||
.section-subheader {
|
||||
display: flex;
|
||||
align-items: flex-start;
|
||||
align-items: center;
|
||||
justify-content: space-between;
|
||||
padding: 20px 0;
|
||||
border-bottom: 1px solid var(--td-component-stroke);
|
||||
margin-bottom: 16px;
|
||||
|
||||
&:last-child {
|
||||
border-bottom: none;
|
||||
}
|
||||
|
||||
&.vertical {
|
||||
flex-direction: column;
|
||||
gap: 12px;
|
||||
|
||||
.setting-control {
|
||||
width: 100%;
|
||||
max-width: 100%;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
.setting-info {
|
||||
flex: 1;
|
||||
max-width: 65%;
|
||||
padding-right: 24px;
|
||||
|
||||
label {
|
||||
font-size: 15px;
|
||||
font-weight: 500;
|
||||
h3 {
|
||||
font-size: 16px;
|
||||
font-weight: 600;
|
||||
color: var(--td-text-color-primary);
|
||||
display: block;
|
||||
margin-bottom: 4px;
|
||||
}
|
||||
|
||||
.desc {
|
||||
font-size: 13px;
|
||||
color: var(--td-text-color-secondary);
|
||||
margin: 0;
|
||||
line-height: 1.5;
|
||||
}
|
||||
}
|
||||
|
||||
.setting-control {
|
||||
flex-shrink: 0;
|
||||
min-width: 280px;
|
||||
.provider-list {
|
||||
display: flex;
|
||||
justify-content: flex-end;
|
||||
align-items: center;
|
||||
flex-direction: column;
|
||||
gap: 8px;
|
||||
}
|
||||
|
||||
.slider-with-value {
|
||||
.provider-item {
|
||||
display: flex;
|
||||
align-items: center;
|
||||
gap: 12px;
|
||||
justify-content: space-between;
|
||||
padding: 14px 16px;
|
||||
background: var(--td-bg-color-container);
|
||||
border: 1px solid var(--td-component-stroke);
|
||||
border-radius: 8px;
|
||||
transition: all 0.2s ease;
|
||||
|
||||
&:hover {
|
||||
border-color: var(--td-brand-color);
|
||||
}
|
||||
}
|
||||
|
||||
.value-display {
|
||||
min-width: 40px;
|
||||
text-align: right;
|
||||
.item-info {
|
||||
display: flex;
|
||||
flex-direction: column;
|
||||
gap: 4px;
|
||||
}
|
||||
|
||||
.item-header {
|
||||
display: flex;
|
||||
align-items: center;
|
||||
gap: 8px;
|
||||
}
|
||||
|
||||
.item-name {
|
||||
font-size: 14px;
|
||||
font-weight: 500;
|
||||
color: var(--td-text-color-primary);
|
||||
}
|
||||
|
||||
.provider-option-wrapper {
|
||||
.item-desc {
|
||||
font-size: 13px;
|
||||
color: var(--td-text-color-secondary);
|
||||
}
|
||||
|
||||
.item-actions {
|
||||
display: flex;
|
||||
flex-direction: column;
|
||||
gap: 4px;
|
||||
padding: 2px 0;
|
||||
align-items: center;
|
||||
}
|
||||
|
||||
.empty-providers {
|
||||
padding: 32px;
|
||||
text-align: center;
|
||||
color: var(--td-text-color-placeholder);
|
||||
border: 1px dashed var(--td-component-stroke);
|
||||
border-radius: 8px;
|
||||
font-size: 14px;
|
||||
}
|
||||
|
||||
.dialog-form-container {
|
||||
margin-top: 12px;
|
||||
}
|
||||
|
||||
.provider-option {
|
||||
display: flex;
|
||||
align-items: center;
|
||||
justify-content: space-between;
|
||||
gap: 8px;
|
||||
flex-wrap: wrap;
|
||||
}
|
||||
|
||||
.provider-name {
|
||||
font-weight: 500;
|
||||
font-size: 14px;
|
||||
color: var(--td-text-color-primary);
|
||||
flex-shrink: 0;
|
||||
}
|
||||
|
||||
.provider-tags {
|
||||
display: flex;
|
||||
align-items: center;
|
||||
gap: 4px;
|
||||
flex-wrap: wrap;
|
||||
flex-shrink: 0;
|
||||
width: 100%;
|
||||
}
|
||||
|
||||
.provider-desc {
|
||||
.form-divider {
|
||||
height: 1px;
|
||||
background: var(--td-component-border);
|
||||
margin: 20px 0;
|
||||
}
|
||||
|
||||
.credentials-hint {
|
||||
margin-bottom: 12px;
|
||||
font-size: 13px;
|
||||
|
||||
a {
|
||||
color: var(--td-brand-color);
|
||||
text-decoration: none;
|
||||
|
||||
&:hover {
|
||||
text-decoration: underline;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
.switch-help {
|
||||
font-size: 12px;
|
||||
color: var(--td-text-color-placeholder);
|
||||
color: var(--td-text-color-secondary);
|
||||
margin-top: 4px;
|
||||
line-height: 1.4;
|
||||
margin-top: 2px;
|
||||
}
|
||||
|
||||
/* 修复下拉项描述与条目重叠:让选项支持多行自适应高度 */
|
||||
:deep(.t-select-option) {
|
||||
height: auto;
|
||||
align-items: flex-start;
|
||||
padding-top: 6px;
|
||||
padding-bottom: 6px;
|
||||
}
|
||||
|
||||
:deep(.t-select-option__content) {
|
||||
white-space: normal;
|
||||
}
|
||||
|
||||
</style>
|
||||
<style lang="less">
|
||||
.t-select__dropdown .t-select-option {
|
||||
height: auto;
|
||||
align-items: flex-start;
|
||||
padding-top: 6px;
|
||||
padding-bottom: 6px;
|
||||
}
|
||||
.t-select__dropdown .t-select-option__content {
|
||||
white-space: normal;
|
||||
}
|
||||
.t-select__dropdown .provider-option-wrapper {
|
||||
.dialog-footer {
|
||||
display: flex;
|
||||
flex-direction: column;
|
||||
gap: 4px;
|
||||
padding: 2px 0;
|
||||
justify-content: space-between;
|
||||
align-items: center;
|
||||
margin-top: 32px;
|
||||
padding-top: 20px;
|
||||
border-top: 1px solid var(--td-component-border);
|
||||
|
||||
.footer-right {
|
||||
display: flex;
|
||||
gap: 12px;
|
||||
}
|
||||
}
|
||||
</style>
|
||||
|
||||
|
||||
@@ -82,6 +82,7 @@ type WebSearchTool struct {
|
||||
webSearchStateService interfaces.WebSearchStateService
|
||||
sessionID string
|
||||
maxResults int
|
||||
providerID string // WebSearchProviderEntity ID (resolved from agent config or tenant default)
|
||||
}
|
||||
|
||||
// NewWebSearchTool creates a new web search tool
|
||||
@@ -92,6 +93,7 @@ func NewWebSearchTool(
|
||||
webSearchStateService interfaces.WebSearchStateService,
|
||||
sessionID string,
|
||||
maxResults int,
|
||||
providerID string,
|
||||
) *WebSearchTool {
|
||||
tool := webSearchTool
|
||||
tool.description = fmt.Sprintf(tool.description, maxResults, maxResults)
|
||||
@@ -104,6 +106,7 @@ func NewWebSearchTool(
|
||||
webSearchStateService: webSearchStateService,
|
||||
sessionID: sessionID,
|
||||
maxResults: maxResults,
|
||||
providerID: providerID,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -150,7 +153,7 @@ func (t *WebSearchTool) Execute(ctx context.Context, args json.RawMessage) (*typ
|
||||
|
||||
// Get tenant info from context (same approach as search.go)
|
||||
tenant := ctx.Value(types.TenantInfoContextKey).(*types.Tenant)
|
||||
if tenant == nil || tenant.WebSearchConfig == nil || tenant.WebSearchConfig.Provider == "" {
|
||||
if tenant == nil || tenant.WebSearchConfig == nil {
|
||||
logger.Errorf(ctx, "[Tool][WebSearch] Web search not configured for tenant %d", tenantID)
|
||||
return &types.ToolResult{
|
||||
Success: false,
|
||||
@@ -158,6 +161,9 @@ func (t *WebSearchTool) Execute(ctx context.Context, args json.RawMessage) (*typ
|
||||
}, fmt.Errorf("web search is not configured for tenant %d", tenantID)
|
||||
}
|
||||
|
||||
// Resolve provider ID: tool-level (set from agent config, which already resolved default)
|
||||
resolvedProviderID := t.providerID
|
||||
|
||||
// Create a copy of web search config with maxResults from agent config
|
||||
searchConfig := *tenant.WebSearchConfig
|
||||
searchConfig.MaxResults = t.maxResults
|
||||
@@ -165,11 +171,11 @@ func (t *WebSearchTool) Execute(ctx context.Context, args json.RawMessage) (*typ
|
||||
// Perform web search
|
||||
logger.Infof(
|
||||
ctx,
|
||||
"[Tool][WebSearch] Performing web search with provider: %s, maxResults: %d",
|
||||
searchConfig.Provider,
|
||||
"[Tool][WebSearch] Performing web search with providerID: %s, maxResults: %d",
|
||||
resolvedProviderID,
|
||||
searchConfig.MaxResults,
|
||||
)
|
||||
webResults, err := t.webSearchService.Search(ctx, &searchConfig, query)
|
||||
webResults, err := t.webSearchService.Search(ctx, resolvedProviderID, &searchConfig, query)
|
||||
if err != nil {
|
||||
logger.Errorf(ctx, "[Tool][WebSearch] Web search failed: %v", err)
|
||||
return &types.ToolResult{
|
||||
|
||||
89
internal/application/repository/web_search_provider.go
Normal file
89
internal/application/repository/web_search_provider.go
Normal file
@@ -0,0 +1,89 @@
|
||||
package repository
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
|
||||
"github.com/Tencent/WeKnora/internal/types"
|
||||
"github.com/Tencent/WeKnora/internal/types/interfaces"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
// webSearchProviderRepository implements the WebSearchProviderRepository interface
|
||||
type webSearchProviderRepository struct {
|
||||
db *gorm.DB
|
||||
}
|
||||
|
||||
// NewWebSearchProviderRepository creates a new web search provider repository
|
||||
func NewWebSearchProviderRepository(db *gorm.DB) interfaces.WebSearchProviderRepository {
|
||||
return &webSearchProviderRepository{db: db}
|
||||
}
|
||||
|
||||
// Create creates a new web search provider
|
||||
func (r *webSearchProviderRepository) Create(ctx context.Context, provider *types.WebSearchProviderEntity) error {
|
||||
return r.db.WithContext(ctx).Create(provider).Error
|
||||
}
|
||||
|
||||
// GetByID retrieves a web search provider by ID within a tenant scope
|
||||
func (r *webSearchProviderRepository) GetByID(ctx context.Context, tenantID uint64, id string) (*types.WebSearchProviderEntity, error) {
|
||||
var provider types.WebSearchProviderEntity
|
||||
if err := r.db.WithContext(ctx).Where(
|
||||
"id = ? AND tenant_id = ?", id, tenantID,
|
||||
).First(&provider).Error; err != nil {
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return nil, nil
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
return &provider, nil
|
||||
}
|
||||
|
||||
// GetDefault retrieves the default provider (is_default=true) for a tenant, or nil if none.
|
||||
func (r *webSearchProviderRepository) GetDefault(ctx context.Context, tenantID uint64) (*types.WebSearchProviderEntity, error) {
|
||||
var provider types.WebSearchProviderEntity
|
||||
if err := r.db.WithContext(ctx).Where(
|
||||
"tenant_id = ? AND is_default = ?", tenantID, true,
|
||||
).First(&provider).Error; err != nil {
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return nil, nil
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
return &provider, nil
|
||||
}
|
||||
|
||||
// List lists all web search providers for a tenant
|
||||
func (r *webSearchProviderRepository) List(ctx context.Context, tenantID uint64) ([]*types.WebSearchProviderEntity, error) {
|
||||
var providers []*types.WebSearchProviderEntity
|
||||
if err := r.db.WithContext(ctx).Where(
|
||||
"tenant_id = ?", tenantID,
|
||||
).Order("created_at ASC").Find(&providers).Error; err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return providers, nil
|
||||
}
|
||||
|
||||
// Update updates a web search provider
|
||||
func (r *webSearchProviderRepository) Update(ctx context.Context, provider *types.WebSearchProviderEntity) error {
|
||||
return r.db.WithContext(ctx).Model(&types.WebSearchProviderEntity{}).Where(
|
||||
"id = ? AND tenant_id = ?", provider.ID, provider.TenantID,
|
||||
).Select("*").Updates(provider).Error
|
||||
}
|
||||
|
||||
// Delete soft-deletes a web search provider
|
||||
func (r *webSearchProviderRepository) Delete(ctx context.Context, tenantID uint64, id string) error {
|
||||
return r.db.WithContext(ctx).Where(
|
||||
"id = ? AND tenant_id = ?", id, tenantID,
|
||||
).Delete(&types.WebSearchProviderEntity{}).Error
|
||||
}
|
||||
|
||||
// ClearDefault clears the default flag for all providers of a tenant, optionally excluding one
|
||||
func (r *webSearchProviderRepository) ClearDefault(ctx context.Context, tenantID uint64, excludeID string) error {
|
||||
query := r.db.WithContext(ctx).Model(&types.WebSearchProviderEntity{}).Where(
|
||||
"tenant_id = ? AND is_default = ?", tenantID, true,
|
||||
)
|
||||
if excludeID != "" {
|
||||
query = query.Where("id != ?", excludeID)
|
||||
}
|
||||
return query.Update("is_default", false).Error
|
||||
}
|
||||
@@ -404,8 +404,9 @@ func (s *agentService) registerTools(
|
||||
s.webSearchStateService,
|
||||
sessionID,
|
||||
config.WebSearchMaxResults,
|
||||
config.WebSearchProviderID,
|
||||
)
|
||||
logger.Infof(ctx, "Registered web_search tool for session: %s, maxResults: %d", sessionID, config.WebSearchMaxResults)
|
||||
logger.Infof(ctx, "Registered web_search tool for session: %s, maxResults: %d, providerID: %s", sessionID, config.WebSearchMaxResults, config.WebSearchProviderID)
|
||||
|
||||
case tools.ToolWebFetch:
|
||||
toolToRegister = tools.NewWebFetchTool(chatModel)
|
||||
|
||||
@@ -15,14 +15,15 @@ import (
|
||||
|
||||
// PluginSearch implements search functionality for chat pipeline
|
||||
type PluginSearch struct {
|
||||
knowledgeBaseService interfaces.KnowledgeBaseService
|
||||
knowledgeService interfaces.KnowledgeService
|
||||
chunkService interfaces.ChunkService
|
||||
config *config.Config
|
||||
webSearchService interfaces.WebSearchService
|
||||
tenantService interfaces.TenantService
|
||||
sessionService interfaces.SessionService
|
||||
webSearchStateService interfaces.WebSearchStateService
|
||||
knowledgeBaseService interfaces.KnowledgeBaseService
|
||||
knowledgeService interfaces.KnowledgeService
|
||||
chunkService interfaces.ChunkService
|
||||
config *config.Config
|
||||
webSearchService interfaces.WebSearchService
|
||||
tenantService interfaces.TenantService
|
||||
sessionService interfaces.SessionService
|
||||
webSearchStateService interfaces.WebSearchStateService
|
||||
webSearchProviderRepo interfaces.WebSearchProviderRepository
|
||||
}
|
||||
|
||||
func NewPluginSearch(eventManager *EventManager,
|
||||
@@ -34,6 +35,7 @@ func NewPluginSearch(eventManager *EventManager,
|
||||
tenantService interfaces.TenantService,
|
||||
sessionService interfaces.SessionService,
|
||||
webSearchStateService interfaces.WebSearchStateService,
|
||||
webSearchProviderRepo interfaces.WebSearchProviderRepository,
|
||||
) *PluginSearch {
|
||||
res := &PluginSearch{
|
||||
knowledgeBaseService: knowledgeBaseService,
|
||||
@@ -44,6 +46,7 @@ func NewPluginSearch(eventManager *EventManager,
|
||||
tenantService: tenantService,
|
||||
sessionService: sessionService,
|
||||
webSearchStateService: webSearchStateService,
|
||||
webSearchProviderRepo: webSearchProviderRepo,
|
||||
}
|
||||
eventManager.Register(res)
|
||||
return res
|
||||
@@ -617,18 +620,21 @@ func (p *PluginSearch) searchWebIfEnabled(ctx context.Context, chatManage *types
|
||||
return nil
|
||||
}
|
||||
tenant, _ := types.TenantInfoFromContext(ctx)
|
||||
if tenant == nil || tenant.WebSearchConfig == nil || tenant.WebSearchConfig.Provider == "" {
|
||||
if tenant == nil || tenant.WebSearchConfig == nil {
|
||||
pipelineWarn(ctx, "Search", "web_config_missing", map[string]interface{}{
|
||||
"tenant_id": chatManage.TenantID,
|
||||
})
|
||||
return nil
|
||||
}
|
||||
|
||||
// Use provider ID already resolved by session layer (agent config > tenant default)
|
||||
providerID := chatManage.WebSearchProviderID
|
||||
|
||||
pipelineInfo(ctx, "Search", "web_request", map[string]interface{}{
|
||||
"tenant_id": chatManage.TenantID,
|
||||
"provider": tenant.WebSearchConfig.Provider,
|
||||
"tenant_id": chatManage.TenantID,
|
||||
"provider_id": providerID,
|
||||
})
|
||||
webResults, err := p.webSearchService.Search(ctx, tenant.WebSearchConfig, chatManage.RewriteQuery)
|
||||
webResults, err := p.webSearchService.Search(ctx, providerID, tenant.WebSearchConfig, chatManage.RewriteQuery)
|
||||
if err != nil {
|
||||
pipelineWarn(ctx, "Search", "web_search_error", map[string]interface{}{
|
||||
"tenant_id": chatManage.TenantID,
|
||||
|
||||
@@ -40,6 +40,7 @@ func NewPluginSearchParallel(
|
||||
tenantService interfaces.TenantService,
|
||||
sessionService interfaces.SessionService,
|
||||
webSearchStateService interfaces.WebSearchStateService,
|
||||
webSearchProviderRepo interfaces.WebSearchProviderRepository,
|
||||
graphRepository interfaces.RetrieveGraphRepository,
|
||||
chunkRepository interfaces.ChunkRepository,
|
||||
knowledgeRepository interfaces.KnowledgeRepository,
|
||||
@@ -54,6 +55,7 @@ func NewPluginSearchParallel(
|
||||
tenantService: tenantService,
|
||||
sessionService: sessionService,
|
||||
webSearchStateService: webSearchStateService,
|
||||
webSearchProviderRepo: webSearchProviderRepo,
|
||||
}
|
||||
|
||||
searchEntityPlugin := &PluginSearchEntity{
|
||||
|
||||
113
internal/application/service/chat_pipeline/web_fetch.go
Normal file
113
internal/application/service/chat_pipeline/web_fetch.go
Normal file
@@ -0,0 +1,113 @@
|
||||
package chatpipeline
|
||||
|
||||
import (
|
||||
"context"
|
||||
"strings"
|
||||
"sync"
|
||||
|
||||
"github.com/Tencent/WeKnora/internal/infrastructure/web_fetch"
|
||||
"github.com/Tencent/WeKnora/internal/logger"
|
||||
"github.com/Tencent/WeKnora/internal/types"
|
||||
)
|
||||
|
||||
// PluginWebFetch fetches full page content for reranked web search results.
|
||||
// It runs between CHUNK_RERANK and CHUNK_MERGE, replacing snippet content
|
||||
// with the full page text for the top N web results.
|
||||
type PluginWebFetch struct{}
|
||||
|
||||
// NewPluginWebFetch creates and registers a new PluginWebFetch instance
|
||||
func NewPluginWebFetch(eventManager *EventManager) *PluginWebFetch {
|
||||
res := &PluginWebFetch{}
|
||||
eventManager.Register(res)
|
||||
return res
|
||||
}
|
||||
|
||||
// ActivationEvents returns the event types this plugin handles
|
||||
func (p *PluginWebFetch) ActivationEvents() []types.EventType {
|
||||
return []types.EventType{types.WEB_FETCH}
|
||||
}
|
||||
|
||||
// OnEvent handles the WEB_FETCH event
|
||||
func (p *PluginWebFetch) OnEvent(
|
||||
ctx context.Context,
|
||||
eventType types.EventType,
|
||||
chatManage *types.ChatManage,
|
||||
next func() *PluginError,
|
||||
) *PluginError {
|
||||
if !chatManage.WebFetchEnabled || !chatManage.WebSearchEnabled {
|
||||
pipelineInfo(ctx, "WebFetch", "skip", map[string]any{"reason": "disabled"})
|
||||
return next()
|
||||
}
|
||||
|
||||
topN := chatManage.WebFetchTopN
|
||||
if topN <= 0 {
|
||||
topN = 3
|
||||
}
|
||||
|
||||
// Find web search results in reranked results
|
||||
var webResults []*types.SearchResult
|
||||
for _, r := range chatManage.RerankResult {
|
||||
if strings.ToLower(r.KnowledgeSource) == "web_search" {
|
||||
webResults = append(webResults, r)
|
||||
if len(webResults) >= topN {
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if len(webResults) == 0 {
|
||||
pipelineInfo(ctx, "WebFetch", "skip", map[string]any{"reason": "no_web_results"})
|
||||
return next()
|
||||
}
|
||||
|
||||
logger.Infof(ctx, "[PIPELINE] stage=WebFetch action=start count=%d", len(webResults))
|
||||
|
||||
// Fetch in parallel
|
||||
type fetchResult struct {
|
||||
idx int
|
||||
content string
|
||||
err error
|
||||
}
|
||||
results := make([]fetchResult, len(webResults))
|
||||
var wg sync.WaitGroup
|
||||
|
||||
for i, r := range webResults {
|
||||
wg.Add(1)
|
||||
go func(idx int, sr *types.SearchResult) {
|
||||
defer wg.Done()
|
||||
fetchURL := sr.ID // web search results use URL as ID
|
||||
if fetchURL == "" {
|
||||
return
|
||||
}
|
||||
content, err := web_fetch.FetchURLContent(ctx, fetchURL)
|
||||
results[idx] = fetchResult{idx: idx, content: content, err: err}
|
||||
}(i, r)
|
||||
}
|
||||
wg.Wait()
|
||||
|
||||
// Replace snippet content with fetched full content
|
||||
fetchedCount := 0
|
||||
for _, fr := range results {
|
||||
if fr.err != nil {
|
||||
logger.Warnf(ctx, "[PIPELINE] stage=WebFetch action=fetch_failed url=%s err=%v",
|
||||
webResults[fr.idx].ID, fr.err)
|
||||
continue
|
||||
}
|
||||
if fr.content == "" {
|
||||
continue
|
||||
}
|
||||
// Truncate to reasonable size for LLM context
|
||||
content := fr.content
|
||||
if len(content) > 8000 {
|
||||
content = content[:8000] + "\n...(truncated)"
|
||||
}
|
||||
webResults[fr.idx].Content = content
|
||||
fetchedCount++
|
||||
}
|
||||
|
||||
pipelineInfo(ctx, "WebFetch", "complete", map[string]any{
|
||||
"fetched": fetchedCount,
|
||||
"total": len(webResults),
|
||||
})
|
||||
return next()
|
||||
}
|
||||
@@ -36,9 +36,10 @@ type sessionService struct {
|
||||
sessionStorage llmcontext.ContextStorage // Session storage
|
||||
knowledgeService interfaces.KnowledgeService // Service for knowledge operations
|
||||
chunkService interfaces.ChunkService // Service for chunk operations
|
||||
webSearchStateRepo interfaces.WebSearchStateService // Service for web search state
|
||||
kbShareService interfaces.KBShareService // Service for KB sharing operations
|
||||
memoryService interfaces.MemoryService // Service for memory operations
|
||||
webSearchStateRepo interfaces.WebSearchStateService // Service for web search state
|
||||
webSearchProviderRepo interfaces.WebSearchProviderRepository // Repository for web search provider entities
|
||||
kbShareService interfaces.KBShareService // Service for KB sharing operations
|
||||
memoryService interfaces.MemoryService // Service for memory operations
|
||||
}
|
||||
|
||||
// NewSessionService creates a new session service instance with all required dependencies
|
||||
@@ -54,24 +55,26 @@ func NewSessionService(cfg *config.Config,
|
||||
agentService interfaces.AgentService,
|
||||
sessionStorage llmcontext.ContextStorage,
|
||||
webSearchStateRepo interfaces.WebSearchStateService,
|
||||
webSearchProviderRepo interfaces.WebSearchProviderRepository,
|
||||
kbShareService interfaces.KBShareService,
|
||||
memoryService interfaces.MemoryService,
|
||||
) interfaces.SessionService {
|
||||
return &sessionService{
|
||||
cfg: cfg,
|
||||
sessionRepo: sessionRepo,
|
||||
messageRepo: messageRepo,
|
||||
knowledgeBaseService: knowledgeBaseService,
|
||||
knowledgeService: knowledgeService,
|
||||
chunkService: chunkService,
|
||||
modelService: modelService,
|
||||
tenantService: tenantService,
|
||||
eventManager: eventManager,
|
||||
agentService: agentService,
|
||||
sessionStorage: sessionStorage,
|
||||
webSearchStateRepo: webSearchStateRepo,
|
||||
kbShareService: kbShareService,
|
||||
memoryService: memoryService,
|
||||
cfg: cfg,
|
||||
sessionRepo: sessionRepo,
|
||||
messageRepo: messageRepo,
|
||||
knowledgeBaseService: knowledgeBaseService,
|
||||
knowledgeService: knowledgeService,
|
||||
chunkService: chunkService,
|
||||
modelService: modelService,
|
||||
tenantService: tenantService,
|
||||
eventManager: eventManager,
|
||||
agentService: agentService,
|
||||
sessionStorage: sessionStorage,
|
||||
webSearchStateRepo: webSearchStateRepo,
|
||||
webSearchProviderRepo: webSearchProviderRepo,
|
||||
kbShareService: kbShareService,
|
||||
memoryService: memoryService,
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -209,6 +209,7 @@ func (s *sessionService) buildAgentConfig(
|
||||
Temperature: customAgent.Config.Temperature,
|
||||
WebSearchEnabled: customAgent.Config.WebSearchEnabled && req.WebSearchEnabled,
|
||||
WebSearchMaxResults: customAgent.Config.WebSearchMaxResults,
|
||||
WebSearchProviderID: customAgent.Config.WebSearchProviderID,
|
||||
MultiTurnEnabled: customAgent.Config.MultiTurnEnabled,
|
||||
HistoryTurns: customAgent.Config.HistoryTurns,
|
||||
MCPSelectionMode: customAgent.Config.MCPSelectionMode,
|
||||
@@ -247,6 +248,13 @@ func (s *sessionService) buildAgentConfig(
|
||||
}
|
||||
}
|
||||
|
||||
// Resolve web search provider ID: agent-level > tenant default (is_default=true)
|
||||
if agentConfig.WebSearchProviderID == "" {
|
||||
if defaultProvider, err := s.webSearchProviderRepo.GetDefault(ctx, tenantInfo.ID); err == nil && defaultProvider != nil {
|
||||
agentConfig.WebSearchProviderID = defaultProvider.ID
|
||||
}
|
||||
}
|
||||
|
||||
logger.Infof(ctx, "Merged agent config from tenant %d and session %s", tenantInfo.ID, req.Session.ID)
|
||||
|
||||
// Log knowledge bases if present
|
||||
|
||||
@@ -114,6 +114,9 @@ func (s *sessionService) KnowledgeQA(
|
||||
RewritePromptSystem: s.cfg.Conversation.RewritePromptSystem,
|
||||
RewritePromptUser: s.cfg.Conversation.RewritePromptUser,
|
||||
WebSearchEnabled: req.WebSearchEnabled,
|
||||
WebSearchProviderID: s.resolveWebSearchProviderID(ctx, req, retrievalTenantID),
|
||||
WebFetchEnabled: s.resolveWebFetchEnabled(req),
|
||||
WebFetchTopN: s.resolveWebFetchTopN(req),
|
||||
TenantID: retrievalTenantID,
|
||||
Images: req.ImageURLs,
|
||||
VLMModelID: vlmModelID,
|
||||
@@ -162,6 +165,7 @@ func (s *sessionService) KnowledgeQA(
|
||||
Add(types.QUERY_UNDERSTAND).
|
||||
Add(types.CHUNK_SEARCH_PARALLEL).
|
||||
Add(types.CHUNK_RERANK).
|
||||
AddIf(req.WebSearchEnabled, types.WEB_FETCH).
|
||||
Add(types.CHUNK_MERGE).
|
||||
Add(types.FILTER_TOP_K).
|
||||
Add(types.DATA_ANALYSIS).
|
||||
@@ -798,3 +802,35 @@ func (s *sessionService) emitFallbackAnswer(ctx context.Context, chatManage *typ
|
||||
logger.Infof(ctx, "Fallback answer event emitted successfully")
|
||||
}
|
||||
}
|
||||
|
||||
// resolveWebSearchProviderID returns the web search provider ID to use for a pipeline request.
|
||||
// Priority: agent config > tenant default (is_default=true)
|
||||
func (s *sessionService) resolveWebSearchProviderID(ctx context.Context, req *types.QARequest, tenantID uint64) string {
|
||||
// 1. Agent-level override
|
||||
if req.CustomAgent != nil && req.CustomAgent.Config.WebSearchProviderID != "" {
|
||||
return req.CustomAgent.Config.WebSearchProviderID
|
||||
}
|
||||
// 2. Tenant default
|
||||
if s.webSearchProviderRepo != nil {
|
||||
if defaultProvider, err := s.webSearchProviderRepo.GetDefault(ctx, tenantID); err == nil && defaultProvider != nil {
|
||||
return defaultProvider.ID
|
||||
}
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
// resolveWebFetchEnabled returns whether auto web fetch is enabled for this request.
|
||||
func (s *sessionService) resolveWebFetchEnabled(req *types.QARequest) bool {
|
||||
if req.CustomAgent != nil {
|
||||
return req.CustomAgent.Config.WebFetchEnabled
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// resolveWebFetchTopN returns how many pages to fetch after rerank.
|
||||
func (s *sessionService) resolveWebFetchTopN(req *types.QARequest) int {
|
||||
if req.CustomAgent != nil && req.CustomAgent.Config.WebFetchTopN > 0 {
|
||||
return req.CustomAgent.Config.WebFetchTopN
|
||||
}
|
||||
return 3
|
||||
}
|
||||
|
||||
@@ -7,7 +7,7 @@ import (
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/Tencent/WeKnora/internal/application/service/web_search"
|
||||
infra_web_search "github.com/Tencent/WeKnora/internal/infrastructure/web_search"
|
||||
"github.com/Tencent/WeKnora/internal/config"
|
||||
"github.com/Tencent/WeKnora/internal/logger"
|
||||
"github.com/Tencent/WeKnora/internal/searchutil"
|
||||
@@ -15,10 +15,117 @@ import (
|
||||
"github.com/Tencent/WeKnora/internal/types/interfaces"
|
||||
)
|
||||
|
||||
// WebSearchService provides web search functionality
|
||||
// WebSearchService provides web search functionality.
|
||||
// It resolves provider configurations from the database and creates provider
|
||||
// instances on-demand via the infrastructure registry.
|
||||
type WebSearchService struct {
|
||||
providers map[string]interfaces.WebSearchProvider
|
||||
timeout int
|
||||
registry *infra_web_search.Registry
|
||||
providerRepo interfaces.WebSearchProviderRepository
|
||||
timeout int
|
||||
}
|
||||
|
||||
// NewWebSearchService creates a new web search service.
|
||||
// The registry holds provider type factories; the providerRepo loads tenant-specific configurations.
|
||||
func NewWebSearchService(
|
||||
cfg *config.Config,
|
||||
registry *infra_web_search.Registry,
|
||||
providerRepo interfaces.WebSearchProviderRepository,
|
||||
) (interfaces.WebSearchService, error) {
|
||||
timeout := 10 // default timeout in seconds
|
||||
if cfg.WebSearch != nil && cfg.WebSearch.Timeout > 0 {
|
||||
timeout = cfg.WebSearch.Timeout
|
||||
}
|
||||
|
||||
return &WebSearchService{
|
||||
registry: registry,
|
||||
providerRepo: providerRepo,
|
||||
timeout: timeout,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Search performs web search using the provider entity identified by providerID.
|
||||
// If providerID is empty, it falls back to the deprecated config.Provider field for backward compatibility.
|
||||
func (s *WebSearchService) Search(
|
||||
ctx context.Context,
|
||||
providerID string,
|
||||
config *types.WebSearchConfig,
|
||||
query string,
|
||||
) ([]*types.WebSearchResult, error) {
|
||||
if config == nil {
|
||||
return nil, fmt.Errorf("web search config is required")
|
||||
}
|
||||
|
||||
// Resolve the provider
|
||||
searchProvider, err := s.resolveProvider(ctx, providerID, config)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Set timeout
|
||||
timeout := time.Duration(s.timeout) * time.Second
|
||||
if timeout == 0 {
|
||||
timeout = 10 * time.Second
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithTimeout(ctx, timeout)
|
||||
defer cancel()
|
||||
|
||||
// Perform search
|
||||
results, err := searchProvider.Search(ctx, query, config.MaxResults, config.IncludeDate)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("web search failed: %w", err)
|
||||
}
|
||||
|
||||
// Apply blacklist filtering
|
||||
results = s.filterBlacklist(results, config.Blacklist)
|
||||
|
||||
return results, nil
|
||||
}
|
||||
|
||||
// resolveProvider resolves a WebSearchProvider instance from either:
|
||||
// 1. A provider entity ID (new path) — loads from DB, creates via registry
|
||||
// 2. The deprecated config.Provider field (backward compatibility) — creates with empty params
|
||||
func (s *WebSearchService) resolveProvider(
|
||||
ctx context.Context,
|
||||
providerID string,
|
||||
cfg *types.WebSearchConfig,
|
||||
) (interfaces.WebSearchProvider, error) {
|
||||
// New path: load provider entity from DB
|
||||
if providerID != "" {
|
||||
tenantID, ok := types.TenantIDFromContext(ctx)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("tenant ID not found in context")
|
||||
}
|
||||
|
||||
entity, err := s.providerRepo.GetByID(ctx, tenantID, providerID)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to load web search provider %s: %w", providerID, err)
|
||||
}
|
||||
if entity == nil {
|
||||
return nil, fmt.Errorf("web search provider not found: %s", providerID)
|
||||
}
|
||||
|
||||
provider, err := s.registry.CreateProvider(string(entity.Provider), entity.Parameters)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create provider %s (%s): %w", entity.Name, entity.Provider, err)
|
||||
}
|
||||
return provider, nil
|
||||
}
|
||||
|
||||
// Backward compatibility: use the deprecated config.Provider field
|
||||
if cfg.Provider != "" {
|
||||
logger.Warnf(ctx, "Using deprecated WebSearchConfig.Provider field: %s. Please migrate to WebSearchProviderEntity.", cfg.Provider)
|
||||
params := types.WebSearchProviderParameters{
|
||||
APIKey: cfg.APIKey,
|
||||
}
|
||||
provider, err := s.registry.CreateProvider(cfg.Provider, params)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("web search provider %s is not available: %w", cfg.Provider, err)
|
||||
}
|
||||
return provider, nil
|
||||
}
|
||||
|
||||
return nil, fmt.Errorf("no web search provider configured")
|
||||
}
|
||||
|
||||
// CompressWithRAG performs RAG-based compression using a temporary, hidden knowledge base.
|
||||
@@ -242,72 +349,6 @@ func stripMarker(content string) string {
|
||||
return content
|
||||
}
|
||||
|
||||
// Search performs web search using the specified provider
|
||||
// This method implements the interface expected by PluginSearch
|
||||
func (s *WebSearchService) Search(
|
||||
ctx context.Context,
|
||||
config *types.WebSearchConfig,
|
||||
query string,
|
||||
) ([]*types.WebSearchResult, error) {
|
||||
if config == nil {
|
||||
return nil, fmt.Errorf("web search config is required")
|
||||
}
|
||||
|
||||
provider, ok := s.providers[config.Provider]
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("web search provider %s is not available", config.Provider)
|
||||
}
|
||||
|
||||
// Set timeout
|
||||
timeout := time.Duration(s.timeout) * time.Second
|
||||
if timeout == 0 {
|
||||
timeout = 10 * time.Second
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithTimeout(ctx, timeout)
|
||||
defer cancel()
|
||||
|
||||
// Perform search
|
||||
results, err := provider.Search(ctx, query, config.MaxResults, config.IncludeDate)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("web search failed: %w", err)
|
||||
}
|
||||
|
||||
// Apply blacklist filtering
|
||||
results = s.filterBlacklist(results, config.Blacklist)
|
||||
|
||||
// Apply compression if needed
|
||||
if config.CompressionMethod != "none" && config.CompressionMethod != "" {
|
||||
// Compression will be handled later in the integration layer
|
||||
// For now, we just return the results
|
||||
}
|
||||
|
||||
return results, nil
|
||||
}
|
||||
|
||||
// NewWebSearchService creates a new web search service
|
||||
func NewWebSearchService(cfg *config.Config, registry *web_search.Registry) (interfaces.WebSearchService, error) {
|
||||
timeout := 10 // default timeout
|
||||
if cfg.WebSearch != nil && cfg.WebSearch.Timeout > 0 {
|
||||
timeout = cfg.WebSearch.Timeout
|
||||
}
|
||||
|
||||
// Create all registered providers
|
||||
providers, err := registry.CreateAllProviders()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
for id := range providers {
|
||||
logger.Infof(context.Background(), "Initialized web search provider: %s", id)
|
||||
}
|
||||
|
||||
return &WebSearchService{
|
||||
providers: providers,
|
||||
timeout: timeout,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// filterBlacklist filters results based on blacklist rules
|
||||
func (s *WebSearchService) filterBlacklist(
|
||||
results []*types.WebSearchResult,
|
||||
|
||||
@@ -1,149 +0,0 @@
|
||||
package web_search
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"os"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func setBingEnv(apiKey string) {
|
||||
os.Setenv("BING_SEARCH_API_KEY", apiKey)
|
||||
}
|
||||
|
||||
func unsetBingEnv() {
|
||||
os.Unsetenv("BING_SEARCH_API_KEY")
|
||||
}
|
||||
|
||||
func TestNewBingProvider(t *testing.T) {
|
||||
setBingEnv("test-api-key")
|
||||
defer unsetBingEnv()
|
||||
|
||||
provider, err := NewBingProvider()
|
||||
require.NoError(t, err)
|
||||
assert.NotNil(t, provider)
|
||||
}
|
||||
|
||||
func TestBingProvider_Search(t *testing.T) {
|
||||
mockResponse := map[string]interface{}{
|
||||
"_type": "SearchResponse",
|
||||
"webPages": map[string]interface{}{
|
||||
"webSearchUrl": "https://www.bing.com/search?q=test",
|
||||
"totalEstimatedMatches": 1000,
|
||||
"value": []map[string]interface{}{
|
||||
{
|
||||
"id": "result-1",
|
||||
"name": "Test Result 1",
|
||||
"url": "https://example.com/1",
|
||||
"isFamilyFriendly": true,
|
||||
"displayUrl": "example.com/1",
|
||||
"snippet": "This is a test snippet 1",
|
||||
"dateLastCrawled": time.Now().Format(time.RFC3339),
|
||||
},
|
||||
{
|
||||
"id": "result-2",
|
||||
"name": "Test Result 2",
|
||||
"url": "https://example.com/2",
|
||||
"isFamilyFriendly": true,
|
||||
"displayUrl": "example.com/2",
|
||||
"snippet": "This is a test snippet 2",
|
||||
"dateLastCrawled": time.Now().Format(time.RFC3339),
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.Method != "GET" {
|
||||
w.WriteHeader(http.StatusMethodNotAllowed)
|
||||
return
|
||||
}
|
||||
|
||||
if r.Header.Get("Ocp-Apim-Subscription-Key") != "test-api-key" {
|
||||
w.WriteHeader(http.StatusUnauthorized)
|
||||
return
|
||||
}
|
||||
|
||||
query := r.URL.Query().Get("q")
|
||||
if query == "" {
|
||||
w.WriteHeader(http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
json.NewEncoder(w).Encode(mockResponse)
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
provider := &BingProvider{
|
||||
client: server.Client(),
|
||||
baseURL: server.URL,
|
||||
apiKey: "test-api-key",
|
||||
}
|
||||
|
||||
t.Run("Successful search", func(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
results, err := provider.Search(ctx, "test query", 10, true)
|
||||
require.NoError(t, err)
|
||||
assert.Len(t, results, 2)
|
||||
assert.Equal(t, "Test Result 1", results[0].Title)
|
||||
assert.Equal(t, "https://example.com/1", results[0].URL)
|
||||
assert.Equal(t, "bing", results[0].Source)
|
||||
})
|
||||
|
||||
t.Run("Empty query", func(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
results, err := provider.Search(ctx, "", 10, true)
|
||||
assert.Error(t, err)
|
||||
assert.Nil(t, results)
|
||||
assert.Contains(t, err.Error(), "query is empty")
|
||||
})
|
||||
}
|
||||
|
||||
func TestBingProvider_Search_Error(t *testing.T) {
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusInternalServerError)
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
provider := &BingProvider{
|
||||
client: server.Client(),
|
||||
baseURL: server.URL,
|
||||
apiKey: "test-api-key",
|
||||
}
|
||||
|
||||
t.Run("Server error", func(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
results, err := provider.Search(ctx, "test query", 10, true)
|
||||
assert.Error(t, err)
|
||||
assert.Nil(t, results)
|
||||
})
|
||||
}
|
||||
|
||||
func TestBingProvider_Search_InvalidJSON(t *testing.T) {
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.Write([]byte("invalid json"))
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
provider := &BingProvider{
|
||||
client: server.Client(),
|
||||
baseURL: server.URL,
|
||||
apiKey: "test-api-key",
|
||||
}
|
||||
|
||||
t.Run("Invalid JSON response", func(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
results, err := provider.Search(ctx, "test query", 10, true)
|
||||
assert.Error(t, err)
|
||||
assert.Nil(t, results)
|
||||
assert.Contains(t, err.Error(), "failed to unmarshal response")
|
||||
})
|
||||
}
|
||||
@@ -1,280 +0,0 @@
|
||||
package web_search
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"net/url"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
// testRoundTripper rewrites outgoing requests that target DuckDuckGo hosts
|
||||
// to the provided test server, preserving path and query.
|
||||
type testRoundTripper struct {
|
||||
base *url.URL
|
||||
next http.RoundTripper
|
||||
}
|
||||
|
||||
func (t *testRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) {
|
||||
// Only rewrite requests to duckduckgo hosts used by the provider
|
||||
if req.URL.Host == "html.duckduckgo.com" || req.URL.Host == "api.duckduckgo.com" {
|
||||
cloned := *req
|
||||
u := *req.URL
|
||||
u.Scheme = t.base.Scheme
|
||||
u.Host = t.base.Host
|
||||
// Keep original path; our test server handlers should register for the same paths.
|
||||
cloned.URL = &u
|
||||
req = &cloned
|
||||
}
|
||||
return t.next.RoundTrip(req)
|
||||
}
|
||||
|
||||
func newTestClient(ts *httptest.Server) *http.Client {
|
||||
baseURL, _ := url.Parse(ts.URL)
|
||||
return &http.Client{
|
||||
Timeout: 5 * time.Second,
|
||||
Transport: &testRoundTripper{
|
||||
base: baseURL,
|
||||
next: http.DefaultTransport,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func TestDuckDuckGoProvider_Name(t *testing.T) {
|
||||
p, _ := NewDuckDuckGoProvider()
|
||||
if p.Name() != "duckduckgo" {
|
||||
t.Fatalf("expected provider name duckduckgo, got %s", p.Name())
|
||||
}
|
||||
}
|
||||
|
||||
func TestDuckDuckGoProvider(t *testing.T) {
|
||||
// Minimal HTML page with two results, matching selectors used in searchHTML
|
||||
html := `
|
||||
<html>
|
||||
<body>
|
||||
<div class="web-result">
|
||||
<a class="result__a" href="https://duckduckgo.com/l/?uddg=https%3A%2F%2Fexample.com%2Fpage1&rut=">Example One</a>
|
||||
<div class="result__snippet">Snippet one</div>
|
||||
</div>
|
||||
<div class="web-result">
|
||||
<a class="result__a" href="//duckduckgo.com/l/?uddg=https%3A%2F%2Fexample.org%2Fpage2&rut=">Example Two</a>
|
||||
<div class="result__snippet">Snippet two</div>
|
||||
</div>
|
||||
</body>
|
||||
</html>`
|
||||
|
||||
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
// Provider requests GET https://html.duckduckgo.com/html/?q=...&kl=...
|
||||
if r.URL.Path == "/html/" {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
_, _ = w.Write([]byte(html))
|
||||
return
|
||||
}
|
||||
t.Fatalf("unexpected request path: %s", r.URL.Path)
|
||||
}))
|
||||
defer ts.Close()
|
||||
|
||||
// Build provider and inject our test client
|
||||
prov, _ := NewDuckDuckGoProvider()
|
||||
dp := prov.(*DuckDuckGoProvider)
|
||||
if dp == nil {
|
||||
t.Fatalf("failed to build provider")
|
||||
}
|
||||
dp.client = newTestClient(ts)
|
||||
|
||||
ctx := context.Background()
|
||||
results, err := dp.Search(ctx, "weknora", 5, false)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if len(results) != 2 {
|
||||
t.Fatalf("expected 2 results, got %d", len(results))
|
||||
}
|
||||
if results[0].Title != "Example One" || !strings.HasPrefix(results[0].URL, "https://example.com/") ||
|
||||
results[0].Snippet != "Snippet one" {
|
||||
t.Fatalf("unexpected first result: %+v", results[0])
|
||||
}
|
||||
if results[1].Title != "Example Two" || !strings.HasPrefix(results[1].URL, "https://example.org/") ||
|
||||
results[1].Snippet != "Snippet two" {
|
||||
t.Fatalf("unexpected second result: %+v", results[1])
|
||||
}
|
||||
}
|
||||
|
||||
func TestDuckDuckGoProvider_Fallback(t *testing.T) {
|
||||
// Simulate HTML returning non-OK to force API fallback, then a minimal API JSON
|
||||
apiResp := struct {
|
||||
AbstractText string `json:"AbstractText"`
|
||||
AbstractURL string `json:"AbstractURL"`
|
||||
Heading string `json:"Heading"`
|
||||
Results []struct {
|
||||
FirstURL string `json:"FirstURL"`
|
||||
Text string `json:"Text"`
|
||||
} `json:"Results"`
|
||||
}{
|
||||
AbstractText: "Abstract snippet",
|
||||
AbstractURL: "https://example.com/abstract",
|
||||
Heading: "Abstract Heading",
|
||||
Results: []struct {
|
||||
FirstURL string `json:"FirstURL"`
|
||||
Text string `json:"Text"`
|
||||
}{
|
||||
{FirstURL: "https://example.net/x", Text: "Title X - Detail X"},
|
||||
},
|
||||
}
|
||||
|
||||
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
switch r.URL.Path {
|
||||
case "/html/":
|
||||
// Force fallback by returning 500
|
||||
w.WriteHeader(http.StatusInternalServerError)
|
||||
default:
|
||||
// API endpoint path "/"
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
enc := json.NewEncoder(w)
|
||||
_ = enc.Encode(apiResp)
|
||||
}
|
||||
}))
|
||||
defer ts.Close()
|
||||
|
||||
prov, _ := NewDuckDuckGoProvider()
|
||||
dp := prov.(*DuckDuckGoProvider)
|
||||
if dp == nil {
|
||||
t.Fatalf("failed to build provider")
|
||||
}
|
||||
dp.client = newTestClient(ts)
|
||||
|
||||
ctx := context.Background()
|
||||
results, err := dp.Search(ctx, "weknora", 3, false)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if len(results) == 0 {
|
||||
t.Fatalf("expected some results from API fallback")
|
||||
}
|
||||
if results[0].URL != "https://example.com/abstract" || results[0].Title != "Abstract Heading" {
|
||||
t.Fatalf("unexpected first API result: %+v", results[0])
|
||||
}
|
||||
}
|
||||
|
||||
// TestDuckDuckGoProvider_Search_Real tests the DuckDuckGo provider against the real DuckDuckGo service.
|
||||
// This is an integration test that requires network connectivity.
|
||||
// Run with: go test -v -run TestDuckDuckGoProvider_Search_Real ./internal/application/service/web_search
|
||||
func TestDuckDuckGoProvider_Search_Real(t *testing.T) {
|
||||
// Skip if running in CI without network access (optional check)
|
||||
if testing.Short() {
|
||||
t.Skip("Skipping real DuckDuckGo integration test in short mode")
|
||||
}
|
||||
|
||||
ctx := context.Background()
|
||||
provider, err := NewDuckDuckGoProvider()
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create DuckDuckGo provider: %v", err)
|
||||
}
|
||||
if provider == nil {
|
||||
t.Fatalf("failed to build provider")
|
||||
}
|
||||
|
||||
// Test with a simple, general query that should return results
|
||||
query := "Go programming language"
|
||||
maxResults := 5
|
||||
|
||||
results, err := provider.Search(ctx, query, maxResults, false)
|
||||
if err != nil {
|
||||
t.Fatalf("Search failed: %v", err)
|
||||
}
|
||||
|
||||
// Verify we got results
|
||||
if len(results) == 0 {
|
||||
t.Fatal("Expected at least one search result, got 0")
|
||||
}
|
||||
|
||||
t.Logf("Received %d results for query: %s", len(results), query)
|
||||
|
||||
// Verify result structure
|
||||
for i, result := range results {
|
||||
if result == nil {
|
||||
t.Fatalf("Result[%d]: is nil", i)
|
||||
}
|
||||
if result.Title == "" {
|
||||
t.Errorf("Result[%d]: Title is empty", i)
|
||||
}
|
||||
if result.URL == "" {
|
||||
t.Errorf("Result[%d]: URL is empty", i)
|
||||
}
|
||||
if !strings.HasPrefix(result.URL, "http://") && !strings.HasPrefix(result.URL, "https://") {
|
||||
t.Errorf("Result[%d]: URL is not valid (should start with http:// or https://): %s", i, result.URL)
|
||||
}
|
||||
if result.Source != "duckduckgo" {
|
||||
t.Errorf("Result[%d]: Source should be 'duckduckgo', got '%s'", i, result.Source)
|
||||
}
|
||||
|
||||
t.Logf("Result[%d]: Title=%s, URL=%s, Snippet=%s", i, result.Title, result.URL, result.Snippet)
|
||||
}
|
||||
|
||||
// Verify we don't exceed maxResults
|
||||
if len(results) > maxResults {
|
||||
t.Errorf("Got %d results, expected at most %d", len(results), maxResults)
|
||||
}
|
||||
|
||||
// Test with maxResults limit
|
||||
limitedResults, err := provider.Search(ctx, query, 2, false)
|
||||
if err != nil {
|
||||
t.Fatalf("Search with limit failed: %v", err)
|
||||
}
|
||||
if len(limitedResults) > 2 {
|
||||
t.Errorf("Got %d results with maxResults=2, expected at most 2", len(limitedResults))
|
||||
}
|
||||
}
|
||||
|
||||
// TestDuckDuckGo_SearchChinese tests the DuckDuckGo provider with Chinese query.
|
||||
// This verifies the Chinese language parameter (kl=cn-zh) works correctly.
|
||||
func TestDuckDuckGo_SearchChinese(t *testing.T) {
|
||||
if testing.Short() {
|
||||
t.Skip("Skipping real DuckDuckGo integration test in short mode")
|
||||
}
|
||||
|
||||
ctx := context.Background()
|
||||
provider, err := NewDuckDuckGoProvider()
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create DuckDuckGo provider: %v", err)
|
||||
}
|
||||
if provider == nil {
|
||||
t.Fatalf("failed to build provider")
|
||||
}
|
||||
|
||||
// Test with a Chinese query
|
||||
query := "WeKnora 企业级RAG框架 介绍 文档"
|
||||
maxResults := 3
|
||||
|
||||
results, err := provider.Search(ctx, query, maxResults, false)
|
||||
if err != nil {
|
||||
t.Fatalf("Search failed: %v", err)
|
||||
}
|
||||
|
||||
if len(results) == 0 {
|
||||
t.Log("Warning: No results returned for Chinese query, but this might be expected")
|
||||
return
|
||||
}
|
||||
|
||||
t.Logf("Received %d results for Chinese query: %s", len(results), query)
|
||||
|
||||
// Verify result structure
|
||||
for i, result := range results {
|
||||
if result == nil {
|
||||
t.Fatalf("Result[%d]: is nil", i)
|
||||
}
|
||||
if result.Title == "" {
|
||||
t.Errorf("Result[%d]: Title is empty", i)
|
||||
}
|
||||
if result.URL == "" {
|
||||
t.Errorf("Result[%d]: URL is empty", i)
|
||||
}
|
||||
if result.Source != "duckduckgo" {
|
||||
t.Errorf("Result[%d]: Source should be 'duckduckgo', got '%s'", i, result.Source)
|
||||
}
|
||||
t.Logf("Result[%d]: Title=%s, URL=%s", i, result.Title, result.URL)
|
||||
}
|
||||
}
|
||||
@@ -1,276 +0,0 @@
|
||||
package web_search
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"os"
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func setGoogleEnv(apiURL string) {
|
||||
os.Setenv("GOOGLE_SEARCH_API_URL", apiURL)
|
||||
}
|
||||
|
||||
func unsetGoogleEnv() {
|
||||
os.Unsetenv("GOOGLE_SEARCH_API_URL")
|
||||
}
|
||||
|
||||
func TestNewGoogleProvider(t *testing.T) {
|
||||
testCases := []struct {
|
||||
name string
|
||||
apiURL string
|
||||
expected error
|
||||
}{
|
||||
{
|
||||
name: "valid config",
|
||||
apiURL: "https://customsearch.googleapis.com/customsearch/v1?api_key=test&engine_id=test",
|
||||
expected: nil,
|
||||
},
|
||||
{
|
||||
name: "missing engine id",
|
||||
apiURL: "https://customsearch.googleapis.com/customsearch/v1?api_key=test",
|
||||
expected: fmt.Errorf("engine_id is empty"),
|
||||
},
|
||||
{
|
||||
name: "missing api key",
|
||||
apiURL: "https://customsearch.googleapis.com/customsearch/v1?engine_id=test",
|
||||
expected: fmt.Errorf("api_key is empty"),
|
||||
},
|
||||
}
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
setGoogleEnv(tc.apiURL)
|
||||
defer unsetGoogleEnv()
|
||||
_, err := NewGoogleProvider()
|
||||
|
||||
if tc.expected == nil {
|
||||
if err != nil {
|
||||
t.Fatalf("expected no error, got %v", err)
|
||||
}
|
||||
} else {
|
||||
if err == nil {
|
||||
t.Fatalf("expected error %v, got nil", tc.expected)
|
||||
}
|
||||
if !strings.Contains(err.Error(), tc.expected.Error()) {
|
||||
t.Fatalf("expected error %v, got %v", tc.expected, err)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestGoogleProvider_Name(t *testing.T) {
|
||||
setGoogleEnv("https://customsearch.googleapis.com/customsearch/v1?api_key=test&engine_id=test")
|
||||
defer unsetGoogleEnv()
|
||||
p, err := NewGoogleProvider()
|
||||
if err != nil {
|
||||
t.Fatalf("failed to create Google provider: %v", err)
|
||||
}
|
||||
if p.Name() != "google" {
|
||||
t.Fatalf("expected provider name google, got %s", p.Name())
|
||||
}
|
||||
}
|
||||
|
||||
func TestGoogleProvider_Search(t *testing.T) {
|
||||
mockResponse := map[string]interface{}{
|
||||
"items": []map[string]interface{}{
|
||||
{
|
||||
"title": "Example Search Result One",
|
||||
"link": "https://example.com/page1",
|
||||
"snippet": "This is the first search result snippet describing the content.",
|
||||
},
|
||||
{
|
||||
"title": "Example Search Result Two",
|
||||
"link": "https://example.org/page2",
|
||||
"snippet": "This is the second search result snippet with more details.",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.URL.Path != "/customsearch/v1" {
|
||||
t.Fatalf("unexpected request path: %s", r.URL.Path)
|
||||
}
|
||||
|
||||
query := r.URL.Query().Get("q")
|
||||
if query != "weknora" {
|
||||
t.Fatalf("unexpected query: %s", query)
|
||||
}
|
||||
|
||||
cx := r.URL.Query().Get("cx")
|
||||
if cx != "test-engine-id" {
|
||||
t.Fatalf("unexpected engine ID: %s", cx)
|
||||
}
|
||||
|
||||
num := r.URL.Query().Get("num")
|
||||
if num != "5" {
|
||||
t.Fatalf("unexpected num parameter: %s", num)
|
||||
}
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
enc := json.NewEncoder(w)
|
||||
_ = enc.Encode(mockResponse)
|
||||
}))
|
||||
defer ts.Close()
|
||||
|
||||
setGoogleEnv(fmt.Sprintf("%s/customsearch/v1?api_key=test-key&engine_id=test-engine-id", ts.URL))
|
||||
defer unsetGoogleEnv()
|
||||
prov, err := NewGoogleProvider()
|
||||
if err != nil {
|
||||
t.Fatalf("failed to create Google provider: %v", err)
|
||||
}
|
||||
|
||||
gp := prov.(*GoogleProvider)
|
||||
if gp == nil {
|
||||
t.Fatalf("failed to cast to GoogleProvider")
|
||||
}
|
||||
|
||||
ctx := context.Background()
|
||||
results, err := prov.Search(ctx, "weknora", 5, false)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
|
||||
if len(results) != 2 {
|
||||
t.Fatalf("expected 2 results, got %d", len(results))
|
||||
}
|
||||
|
||||
if results[0].Title != "Example Search Result One" ||
|
||||
results[0].URL != "https://example.com/page1" ||
|
||||
results[0].Snippet != "This is the first search result snippet describing the content." ||
|
||||
results[0].Source != "google" {
|
||||
t.Fatalf("unexpected first result: %+v", results[0])
|
||||
}
|
||||
|
||||
if results[1].Title != "Example Search Result Two" ||
|
||||
results[1].URL != "https://example.org/page2" ||
|
||||
results[1].Snippet != "This is the second search result snippet with more details." ||
|
||||
results[1].Source != "google" {
|
||||
t.Fatalf("unexpected second result: %+v", results[1])
|
||||
}
|
||||
}
|
||||
|
||||
func TestGoogleProvider_Search_EmptyQuery(t *testing.T) {
|
||||
setGoogleEnv("https://customsearch.googleapis.com/customsearch/v1?api_key=test&engine_id=test")
|
||||
defer unsetGoogleEnv()
|
||||
prov, err := NewGoogleProvider()
|
||||
if err != nil {
|
||||
t.Fatalf("failed to create Google provider: %v", err)
|
||||
}
|
||||
|
||||
ctx := context.Background()
|
||||
results, err := prov.Search(ctx, "", 5, false)
|
||||
if err == nil {
|
||||
t.Fatal("expected error for empty query, got nil")
|
||||
}
|
||||
if !strings.Contains(err.Error(), "query is empty") {
|
||||
t.Fatalf("expected 'query is empty' error, got: %v", err)
|
||||
}
|
||||
if results != nil {
|
||||
t.Fatalf("expected nil results for empty query, got: %v", results)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGoogleProvider_Search_NoResults(t *testing.T) {
|
||||
mockResponse := map[string]interface{}{
|
||||
"items": []map[string]interface{}{},
|
||||
}
|
||||
|
||||
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
enc := json.NewEncoder(w)
|
||||
_ = enc.Encode(mockResponse)
|
||||
}))
|
||||
defer ts.Close()
|
||||
|
||||
setGoogleEnv(fmt.Sprintf("%s/customsearch/v1?api_key=test-key&engine_id=test-engine-id", ts.URL))
|
||||
defer unsetGoogleEnv()
|
||||
prov, err := NewGoogleProvider()
|
||||
if err != nil {
|
||||
t.Fatalf("failed to create Google provider: %v", err)
|
||||
}
|
||||
|
||||
ctx := context.Background()
|
||||
results, err := prov.Search(ctx, "nonexistent", 5, false)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
|
||||
if len(results) != 0 {
|
||||
t.Fatalf("expected 0 results, got %d", len(results))
|
||||
}
|
||||
}
|
||||
|
||||
func TestGoogleProvider_Search_ErrorResponse(t *testing.T) {
|
||||
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusInternalServerError)
|
||||
w.Write([]byte("Internal Server Error"))
|
||||
}))
|
||||
defer ts.Close()
|
||||
|
||||
setGoogleEnv(fmt.Sprintf("%s/customsearch/v1?api_key=test-key&engine_id=test-engine-id", ts.URL))
|
||||
defer unsetGoogleEnv()
|
||||
prov, err := NewGoogleProvider()
|
||||
if err != nil {
|
||||
t.Fatalf("failed to create Google provider: %v", err)
|
||||
}
|
||||
|
||||
ctx := context.Background()
|
||||
results, err := prov.Search(ctx, "test", 5, false)
|
||||
if err == nil {
|
||||
t.Fatal("expected error for server error response, got nil")
|
||||
}
|
||||
if results != nil {
|
||||
t.Fatalf("expected nil results for error response, got: %v", results)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGoogleProvider_Search_MaxResults(t *testing.T) {
|
||||
mockResponse := map[string]interface{}{
|
||||
"items": []map[string]interface{}{
|
||||
{"title": "Result 1", "link": "https://example.com/1", "snippet": "Snippet 1"},
|
||||
{"title": "Result 2", "link": "https://example.com/2", "snippet": "Snippet 2"},
|
||||
{"title": "Result 3", "link": "https://example.com/3", "snippet": "Snippet 3"},
|
||||
},
|
||||
}
|
||||
|
||||
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
num := r.URL.Query().Get("num")
|
||||
if num != "2" {
|
||||
t.Fatalf("expected num=2, got %s", num)
|
||||
}
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
enc := json.NewEncoder(w)
|
||||
_ = enc.Encode(mockResponse)
|
||||
}))
|
||||
defer ts.Close()
|
||||
|
||||
setGoogleEnv(fmt.Sprintf("%s/customsearch/v1?api_key=test-key&engine_id=test-engine-id", ts.URL))
|
||||
defer unsetGoogleEnv()
|
||||
prov, err := NewGoogleProvider()
|
||||
if err != nil {
|
||||
t.Fatalf("failed to create Google provider: %v", err)
|
||||
}
|
||||
|
||||
ctx := context.Background()
|
||||
results, err := prov.Search(ctx, "test", 2, false)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
|
||||
if len(results) != 3 {
|
||||
t.Fatalf("expected 3 results, got %d", len(results))
|
||||
}
|
||||
|
||||
if results[0].Title != "Result 1" || results[1].Title != "Result 2" || results[2].Title != "Result 3" {
|
||||
t.Fatalf("unexpected results order or content")
|
||||
}
|
||||
}
|
||||
@@ -1,88 +0,0 @@
|
||||
package web_search
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"sync"
|
||||
|
||||
"github.com/Tencent/WeKnora/internal/types"
|
||||
"github.com/Tencent/WeKnora/internal/types/interfaces"
|
||||
)
|
||||
|
||||
// ProviderFactory creates a new web search provider instance
|
||||
type ProviderFactory func() (interfaces.WebSearchProvider, error)
|
||||
|
||||
// ProviderRegistration holds provider metadata and factory
|
||||
type ProviderRegistration struct {
|
||||
Info types.WebSearchProviderInfo
|
||||
Factory ProviderFactory
|
||||
}
|
||||
|
||||
// Registry manages web search provider registrations
|
||||
type Registry struct {
|
||||
providers map[string]*ProviderRegistration
|
||||
mu sync.RWMutex
|
||||
}
|
||||
|
||||
// NewRegistry creates a new web search provider registry
|
||||
func NewRegistry() *Registry {
|
||||
return &Registry{
|
||||
providers: make(map[string]*ProviderRegistration),
|
||||
}
|
||||
}
|
||||
|
||||
// Register registers a web search provider
|
||||
func (r *Registry) Register(info types.WebSearchProviderInfo, factory ProviderFactory) {
|
||||
r.mu.Lock()
|
||||
defer r.mu.Unlock()
|
||||
r.providers[info.ID] = &ProviderRegistration{
|
||||
Info: info,
|
||||
Factory: factory,
|
||||
}
|
||||
}
|
||||
|
||||
// GetRegistration returns the registration for a provider
|
||||
func (r *Registry) GetRegistration(id string) (*ProviderRegistration, bool) {
|
||||
r.mu.RLock()
|
||||
defer r.mu.RUnlock()
|
||||
reg, ok := r.providers[id]
|
||||
return reg, ok
|
||||
}
|
||||
|
||||
// GetAllProviderInfos returns info for all registered providers
|
||||
func (r *Registry) GetAllProviderInfos() []types.WebSearchProviderInfo {
|
||||
r.mu.RLock()
|
||||
defer r.mu.RUnlock()
|
||||
infos := make([]types.WebSearchProviderInfo, 0, len(r.providers))
|
||||
for _, reg := range r.providers {
|
||||
infos = append(infos, reg.Info)
|
||||
}
|
||||
return infos
|
||||
}
|
||||
|
||||
// CreateProvider creates a provider instance by ID
|
||||
func (r *Registry) CreateProvider(id string) (interfaces.WebSearchProvider, error) {
|
||||
r.mu.RLock()
|
||||
reg, ok := r.providers[id]
|
||||
r.mu.RUnlock()
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("web search provider %s not registered", id)
|
||||
}
|
||||
return reg.Factory()
|
||||
}
|
||||
|
||||
// CreateAllProviders creates instances of all registered providers
|
||||
func (r *Registry) CreateAllProviders() (map[string]interfaces.WebSearchProvider, error) {
|
||||
r.mu.RLock()
|
||||
defer r.mu.RUnlock()
|
||||
|
||||
providers := make(map[string]interfaces.WebSearchProvider)
|
||||
for id, reg := range r.providers {
|
||||
provider, err := reg.Factory()
|
||||
if err != nil {
|
||||
// Skip providers that fail to initialize (e.g., missing API keys)
|
||||
continue
|
||||
}
|
||||
providers[id] = provider
|
||||
}
|
||||
return providers, nil
|
||||
}
|
||||
108
internal/application/service/web_search_provider.go
Normal file
108
internal/application/service/web_search_provider.go
Normal file
@@ -0,0 +1,108 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
|
||||
"github.com/Tencent/WeKnora/internal/logger"
|
||||
"github.com/Tencent/WeKnora/internal/types"
|
||||
"github.com/Tencent/WeKnora/internal/types/interfaces"
|
||||
)
|
||||
|
||||
// webSearchProviderService implements interfaces.WebSearchProviderService
|
||||
type webSearchProviderService struct {
|
||||
repo interfaces.WebSearchProviderRepository
|
||||
}
|
||||
|
||||
// NewWebSearchProviderService creates a new web search provider service
|
||||
func NewWebSearchProviderService(repo interfaces.WebSearchProviderRepository) interfaces.WebSearchProviderService {
|
||||
return &webSearchProviderService{repo: repo}
|
||||
}
|
||||
|
||||
// CreateProvider creates a new web search provider configuration.
|
||||
func (s *webSearchProviderService) CreateProvider(ctx context.Context, provider *types.WebSearchProviderEntity) error {
|
||||
if provider.TenantID == 0 {
|
||||
return fmt.Errorf("tenant ID is required")
|
||||
}
|
||||
|
||||
if !isValidProviderType(provider.Provider) {
|
||||
return fmt.Errorf("invalid provider type: %s", provider.Provider)
|
||||
}
|
||||
|
||||
if err := validateProviderParameters(provider.Provider, provider.Parameters); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if provider.IsDefault {
|
||||
if err := s.repo.ClearDefault(ctx, provider.TenantID, ""); err != nil {
|
||||
logger.Warnf(ctx, "Failed to clear default providers: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
logger.Infof(ctx, "Creating web search provider: tenant=%d, name=%s, type=%s", provider.TenantID, provider.Name, provider.Provider)
|
||||
return s.repo.Create(ctx, provider)
|
||||
}
|
||||
|
||||
// UpdateProvider updates an existing provider.
|
||||
func (s *webSearchProviderService) UpdateProvider(ctx context.Context, provider *types.WebSearchProviderEntity) error {
|
||||
if provider.TenantID == 0 {
|
||||
return fmt.Errorf("tenant ID is required")
|
||||
}
|
||||
|
||||
// Validate provider type if set
|
||||
if provider.Provider != "" && !isValidProviderType(provider.Provider) {
|
||||
return fmt.Errorf("invalid provider type: %s", provider.Provider)
|
||||
}
|
||||
|
||||
if provider.IsDefault {
|
||||
if err := s.repo.ClearDefault(ctx, provider.TenantID, provider.ID); err != nil {
|
||||
logger.Warnf(ctx, "Failed to clear default providers: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
logger.Infof(ctx, "Updating web search provider: tenant=%d, id=%s", provider.TenantID, provider.ID)
|
||||
return s.repo.Update(ctx, provider)
|
||||
}
|
||||
|
||||
// DeleteProvider deletes a provider by tenant + id.
|
||||
func (s *webSearchProviderService) DeleteProvider(ctx context.Context, tenantID uint64, id string) error {
|
||||
logger.Infof(ctx, "Deleting web search provider: tenant=%d, id=%s", tenantID, id)
|
||||
return s.repo.Delete(ctx, tenantID, id)
|
||||
}
|
||||
|
||||
// isValidProviderType checks if the given provider type is supported
|
||||
func isValidProviderType(provider types.WebSearchProviderType) bool {
|
||||
switch provider {
|
||||
case types.WebSearchProviderTypeBing,
|
||||
types.WebSearchProviderTypeGoogle,
|
||||
types.WebSearchProviderTypeDuckDuckGo,
|
||||
types.WebSearchProviderTypeTavily:
|
||||
return true
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
// validateProviderParameters validates required parameters for each provider type
|
||||
func validateProviderParameters(provider types.WebSearchProviderType, params types.WebSearchProviderParameters) error {
|
||||
switch provider {
|
||||
case types.WebSearchProviderTypeBing:
|
||||
if params.APIKey == "" {
|
||||
return fmt.Errorf("API key is required for Bing provider")
|
||||
}
|
||||
case types.WebSearchProviderTypeGoogle:
|
||||
if params.APIKey == "" {
|
||||
return fmt.Errorf("API key is required for Google provider")
|
||||
}
|
||||
if params.EngineID == "" {
|
||||
return fmt.Errorf("engine ID is required for Google provider")
|
||||
}
|
||||
case types.WebSearchProviderTypeTavily:
|
||||
if params.APIKey == "" {
|
||||
return fmt.Errorf("API key is required for Tavily provider")
|
||||
}
|
||||
case types.WebSearchProviderTypeDuckDuckGo:
|
||||
// No API key required
|
||||
}
|
||||
return nil
|
||||
}
|
||||
@@ -47,7 +47,7 @@ import (
|
||||
"github.com/Tencent/WeKnora/internal/application/service/llmcontext"
|
||||
memoryService "github.com/Tencent/WeKnora/internal/application/service/memory"
|
||||
"github.com/Tencent/WeKnora/internal/application/service/retriever"
|
||||
"github.com/Tencent/WeKnora/internal/application/service/web_search"
|
||||
infra_web_search "github.com/Tencent/WeKnora/internal/infrastructure/web_search"
|
||||
"github.com/Tencent/WeKnora/internal/config"
|
||||
"github.com/Tencent/WeKnora/internal/database"
|
||||
"github.com/Tencent/WeKnora/internal/datasource"
|
||||
@@ -180,9 +180,11 @@ func BuildContainer(container *dig.Container) *dig.Container {
|
||||
|
||||
// Web search service (needed by AgentService)
|
||||
logger.Debugf(ctx, "[Container] Registering web search registry and providers...")
|
||||
must(container.Provide(web_search.NewRegistry))
|
||||
must(container.Provide(infra_web_search.NewRegistry))
|
||||
must(container.Invoke(registerWebSearchProviders))
|
||||
must(container.Provide(repository.NewWebSearchProviderRepository))
|
||||
must(container.Provide(service.NewWebSearchService))
|
||||
must(container.Provide(service.NewWebSearchProviderService))
|
||||
|
||||
// Agent service layer (requires event bus, web search service)
|
||||
// SessionService is passed as parameter to CreateAgentEngine method when creating AgentService
|
||||
@@ -219,6 +221,7 @@ func BuildContainer(container *dig.Container) *dig.Container {
|
||||
must(container.Provide(chatpipeline.NewEventManager))
|
||||
must(container.Invoke(chatpipeline.NewPluginSearch))
|
||||
must(container.Invoke(chatpipeline.NewPluginRerank))
|
||||
must(container.Invoke(chatpipeline.NewPluginWebFetch))
|
||||
must(container.Invoke(chatpipeline.NewPluginMerge))
|
||||
must(container.Invoke(chatpipeline.NewPluginDataAnalysis))
|
||||
must(container.Invoke(chatpipeline.NewPluginIntoChatMessage))
|
||||
@@ -250,6 +253,7 @@ func BuildContainer(container *dig.Container) *dig.Container {
|
||||
must(container.Provide(handler.NewSystemHandler))
|
||||
must(container.Provide(handler.NewMCPServiceHandler))
|
||||
must(container.Provide(handler.NewWebSearchHandler))
|
||||
must(container.Provide(handler.NewWebSearchProviderHandler))
|
||||
must(container.Provide(handler.NewCustomAgentHandler))
|
||||
must(container.Provide(service.NewSkillService))
|
||||
must(container.Provide(handler.NewSkillHandler))
|
||||
@@ -951,27 +955,14 @@ func NewDuckDB() (*sql.DB, error) {
|
||||
return sqlDB, nil
|
||||
}
|
||||
|
||||
// registerWebSearchProviders registers all web search providers to the registry
|
||||
func registerWebSearchProviders(registry *web_search.Registry) {
|
||||
// Register DuckDuckGo provider
|
||||
registry.Register(web_search.DuckDuckGoProviderInfo(), func() (interfaces.WebSearchProvider, error) {
|
||||
return web_search.NewDuckDuckGoProvider()
|
||||
})
|
||||
|
||||
// Register Google provider
|
||||
registry.Register(web_search.GoogleProviderInfo(), func() (interfaces.WebSearchProvider, error) {
|
||||
return web_search.NewGoogleProvider()
|
||||
})
|
||||
|
||||
// Register Bing provider
|
||||
registry.Register(web_search.BingProviderInfo(), func() (interfaces.WebSearchProvider, error) {
|
||||
return web_search.NewBingProvider()
|
||||
})
|
||||
|
||||
// Register Tavily provider
|
||||
registry.Register(web_search.TavilyProviderInfo(), func() (interfaces.WebSearchProvider, error) {
|
||||
return web_search.NewTavilyProvider()
|
||||
})
|
||||
// registerWebSearchProviders registers all web search provider types to the registry.
|
||||
// Each provider type is registered with its factory function that accepts parameters.
|
||||
// Provider instances are created on-demand when tenants configure them.
|
||||
func registerWebSearchProviders(registry *infra_web_search.Registry) {
|
||||
registry.Register("duckduckgo", infra_web_search.NewDuckDuckGoProvider)
|
||||
registry.Register("google", infra_web_search.NewGoogleProvider)
|
||||
registry.Register("bing", infra_web_search.NewBingProvider)
|
||||
registry.Register("tavily", infra_web_search.NewTavilyProvider)
|
||||
}
|
||||
|
||||
// registerIMAdapterFactories registers adapter factories for each IM platform
|
||||
|
||||
@@ -3,42 +3,22 @@ package handler
|
||||
import (
|
||||
"net/http"
|
||||
|
||||
"github.com/Tencent/WeKnora/internal/application/service/web_search"
|
||||
"github.com/Tencent/WeKnora/internal/logger"
|
||||
"github.com/Tencent/WeKnora/internal/types"
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
// WebSearchHandler handles web search related requests
|
||||
type WebSearchHandler struct {
|
||||
registry *web_search.Registry
|
||||
}
|
||||
// WebSearchHandler handles legacy web search related requests
|
||||
type WebSearchHandler struct{}
|
||||
|
||||
// NewWebSearchHandler creates a new web search handler
|
||||
func NewWebSearchHandler(registry *web_search.Registry) *WebSearchHandler {
|
||||
return &WebSearchHandler{
|
||||
registry: registry,
|
||||
}
|
||||
func NewWebSearchHandler() *WebSearchHandler {
|
||||
return &WebSearchHandler{}
|
||||
}
|
||||
|
||||
// GetProviders returns the list of available web search providers
|
||||
// @Summary Get available web search providers
|
||||
// @Description Returns the list of available web search providers from configuration
|
||||
// @Tags web-search
|
||||
// @Accept json
|
||||
// @Produce json
|
||||
// @Success 200 {object} map[string]interface{} "List of providers"
|
||||
// @Security Bearer
|
||||
// @Security ApiKeyAuth
|
||||
// @Router /web-search/providers [get]
|
||||
// GetProviders returns the list of available web search provider types
|
||||
func (h *WebSearchHandler) GetProviders(c *gin.Context) {
|
||||
ctx := c.Request.Context()
|
||||
logger.Info(ctx, "Getting web search providers")
|
||||
|
||||
providers := h.registry.GetAllProviderInfos()
|
||||
|
||||
logger.Infof(ctx, "Returning %d web search providers", len(providers))
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": true,
|
||||
"data": providers,
|
||||
"data": types.GetWebSearchProviderTypes(),
|
||||
})
|
||||
}
|
||||
|
||||
319
internal/handler/web_search_provider.go
Normal file
319
internal/handler/web_search_provider.go
Normal file
@@ -0,0 +1,319 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net/http"
|
||||
|
||||
"github.com/Tencent/WeKnora/internal/errors"
|
||||
infra_web_search "github.com/Tencent/WeKnora/internal/infrastructure/web_search"
|
||||
"github.com/Tencent/WeKnora/internal/logger"
|
||||
"github.com/Tencent/WeKnora/internal/types"
|
||||
"github.com/Tencent/WeKnora/internal/types/interfaces"
|
||||
secutils "github.com/Tencent/WeKnora/internal/utils"
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
// WebSearchProviderHandler handles HTTP requests for web search provider CRUD
|
||||
type WebSearchProviderHandler struct {
|
||||
repo interfaces.WebSearchProviderRepository
|
||||
service interfaces.WebSearchProviderService
|
||||
registry *infra_web_search.Registry
|
||||
}
|
||||
|
||||
// NewWebSearchProviderHandler creates a new handler
|
||||
func NewWebSearchProviderHandler(
|
||||
repo interfaces.WebSearchProviderRepository,
|
||||
service interfaces.WebSearchProviderService,
|
||||
registry *infra_web_search.Registry,
|
||||
) *WebSearchProviderHandler {
|
||||
return &WebSearchProviderHandler{repo: repo, service: service, registry: registry}
|
||||
}
|
||||
|
||||
// --- request DTOs ---
|
||||
|
||||
// CreateProviderRequest defines the request body for creating a provider
|
||||
type CreateProviderRequest struct {
|
||||
Name string `json:"name" binding:"required"`
|
||||
Provider types.WebSearchProviderType `json:"provider" binding:"required"`
|
||||
Description string `json:"description"`
|
||||
Parameters types.WebSearchProviderParameters `json:"parameters"`
|
||||
IsDefault bool `json:"is_default"`
|
||||
}
|
||||
|
||||
// UpdateProviderRequest defines the request body for updating a provider
|
||||
type UpdateProviderRequest struct {
|
||||
Name string `json:"name"`
|
||||
Description string `json:"description"`
|
||||
Parameters types.WebSearchProviderParameters `json:"parameters"`
|
||||
IsDefault bool `json:"is_default"`
|
||||
}
|
||||
|
||||
// --- helpers ---
|
||||
|
||||
// getTenantID extracts tenant ID from gin context (set by auth middleware).
|
||||
func (h *WebSearchProviderHandler) getTenantID(c *gin.Context) uint64 {
|
||||
return c.GetUint64(types.TenantIDContextKey.String())
|
||||
}
|
||||
|
||||
// getOwnedProvider loads a provider and verifies it belongs to the given tenant.
|
||||
// Returns (nil, status, msg) on failure so callers can respond immediately.
|
||||
func (h *WebSearchProviderHandler) getOwnedProvider(
|
||||
ctx context.Context, tenantID uint64, id string,
|
||||
) (*types.WebSearchProviderEntity, int, string) {
|
||||
provider, err := h.repo.GetByID(ctx, tenantID, id)
|
||||
if err != nil {
|
||||
return nil, http.StatusInternalServerError, "failed to query provider"
|
||||
}
|
||||
if provider == nil {
|
||||
return nil, http.StatusNotFound, "web search provider not found"
|
||||
}
|
||||
return provider, http.StatusOK, ""
|
||||
}
|
||||
|
||||
// --- endpoints ---
|
||||
|
||||
// CreateProvider creates a new web search provider
|
||||
func (h *WebSearchProviderHandler) CreateProvider(c *gin.Context) {
|
||||
ctx := c.Request.Context()
|
||||
|
||||
tenantID := h.getTenantID(c)
|
||||
if tenantID == 0 {
|
||||
c.JSON(http.StatusUnauthorized, gin.H{"success": false, "error": "unauthorized: tenant context missing"})
|
||||
return
|
||||
}
|
||||
|
||||
var req CreateProviderRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
logger.Warnf(ctx, "Invalid create provider request: %v", err)
|
||||
c.Error(errors.NewBadRequestError(err.Error()))
|
||||
return
|
||||
}
|
||||
|
||||
logger.Infof(ctx, "Creating web search provider: tenant=%d, name=%s, type=%s",
|
||||
tenantID, secutils.SanitizeForLog(req.Name), secutils.SanitizeForLog(string(req.Provider)))
|
||||
|
||||
provider := &types.WebSearchProviderEntity{
|
||||
TenantID: tenantID,
|
||||
Name: secutils.SanitizeForLog(req.Name),
|
||||
Provider: req.Provider,
|
||||
Description: secutils.SanitizeForLog(req.Description),
|
||||
Parameters: req.Parameters,
|
||||
IsDefault: req.IsDefault,
|
||||
}
|
||||
|
||||
if err := h.service.CreateProvider(ctx, provider); err != nil {
|
||||
logger.Warnf(ctx, "Failed to create web search provider: %v", err)
|
||||
c.Error(errors.NewInternalServerError(err.Error()))
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusCreated, gin.H{
|
||||
"success": true,
|
||||
"data": provider,
|
||||
})
|
||||
}
|
||||
|
||||
// ListProviders lists all web search providers for the current tenant
|
||||
func (h *WebSearchProviderHandler) ListProviders(c *gin.Context) {
|
||||
ctx := c.Request.Context()
|
||||
|
||||
tenantID := h.getTenantID(c)
|
||||
if tenantID == 0 {
|
||||
c.JSON(http.StatusUnauthorized, gin.H{"success": false, "error": "unauthorized: tenant context missing"})
|
||||
return
|
||||
}
|
||||
|
||||
providers, err := h.repo.List(ctx, tenantID)
|
||||
if err != nil {
|
||||
logger.Warnf(ctx, "Failed to list web search providers: %v", err)
|
||||
c.Error(errors.NewInternalServerError(err.Error()))
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": true,
|
||||
"data": providers,
|
||||
})
|
||||
}
|
||||
|
||||
// GetProvider retrieves a single web search provider by ID
|
||||
func (h *WebSearchProviderHandler) GetProvider(c *gin.Context) {
|
||||
ctx := c.Request.Context()
|
||||
|
||||
tenantID := h.getTenantID(c)
|
||||
if tenantID == 0 {
|
||||
c.JSON(http.StatusUnauthorized, gin.H{"success": false, "error": "unauthorized: tenant context missing"})
|
||||
return
|
||||
}
|
||||
|
||||
id := c.Param("id")
|
||||
provider, status, msg := h.getOwnedProvider(ctx, tenantID, id)
|
||||
if status != http.StatusOK {
|
||||
c.JSON(status, gin.H{"success": false, "error": msg})
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": true,
|
||||
"data": provider,
|
||||
})
|
||||
}
|
||||
|
||||
// UpdateProvider updates a web search provider
|
||||
func (h *WebSearchProviderHandler) UpdateProvider(c *gin.Context) {
|
||||
ctx := c.Request.Context()
|
||||
|
||||
tenantID := h.getTenantID(c)
|
||||
if tenantID == 0 {
|
||||
c.JSON(http.StatusUnauthorized, gin.H{"success": false, "error": "unauthorized: tenant context missing"})
|
||||
return
|
||||
}
|
||||
|
||||
id := c.Param("id")
|
||||
|
||||
// Ownership check
|
||||
existing, status, msg := h.getOwnedProvider(ctx, tenantID, id)
|
||||
if status != http.StatusOK {
|
||||
c.JSON(status, gin.H{"success": false, "error": msg})
|
||||
return
|
||||
}
|
||||
|
||||
var req UpdateProviderRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
c.Error(errors.NewBadRequestError(err.Error()))
|
||||
return
|
||||
}
|
||||
|
||||
// Build updated entity, keeping immutable fields from existing
|
||||
provider := &types.WebSearchProviderEntity{
|
||||
ID: id,
|
||||
TenantID: tenantID,
|
||||
Name: secutils.SanitizeForLog(req.Name),
|
||||
Provider: existing.Provider, // Provider type is immutable after creation
|
||||
Description: secutils.SanitizeForLog(req.Description),
|
||||
Parameters: req.Parameters,
|
||||
IsDefault: req.IsDefault,
|
||||
}
|
||||
|
||||
if err := h.service.UpdateProvider(ctx, provider); err != nil {
|
||||
logger.Warnf(ctx, "Failed to update web search provider %s: %v", id, err)
|
||||
c.Error(errors.NewInternalServerError(err.Error()))
|
||||
return
|
||||
}
|
||||
|
||||
// Re-fetch to get the full stored state
|
||||
updated, _ := h.repo.GetByID(ctx, tenantID, id)
|
||||
if updated != nil {
|
||||
c.JSON(http.StatusOK, gin.H{"success": true, "data": updated})
|
||||
} else {
|
||||
c.JSON(http.StatusOK, gin.H{"success": true})
|
||||
}
|
||||
}
|
||||
|
||||
// DeleteProvider deletes a web search provider
|
||||
func (h *WebSearchProviderHandler) DeleteProvider(c *gin.Context) {
|
||||
ctx := c.Request.Context()
|
||||
|
||||
tenantID := h.getTenantID(c)
|
||||
if tenantID == 0 {
|
||||
c.JSON(http.StatusUnauthorized, gin.H{"success": false, "error": "unauthorized: tenant context missing"})
|
||||
return
|
||||
}
|
||||
|
||||
id := c.Param("id")
|
||||
|
||||
// Ownership check
|
||||
if _, status, msg := h.getOwnedProvider(ctx, tenantID, id); status != http.StatusOK {
|
||||
c.JSON(status, gin.H{"success": false, "error": msg})
|
||||
return
|
||||
}
|
||||
|
||||
if err := h.service.DeleteProvider(ctx, tenantID, id); err != nil {
|
||||
logger.Warnf(ctx, "Failed to delete web search provider %s: %v", id, err)
|
||||
c.Error(errors.NewInternalServerError(err.Error()))
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{"success": true})
|
||||
}
|
||||
|
||||
// ListProviderTypes returns available provider types and their parameter requirements
|
||||
func (h *WebSearchProviderHandler) ListProviderTypes(c *gin.Context) {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": true,
|
||||
"data": types.GetWebSearchProviderTypes(),
|
||||
})
|
||||
}
|
||||
|
||||
// TestProviderByID tests an existing saved provider by performing a sample search
|
||||
func (h *WebSearchProviderHandler) TestProviderByID(c *gin.Context) {
|
||||
ctx := c.Request.Context()
|
||||
|
||||
tenantID := h.getTenantID(c)
|
||||
if tenantID == 0 {
|
||||
c.JSON(http.StatusUnauthorized, gin.H{"success": false, "error": "unauthorized: tenant context missing"})
|
||||
return
|
||||
}
|
||||
|
||||
id := c.Param("id")
|
||||
provider, status, msg := h.getOwnedProvider(ctx, tenantID, id)
|
||||
if status != http.StatusOK {
|
||||
c.JSON(status, gin.H{"success": false, "error": msg})
|
||||
return
|
||||
}
|
||||
|
||||
if err := h.doTestSearch(ctx, string(provider.Provider), provider.Parameters); err != nil {
|
||||
logger.Warnf(ctx, "Web search provider test failed: %v", err)
|
||||
c.JSON(http.StatusOK, gin.H{"success": false, "error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{"success": true})
|
||||
}
|
||||
|
||||
// TestProviderRequest defines the body for testing raw credentials
|
||||
type TestProviderRequest struct {
|
||||
Provider string `json:"provider" binding:"required"`
|
||||
Parameters types.WebSearchProviderParameters `json:"parameters"`
|
||||
}
|
||||
|
||||
// TestProviderRaw tests a provider with raw credentials (no persistence)
|
||||
func (h *WebSearchProviderHandler) TestProviderRaw(c *gin.Context) {
|
||||
ctx := c.Request.Context()
|
||||
|
||||
var req TestProviderRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
c.Error(errors.NewBadRequestError(err.Error()))
|
||||
return
|
||||
}
|
||||
|
||||
if err := h.doTestSearch(ctx, req.Provider, req.Parameters); err != nil {
|
||||
logger.Warnf(ctx, "Web search provider test failed: %v", err)
|
||||
c.JSON(http.StatusOK, gin.H{"success": false, "error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{"success": true})
|
||||
}
|
||||
|
||||
// doTestSearch creates a temporary provider and runs a simple test query
|
||||
func (h *WebSearchProviderHandler) doTestSearch(ctx context.Context, providerType string, params types.WebSearchProviderParameters) error {
|
||||
logger.Infof(ctx, "[WebSearch][Test] testing provider type=%s", providerType)
|
||||
searchProvider, err := h.registry.CreateProvider(providerType, params)
|
||||
if err != nil {
|
||||
logger.Warnf(ctx, "[WebSearch][Test] failed to create provider: %v", err)
|
||||
return fmt.Errorf("failed to create provider: %w", err)
|
||||
}
|
||||
results, err := searchProvider.Search(ctx, "test", 1, false)
|
||||
if err != nil {
|
||||
logger.Warnf(ctx, "[WebSearch][Test] search failed: %v", err)
|
||||
return err
|
||||
}
|
||||
if len(results) == 0 {
|
||||
logger.Warnf(ctx, "[WebSearch][Test] search returned 0 results — API key or configuration may be invalid")
|
||||
return fmt.Errorf("search returned 0 results, please verify your API key and configuration")
|
||||
}
|
||||
logger.Infof(ctx, "[WebSearch][Test] succeeded: type=%s, results=%d", providerType, len(results))
|
||||
return nil
|
||||
}
|
||||
171
internal/infrastructure/web_fetch/fetcher.go
Normal file
171
internal/infrastructure/web_fetch/fetcher.go
Normal file
@@ -0,0 +1,171 @@
|
||||
// Package web_fetch provides a public URL content fetcher with SSRF protection.
|
||||
// It extracts core logic from the agent WebFetchTool so it can be used by the chat pipeline.
|
||||
package web_fetch
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/PuerkitoBio/goquery"
|
||||
"github.com/Tencent/WeKnora/internal/logger"
|
||||
"github.com/Tencent/WeKnora/internal/utils"
|
||||
)
|
||||
|
||||
const (
|
||||
fetchTimeout = 15 * time.Second
|
||||
maxBodySize = 100 * 1024 // 100KB
|
||||
)
|
||||
|
||||
// FetchURLContent fetches a URL and returns its text content (HTML converted to clean text).
|
||||
// Includes SSRF validation, DNS pinning, browser-like headers, and content size limits.
|
||||
func FetchURLContent(ctx context.Context, rawURL string) (string, error) {
|
||||
if rawURL == "" {
|
||||
return "", fmt.Errorf("url is empty")
|
||||
}
|
||||
|
||||
// SSRF validation
|
||||
if safe, reason := utils.IsSSRFSafeURL(rawURL); !safe {
|
||||
return "", fmt.Errorf("URL rejected: %s", reason)
|
||||
}
|
||||
|
||||
u, err := url.Parse(rawURL)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("invalid URL: %w", err)
|
||||
}
|
||||
hostname := u.Hostname()
|
||||
port := u.Port()
|
||||
if port == "" {
|
||||
if u.Scheme == "https" {
|
||||
port = "443"
|
||||
} else {
|
||||
port = "80"
|
||||
}
|
||||
}
|
||||
|
||||
// DNS pinning: resolve once, use pinned IP
|
||||
ips, err := net.DefaultResolver.LookupIP(context.Background(), "ip", hostname)
|
||||
if err != nil || len(ips) == 0 {
|
||||
return "", fmt.Errorf("DNS lookup failed for %s: %w", hostname, err)
|
||||
}
|
||||
var pinnedIP net.IP
|
||||
for _, ip := range ips {
|
||||
if utils.IsPublicIP(ip) {
|
||||
pinnedIP = ip
|
||||
break
|
||||
}
|
||||
}
|
||||
if pinnedIP == nil {
|
||||
return "", fmt.Errorf("no public IP for host %s", hostname)
|
||||
}
|
||||
|
||||
// Build request with pinned IP
|
||||
hostPort := net.JoinHostPort(pinnedIP.String(), port)
|
||||
fetchURL := *u
|
||||
fetchURL.Host = hostPort
|
||||
|
||||
ctx, cancel := context.WithTimeout(ctx, fetchTimeout)
|
||||
defer cancel()
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodGet, fetchURL.String(), nil)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
req.Host = hostname
|
||||
|
||||
// Browser-like headers to reduce 403 rejections.
|
||||
// These match a real Chrome browser fingerprint.
|
||||
req.Header.Set("User-Agent",
|
||||
"Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/131.0.0.0 Safari/537.36")
|
||||
req.Header.Set("Accept",
|
||||
"text/html,application/xhtml+xml,application/xml;q=0.9,image/avif,image/webp,image/apng,*/*;q=0.8,application/signed-exchange;v=b3;q=0.7")
|
||||
req.Header.Set("Accept-Language", "zh-CN,zh;q=0.9,en;q=0.8,en-GB;q=0.7,en-US;q=0.6")
|
||||
req.Header.Set("Accept-Encoding", "identity") // no gzip to simplify reading
|
||||
req.Header.Set("Cache-Control", "no-cache")
|
||||
req.Header.Set("Pragma", "no-cache")
|
||||
req.Header.Set("Sec-Ch-Ua", `"Chromium";v="131", "Not_A Brand";v="24"`)
|
||||
req.Header.Set("Sec-Ch-Ua-Mobile", "?0")
|
||||
req.Header.Set("Sec-Ch-Ua-Platform", `"macOS"`)
|
||||
req.Header.Set("Sec-Fetch-Dest", "document")
|
||||
req.Header.Set("Sec-Fetch-Mode", "navigate")
|
||||
req.Header.Set("Sec-Fetch-Site", "none")
|
||||
req.Header.Set("Sec-Fetch-User", "?1")
|
||||
req.Header.Set("Upgrade-Insecure-Requests", "1")
|
||||
req.Header.Set("Referer", u.Scheme+"://"+hostname+"/")
|
||||
|
||||
// Custom transport: TLS ServerName for certificate validation with pinned IP.
|
||||
client := &http.Client{
|
||||
Timeout: fetchTimeout,
|
||||
Transport: &http.Transport{
|
||||
TLSClientConfig: &tls.Config{
|
||||
ServerName: hostname,
|
||||
},
|
||||
},
|
||||
// Follow redirects (default behavior), up to 10 hops
|
||||
}
|
||||
resp, err := client.Do(req)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("fetch failed: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return "", fmt.Errorf("HTTP %d %s", resp.StatusCode, resp.Status)
|
||||
}
|
||||
|
||||
body, err := io.ReadAll(io.LimitReader(resp.Body, maxBodySize))
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("read failed: %w", err)
|
||||
}
|
||||
|
||||
text := htmlToText(string(body))
|
||||
logger.Infof(ctx, "[WebFetch] fetched %s → %d chars", rawURL, len(text))
|
||||
return text, nil
|
||||
}
|
||||
|
||||
// htmlToText extracts clean text from HTML, removing scripts/styles/nav.
|
||||
func htmlToText(html string) string {
|
||||
doc, err := goquery.NewDocumentFromReader(strings.NewReader(html))
|
||||
if err != nil {
|
||||
return stripTags(html)
|
||||
}
|
||||
doc.Find("script, style, nav, footer, header, iframe, noscript, svg, img").Remove()
|
||||
|
||||
var sb strings.Builder
|
||||
doc.Find("body").Each(func(i int, s *goquery.Selection) {
|
||||
sb.WriteString(s.Text())
|
||||
})
|
||||
text := sb.String()
|
||||
|
||||
// Normalize whitespace: collapse blank lines
|
||||
lines := strings.Split(text, "\n")
|
||||
var cleaned []string
|
||||
for _, line := range lines {
|
||||
line = strings.TrimSpace(line)
|
||||
if line != "" {
|
||||
cleaned = append(cleaned, line)
|
||||
}
|
||||
}
|
||||
return strings.Join(cleaned, "\n")
|
||||
}
|
||||
|
||||
func stripTags(s string) string {
|
||||
var sb strings.Builder
|
||||
inTag := false
|
||||
for _, r := range s {
|
||||
if r == '<' {
|
||||
inTag = true
|
||||
} else if r == '>' {
|
||||
inTag = false
|
||||
} else if !inTag {
|
||||
sb.WriteRune(r)
|
||||
}
|
||||
}
|
||||
return sb.String()
|
||||
}
|
||||
@@ -7,22 +7,21 @@ import (
|
||||
"io"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"os"
|
||||
"strconv"
|
||||
"time"
|
||||
|
||||
"github.com/Tencent/WeKnora/internal/logger"
|
||||
"github.com/Tencent/WeKnora/internal/types"
|
||||
"github.com/Tencent/WeKnora/internal/types/interfaces"
|
||||
)
|
||||
|
||||
const (
|
||||
// defaultBingSearchURL is the default Bing search API URL.
|
||||
// Reference: https://learn.microsoft.com/en-us/previous-versions/bing/search-apis/bing-web-search/reference/endpoints
|
||||
// defaultBingSearchURL is the hardcoded Bing search API URL.
|
||||
// Not configurable by tenants — prevents SSRF.
|
||||
defaultBingSearchURL = "https://api.bing.microsoft.com/v7.0/search"
|
||||
)
|
||||
|
||||
var (
|
||||
// defaultUserAgentHeader for PC. https://learn.microsoft.com/en-us/previous-versions/bing/search-apis/bing-web-search/reference/headers
|
||||
defaultUserAgentHeader = "Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/139.0.0.0 Safari/537.36"
|
||||
defaultBingTimeout = 10 * time.Second
|
||||
)
|
||||
@@ -50,33 +49,21 @@ type BingProvider struct {
|
||||
apiKey string
|
||||
}
|
||||
|
||||
// NewBingProvider creates a new Bing provider
|
||||
func NewBingProvider() (interfaces.WebSearchProvider, error) {
|
||||
apiKey := os.Getenv("BING_SEARCH_API_KEY")
|
||||
if len(apiKey) == 0 {
|
||||
return nil, fmt.Errorf("BING_SEARCH_API_KEY is not set")
|
||||
// NewBingProvider creates a new Bing provider from parameters (no environment variables).
|
||||
func NewBingProvider(params types.WebSearchProviderParameters) (interfaces.WebSearchProvider, error) {
|
||||
if params.APIKey == "" {
|
||||
return nil, fmt.Errorf("API key is required for Bing provider")
|
||||
}
|
||||
client := &http.Client{
|
||||
Timeout: defaultBingTimeout,
|
||||
}
|
||||
return &BingProvider{
|
||||
client: client,
|
||||
baseURL: defaultBingSearchURL,
|
||||
apiKey: apiKey,
|
||||
baseURL: defaultBingSearchURL, // Hardcoded — not tenant-configurable
|
||||
apiKey: params.APIKey,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// BingProviderInfo returns the provider info for registration
|
||||
func BingProviderInfo() types.WebSearchProviderInfo {
|
||||
return types.WebSearchProviderInfo{
|
||||
ID: "bing",
|
||||
Name: "Bing",
|
||||
Free: false,
|
||||
RequiresAPIKey: true,
|
||||
Description: "Bing Search API",
|
||||
}
|
||||
}
|
||||
|
||||
// Name returns the provider name
|
||||
func (p *BingProvider) Name() string {
|
||||
return "bing"
|
||||
@@ -92,11 +79,18 @@ func (p *BingProvider) Search(
|
||||
if len(query) == 0 {
|
||||
return nil, fmt.Errorf("query is empty")
|
||||
}
|
||||
logger.Infof(ctx, "[WebSearch][Bing] query=%q maxResults=%d url=%s", query, maxResults, p.baseURL)
|
||||
req, err := p.buildParams(ctx, query, maxResults, includeDate)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return p.doSearch(ctx, req)
|
||||
results, err := p.doSearch(ctx, req)
|
||||
if err != nil {
|
||||
logger.Warnf(ctx, "[WebSearch][Bing] failed: %v", err)
|
||||
return nil, err
|
||||
}
|
||||
logger.Infof(ctx, "[WebSearch][Bing] returned %d results", len(results))
|
||||
return results, nil
|
||||
}
|
||||
|
||||
func (p *BingProvider) doSearch(ctx context.Context, req *http.Request) ([]*types.WebSearchResult, error) {
|
||||
@@ -110,6 +104,12 @@ func (p *BingProvider) doSearch(ctx context.Context, req *http.Request) ([]*type
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
logger.Warnf(ctx, "[WebSearch][Bing] API returned status %d: %s", resp.StatusCode, string(body))
|
||||
return nil, fmt.Errorf("bing API returned status %d: %s", resp.StatusCode, string(body))
|
||||
}
|
||||
|
||||
var respData bingSearchResponse
|
||||
if err := json.Unmarshal(body, &respData); err != nil {
|
||||
return nil, fmt.Errorf("failed to unmarshal response: %w", err)
|
||||
@@ -128,7 +128,6 @@ func (p *BingProvider) doSearch(ctx context.Context, req *http.Request) ([]*type
|
||||
}
|
||||
|
||||
// bingSearchResponse defines the response structure for Bing search API.
|
||||
// ref: https://learn.microsoft.com/en-us/previous-versions/bing/search-apis/bing-web-search/quickstarts/rest/go
|
||||
type bingSearchResponse struct {
|
||||
Type string `json:"_type"`
|
||||
QueryContext struct {
|
||||
@@ -183,8 +182,6 @@ type bingSearchResponse struct {
|
||||
} `json:"rankingResponse"`
|
||||
}
|
||||
|
||||
// buildParams builds the request parameters for Bing search API.
|
||||
// ref: https://learn.microsoft.com/en-us/previous-versions/bing/search-apis/bing-web-search/quickstarts/rest/go
|
||||
func (p *BingProvider) buildParams(ctx context.Context, query string, maxResults int, includeDate bool) (*http.Request, error) {
|
||||
params := url.Values{}
|
||||
params.Set("q", query)
|
||||
@@ -22,8 +22,9 @@ type DuckDuckGoProvider struct {
|
||||
client *http.Client
|
||||
}
|
||||
|
||||
// NewDuckDuckGoProvider creates a new DuckDuckGo provider
|
||||
func NewDuckDuckGoProvider() (interfaces.WebSearchProvider, error) {
|
||||
// NewDuckDuckGoProvider creates a new DuckDuckGo provider.
|
||||
// DuckDuckGo is free and requires no API key or configuration.
|
||||
func NewDuckDuckGoProvider(params types.WebSearchProviderParameters) (interfaces.WebSearchProvider, error) {
|
||||
return &DuckDuckGoProvider{
|
||||
client: &http.Client{
|
||||
Timeout: 30 * time.Second,
|
||||
@@ -31,17 +32,6 @@ func NewDuckDuckGoProvider() (interfaces.WebSearchProvider, error) {
|
||||
}, nil
|
||||
}
|
||||
|
||||
// DuckDuckGoProviderInfo returns the provider info for registration
|
||||
func DuckDuckGoProviderInfo() types.WebSearchProviderInfo {
|
||||
return types.WebSearchProviderInfo{
|
||||
ID: "duckduckgo",
|
||||
Name: "DuckDuckGo",
|
||||
Free: true,
|
||||
RequiresAPIKey: false,
|
||||
Description: "DuckDuckGo Search API",
|
||||
}
|
||||
}
|
||||
|
||||
// Name returns the provider name
|
||||
func (p *DuckDuckGoProvider) Name() string {
|
||||
return "duckduckgo"
|
||||
@@ -82,7 +72,6 @@ func (p *DuckDuckGoProvider) searchHTML(
|
||||
baseURL := "https://html.duckduckgo.com/html/"
|
||||
params := url.Values{}
|
||||
params.Set("q", query)
|
||||
// Prefer Chinese results if applicable; otherwise DDG will auto-detect
|
||||
params.Set("kl", "cn-zh")
|
||||
|
||||
reqURL := baseURL + "?" + params.Encode()
|
||||
@@ -90,13 +79,11 @@ func (p *DuckDuckGoProvider) searchHTML(
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create request: %w", err)
|
||||
}
|
||||
// Use a realistic UA to avoid blocks
|
||||
req.Header.Set(
|
||||
"User-Agent",
|
||||
"Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/120.0.0.0 Safari/537.36",
|
||||
)
|
||||
|
||||
// print curl of request
|
||||
curlCommand := fmt.Sprintf(
|
||||
"curl -X GET '%s' -H 'User-Agent: Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/120.0.0.0 Safari/537.36'",
|
||||
req.URL.String(),
|
||||
@@ -119,7 +106,6 @@ func (p *DuckDuckGoProvider) searchHTML(
|
||||
}
|
||||
|
||||
results := make([]*types.WebSearchResult, 0, maxResults)
|
||||
// Structure based on DDG HTML page
|
||||
doc.Find(".web-result").Each(func(i int, s *goquery.Selection) {
|
||||
if len(results) >= maxResults {
|
||||
return
|
||||
@@ -3,12 +3,11 @@ package web_search
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net/url"
|
||||
"os"
|
||||
|
||||
"google.golang.org/api/customsearch/v1"
|
||||
"google.golang.org/api/option"
|
||||
|
||||
"github.com/Tencent/WeKnora/internal/logger"
|
||||
"github.com/Tencent/WeKnora/internal/types"
|
||||
"github.com/Tencent/WeKnora/internal/types/interfaces"
|
||||
)
|
||||
@@ -18,54 +17,32 @@ type GoogleProvider struct {
|
||||
srv *customsearch.Service
|
||||
apiKey string
|
||||
engineID string
|
||||
baseURL string
|
||||
}
|
||||
|
||||
// NewGoogleProvider creates a new Google provider
|
||||
func NewGoogleProvider() (interfaces.WebSearchProvider, error) {
|
||||
apiURL := os.Getenv("GOOGLE_SEARCH_API_URL")
|
||||
if apiURL == "" {
|
||||
return nil, fmt.Errorf("GOOGLE_SEARCH_API_URL environment variable is not set")
|
||||
// NewGoogleProvider creates a new Google provider from parameters (no environment variables).
|
||||
// The API endpoint is the official Google Custom Search endpoint — not tenant-configurable.
|
||||
func NewGoogleProvider(params types.WebSearchProviderParameters) (interfaces.WebSearchProvider, error) {
|
||||
if params.APIKey == "" {
|
||||
return nil, fmt.Errorf("API key is required for Google provider")
|
||||
}
|
||||
if params.EngineID == "" {
|
||||
return nil, fmt.Errorf("engine ID is required for Google provider")
|
||||
}
|
||||
|
||||
u, err := url.Parse(apiURL)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
clientOpts := []option.ClientOption{
|
||||
option.WithAPIKey(params.APIKey),
|
||||
}
|
||||
engineID := u.Query().Get("engine_id")
|
||||
if engineID == "" {
|
||||
return nil, fmt.Errorf("engine_id is empty")
|
||||
}
|
||||
apiKey := u.Query().Get("api_key")
|
||||
if apiKey == "" {
|
||||
return nil, fmt.Errorf("api_key is empty")
|
||||
}
|
||||
clientOpts := make([]option.ClientOption, 0)
|
||||
clientOpts = append(clientOpts, option.WithAPIKey(apiKey))
|
||||
clientOpts = append(clientOpts, option.WithEndpoint(u.Scheme+"://"+u.Host))
|
||||
srv, err := customsearch.NewService(context.Background(), clientOpts...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &GoogleProvider{
|
||||
srv: srv,
|
||||
apiKey: apiKey,
|
||||
engineID: engineID,
|
||||
baseURL: apiURL,
|
||||
apiKey: params.APIKey,
|
||||
engineID: params.EngineID,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// GoogleProviderInfo returns the provider info for registration
|
||||
func GoogleProviderInfo() types.WebSearchProviderInfo {
|
||||
return types.WebSearchProviderInfo{
|
||||
ID: "google",
|
||||
Name: "Google",
|
||||
Free: false,
|
||||
RequiresAPIKey: true,
|
||||
Description: "Google Custom Search API",
|
||||
}
|
||||
}
|
||||
|
||||
// Name returns the provider name
|
||||
func (p *GoogleProvider) Name() string {
|
||||
return "google"
|
||||
@@ -81,6 +58,7 @@ func (p *GoogleProvider) Search(
|
||||
if len(query) == 0 {
|
||||
return nil, fmt.Errorf("query is empty")
|
||||
}
|
||||
logger.Infof(ctx, "[WebSearch][Google] query=%q maxResults=%d engineID=%s", query, maxResults, p.engineID)
|
||||
cseCall := p.srv.Cse.List().Context(ctx).Cx(p.engineID).Q(query)
|
||||
|
||||
if maxResults > 0 {
|
||||
@@ -92,6 +70,7 @@ func (p *GoogleProvider) Search(
|
||||
|
||||
resp, err := cseCall.Do()
|
||||
if err != nil {
|
||||
logger.Warnf(ctx, "[WebSearch][Google] failed: %v", err)
|
||||
return nil, err
|
||||
}
|
||||
results := make([]*types.WebSearchResult, 0)
|
||||
@@ -104,5 +83,6 @@ func (p *GoogleProvider) Search(
|
||||
}
|
||||
results = append(results, result)
|
||||
}
|
||||
logger.Infof(ctx, "[WebSearch][Google] returned %d results", len(results))
|
||||
return results, nil
|
||||
}
|
||||
45
internal/infrastructure/web_search/registry.go
Normal file
45
internal/infrastructure/web_search/registry.go
Normal file
@@ -0,0 +1,45 @@
|
||||
package web_search
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"sync"
|
||||
|
||||
"github.com/Tencent/WeKnora/internal/types"
|
||||
"github.com/Tencent/WeKnora/internal/types/interfaces"
|
||||
)
|
||||
|
||||
// ProviderFactory creates a new web search provider instance from parameters.
|
||||
type ProviderFactory func(params types.WebSearchProviderParameters) (interfaces.WebSearchProvider, error)
|
||||
|
||||
// Registry manages web search provider type registrations.
|
||||
// It maps provider type IDs (e.g., "bing", "google") to their factory functions.
|
||||
// Instances are created on-demand with tenant-specific parameters.
|
||||
type Registry struct {
|
||||
factories map[string]ProviderFactory
|
||||
mu sync.RWMutex
|
||||
}
|
||||
|
||||
// NewRegistry creates a new web search provider registry
|
||||
func NewRegistry() *Registry {
|
||||
return &Registry{
|
||||
factories: make(map[string]ProviderFactory),
|
||||
}
|
||||
}
|
||||
|
||||
// Register registers a provider type factory by ID
|
||||
func (r *Registry) Register(id string, factory ProviderFactory) {
|
||||
r.mu.Lock()
|
||||
defer r.mu.Unlock()
|
||||
r.factories[id] = factory
|
||||
}
|
||||
|
||||
// CreateProvider creates a provider instance by type with the given parameters.
|
||||
func (r *Registry) CreateProvider(providerType string, params types.WebSearchProviderParameters) (interfaces.WebSearchProvider, error) {
|
||||
r.mu.RLock()
|
||||
factory, ok := r.factories[providerType]
|
||||
r.mu.RUnlock()
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("web search provider type %s not registered", providerType)
|
||||
}
|
||||
return factory(params)
|
||||
}
|
||||
@@ -7,14 +7,16 @@ import (
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"os"
|
||||
"time"
|
||||
|
||||
"github.com/Tencent/WeKnora/internal/logger"
|
||||
"github.com/Tencent/WeKnora/internal/types"
|
||||
"github.com/Tencent/WeKnora/internal/types/interfaces"
|
||||
)
|
||||
|
||||
const (
|
||||
// defaultTavilySearchURL is the hardcoded Tavily API URL.
|
||||
// Not configurable by tenants — prevents SSRF.
|
||||
defaultTavilySearchURL = "https://api.tavily.com/search"
|
||||
)
|
||||
|
||||
@@ -29,11 +31,10 @@ type TavilyProvider struct {
|
||||
apiKey string
|
||||
}
|
||||
|
||||
// NewTavilyProvider creates a new Tavily provider
|
||||
func NewTavilyProvider() (interfaces.WebSearchProvider, error) {
|
||||
apiKey := os.Getenv("TAVILY_API_KEY")
|
||||
if len(apiKey) == 0 {
|
||||
return nil, fmt.Errorf("TAVILY_API_KEY is not set")
|
||||
// NewTavilyProvider creates a new Tavily provider from parameters (no environment variables).
|
||||
func NewTavilyProvider(params types.WebSearchProviderParameters) (interfaces.WebSearchProvider, error) {
|
||||
if params.APIKey == "" {
|
||||
return nil, fmt.Errorf("API key is required for Tavily provider")
|
||||
}
|
||||
client := &http.Client{
|
||||
Timeout: defaultTavilyTimeout,
|
||||
@@ -41,21 +42,10 @@ func NewTavilyProvider() (interfaces.WebSearchProvider, error) {
|
||||
return &TavilyProvider{
|
||||
client: client,
|
||||
baseURL: defaultTavilySearchURL,
|
||||
apiKey: apiKey,
|
||||
apiKey: params.APIKey,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// TavilyProviderInfo returns the provider info for registration
|
||||
func TavilyProviderInfo() types.WebSearchProviderInfo {
|
||||
return types.WebSearchProviderInfo{
|
||||
ID: "tavily",
|
||||
Name: "Tavily",
|
||||
Free: false,
|
||||
RequiresAPIKey: true,
|
||||
Description: "Tavily Search API",
|
||||
}
|
||||
}
|
||||
|
||||
// Name returns the provider name
|
||||
func (p *TavilyProvider) Name() string {
|
||||
return "tavily"
|
||||
@@ -71,6 +61,7 @@ func (p *TavilyProvider) Search(
|
||||
if len(query) == 0 {
|
||||
return nil, fmt.Errorf("query is empty")
|
||||
}
|
||||
logger.Infof(ctx, "[WebSearch][Tavily] query=%q maxResults=%d url=%s", query, maxResults, p.baseURL)
|
||||
|
||||
reqBody := tavilySearchRequest{
|
||||
APIKey: p.apiKey,
|
||||
@@ -97,6 +88,7 @@ func (p *TavilyProvider) Search(
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
respBody, _ := io.ReadAll(resp.Body)
|
||||
logger.Warnf(ctx, "[WebSearch][Tavily] API returned status %d: %s", resp.StatusCode, string(respBody))
|
||||
return nil, fmt.Errorf("tavily API returned status %d: %s", resp.StatusCode, string(respBody))
|
||||
}
|
||||
|
||||
@@ -125,6 +117,7 @@ func (p *TavilyProvider) Search(
|
||||
}
|
||||
results = append(results, result)
|
||||
}
|
||||
logger.Infof(ctx, "[WebSearch][Tavily] returned %d results", len(results))
|
||||
return results, nil
|
||||
}
|
||||
|
||||
@@ -53,7 +53,8 @@ type RouterParams struct {
|
||||
InitializationHandler *handler.InitializationHandler
|
||||
SystemHandler *handler.SystemHandler
|
||||
MCPServiceHandler *handler.MCPServiceHandler
|
||||
WebSearchHandler *handler.WebSearchHandler
|
||||
WebSearchHandler *handler.WebSearchHandler
|
||||
WebSearchProviderHandler *handler.WebSearchProviderHandler
|
||||
FAQHandler *handler.FAQHandler
|
||||
TagHandler *handler.TagHandler
|
||||
CustomAgentHandler *handler.CustomAgentHandler
|
||||
@@ -137,6 +138,7 @@ func NewRouter(params RouterParams) *gin.Engine {
|
||||
RegisterSystemRoutes(v1, params.SystemHandler)
|
||||
RegisterMCPServiceRoutes(v1, params.MCPServiceHandler)
|
||||
RegisterWebSearchRoutes(v1, params.WebSearchHandler)
|
||||
RegisterWebSearchProviderRoutes(v1, params.WebSearchProviderHandler)
|
||||
RegisterCustomAgentRoutes(v1, params.CustomAgentHandler)
|
||||
RegisterSkillRoutes(v1, params.SkillHandler)
|
||||
RegisterOrganizationRoutes(v1, params.OrganizationHandler)
|
||||
@@ -477,6 +479,25 @@ func RegisterWebSearchRoutes(r *gin.RouterGroup, webSearchHandler *handler.WebSe
|
||||
}
|
||||
}
|
||||
|
||||
// RegisterWebSearchProviderRoutes registers CRUD routes for web search provider configurations
|
||||
func RegisterWebSearchProviderRoutes(r *gin.RouterGroup, h *handler.WebSearchProviderHandler) {
|
||||
providers := r.Group("/web-search-providers")
|
||||
{
|
||||
// List available provider types (metadata for UI forms)
|
||||
providers.GET("/types", h.ListProviderTypes)
|
||||
// Test with raw credentials (no persistence)
|
||||
providers.POST("/test", h.TestProviderRaw)
|
||||
// CRUD
|
||||
providers.POST("", h.CreateProvider)
|
||||
providers.GET("", h.ListProviders)
|
||||
providers.GET("/:id", h.GetProvider)
|
||||
providers.PUT("/:id", h.UpdateProvider)
|
||||
providers.DELETE("/:id", h.DeleteProvider)
|
||||
// Test existing saved provider
|
||||
providers.POST("/:id/test", h.TestProviderByID)
|
||||
}
|
||||
}
|
||||
|
||||
// RegisterCustomAgentRoutes registers custom agent routes
|
||||
func RegisterCustomAgentRoutes(r *gin.RouterGroup, agentHandler *handler.CustomAgentHandler) {
|
||||
agents := r.Group("/agents")
|
||||
|
||||
@@ -64,7 +64,7 @@ func ConvertWebSearchResults(
|
||||
result := &types.SearchResult{
|
||||
ID: chunkID,
|
||||
Content: content,
|
||||
KnowledgeID: "",
|
||||
KnowledgeID: chunkID, // Use URL as KnowledgeID so each web result stays independent during merge
|
||||
ChunkIndex: 0,
|
||||
KnowledgeTitle: webResult.Title,
|
||||
StartAt: 0,
|
||||
|
||||
@@ -25,6 +25,7 @@ type AgentConfig struct {
|
||||
UseCustomSystemPrompt bool `json:"use_custom_system_prompt"` // Whether to use custom system prompt instead of default
|
||||
WebSearchEnabled bool `json:"web_search_enabled"` // Whether web search tool is enabled
|
||||
WebSearchMaxResults int `json:"web_search_max_results"` // Maximum number of web search results (default: 5)
|
||||
WebSearchProviderID string `json:"web_search_provider_id,omitempty"` // WebSearchProviderEntity ID (resolved from agent config)
|
||||
MultiTurnEnabled bool `json:"multi_turn_enabled"` // Whether multi-turn conversation is enabled
|
||||
HistoryTurns int `json:"history_turns"` // Number of history turns to keep in context
|
||||
SearchTargets SearchTargets `json:"-"` // Pre-computed unified search targets (runtime only)
|
||||
|
||||
@@ -46,9 +46,12 @@ type PipelineRequest struct {
|
||||
ChatModelSupportsVision bool `json:"-"`
|
||||
|
||||
// Misc request-scoped config
|
||||
TenantID uint64 `json:"-"`
|
||||
WebSearchEnabled bool `json:"-"`
|
||||
Language string `json:"-"`
|
||||
TenantID uint64 `json:"-"`
|
||||
WebSearchEnabled bool `json:"-"`
|
||||
WebSearchProviderID string `json:"-"` // Resolved from agent config or tenant default
|
||||
WebFetchEnabled bool `json:"-"` // Auto-fetch full page content for web search results after rerank
|
||||
WebFetchTopN int `json:"-"` // Max pages to fetch (default 3)
|
||||
Language string `json:"-"`
|
||||
}
|
||||
|
||||
// QueryIntent represents the classified intent of a user query.
|
||||
@@ -183,6 +186,9 @@ func (c *ChatManage) Clone() *ChatManage {
|
||||
ChatModelSupportsVision: c.ChatModelSupportsVision,
|
||||
TenantID: c.TenantID,
|
||||
WebSearchEnabled: c.WebSearchEnabled,
|
||||
WebSearchProviderID: c.WebSearchProviderID,
|
||||
WebFetchEnabled: c.WebFetchEnabled,
|
||||
WebFetchTopN: c.WebFetchTopN,
|
||||
Language: c.Language,
|
||||
},
|
||||
PipelineState: PipelineState{
|
||||
@@ -205,6 +211,7 @@ const (
|
||||
CHUNK_SEARCH_PARALLEL EventType = "chunk_search_parallel"
|
||||
ENTITY_SEARCH EventType = "entity_search"
|
||||
CHUNK_RERANK EventType = "chunk_rerank"
|
||||
WEB_FETCH EventType = "web_fetch"
|
||||
CHUNK_MERGE EventType = "chunk_merge"
|
||||
DATA_ANALYSIS EventType = "data_analysis"
|
||||
INTO_CHAT_MESSAGE EventType = "into_chat_message"
|
||||
|
||||
@@ -141,6 +141,13 @@ type CustomAgentConfig struct {
|
||||
WebSearchEnabled bool `yaml:"web_search_enabled" json:"web_search_enabled"`
|
||||
// Maximum web search results
|
||||
WebSearchMaxResults int `yaml:"web_search_max_results" json:"web_search_max_results"`
|
||||
// WebSearchProviderID references a specific WebSearchProviderEntity.
|
||||
// If empty, the tenant's default provider (is_default=true) is used.
|
||||
WebSearchProviderID string `yaml:"web_search_provider_id" json:"web_search_provider_id,omitempty"`
|
||||
// Whether to auto-fetch full page content for reranked web search results
|
||||
WebFetchEnabled bool `yaml:"web_fetch_enabled" json:"web_fetch_enabled"`
|
||||
// Max number of pages to fetch after rerank (default: 3)
|
||||
WebFetchTopN int `yaml:"web_fetch_top_n" json:"web_fetch_top_n,omitempty"`
|
||||
|
||||
// ===== Multi-turn Conversation Settings =====
|
||||
// Whether multi-turn conversation is enabled
|
||||
|
||||
@@ -16,8 +16,9 @@ type WebSearchProvider interface {
|
||||
|
||||
// WebSearchService defines the interface for web search services
|
||||
type WebSearchService interface {
|
||||
// Search performs a web search
|
||||
Search(ctx context.Context, config *types.WebSearchConfig, query string) ([]*types.WebSearchResult, error)
|
||||
// Search performs a web search using the provider entity identified by providerID.
|
||||
// If providerID is empty, it falls back to the deprecated config.Provider field for backward compatibility.
|
||||
Search(ctx context.Context, providerID string, config *types.WebSearchConfig, query string) ([]*types.WebSearchResult, error)
|
||||
// CompressWithRAG performs RAG-based compression using a temporary, hidden knowledge base
|
||||
// The temporary knowledge base is deleted after use. The UI will not list it due to repo filtering.
|
||||
CompressWithRAG(ctx context.Context, sessionID string, tempKBID string, questions []string,
|
||||
|
||||
39
internal/types/interfaces/web_search_provider.go
Normal file
39
internal/types/interfaces/web_search_provider.go
Normal file
@@ -0,0 +1,39 @@
|
||||
package interfaces
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/Tencent/WeKnora/internal/types"
|
||||
)
|
||||
|
||||
// WebSearchProviderRepository defines the repository interface for web search provider CRUD
|
||||
type WebSearchProviderRepository interface {
|
||||
// Create creates a new web search provider
|
||||
Create(ctx context.Context, provider *types.WebSearchProviderEntity) error
|
||||
// GetByID retrieves a web search provider by ID within a tenant scope
|
||||
GetByID(ctx context.Context, tenantID uint64, id string) (*types.WebSearchProviderEntity, error)
|
||||
// GetDefault retrieves the default provider (is_default=true) for a tenant, or nil if none.
|
||||
GetDefault(ctx context.Context, tenantID uint64) (*types.WebSearchProviderEntity, error)
|
||||
// List lists all web search providers for a tenant
|
||||
List(ctx context.Context, tenantID uint64) ([]*types.WebSearchProviderEntity, error)
|
||||
// Update updates a web search provider
|
||||
Update(ctx context.Context, provider *types.WebSearchProviderEntity) error
|
||||
// Delete deletes a web search provider (soft delete)
|
||||
Delete(ctx context.Context, tenantID uint64, id string) error
|
||||
// ClearDefault clears the default flag for all providers of a tenant, optionally excluding one
|
||||
ClearDefault(ctx context.Context, tenantID uint64, excludeID string) error
|
||||
}
|
||||
|
||||
// WebSearchProviderService defines the service interface for web search provider management.
|
||||
// Tenant isolation is enforced by the handler layer (getOwned pattern).
|
||||
// Service methods operate on entities whose TenantID is already verified.
|
||||
type WebSearchProviderService interface {
|
||||
// CreateProvider creates a new web search provider.
|
||||
// provider.TenantID must be set by the caller (handler).
|
||||
CreateProvider(ctx context.Context, provider *types.WebSearchProviderEntity) error
|
||||
// UpdateProvider updates an existing provider.
|
||||
// provider.TenantID must be set by the caller (handler) for the repository WHERE clause.
|
||||
UpdateProvider(ctx context.Context, provider *types.WebSearchProviderEntity) error
|
||||
// DeleteProvider deletes a provider by tenant + id.
|
||||
DeleteProvider(ctx context.Context, tenantID uint64, id string) error
|
||||
}
|
||||
@@ -8,8 +8,11 @@ import (
|
||||
|
||||
// WebSearchConfig represents the web search configuration for a tenant
|
||||
type WebSearchConfig struct {
|
||||
Provider string `json:"provider"` // 搜索引擎提供商ID
|
||||
APIKey string `json:"api_key"` // API密钥(如果需要)
|
||||
// Deprecated: Use WebSearchProviderEntity.Parameters.APIKey instead.
|
||||
Provider string `json:"provider,omitempty"`
|
||||
// Deprecated: Use WebSearchProviderEntity.Parameters.APIKey instead.
|
||||
APIKey string `json:"api_key,omitempty"`
|
||||
|
||||
MaxResults int `json:"max_results"` // 最大搜索结果数
|
||||
IncludeDate bool `json:"include_date"` // 是否包含日期
|
||||
CompressionMethod string `json:"compression_method"` // 压缩方法:none, summary, extract, rag
|
||||
|
||||
157
internal/types/web_search_provider.go
Normal file
157
internal/types/web_search_provider.go
Normal file
@@ -0,0 +1,157 @@
|
||||
package types
|
||||
|
||||
import (
|
||||
"database/sql/driver"
|
||||
"encoding/json"
|
||||
"time"
|
||||
|
||||
"github.com/Tencent/WeKnora/internal/utils"
|
||||
"github.com/google/uuid"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
// WebSearchProviderType represents the type of web search provider
|
||||
type WebSearchProviderType string
|
||||
|
||||
const (
|
||||
WebSearchProviderTypeBing WebSearchProviderType = "bing"
|
||||
WebSearchProviderTypeGoogle WebSearchProviderType = "google"
|
||||
WebSearchProviderTypeDuckDuckGo WebSearchProviderType = "duckduckgo"
|
||||
WebSearchProviderTypeTavily WebSearchProviderType = "tavily"
|
||||
)
|
||||
|
||||
// WebSearchProviderEntity represents a configured web search provider instance for a tenant.
|
||||
// This is a CRUD entity stored in the database, similar to the Model entity.
|
||||
// Each tenant can create multiple provider configurations (e.g., "Production Bing", "Test Google").
|
||||
// Agents reference these by ID.
|
||||
type WebSearchProviderEntity struct {
|
||||
// Unique identifier (UUID, auto-generated)
|
||||
ID string `yaml:"id" json:"id" gorm:"type:varchar(36);primaryKey"`
|
||||
// Tenant ID for scoping
|
||||
TenantID uint64 `yaml:"tenant_id" json:"tenant_id"`
|
||||
// User-friendly name, e.g., "Production Bing Search"
|
||||
Name string `yaml:"name" json:"name" gorm:"type:varchar(255);not null"`
|
||||
// Provider type: bing, google, duckduckgo, tavily
|
||||
Provider WebSearchProviderType `yaml:"provider" json:"provider" gorm:"type:varchar(50);not null"`
|
||||
// Description
|
||||
Description string `yaml:"description" json:"description" gorm:"type:text"`
|
||||
// Provider-specific parameters (API key, engine ID, etc.) stored as encrypted JSON
|
||||
Parameters WebSearchProviderParameters `yaml:"parameters" json:"parameters" gorm:"type:json"`
|
||||
// Whether this is the default provider for the tenant
|
||||
IsDefault bool `yaml:"is_default" json:"is_default" gorm:"default:false"`
|
||||
// Timestamps
|
||||
CreatedAt time.Time `yaml:"created_at" json:"created_at"`
|
||||
UpdatedAt time.Time `yaml:"updated_at" json:"updated_at"`
|
||||
DeletedAt gorm.DeletedAt `yaml:"deleted_at" json:"deleted_at" gorm:"index"`
|
||||
}
|
||||
|
||||
// TableName returns the table name for WebSearchProviderEntity
|
||||
func (WebSearchProviderEntity) TableName() string {
|
||||
return "web_search_providers"
|
||||
}
|
||||
|
||||
// BeforeCreate is a GORM hook that runs before creating a new record.
|
||||
// Automatically generates a UUID for new providers.
|
||||
func (e *WebSearchProviderEntity) BeforeCreate(tx *gorm.DB) (err error) {
|
||||
if e.ID == "" {
|
||||
e.ID = uuid.New().String()
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// WebSearchProviderParameters holds provider-specific configuration.
|
||||
// API keys are encrypted at rest using AES-GCM.
|
||||
// BaseURL is intentionally NOT included — each provider type uses a hardcoded
|
||||
// official API endpoint to prevent SSRF attacks.
|
||||
type WebSearchProviderParameters struct {
|
||||
// API key for the search provider (encrypted in DB)
|
||||
APIKey string `yaml:"api_key" json:"api_key,omitempty"`
|
||||
// Google Custom Search Engine ID (only for Google provider)
|
||||
EngineID string `yaml:"engine_id" json:"engine_id,omitempty"`
|
||||
// Provider-specific extra configuration for future extensibility
|
||||
ExtraConfig map[string]string `yaml:"extra_config" json:"extra_config,omitempty"`
|
||||
}
|
||||
|
||||
// Value implements the driver.Valuer interface.
|
||||
// Encrypts APIKey before persisting to database.
|
||||
func (p WebSearchProviderParameters) Value() (driver.Value, error) {
|
||||
if key := utils.GetAESKey(); key != nil && p.APIKey != "" {
|
||||
if encrypted, err := utils.EncryptAESGCM(p.APIKey, key); err == nil {
|
||||
p.APIKey = encrypted
|
||||
}
|
||||
}
|
||||
return json.Marshal(p)
|
||||
}
|
||||
|
||||
// Scan implements the sql.Scanner interface.
|
||||
// Decrypts APIKey after loading from database.
|
||||
func (p *WebSearchProviderParameters) Scan(value interface{}) error {
|
||||
if value == nil {
|
||||
return nil
|
||||
}
|
||||
b, ok := value.([]byte)
|
||||
if !ok {
|
||||
return nil
|
||||
}
|
||||
if err := json.Unmarshal(b, p); err != nil {
|
||||
return err
|
||||
}
|
||||
if key := utils.GetAESKey(); key != nil && p.APIKey != "" {
|
||||
if decrypted, err := utils.DecryptAESGCM(p.APIKey, key); err == nil {
|
||||
p.APIKey = decrypted
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// WebSearchProviderTypeInfo describes the metadata of a provider type.
|
||||
// Used by the GET /types endpoint so the frontend can dynamically render forms.
|
||||
type WebSearchProviderTypeInfo struct {
|
||||
// Provider type identifier
|
||||
ID string `json:"id"`
|
||||
// Human-readable name
|
||||
Name string `json:"name"`
|
||||
// Whether the provider requires an API key
|
||||
RequiresAPIKey bool `json:"requires_api_key"`
|
||||
// Whether the provider requires an engine ID (e.g., Google CSE)
|
||||
RequiresEngineID bool `json:"requires_engine_id"`
|
||||
// Description
|
||||
Description string `json:"description"`
|
||||
// URL to the provider's official website or documentation for obtaining credentials
|
||||
DocsURL string `json:"docs_url,omitempty"`
|
||||
}
|
||||
|
||||
// GetWebSearchProviderTypes returns metadata for all supported provider types.
|
||||
func GetWebSearchProviderTypes() []WebSearchProviderTypeInfo {
|
||||
return []WebSearchProviderTypeInfo{
|
||||
{
|
||||
ID: "duckduckgo",
|
||||
Name: "DuckDuckGo",
|
||||
RequiresAPIKey: false,
|
||||
Description: "DuckDuckGo Search (free, no API key required)",
|
||||
DocsURL: "https://duckduckgo.com/",
|
||||
},
|
||||
{
|
||||
ID: "bing",
|
||||
Name: "Bing",
|
||||
RequiresAPIKey: true,
|
||||
Description: "Bing Search API (requires API key from Azure)",
|
||||
DocsURL: "https://learn.microsoft.com/en-us/bing/search-apis/bing-web-search/overview",
|
||||
},
|
||||
{
|
||||
ID: "google",
|
||||
Name: "Google",
|
||||
RequiresAPIKey: true,
|
||||
RequiresEngineID: true,
|
||||
Description: "Google Custom Search API (requires API key and engine ID)",
|
||||
DocsURL: "https://developers.google.com/custom-search/v1/overview",
|
||||
},
|
||||
{
|
||||
ID: "tavily",
|
||||
Name: "Tavily",
|
||||
RequiresAPIKey: true,
|
||||
Description: "Tavily Search API (requires API key)",
|
||||
DocsURL: "https://tavily.com/",
|
||||
},
|
||||
}
|
||||
}
|
||||
@@ -56,23 +56,5 @@ CREATE INDEX IF NOT EXISTS idx_sync_logs_tenant_id ON sync_logs (tenant_id);
|
||||
CREATE INDEX IF NOT EXISTS idx_sync_logs_status ON sync_logs (status);
|
||||
CREATE INDEX IF NOT EXISTS idx_sync_logs_started_at ON sync_logs (started_at);
|
||||
|
||||
-- Trigger function to auto-update updated_at column
|
||||
CREATE OR REPLACE FUNCTION update_updated_at_column()
|
||||
RETURNS TRIGGER AS $$
|
||||
BEGIN
|
||||
NEW.updated_at = CURRENT_TIMESTAMP;
|
||||
RETURN NEW;
|
||||
END;
|
||||
$$ LANGUAGE plpgsql;
|
||||
|
||||
CREATE TRIGGER trg_data_sources_updated_at
|
||||
BEFORE UPDATE ON data_sources
|
||||
FOR EACH ROW
|
||||
EXECUTE FUNCTION update_updated_at_column();
|
||||
|
||||
CREATE TRIGGER trg_sync_logs_updated_at
|
||||
BEFORE UPDATE ON sync_logs
|
||||
FOR EACH ROW
|
||||
EXECUTE FUNCTION update_updated_at_column();
|
||||
|
||||
DO $$ BEGIN RAISE NOTICE '[Migration 000028] data_sources and sync_logs tables created successfully'; END $$;
|
||||
@@ -0,0 +1,3 @@
|
||||
-- Rollback migration: 000029_web_search_providers
|
||||
DROP TRIGGER IF EXISTS trg_web_search_providers_updated_at ON web_search_providers;
|
||||
DROP TABLE IF EXISTS web_search_providers;
|
||||
31
migrations/versioned/000030_web_search_providers.up.sql
Normal file
31
migrations/versioned/000030_web_search_providers.up.sql
Normal file
@@ -0,0 +1,31 @@
|
||||
-- Migration: 000029_web_search_providers
|
||||
-- Description: Create web_search_providers table for tenant-specific search engine configurations
|
||||
DO $$ BEGIN RAISE NOTICE '[Migration 000029] Creating web_search_providers table'; END $$;
|
||||
|
||||
-- Create web_search_providers table for managing tenant search engine configurations
|
||||
-- Each row represents a configured search provider instance (e.g., "Production Bing", "Test Google")
|
||||
-- Agents reference these by ID via custom_agents.config.web_search_provider_id
|
||||
CREATE TABLE IF NOT EXISTS web_search_providers (
|
||||
id VARCHAR(36) NOT NULL PRIMARY KEY,
|
||||
tenant_id BIGINT NOT NULL,
|
||||
name VARCHAR(255) NOT NULL,
|
||||
provider VARCHAR(50) NOT NULL,
|
||||
description TEXT,
|
||||
parameters JSONB,
|
||||
is_default BOOLEAN DEFAULT false,
|
||||
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
|
||||
updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
|
||||
deleted_at TIMESTAMP NULL
|
||||
);
|
||||
|
||||
CREATE INDEX IF NOT EXISTS idx_web_search_providers_tenant_id ON web_search_providers (tenant_id);
|
||||
CREATE INDEX IF NOT EXISTS idx_web_search_providers_provider ON web_search_providers (provider);
|
||||
CREATE INDEX IF NOT EXISTS idx_web_search_providers_deleted_at ON web_search_providers (deleted_at);
|
||||
|
||||
-- Auto-update updated_at column (reuses the function from migration 000028)
|
||||
CREATE TRIGGER trg_web_search_providers_updated_at
|
||||
BEFORE UPDATE ON web_search_providers
|
||||
FOR EACH ROW
|
||||
EXECUTE FUNCTION update_updated_at_column();
|
||||
|
||||
DO $$ BEGIN RAISE NOTICE '[Migration 000029] web_search_providers table created successfully'; END $$;
|
||||
Reference in New Issue
Block a user