feat(web-search): refactor web search provider

- Updated relevant files to include provider registration, implementation, and metadata.
- Enhanced frontend components to support provider management and configuration.
- Added localization for new provider settings and messages.
- Implemented backend repository methods for CRUD operations on web search providers.
This commit is contained in:
wizardchen
2026-03-31 20:14:26 +08:00
committed by lyingbug
parent c59ae84157
commit a5ca9a2784
46 changed files with 2311 additions and 1499 deletions

View File

@@ -0,0 +1,267 @@
# 添加新的网络搜索引擎
本文档说明如何在 WeKnora 中添加一个新的网络搜索引擎类型(如 Brave Search、Searx 等)。
## 架构概述
```
internal/
├── types/
│ └── web_search_provider.go # 实体定义 + Provider 类型元数据
├── infrastructure/
│ └── web_search/
│ ├── registry.go # Provider 工厂注册表
│ ├── bing.go # Bing 实现
│ ├── google.go # Google 实现
│ ├── duckduckgo.go # DuckDuckGo 实现
│ └── tavily.go # Tavily 实现
├── container/
│ └── container.go # DI 注册registerWebSearchProviders
└── types/interfaces/
└── web_search.go # WebSearchProvider 接口
```
搜索引擎的 API 端点**硬编码**在代码中,不向用户暴露 BaseURL从源头消除 SSRF 风险。
## 步骤
以添加 **Brave Search** 为例。
### 1. 在 `types/web_search_provider.go` 中注册类型常量
```go
const (
WebSearchProviderTypeBing WebSearchProviderType = "bing"
WebSearchProviderTypeGoogle WebSearchProviderType = "google"
WebSearchProviderTypeDuckDuckGo WebSearchProviderType = "duckduckgo"
WebSearchProviderTypeTavily WebSearchProviderType = "tavily"
WebSearchProviderTypeBrave WebSearchProviderType = "brave" // ← 新增
)
```
### 2. 在 `GetWebSearchProviderTypes()` 中添加类型元数据
```go
{
ID: "brave",
Name: "Brave Search",
Free: false,
RequiresAPIKey: true,
Description: "Brave Search API",
DocsURL: "https://brave.com/search/api/",
},
```
字段说明:
| 字段 | 说明 |
| ---------------- | ---------------------------------------------- |
| `ID` | 唯一标识,存入数据库,不可更改 |
| `Name` | 前端展示名称 |
| `Free` | 是否免费 |
| `RequiresAPIKey` | 是否需要 API Key |
| `RequiresEngineID` | 是否需要额外 ID如 Google CSE |
| `Description` | 简短描述 |
| `DocsURL` | 官方文档链接,前端在添加对话框中显示 |
### 3. 创建 Provider 实现
新建 `internal/infrastructure/web_search/brave.go`
```go
package web_search
import (
"context"
"encoding/json"
"fmt"
"io"
"net/http"
"time"
"github.com/Tencent/WeKnora/internal/types"
"github.com/Tencent/WeKnora/internal/types/interfaces"
)
const defaultBraveSearchURL = "https://api.search.brave.com/res/v1/web/search"
type BraveProvider struct {
client *http.Client
apiKey string
}
// NewBraveProvider 从参数创建实例(不读环境变量)
func NewBraveProvider(params types.WebSearchProviderParameters) (interfaces.WebSearchProvider, error) {
if params.APIKey == "" {
return nil, fmt.Errorf("API key is required for Brave provider")
}
return &BraveProvider{
client: &http.Client{Timeout: 10 * time.Second},
apiKey: params.APIKey,
}, nil
}
func BraveProviderTypeInfo() types.WebSearchProviderTypeInfo {
return types.WebSearchProviderTypeInfo{
ID: "brave",
Name: "Brave Search",
Free: false,
RequiresAPIKey: true,
Description: "Brave Search API",
DocsURL: "https://brave.com/search/api/",
}
}
func (p *BraveProvider) Name() string { return "brave" }
func (p *BraveProvider) Search(
ctx context.Context, query string, maxResults int, includeDate bool,
) ([]*types.WebSearchResult, error) {
// 构造请求 — BaseURL 硬编码
req, err := http.NewRequestWithContext(ctx, "GET",
fmt.Sprintf("%s?q=%s&count=%d", defaultBraveSearchURL, query, maxResults), nil)
if err != nil {
return nil, err
}
req.Header.Set("X-Subscription-Token", p.apiKey)
req.Header.Set("Accept", "application/json")
resp, err := p.client.Do(req)
if err != nil {
return nil, err
}
defer resp.Body.Close()
// 解析响应
body, _ := io.ReadAll(resp.Body)
var data braveResponse
if err := json.Unmarshal(body, &data); err != nil {
return nil, err
}
results := make([]*types.WebSearchResult, 0, len(data.Web.Results))
for _, r := range data.Web.Results {
results = append(results, &types.WebSearchResult{
Title: r.Title,
URL: r.URL,
Snippet: r.Description,
Source: "brave",
})
}
return results, nil
}
type braveResponse struct {
Web struct {
Results []struct {
Title string `json:"title"`
URL string `json:"url"`
Description string `json:"description"`
} `json:"results"`
} `json:"web"`
}
```
**关键要求**
1. **构造函数签名**必须为 `func(types.WebSearchProviderParameters) (interfaces.WebSearchProvider, error)`
2. **API 端点硬编码**为常量,不从参数中读取
3. **实现 `interfaces.WebSearchProvider` 接口**`Name()``Search()`
### 4. 在 Service 的参数校验中添加新类型
编辑 `internal/application/service/web_search_provider.go`
```go
func isValidProviderType(provider types.WebSearchProviderType) bool {
switch provider {
case types.WebSearchProviderTypeBing,
types.WebSearchProviderTypeGoogle,
types.WebSearchProviderTypeDuckDuckGo,
types.WebSearchProviderTypeTavily,
types.WebSearchProviderTypeBrave: // ← 新增
return true
default:
return false
}
}
```
### 5. 在 DI 容器中注册
编辑 `internal/container/container.go``registerWebSearchProviders` 函数:
```go
func registerWebSearchProviders(registry *infra_web_search.Registry) {
// ... 已有注册 ...
// Register Brave provider type
registry.Register(infra_web_search.BraveProviderTypeInfo(), infra_web_search.NewBraveProvider)
}
```
### 6. 验证
```bash
# 编译
go build ./...
# 启动后调用 API 验证类型列表
curl http://localhost:8080/api/v1/web-search-providers/types \
-H 'X-API-Key: your_key'
# 创建 Brave 搜索引擎实例
curl -X POST http://localhost:8080/api/v1/web-search-providers \
-H 'X-API-Key: your_key' \
-H 'Content-Type: application/json' \
-d '{
"name": "Brave Search",
"provider": "brave",
"parameters": { "api_key": "BSA..." },
"is_default": true
}'
```
## 需要额外参数的情况
如果新引擎需要 API Key 以外的参数(类似 Google 的 `engine_id`),有两种方式:
### 方式一:使用 `ExtraConfig`
利用 `WebSearchProviderParameters.ExtraConfig` 字段,不需要改类型定义:
```go
func NewFooProvider(params types.WebSearchProviderParameters) (interfaces.WebSearchProvider, error) {
region := params.ExtraConfig["region"]
if region == "" {
region = "us"
}
// ...
}
```
前端在 `GetWebSearchProviderTypes()` 中可以标注需要哪些 extra 字段(后续支持动态表单渲染)。
### 方式二:添加专用字段
如果参数非常通用(比如多个引擎都需要),可以在 `WebSearchProviderParameters` 中添加新字段:
```go
type WebSearchProviderParameters struct {
APIKey string `json:"api_key,omitempty"`
EngineID string `json:"engine_id,omitempty"`
Region string `json:"region,omitempty"` // ← 新增
ExtraConfig map[string]string `json:"extra_config,omitempty"`
}
```
同时在 `WebSearchProviderTypeInfo` 中添加 `RequiresRegion bool` 等字段,前端根据此动态显示输入框。
## 文件变更清单
| 文件 | 操作 |
| ---- | ---- |
| `internal/types/web_search_provider.go` | 添加常量 + 类型元数据 |
| `internal/infrastructure/web_search/brave.go` | **新建** Provider 实现 |
| `internal/application/service/web_search_provider.go` | `isValidProviderType` 加新类型 |
| `internal/container/container.go` | `registerWebSearchProviders` 注册 |

View File

@@ -0,0 +1,73 @@
import { get, post, put, del } from '@/utils/request'
// WebSearchProviderEntity represents a configured web search provider instance
export interface WebSearchProviderEntity {
id?: string
tenant_id?: number
name: string
provider: 'bing' | 'google' | 'duckduckgo' | 'tavily'
description?: string
parameters: {
api_key?: string
engine_id?: string
extra_config?: Record<string, string>
}
is_default?: boolean
created_at?: string
updated_at?: string
}
// WebSearchProviderTypeInfo describes metadata for a provider type
export interface WebSearchProviderTypeInfo {
id: string
name: string
requires_api_key: boolean
requires_engine_id?: boolean
description?: string
docs_url?: string
}
// Create a new web search provider
export function createWebSearchProvider(data: Partial<WebSearchProviderEntity>) {
return post('/api/v1/web-search-providers', data)
}
// List all web search providers for the current tenant
export function listWebSearchProviders() {
return get('/api/v1/web-search-providers')
}
// Get a single web search provider by ID
export function getWebSearchProvider(id: string) {
return get(`/api/v1/web-search-providers/${id}`)
}
// Update an existing web search provider
export function updateWebSearchProvider(id: string, data: Partial<WebSearchProviderEntity>) {
return put(`/api/v1/web-search-providers/${id}`, data)
}
// Delete a web search provider
export function deleteWebSearchProvider(id: string) {
return del(`/api/v1/web-search-providers/${id}`)
}
// Get available provider types (for dynamic form rendering)
export function listWebSearchProviderTypes(): Promise<WebSearchProviderTypeInfo[]> {
return get('/api/v1/web-search-providers/types').then((res: any) => {
if (res.success && res.data) {
return res.data
}
return []
})
}
// Test a web search provider connection.
// If id is provided, tests the existing saved provider.
// If data is provided, tests with raw credentials (no persistence).
export function testWebSearchProvider(id?: string, data?: { provider: string; parameters: any }): Promise<any> {
if (id) {
return post(`/api/v1/web-search-providers/${id}/test`, {})
}
return post('/api/v1/web-search-providers/test', data || {})
}

View File

@@ -1,6 +1,7 @@
import { get, put } from '@/utils/request'
// WebSearchProviderConfig represents information about a web search provider
// WebSearchProviderConfig represents information about a web search provider type
// Deprecated: Use WebSearchProviderTypeInfo from web-search-provider.ts instead
export interface WebSearchProviderConfig {
id: string
name: string
@@ -12,7 +13,10 @@ export interface WebSearchProviderConfig {
// WebSearchConfig represents the web search configuration for a tenant
export interface WebSearchConfig {
provider: string
// New: references a WebSearchProviderEntity by ID
default_provider_id?: string
// Deprecated: kept for backward compatibility
provider?: string
api_key?: string
max_results: number
include_date: boolean
@@ -24,7 +28,7 @@ export interface WebSearchConfig {
document_fragments?: number
}
// Get web search providers
// Get web search provider types (available engines)
export function getWebSearchProviders() {
return get('/api/v1/web-search/providers')
}
@@ -38,4 +42,3 @@ export function getTenantWebSearchConfig() {
export function updateTenantWebSearchConfig(config: WebSearchConfig) {
return put('/api/v1/tenants/kv/web-search-config', config)
}

View File

@@ -450,7 +450,11 @@ export default {
suggestedPrompts: 'Suggested Prompts',
mode: 'Running Mode',
webSearch: 'Web Search',
webSearchProvider: 'Search Engine',
webSearchProviderPlaceholder: 'Use default search engine',
webSearchMaxResults: 'Max Search Results',
webFetchEnabled: 'Auto-Fetch Page Content',
webFetchTopN: 'Pages to Fetch',
knowledgeBases: 'Knowledge Bases',
allKnowledgeBases: 'All Knowledge Bases',
allKnowledgeBasesDesc: 'Agent can access all knowledge bases',
@@ -699,6 +703,34 @@ export default {
webSearchSettings: {
title: 'Web Search Configuration',
description: 'Configure web search so answers can include up-to-date information from the internet.',
// Provider entity management
providersTitle: 'Search Engine Providers',
addProvider: 'Add Provider',
editProvider: 'Edit Provider',
noProviders: 'No search engine providers configured. Click "Add Provider" to get started.',
deleteConfirm: 'Are you sure you want to delete this provider?',
default: 'Default',
providerNameLabel: 'Name',
providerNamePlaceholder: 'e.g., Production Bing Search',
providerTypeLabel: 'Provider Type',
providerDescLabel: 'Notes',
providerDescPlaceholder: 'Optional, e.g., for testing',
engineIdLabel: 'Engine ID',
setAsDefault: 'Set as default',
testConnection: 'Test Connection',
testing: 'Testing...',
free: 'Free',
viewDocs: 'View docs for API key',
apiKeyUnchanged: 'Leave empty to keep current key',
noDescription: "No description provided",
noProvidersDesc: "Add a web search provider to enable your agents to retrieve real-time information from the internet.",
basicInfo: "Basic Information",
credentials: "Credentials",
setAsDefaultDesc: "This provider will be used by default when an agent doesn't specify one",
// Search behavior
searchBehaviorTitle: 'Search Behavior',
defaultProviderLabel: 'Default Provider',
defaultProviderDescription: 'Select the default search provider for agents that do not specify their own.',
providerLabel: 'Search Provider',
providerDescription: 'Choose the search engine service used for web search',
providerPlaceholder: 'Select a search engine...',
@@ -722,7 +754,12 @@ export default {
toasts: {
loadProvidersFailed: 'Failed to load search providers: {message}',
saveSuccess: 'Web search configuration saved',
saveFailed: 'Failed to save configuration: {message}'
saveFailed: 'Failed to save configuration: {message}',
providerCreated: 'Search provider created',
providerUpdated: 'Search provider updated',
providerDeleted: 'Search provider deleted',
testSuccess: 'Connection test succeeded',
testFailed: 'Connection test failed',
}
},
chatHistorySettings: {
@@ -3178,7 +3215,10 @@ export default {
maxIterations: 'Maximum reasoning steps when the Agent executes tasks',
kbScope: 'Select the scope of knowledge bases accessible to the agent',
webSearch: 'When enabled, the agent can search the internet for information',
webSearchProvider: 'Specify a search engine for this agent. Leave empty to use the default.',
webSearchMaxResults: 'Maximum number of results returned per search',
webFetchEnabled: 'After reranking, auto-fetch full page content from top web results for better answers',
webFetchTopN: 'Maximum number of web pages to fetch after reranking',
retrievalSection: 'Configure knowledge base retrieval and ranking parameters',
queryExpansion: 'Automatically expand query terms to improve recall',
embeddingTopK: 'Maximum number of results from vector retrieval',

View File

@@ -550,6 +550,31 @@ export default {
title: "웹 검색 설정",
description:
"웹 검색 기능을 구성하여 질문에 답변할 때 인터넷에서 실시간 정보를 가져와 지식베이스 내용을 보완합니다",
providersTitle: "검색 엔진 프로바이더 설정",
addProvider: "프로바이더 추가",
editProvider: "프로바이더 편집",
noProviders: "검색 엔진 프로바이더가 설정되지 않았습니다. \"프로바이더 추가\"를 클릭하여 시작하세요.",
deleteConfirm: "이 프로바이더를 삭제하시겠습니까?",
default: "기본",
providerNameLabel: "이름",
providerNamePlaceholder: "예: 프로덕션 Bing 검색",
providerTypeLabel: "프로바이더 유형",
providerDescLabel: "설명",
engineIdLabel: "엔진 ID",
setAsDefault: "기본으로 설정",
free: "무료",
viewDocs: "문서에서 API 키 확인",
apiKeyUnchanged: "현재 키를 유지하려면 비워두세요",
testConnection: "연결 테스트",
testing: "테스트 중...",
noDescription: "설명 없음",
noProvidersDesc: "웹 검색 프로바이더를 추가하여 에이전트가 인터넷에서 실시간 정보를 가져올 수 있도록 합니다.",
basicInfo: "기본 정보",
credentials: "자격 증명",
setAsDefaultDesc: "에이전트가 검색 엔진을 지정하지 않은 경우 이 프로바이더가 기본적으로 사용됩니다",
searchBehaviorTitle: "검색 동작 설정",
defaultProviderLabel: "기본 프로바이더",
defaultProviderDescription: "자체 프로바이더를 지정하지 않은 에이전트의 기본 검색 프로바이더를 선택합니다",
providerLabel: "검색 엔진 프로바이더",
providerDescription: "웹 검색에 사용할 검색 엔진 서비스 선택",
providerPlaceholder: "검색 엔진 선택...",
@@ -575,6 +600,9 @@ export default {
loadProvidersFailed: "검색 엔진 목록 로드 실패: {message}",
saveSuccess: "웹 검색 설정이 저장되었습니다",
saveFailed: "설정 저장 실패: {message}",
providerCreated: "검색 엔진 프로바이더가 생성되었습니다",
providerUpdated: "검색 엔진 프로바이더가 업데이트되었습니다",
providerDeleted: "검색 엔진 프로바이더가 삭제되었습니다",
},
},
chatHistorySettings: {

View File

@@ -672,6 +672,31 @@ export default {
webSearchSettings: {
title: 'Настройки веб-поиска',
description: 'Настройте веб-поиск, чтобы ответы могли включать актуальную информацию из интернета.',
providersTitle: 'Поисковые провайдеры',
addProvider: 'Добавить провайдер',
editProvider: 'Редактировать провайдер',
noProviders: 'Поисковые провайдеры не настроены. Нажмите «Добавить провайдер», чтобы начать.',
deleteConfirm: 'Вы уверены, что хотите удалить этот провайдер?',
default: 'По умолчанию',
providerNameLabel: 'Название',
providerNamePlaceholder: 'Напр., Продакшн Bing Поиск',
providerTypeLabel: 'Тип провайдера',
providerDescLabel: 'Описание',
engineIdLabel: 'ID движка',
setAsDefault: 'Установить по умолчанию',
free: 'Бесплатно',
viewDocs: 'Документация для получения ключа',
apiKeyUnchanged: 'Оставьте пустым, чтобы сохранить текущий ключ',
testConnection: 'Проверить соединение',
testing: 'Тестирование...',
noDescription: "Нет описания",
noProvidersDesc: "Добавьте провайдера веб-поиска, чтобы позволить вашим агентам получать информацию из Интернета в реальном времени.",
basicInfo: "Основная информация",
credentials: "Учетные данные",
setAsDefaultDesc: "Этот провайдер будет использоваться по умолчанию, если агент не укажет свой",
searchBehaviorTitle: 'Поведение поиска',
defaultProviderLabel: 'Провайдер по умолчанию',
defaultProviderDescription: 'Выберите провайдер поиска по умолчанию для агентов, не указавших собственный.',
providerLabel: 'Провайдер поиска',
providerDescription: 'Выберите поисковый сервис, используемый для веб-поиска',
providerPlaceholder: 'Выберите поисковую систему...',
@@ -697,7 +722,10 @@ export default {
toasts: {
loadProvidersFailed: 'Не удалось загрузить список поисковых провайдеров: {message}',
saveSuccess: 'Настройки веб-поиска сохранены',
saveFailed: 'Не удалось сохранить настройки: {message}'
saveFailed: 'Не удалось сохранить настройки: {message}',
providerCreated: 'Поисковый провайдер создан',
providerUpdated: 'Поисковый провайдер обновлён',
providerDeleted: 'Поисковый провайдер удалён',
}
},
chatHistorySettings: {

View File

@@ -558,6 +558,34 @@ export default {
title: "网络搜索配置",
description:
"配置网络搜索功能,在回答问题时可以从互联网获取实时信息补充知识库内容",
// Provider entity management
providersTitle: "搜索引擎配置",
addProvider: "添加搜索引擎",
editProvider: "编辑搜索引擎",
noProviders: "暂无搜索引擎配置,点击「添加搜索引擎」开始配置。",
deleteConfirm: "确定要删除此搜索引擎配置吗?",
default: "默认",
providerNameLabel: "名称",
providerNamePlaceholder: "例如:生产环境 Bing 搜索",
providerTypeLabel: "引擎类型",
providerDescLabel: "备注",
providerDescPlaceholder: "可选,如:测试环境用",
engineIdLabel: "搜索引擎 ID",
setAsDefault: "设为默认",
testConnection: "测试连接",
testing: "测试中...",
free: "免费",
viewDocs: "查看文档获取密钥",
apiKeyUnchanged: "留空保持当前密钥不变",
noDescription: "暂无描述",
noProvidersDesc: "添加一个网络搜索引擎,为您的智能体提供实时的互联网信息检索能力。",
basicInfo: "基础信息",
credentials: "凭证信息",
setAsDefaultDesc: "当智能体没有指定特定的搜索引擎时,将默认使用此配置",
// Search behavior
searchBehaviorTitle: "搜索行为配置",
defaultProviderLabel: "默认搜索引擎",
defaultProviderDescription: "为未指定搜索引擎的智能体选择默认使用的搜索引擎",
providerLabel: "搜索引擎提供商",
providerDescription: "选择用于网络搜索的搜索引擎服务",
providerPlaceholder: "选择搜索引擎...",
@@ -583,6 +611,11 @@ export default {
loadProvidersFailed: "加载搜索引擎列表失败: {message}",
saveSuccess: "网络搜索配置已保存",
saveFailed: "保存配置失败: {message}",
providerCreated: "搜索引擎配置已创建",
providerUpdated: "搜索引擎配置已更新",
providerDeleted: "搜索引擎配置已删除",
testSuccess: "连接测试成功",
testFailed: "连接测试失败",
},
},
chatHistorySettings: {
@@ -1155,7 +1188,11 @@ export default {
suggestedPrompts: "推荐问题",
mode: "运行模式",
webSearch: "网络搜索",
webSearchProvider: "搜索引擎",
webSearchProviderPlaceholder: "使用默认搜索引擎",
webSearchMaxResults: "最大搜索结果数",
webFetchEnabled: "自动抓取页面内容",
webFetchTopN: "抓取页面数",
knowledgeBases: "关联知识库",
allKnowledgeBases: "全部知识库",
allKnowledgeBasesDesc: "智能体可访问所有知识库",
@@ -3176,7 +3213,10 @@ export default {
maxIterations: "Agent 执行任务时的最大推理步骤数",
kbScope: "选择智能体可访问的知识库范围",
webSearch: "启用后智能体可以搜索互联网获取信息",
webSearchProvider: "为此智能体指定搜索引擎,留空则使用默认搜索引擎",
webSearchMaxResults: "每次搜索返回的最大结果数量",
webFetchEnabled: "Rerank 后自动抓取排名靠前的网页完整内容,提升回答质量",
webFetchTopN: "Rerank 后最多抓取几个网页的完整内容",
retrievalSection: "配置知识库检索和排序的参数",
queryExpansion: "自动扩展查询词以提高召回率",
embeddingTopK: "向量检索返回的最大结果数量",

View File

@@ -914,6 +914,32 @@
</div>
</div>
<!-- 网络搜索最大结果数 -->
<div v-if="formData.config.web_search_enabled" class="setting-row">
<div class="setting-info">
<label>{{ $t('agent.editor.webSearchProvider') }}</label>
<p class="desc">{{ $t('agentEditor.desc.webSearchProvider') }}</p>
</div>
<div class="setting-control">
<t-select
v-model="formData.config.web_search_provider_id"
clearable
:placeholder="$t('agent.editor.webSearchProviderPlaceholder')"
style="width: 240px;"
>
<t-option
v-for="p in webSearchProviderList"
:key="p.id"
:value="p.id"
:label="p.name"
>
<span>{{ p.name }}</span>
<t-tag v-if="p.is_default" theme="primary" size="small" style="margin-left: 6px;">{{ $t('common.default') }}</t-tag>
</t-option>
</t-select>
</div>
</div>
<!-- 网络搜索最大结果数 -->
<div v-if="formData.config.web_search_enabled" class="setting-row">
<div class="setting-info">
@@ -927,6 +953,31 @@
</div>
</div>
</div>
<!-- 自动抓取页面内容 -->
<div v-if="formData.config.web_search_enabled" class="setting-row">
<div class="setting-info">
<label>{{ $t('agent.editor.webFetchEnabled') }}</label>
<p class="desc">{{ $t('agentEditor.desc.webFetchEnabled') }}</p>
</div>
<div class="setting-control">
<t-switch v-model="formData.config.web_fetch_enabled" />
</div>
</div>
<!-- 抓取页面数 -->
<div v-if="formData.config.web_search_enabled && formData.config.web_fetch_enabled" class="setting-row">
<div class="setting-info">
<label>{{ $t('agent.editor.webFetchTopN') }}</label>
<p class="desc">{{ $t('agentEditor.desc.webFetchTopN') }}</p>
</div>
<div class="setting-control">
<div class="slider-wrapper">
<t-slider v-model="formData.config.web_fetch_top_n" :min="1" :max="10" />
<span class="slider-value">{{ formData.config.web_fetch_top_n }}</span>
</div>
</div>
</div>
</div>
</div>
@@ -1164,6 +1215,7 @@ import { listModels, type ModelConfig } from '@/api/model';
import { listKnowledgeBases } from '@/api/knowledge-base';
import { listMCPServices, type MCPService } from '@/api/mcp-service';
import { listSkills, type SkillInfo } from '@/api/skill';
import { listWebSearchProviders, type WebSearchProviderEntity } from '@/api/web-search-provider';
import { getAgentConfig, getConversationConfig, getStorageEngineStatus, type StorageEngineStatusItem, type PromptTemplate } from '@/api/system';
import { useUIStore } from '@/stores/ui';
import { useOrganizationStore } from '@/stores/organization';
@@ -1195,6 +1247,7 @@ const saving = ref(false);
const allModels = ref<ModelConfig[]>([]);
const kbOptions = ref<{ label: string; value: string; type?: 'document' | 'faq'; count?: number; shared?: boolean; orgName?: string }[]>([]);
const mcpOptions = ref<{ label: string; value: string }[]>([]);
const webSearchProviderList = ref<WebSearchProviderEntity[]>([]);
const skillOptions = ref<{ name: string; description: string }[]>([]);
// 是否允许启用 Skills取决于后端沙箱是否启用disabled 时为 false未请求前为 false 避免闪显)
const skillsAvailable = ref(false);
@@ -1919,6 +1972,16 @@ const loadDependencies = async () => {
console.warn('Failed to load storage engine status', e);
}
// 加载网络搜索引擎配置列表
try {
const wsRes: any = await listWebSearchProviders();
if (wsRes?.data && Array.isArray(wsRes.data)) {
webSearchProviderList.value = wsRes.data;
}
} catch (e) {
console.warn('Failed to load web search providers', e);
}
// 加载占位符定义(从统一 API
try {
const placeholdersRes = await getPlaceholders();

View File

@@ -6,380 +6,325 @@
</div>
<div class="settings-group">
<!-- 搜索引擎提供商 -->
<div class="setting-row">
<div class="setting-info">
<label>{{ t('webSearchSettings.providerLabel') }}</label>
<p class="desc">{{ t('webSearchSettings.providerDescription') }}</p>
</div>
<div class="setting-control">
<t-select
v-model="localProvider"
:loading="loadingProviders"
filterable
:placeholder="t('webSearchSettings.providerPlaceholder')"
@change="handleProviderChange"
@focus="loadProviders"
style="width: 280px;"
>
<t-option
v-for="provider in providers"
:key="provider.id"
:value="provider.id"
:label="provider.name"
>
<div class="provider-option-wrapper">
<div class="provider-option">
<span class="provider-name">{{ provider.name }}</span>
</div>
</div>
</t-option>
</t-select>
</div>
<div class="section-subheader">
<h3>{{ t('webSearchSettings.providersTitle') }}</h3>
<t-button theme="primary" size="small" @click="openAddDialog">
<template #icon><add-icon /></template>
{{ t('webSearchSettings.addProvider') }}
</t-button>
</div>
<!-- API 密钥 -->
<div v-if="selectedProvider && selectedProvider.requires_api_key" class="setting-row">
<div class="setting-info">
<label>{{ t('webSearchSettings.apiKeyLabel') }}</label>
<p class="desc">{{ t('webSearchSettings.apiKeyDescription') }}</p>
</div>
<div class="setting-control">
<t-input
v-model="localAPIKey"
type="password"
:placeholder="t('webSearchSettings.apiKeyPlaceholder')"
@change="handleAPIKeyChange"
style="width: 400px;"
:show-password="true"
/>
</div>
</div>
<!-- 最大结果数 -->
<div class="setting-row">
<div class="setting-info">
<label>{{ t('webSearchSettings.maxResultsLabel') }}</label>
<p class="desc">{{ t('webSearchSettings.maxResultsDescription') }}</p>
</div>
<div class="setting-control">
<div class="slider-with-value">
<t-slider
v-model="localMaxResults"
:min="1"
:max="50"
:step="1"
:marks="{ 1: '1', 10: '10', 20: '20', 30: '30', 40: '40', 50: '50' }"
@change="handleMaxResultsChange"
style="width: 200px;"
/>
<span class="value-display">{{ localMaxResults }}</span>
<!-- Provider List -->
<div v-if="providerEntities.length > 0" class="provider-list">
<div v-for="entity in providerEntities" :key="entity.id" class="provider-item">
<div class="item-info">
<div class="item-header">
<span class="item-name">{{ entity.name }}</span>
<t-tag v-if="entity.is_default" theme="primary" size="small" variant="light">
{{ t('webSearchSettings.default') }}
</t-tag>
<t-tag size="small" variant="outline">{{ entity.provider }}</t-tag>
</div>
<div class="item-desc">{{ entity.description || t('webSearchSettings.noDescription') }}</div>
</div>
<div class="item-actions">
<t-button theme="default" variant="text" size="small" @click="testExistingConnection(entity)" :loading="testingId === entity.id">
{{ t('webSearchSettings.testConnection') }}
</t-button>
<t-button theme="primary" variant="text" size="small" @click="editProvider(entity)">
{{ t('common.edit') }}
</t-button>
<t-popconfirm :content="t('webSearchSettings.deleteConfirm')" @confirm="deleteProvider(entity.id!)">
<t-button theme="danger" variant="text" size="small">
{{ t('common.delete') }}
</t-button>
</t-popconfirm>
</div>
</div>
</div>
<!-- 包含日期 -->
<div class="setting-row">
<div class="setting-info">
<label>{{ t('webSearchSettings.includeDateLabel') }}</label>
<p class="desc">{{ t('webSearchSettings.includeDateDescription') }}</p>
</div>
<div class="setting-control">
<t-switch
v-model="localIncludeDate"
@change="handleIncludeDateChange"
/>
</div>
</div>
<!-- 压缩方法 -->
<div class="setting-row">
<div class="setting-info">
<label>{{ t('webSearchSettings.compressionLabel') }}</label>
<p class="desc">{{ t('webSearchSettings.compressionDescription') }}</p>
</div>
<div class="setting-control">
<t-select
v-model="localCompressionMethod"
@change="handleCompressionMethodChange"
style="width: 280px;"
:placeholder="t('webSearchSettings.compressionLabel')"
>
<t-option value="none" :label="t('webSearchSettings.compressionNone')">
{{ t('webSearchSettings.compressionNone') }}
</t-option>
<t-option value="llm_summary" :label="t('webSearchSettings.compressionSummary')">
{{ t('webSearchSettings.compressionSummary') }}
</t-option>
</t-select>
</div>
</div>
<!-- 黑名单 -->
<div class="setting-row vertical">
<div class="setting-info">
<label>{{ t('webSearchSettings.blacklistLabel') }}</label>
<p class="desc">{{ t('webSearchSettings.blacklistDescription') }}</p>
</div>
<div class="setting-control">
<t-textarea
v-model="localBlacklistText"
:placeholder="t('webSearchSettings.blacklistPlaceholder')"
:autosize="{ minRows: 4, maxRows: 8 }"
@change="handleBlacklistChange"
style="width: 500px;"
/>
</div>
<!-- Empty State -->
<div v-else class="empty-providers">
<p>{{ t('webSearchSettings.noProvidersDesc') }}</p>
</div>
</div>
<!-- Add/Edit Dialog -->
<t-dialog
v-model:visible="showAddProviderDialog"
:header="editingProvider ? t('webSearchSettings.editProvider') : t('webSearchSettings.addProvider')"
width="520px"
:footer="false"
destroy-on-close
>
<div class="dialog-form-container">
<t-form :data="providerForm" label-align="top" @submit="saveProvider" class="provider-form">
<t-form-item :label="t('webSearchSettings.providerTypeLabel')" name="provider">
<t-select v-model="providerForm.provider" :disabled="!!editingProvider" @change="onProviderTypeChange">
<t-option v-for="pt in providerTypes" :key="pt.id" :value="pt.id" :label="pt.name">
<div class="provider-option">
<span>{{ pt.name }}</span>
<t-tag v-if="pt.free" theme="success" size="small" variant="light">{{ t('webSearchSettings.free') }}</t-tag>
</div>
</t-option>
</t-select>
</t-form-item>
<t-form-item :label="t('webSearchSettings.providerNameLabel')" name="name">
<t-input v-model="providerForm.name" :placeholder="selectedProviderType?.name || t('webSearchSettings.providerNamePlaceholder')" />
</t-form-item>
<t-form-item :label="t('webSearchSettings.providerDescLabel')" name="description">
<t-input v-model="providerForm.description" :placeholder="t('webSearchSettings.providerDescPlaceholder')" />
</t-form-item>
<template v-if="selectedProviderType?.requires_api_key || selectedProviderType?.requires_engine_id">
<div class="form-divider"></div>
<div class="credentials-hint" v-if="selectedProviderType?.docs_url">
<a :href="selectedProviderType.docs_url" target="_blank" rel="noopener noreferrer">
{{ t('webSearchSettings.viewDocs') }}
</a>
</div>
<t-form-item v-if="selectedProviderType?.requires_api_key" :label="t('webSearchSettings.apiKeyLabel')" name="parameters.api_key">
<t-input
v-model="providerForm.parameters.api_key"
type="password"
:placeholder="editingProvider ? t('webSearchSettings.apiKeyUnchanged') : t('webSearchSettings.apiKeyPlaceholder')"
/>
</t-form-item>
<t-form-item v-if="selectedProviderType?.requires_engine_id" :label="t('webSearchSettings.engineIdLabel')" name="parameters.engine_id">
<t-input v-model="providerForm.parameters.engine_id" :placeholder="t('webSearchSettings.engineIdLabel')" />
</t-form-item>
</template>
<div class="form-divider"></div>
<t-form-item :label="t('webSearchSettings.setAsDefault')" name="is_default">
<template #help>
<div class="switch-help">
{{ t('webSearchSettings.setAsDefaultDesc') }}
</div>
</template>
<t-switch v-model="providerForm.is_default" />
</t-form-item>
<div class="dialog-footer">
<div class="footer-left">
<t-button
v-if="selectedProviderType && !selectedProviderType.free"
theme="default"
variant="outline"
:loading="testing"
@click="testConnection"
>
{{ testing ? t('webSearchSettings.testing') : t('webSearchSettings.testConnection') }}
</t-button>
</div>
<div class="footer-right">
<t-button theme="default" variant="base" @click="showAddProviderDialog = false">{{ t('common.cancel') }}</t-button>
<t-button theme="primary" type="submit" :loading="saving">{{ t('common.save') }}</t-button>
</div>
</div>
</t-form>
</div>
</t-dialog>
</div>
</template>
<script setup lang="ts">
import { ref, computed, onMounted, nextTick } from 'vue'
import { ref, computed, onMounted } from 'vue'
import { MessagePlugin } from 'tdesign-vue-next'
import { useI18n } from 'vue-i18n'
import { getWebSearchProviders, getTenantWebSearchConfig, updateTenantWebSearchConfig, type WebSearchProviderConfig, type WebSearchConfig } from '@/api/web-search'
import { AddIcon } from 'tdesign-icons-vue-next'
import {
listWebSearchProviders,
listWebSearchProviderTypes,
createWebSearchProvider,
updateWebSearchProvider,
deleteWebSearchProvider as deleteWebSearchProviderAPI,
testWebSearchProvider,
type WebSearchProviderEntity,
type WebSearchProviderTypeInfo,
} from '@/api/web-search-provider'
const { t } = useI18n()
// 本地状态
const loadingProviders = ref(false)
const providers = ref<WebSearchProviderConfig[]>([])
const localProvider = ref<string>('')
const localAPIKey = ref<string>('')
const localMaxResults = ref<number>(5)
const localIncludeDate = ref<boolean>(true)
const localCompressionMethod = ref<string>('none')
const localBlacklistText = ref<string>('')
const isInitializing = ref(true) // 标记是否正在初始化,初始化期间不触发自动保存
const initialConfig = ref<WebSearchConfig | null>(null) // 保存初始配置,用于比较是否有变化
// ===== State =====
const providerEntities = ref<WebSearchProviderEntity[]>([])
const providerTypes = ref<WebSearchProviderTypeInfo[]>([])
const showAddProviderDialog = ref(false)
const editingProvider = ref<WebSearchProviderEntity | null>(null)
const testing = ref(false)
const testingId = ref<string | null>(null)
const saving = ref(false)
// 计算属性:当前选中的提供商
const selectedProvider = computed(() => {
return providers.value.find(p => p.id === localProvider.value)
const providerForm = ref<{
name: string
provider: string
description: string
parameters: { api_key?: string; engine_id?: string }
is_default: boolean
}>({
name: '',
provider: 'duckduckgo',
description: '',
parameters: {},
is_default: false,
})
// 加载提供商列表
const loadProviders = async () => {
if (providers.value.length > 0) {
return // 已加载过
}
loadingProviders.value = true
// ===== Computed =====
const selectedProviderType = computed(() => {
return providerTypes.value.find(pt => pt.id === providerForm.value.provider)
})
// ===== Methods =====
const onProviderTypeChange = () => {
providerForm.value.parameters = {}
}
const loadProviderEntities = async () => {
try {
const response = await getWebSearchProviders()
// request拦截器已经处理了响应直接使用data字段
const response = await listWebSearchProviders()
if (response.data && Array.isArray(response.data)) {
providers.value = response.data
providerEntities.value = response.data
}
} catch (error: any) {
console.error('Failed to load web search providers:', error)
const errorMessage = error?.message || t('webSearchSettings.errors.unknown')
MessagePlugin.error(t('webSearchSettings.toasts.loadProvidersFailed', { message: errorMessage }))
} finally {
loadingProviders.value = false
} catch (error) {
console.error('Failed to load provider entities:', error)
}
}
// 加载租户配置
const loadTenantConfig = async () => {
const loadProviderTypes = async () => {
try {
const response = await getTenantWebSearchConfig()
// request拦截器已经处理了响应直接使用data字段
if (response.data) {
const config = response.data
// 在设置初始值时,禁用自动保存
isInitializing.value = true
// 保存初始配置的副本(用于后续比较)
const blacklist = (config.blacklist || []).join('\n')
initialConfig.value = {
provider: config.provider || '',
api_key: config.api_key === '***' ? '***' : config.api_key || '',
max_results: config.max_results || 5,
include_date: config.include_date !== undefined ? config.include_date : true,
compression_method: config.compression_method || 'none',
blacklist: config.blacklist || []
}
// 设置本地状态值
localProvider.value = config.provider || ''
// API key 在响应中被隐藏,如果是 "***",说明已配置但未返回实际值
localAPIKey.value = config.api_key === '***' ? '***' : config.api_key || ''
localMaxResults.value = config.max_results || 5
localIncludeDate.value = config.include_date !== undefined ? config.include_date : true
localCompressionMethod.value = config.compression_method || 'none'
localBlacklistText.value = blacklist
// 等待所有响应式更新完成后再启用自动保存
await nextTick()
await nextTick()
// 使用 setTimeout 确保所有事件都已处理完毕
setTimeout(() => {
isInitializing.value = false
}, 100)
} else {
// 如果没有配置数据,保存默认配置
initialConfig.value = {
provider: '',
api_key: '',
max_results: 5,
include_date: true,
compression_method: 'none',
blacklist: []
}
await nextTick()
setTimeout(() => {
isInitializing.value = false
}, 100)
}
} catch (error: any) {
console.error('Failed to load tenant web search config:', error)
// 如果配置不存在,使用默认值(不显示错误)
initialConfig.value = {
provider: '',
providerTypes.value = await listWebSearchProviderTypes()
} catch (error) {
console.error('Failed to load provider types:', error)
}
}
const openAddDialog = () => {
editingProvider.value = null
providerForm.value = {
name: '',
provider: providerTypes.value[0]?.id || 'duckduckgo',
description: '',
parameters: {},
is_default: providerEntities.value.length === 0
}
showAddProviderDialog.value = true
}
const editProvider = (entity: WebSearchProviderEntity) => {
editingProvider.value = entity
providerForm.value = {
name: entity.name,
provider: entity.provider,
description: entity.description || '',
parameters: {
api_key: '',
max_results: 5,
include_date: true,
compression_method: 'none',
blacklist: []
}
await nextTick()
setTimeout(() => {
isInitializing.value = false
}, 100)
engine_id: entity.parameters?.engine_id || '',
},
is_default: entity.is_default || false,
}
showAddProviderDialog.value = true
}
// 检查配置是否有变化
const hasConfigChanged = (): boolean => {
if (!initialConfig.value) {
return true // 如果没有初始配置,认为有变化
}
const blacklist = localBlacklistText.value
.split('\n')
.map(line => line.trim())
.filter(line => line.length > 0)
const currentConfig: WebSearchConfig = {
provider: localProvider.value,
api_key: localAPIKey.value,
max_results: localMaxResults.value,
include_date: localIncludeDate.value,
compression_method: localCompressionMethod.value,
blacklist: blacklist
}
// 比较配置是否有变化(忽略 API key 的 '***' 占位符)
const initial = initialConfig.value
if (currentConfig.provider !== initial.provider) return true
if (currentConfig.api_key !== initial.api_key &&
!(currentConfig.api_key === '***' && initial.api_key === '***')) return true
if (currentConfig.max_results !== initial.max_results) return true
if (currentConfig.include_date !== initial.include_date) return true
if (currentConfig.compression_method !== initial.compression_method) return true
// 比较黑名单数组
const currentBlacklist = blacklist.sort().join(',')
const initialBlacklist = (initial.blacklist || []).sort().join(',')
if (currentBlacklist !== initialBlacklist) return true
return false
}
// 保存配置
const saveConfig = async () => {
// 如果配置没有变化,不保存
if (!hasConfigChanged()) {
const saveProvider = async ({ validateResult, firstError }: any) => {
if (validateResult !== true && validateResult !== undefined) {
MessagePlugin.warning(firstError || 'Please check the form fields')
return
}
saving.value = true
try {
const blacklist = localBlacklistText.value
.split('\n')
.map(line => line.trim())
.filter(line => line.length > 0)
const config: WebSearchConfig = {
provider: localProvider.value,
api_key: localAPIKey.value,
max_results: localMaxResults.value,
include_date: localIncludeDate.value,
compression_method: localCompressionMethod.value,
blacklist: blacklist
const data: Partial<WebSearchProviderEntity> = {
name: providerForm.value.name.trim() || selectedProviderType.value?.name || providerForm.value.provider,
provider: providerForm.value.provider as any,
description: providerForm.value.description,
parameters: { ...providerForm.value.parameters },
is_default: providerForm.value.is_default,
}
await updateTenantWebSearchConfig(config)
// 更新初始配置,避免重复保存
initialConfig.value = {
provider: config.provider,
api_key: config.api_key,
max_results: config.max_results,
include_date: config.include_date,
compression_method: config.compression_method,
blacklist: [...config.blacklist]
if (editingProvider.value && !data.parameters!.api_key) {
delete data.parameters!.api_key
}
MessagePlugin.success(t('webSearchSettings.toasts.saveSuccess'))
if (editingProvider.value) {
await updateWebSearchProvider(editingProvider.value.id!, data)
MessagePlugin.success(t('webSearchSettings.toasts.providerUpdated'))
} else {
await createWebSearchProvider(data)
MessagePlugin.success(t('webSearchSettings.toasts.providerCreated'))
}
showAddProviderDialog.value = false
await loadProviderEntities()
} catch (error: any) {
console.error('Failed to save web search config:', error)
const errorMessage = error?.message || t('webSearchSettings.errors.unknown')
MessagePlugin.error(t('webSearchSettings.toasts.saveFailed', { message: errorMessage }))
throw error
MessagePlugin.error(error?.message || 'Failed to save provider')
} finally {
saving.value = false
}
}
// 防抖保存
let saveTimer: number | null = null
const debouncedSave = () => {
// 初始化期间不触发自动保存
if (isInitializing.value) {
return
const deleteProvider = async (id: string) => {
try {
await deleteWebSearchProviderAPI(id)
MessagePlugin.success(t('webSearchSettings.toasts.providerDeleted'))
await loadProviderEntities()
} catch (error: any) {
MessagePlugin.error(error?.message || 'Failed to delete provider')
}
if (saveTimer) {
clearTimeout(saveTimer)
}
const testConnection = async () => {
testing.value = true
try {
const data = {
provider: providerForm.value.provider,
parameters: { ...providerForm.value.parameters },
}
if (editingProvider.value && !data.parameters.api_key) {
const res = await testWebSearchProvider(editingProvider.value.id!)
if (res.success) {
MessagePlugin.success(t('webSearchSettings.toasts.testSuccess'))
} else {
MessagePlugin.error(res.error || t('webSearchSettings.toasts.testFailed'))
}
} else {
const res = await testWebSearchProvider(undefined, data)
if (res.success) {
MessagePlugin.success(t('webSearchSettings.toasts.testSuccess'))
} else {
MessagePlugin.error(res.error || t('webSearchSettings.toasts.testFailed'))
}
}
} catch (error: any) {
MessagePlugin.error(error?.message || t('webSearchSettings.toasts.testFailed'))
} finally {
testing.value = false
}
saveTimer = window.setTimeout(() => {
saveConfig().catch(() => {
// 错误已在 saveConfig 中处理
})
}, 500)
}
// 处理变化
const handleProviderChange = () => {
debouncedSave()
const testExistingConnection = async (entity: WebSearchProviderEntity) => {
testingId.value = entity.id!
try {
const res = await testWebSearchProvider(entity.id!)
if (res.success) {
MessagePlugin.success(t('webSearchSettings.toasts.testSuccess'))
} else {
MessagePlugin.error(res.error || t('webSearchSettings.toasts.testFailed'))
}
} catch (error: any) {
MessagePlugin.error(error?.message || t('webSearchSettings.toasts.testFailed'))
} finally {
testingId.value = null
}
}
const handleAPIKeyChange = () => {
debouncedSave()
}
const handleMaxResultsChange = () => {
debouncedSave()
}
const handleIncludeDateChange = () => {
debouncedSave()
}
const handleCompressionMethodChange = () => {
debouncedSave()
}
const handleBlacklistChange = () => {
debouncedSave()
}
// 初始化
// ===== Init =====
onMounted(async () => {
isInitializing.value = true
await loadProviders()
await loadTenantConfig()
// loadTenantConfig 内部已经处理了 isInitializing这里不需要再设置
await Promise.all([loadProviderTypes(), loadProviderEntities()])
})
</script>
@@ -409,139 +354,130 @@ onMounted(async () => {
.settings-group {
display: flex;
flex-direction: column;
gap: 0;
}
.setting-row {
.section-subheader {
display: flex;
align-items: flex-start;
align-items: center;
justify-content: space-between;
padding: 20px 0;
border-bottom: 1px solid var(--td-component-stroke);
margin-bottom: 16px;
&:last-child {
border-bottom: none;
}
&.vertical {
flex-direction: column;
gap: 12px;
.setting-control {
width: 100%;
max-width: 100%;
}
}
}
.setting-info {
flex: 1;
max-width: 65%;
padding-right: 24px;
label {
font-size: 15px;
font-weight: 500;
h3 {
font-size: 16px;
font-weight: 600;
color: var(--td-text-color-primary);
display: block;
margin-bottom: 4px;
}
.desc {
font-size: 13px;
color: var(--td-text-color-secondary);
margin: 0;
line-height: 1.5;
}
}
.setting-control {
flex-shrink: 0;
min-width: 280px;
.provider-list {
display: flex;
justify-content: flex-end;
align-items: center;
flex-direction: column;
gap: 8px;
}
.slider-with-value {
.provider-item {
display: flex;
align-items: center;
gap: 12px;
justify-content: space-between;
padding: 14px 16px;
background: var(--td-bg-color-container);
border: 1px solid var(--td-component-stroke);
border-radius: 8px;
transition: all 0.2s ease;
&:hover {
border-color: var(--td-brand-color);
}
}
.value-display {
min-width: 40px;
text-align: right;
.item-info {
display: flex;
flex-direction: column;
gap: 4px;
}
.item-header {
display: flex;
align-items: center;
gap: 8px;
}
.item-name {
font-size: 14px;
font-weight: 500;
color: var(--td-text-color-primary);
}
.provider-option-wrapper {
.item-desc {
font-size: 13px;
color: var(--td-text-color-secondary);
}
.item-actions {
display: flex;
flex-direction: column;
gap: 4px;
padding: 2px 0;
align-items: center;
}
.empty-providers {
padding: 32px;
text-align: center;
color: var(--td-text-color-placeholder);
border: 1px dashed var(--td-component-stroke);
border-radius: 8px;
font-size: 14px;
}
.dialog-form-container {
margin-top: 12px;
}
.provider-option {
display: flex;
align-items: center;
justify-content: space-between;
gap: 8px;
flex-wrap: wrap;
}
.provider-name {
font-weight: 500;
font-size: 14px;
color: var(--td-text-color-primary);
flex-shrink: 0;
}
.provider-tags {
display: flex;
align-items: center;
gap: 4px;
flex-wrap: wrap;
flex-shrink: 0;
width: 100%;
}
.provider-desc {
.form-divider {
height: 1px;
background: var(--td-component-border);
margin: 20px 0;
}
.credentials-hint {
margin-bottom: 12px;
font-size: 13px;
a {
color: var(--td-brand-color);
text-decoration: none;
&:hover {
text-decoration: underline;
}
}
}
.switch-help {
font-size: 12px;
color: var(--td-text-color-placeholder);
color: var(--td-text-color-secondary);
margin-top: 4px;
line-height: 1.4;
margin-top: 2px;
}
/* 修复下拉项描述与条目重叠:让选项支持多行自适应高度 */
:deep(.t-select-option) {
height: auto;
align-items: flex-start;
padding-top: 6px;
padding-bottom: 6px;
}
:deep(.t-select-option__content) {
white-space: normal;
}
</style>
<style lang="less">
.t-select__dropdown .t-select-option {
height: auto;
align-items: flex-start;
padding-top: 6px;
padding-bottom: 6px;
}
.t-select__dropdown .t-select-option__content {
white-space: normal;
}
.t-select__dropdown .provider-option-wrapper {
.dialog-footer {
display: flex;
flex-direction: column;
gap: 4px;
padding: 2px 0;
justify-content: space-between;
align-items: center;
margin-top: 32px;
padding-top: 20px;
border-top: 1px solid var(--td-component-border);
.footer-right {
display: flex;
gap: 12px;
}
}
</style>

View File

@@ -82,6 +82,7 @@ type WebSearchTool struct {
webSearchStateService interfaces.WebSearchStateService
sessionID string
maxResults int
providerID string // WebSearchProviderEntity ID (resolved from agent config or tenant default)
}
// NewWebSearchTool creates a new web search tool
@@ -92,6 +93,7 @@ func NewWebSearchTool(
webSearchStateService interfaces.WebSearchStateService,
sessionID string,
maxResults int,
providerID string,
) *WebSearchTool {
tool := webSearchTool
tool.description = fmt.Sprintf(tool.description, maxResults, maxResults)
@@ -104,6 +106,7 @@ func NewWebSearchTool(
webSearchStateService: webSearchStateService,
sessionID: sessionID,
maxResults: maxResults,
providerID: providerID,
}
}
@@ -150,7 +153,7 @@ func (t *WebSearchTool) Execute(ctx context.Context, args json.RawMessage) (*typ
// Get tenant info from context (same approach as search.go)
tenant := ctx.Value(types.TenantInfoContextKey).(*types.Tenant)
if tenant == nil || tenant.WebSearchConfig == nil || tenant.WebSearchConfig.Provider == "" {
if tenant == nil || tenant.WebSearchConfig == nil {
logger.Errorf(ctx, "[Tool][WebSearch] Web search not configured for tenant %d", tenantID)
return &types.ToolResult{
Success: false,
@@ -158,6 +161,9 @@ func (t *WebSearchTool) Execute(ctx context.Context, args json.RawMessage) (*typ
}, fmt.Errorf("web search is not configured for tenant %d", tenantID)
}
// Resolve provider ID: tool-level (set from agent config, which already resolved default)
resolvedProviderID := t.providerID
// Create a copy of web search config with maxResults from agent config
searchConfig := *tenant.WebSearchConfig
searchConfig.MaxResults = t.maxResults
@@ -165,11 +171,11 @@ func (t *WebSearchTool) Execute(ctx context.Context, args json.RawMessage) (*typ
// Perform web search
logger.Infof(
ctx,
"[Tool][WebSearch] Performing web search with provider: %s, maxResults: %d",
searchConfig.Provider,
"[Tool][WebSearch] Performing web search with providerID: %s, maxResults: %d",
resolvedProviderID,
searchConfig.MaxResults,
)
webResults, err := t.webSearchService.Search(ctx, &searchConfig, query)
webResults, err := t.webSearchService.Search(ctx, resolvedProviderID, &searchConfig, query)
if err != nil {
logger.Errorf(ctx, "[Tool][WebSearch] Web search failed: %v", err)
return &types.ToolResult{

View File

@@ -0,0 +1,89 @@
package repository
import (
"context"
"errors"
"github.com/Tencent/WeKnora/internal/types"
"github.com/Tencent/WeKnora/internal/types/interfaces"
"gorm.io/gorm"
)
// webSearchProviderRepository implements the WebSearchProviderRepository interface
type webSearchProviderRepository struct {
db *gorm.DB
}
// NewWebSearchProviderRepository creates a new web search provider repository
func NewWebSearchProviderRepository(db *gorm.DB) interfaces.WebSearchProviderRepository {
return &webSearchProviderRepository{db: db}
}
// Create creates a new web search provider
func (r *webSearchProviderRepository) Create(ctx context.Context, provider *types.WebSearchProviderEntity) error {
return r.db.WithContext(ctx).Create(provider).Error
}
// GetByID retrieves a web search provider by ID within a tenant scope
func (r *webSearchProviderRepository) GetByID(ctx context.Context, tenantID uint64, id string) (*types.WebSearchProviderEntity, error) {
var provider types.WebSearchProviderEntity
if err := r.db.WithContext(ctx).Where(
"id = ? AND tenant_id = ?", id, tenantID,
).First(&provider).Error; err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, nil
}
return nil, err
}
return &provider, nil
}
// GetDefault retrieves the default provider (is_default=true) for a tenant, or nil if none.
func (r *webSearchProviderRepository) GetDefault(ctx context.Context, tenantID uint64) (*types.WebSearchProviderEntity, error) {
var provider types.WebSearchProviderEntity
if err := r.db.WithContext(ctx).Where(
"tenant_id = ? AND is_default = ?", tenantID, true,
).First(&provider).Error; err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, nil
}
return nil, err
}
return &provider, nil
}
// List lists all web search providers for a tenant
func (r *webSearchProviderRepository) List(ctx context.Context, tenantID uint64) ([]*types.WebSearchProviderEntity, error) {
var providers []*types.WebSearchProviderEntity
if err := r.db.WithContext(ctx).Where(
"tenant_id = ?", tenantID,
).Order("created_at ASC").Find(&providers).Error; err != nil {
return nil, err
}
return providers, nil
}
// Update updates a web search provider
func (r *webSearchProviderRepository) Update(ctx context.Context, provider *types.WebSearchProviderEntity) error {
return r.db.WithContext(ctx).Model(&types.WebSearchProviderEntity{}).Where(
"id = ? AND tenant_id = ?", provider.ID, provider.TenantID,
).Select("*").Updates(provider).Error
}
// Delete soft-deletes a web search provider
func (r *webSearchProviderRepository) Delete(ctx context.Context, tenantID uint64, id string) error {
return r.db.WithContext(ctx).Where(
"id = ? AND tenant_id = ?", id, tenantID,
).Delete(&types.WebSearchProviderEntity{}).Error
}
// ClearDefault clears the default flag for all providers of a tenant, optionally excluding one
func (r *webSearchProviderRepository) ClearDefault(ctx context.Context, tenantID uint64, excludeID string) error {
query := r.db.WithContext(ctx).Model(&types.WebSearchProviderEntity{}).Where(
"tenant_id = ? AND is_default = ?", tenantID, true,
)
if excludeID != "" {
query = query.Where("id != ?", excludeID)
}
return query.Update("is_default", false).Error
}

View File

@@ -404,8 +404,9 @@ func (s *agentService) registerTools(
s.webSearchStateService,
sessionID,
config.WebSearchMaxResults,
config.WebSearchProviderID,
)
logger.Infof(ctx, "Registered web_search tool for session: %s, maxResults: %d", sessionID, config.WebSearchMaxResults)
logger.Infof(ctx, "Registered web_search tool for session: %s, maxResults: %d, providerID: %s", sessionID, config.WebSearchMaxResults, config.WebSearchProviderID)
case tools.ToolWebFetch:
toolToRegister = tools.NewWebFetchTool(chatModel)

View File

@@ -15,14 +15,15 @@ import (
// PluginSearch implements search functionality for chat pipeline
type PluginSearch struct {
knowledgeBaseService interfaces.KnowledgeBaseService
knowledgeService interfaces.KnowledgeService
chunkService interfaces.ChunkService
config *config.Config
webSearchService interfaces.WebSearchService
tenantService interfaces.TenantService
sessionService interfaces.SessionService
webSearchStateService interfaces.WebSearchStateService
knowledgeBaseService interfaces.KnowledgeBaseService
knowledgeService interfaces.KnowledgeService
chunkService interfaces.ChunkService
config *config.Config
webSearchService interfaces.WebSearchService
tenantService interfaces.TenantService
sessionService interfaces.SessionService
webSearchStateService interfaces.WebSearchStateService
webSearchProviderRepo interfaces.WebSearchProviderRepository
}
func NewPluginSearch(eventManager *EventManager,
@@ -34,6 +35,7 @@ func NewPluginSearch(eventManager *EventManager,
tenantService interfaces.TenantService,
sessionService interfaces.SessionService,
webSearchStateService interfaces.WebSearchStateService,
webSearchProviderRepo interfaces.WebSearchProviderRepository,
) *PluginSearch {
res := &PluginSearch{
knowledgeBaseService: knowledgeBaseService,
@@ -44,6 +46,7 @@ func NewPluginSearch(eventManager *EventManager,
tenantService: tenantService,
sessionService: sessionService,
webSearchStateService: webSearchStateService,
webSearchProviderRepo: webSearchProviderRepo,
}
eventManager.Register(res)
return res
@@ -617,18 +620,21 @@ func (p *PluginSearch) searchWebIfEnabled(ctx context.Context, chatManage *types
return nil
}
tenant, _ := types.TenantInfoFromContext(ctx)
if tenant == nil || tenant.WebSearchConfig == nil || tenant.WebSearchConfig.Provider == "" {
if tenant == nil || tenant.WebSearchConfig == nil {
pipelineWarn(ctx, "Search", "web_config_missing", map[string]interface{}{
"tenant_id": chatManage.TenantID,
})
return nil
}
// Use provider ID already resolved by session layer (agent config > tenant default)
providerID := chatManage.WebSearchProviderID
pipelineInfo(ctx, "Search", "web_request", map[string]interface{}{
"tenant_id": chatManage.TenantID,
"provider": tenant.WebSearchConfig.Provider,
"tenant_id": chatManage.TenantID,
"provider_id": providerID,
})
webResults, err := p.webSearchService.Search(ctx, tenant.WebSearchConfig, chatManage.RewriteQuery)
webResults, err := p.webSearchService.Search(ctx, providerID, tenant.WebSearchConfig, chatManage.RewriteQuery)
if err != nil {
pipelineWarn(ctx, "Search", "web_search_error", map[string]interface{}{
"tenant_id": chatManage.TenantID,

View File

@@ -40,6 +40,7 @@ func NewPluginSearchParallel(
tenantService interfaces.TenantService,
sessionService interfaces.SessionService,
webSearchStateService interfaces.WebSearchStateService,
webSearchProviderRepo interfaces.WebSearchProviderRepository,
graphRepository interfaces.RetrieveGraphRepository,
chunkRepository interfaces.ChunkRepository,
knowledgeRepository interfaces.KnowledgeRepository,
@@ -54,6 +55,7 @@ func NewPluginSearchParallel(
tenantService: tenantService,
sessionService: sessionService,
webSearchStateService: webSearchStateService,
webSearchProviderRepo: webSearchProviderRepo,
}
searchEntityPlugin := &PluginSearchEntity{

View File

@@ -0,0 +1,113 @@
package chatpipeline
import (
"context"
"strings"
"sync"
"github.com/Tencent/WeKnora/internal/infrastructure/web_fetch"
"github.com/Tencent/WeKnora/internal/logger"
"github.com/Tencent/WeKnora/internal/types"
)
// PluginWebFetch fetches full page content for reranked web search results.
// It runs between CHUNK_RERANK and CHUNK_MERGE, replacing snippet content
// with the full page text for the top N web results.
type PluginWebFetch struct{}
// NewPluginWebFetch creates and registers a new PluginWebFetch instance
func NewPluginWebFetch(eventManager *EventManager) *PluginWebFetch {
res := &PluginWebFetch{}
eventManager.Register(res)
return res
}
// ActivationEvents returns the event types this plugin handles
func (p *PluginWebFetch) ActivationEvents() []types.EventType {
return []types.EventType{types.WEB_FETCH}
}
// OnEvent handles the WEB_FETCH event
func (p *PluginWebFetch) OnEvent(
ctx context.Context,
eventType types.EventType,
chatManage *types.ChatManage,
next func() *PluginError,
) *PluginError {
if !chatManage.WebFetchEnabled || !chatManage.WebSearchEnabled {
pipelineInfo(ctx, "WebFetch", "skip", map[string]any{"reason": "disabled"})
return next()
}
topN := chatManage.WebFetchTopN
if topN <= 0 {
topN = 3
}
// Find web search results in reranked results
var webResults []*types.SearchResult
for _, r := range chatManage.RerankResult {
if strings.ToLower(r.KnowledgeSource) == "web_search" {
webResults = append(webResults, r)
if len(webResults) >= topN {
break
}
}
}
if len(webResults) == 0 {
pipelineInfo(ctx, "WebFetch", "skip", map[string]any{"reason": "no_web_results"})
return next()
}
logger.Infof(ctx, "[PIPELINE] stage=WebFetch action=start count=%d", len(webResults))
// Fetch in parallel
type fetchResult struct {
idx int
content string
err error
}
results := make([]fetchResult, len(webResults))
var wg sync.WaitGroup
for i, r := range webResults {
wg.Add(1)
go func(idx int, sr *types.SearchResult) {
defer wg.Done()
fetchURL := sr.ID // web search results use URL as ID
if fetchURL == "" {
return
}
content, err := web_fetch.FetchURLContent(ctx, fetchURL)
results[idx] = fetchResult{idx: idx, content: content, err: err}
}(i, r)
}
wg.Wait()
// Replace snippet content with fetched full content
fetchedCount := 0
for _, fr := range results {
if fr.err != nil {
logger.Warnf(ctx, "[PIPELINE] stage=WebFetch action=fetch_failed url=%s err=%v",
webResults[fr.idx].ID, fr.err)
continue
}
if fr.content == "" {
continue
}
// Truncate to reasonable size for LLM context
content := fr.content
if len(content) > 8000 {
content = content[:8000] + "\n...(truncated)"
}
webResults[fr.idx].Content = content
fetchedCount++
}
pipelineInfo(ctx, "WebFetch", "complete", map[string]any{
"fetched": fetchedCount,
"total": len(webResults),
})
return next()
}

View File

@@ -36,9 +36,10 @@ type sessionService struct {
sessionStorage llmcontext.ContextStorage // Session storage
knowledgeService interfaces.KnowledgeService // Service for knowledge operations
chunkService interfaces.ChunkService // Service for chunk operations
webSearchStateRepo interfaces.WebSearchStateService // Service for web search state
kbShareService interfaces.KBShareService // Service for KB sharing operations
memoryService interfaces.MemoryService // Service for memory operations
webSearchStateRepo interfaces.WebSearchStateService // Service for web search state
webSearchProviderRepo interfaces.WebSearchProviderRepository // Repository for web search provider entities
kbShareService interfaces.KBShareService // Service for KB sharing operations
memoryService interfaces.MemoryService // Service for memory operations
}
// NewSessionService creates a new session service instance with all required dependencies
@@ -54,24 +55,26 @@ func NewSessionService(cfg *config.Config,
agentService interfaces.AgentService,
sessionStorage llmcontext.ContextStorage,
webSearchStateRepo interfaces.WebSearchStateService,
webSearchProviderRepo interfaces.WebSearchProviderRepository,
kbShareService interfaces.KBShareService,
memoryService interfaces.MemoryService,
) interfaces.SessionService {
return &sessionService{
cfg: cfg,
sessionRepo: sessionRepo,
messageRepo: messageRepo,
knowledgeBaseService: knowledgeBaseService,
knowledgeService: knowledgeService,
chunkService: chunkService,
modelService: modelService,
tenantService: tenantService,
eventManager: eventManager,
agentService: agentService,
sessionStorage: sessionStorage,
webSearchStateRepo: webSearchStateRepo,
kbShareService: kbShareService,
memoryService: memoryService,
cfg: cfg,
sessionRepo: sessionRepo,
messageRepo: messageRepo,
knowledgeBaseService: knowledgeBaseService,
knowledgeService: knowledgeService,
chunkService: chunkService,
modelService: modelService,
tenantService: tenantService,
eventManager: eventManager,
agentService: agentService,
sessionStorage: sessionStorage,
webSearchStateRepo: webSearchStateRepo,
webSearchProviderRepo: webSearchProviderRepo,
kbShareService: kbShareService,
memoryService: memoryService,
}
}

View File

@@ -209,6 +209,7 @@ func (s *sessionService) buildAgentConfig(
Temperature: customAgent.Config.Temperature,
WebSearchEnabled: customAgent.Config.WebSearchEnabled && req.WebSearchEnabled,
WebSearchMaxResults: customAgent.Config.WebSearchMaxResults,
WebSearchProviderID: customAgent.Config.WebSearchProviderID,
MultiTurnEnabled: customAgent.Config.MultiTurnEnabled,
HistoryTurns: customAgent.Config.HistoryTurns,
MCPSelectionMode: customAgent.Config.MCPSelectionMode,
@@ -247,6 +248,13 @@ func (s *sessionService) buildAgentConfig(
}
}
// Resolve web search provider ID: agent-level > tenant default (is_default=true)
if agentConfig.WebSearchProviderID == "" {
if defaultProvider, err := s.webSearchProviderRepo.GetDefault(ctx, tenantInfo.ID); err == nil && defaultProvider != nil {
agentConfig.WebSearchProviderID = defaultProvider.ID
}
}
logger.Infof(ctx, "Merged agent config from tenant %d and session %s", tenantInfo.ID, req.Session.ID)
// Log knowledge bases if present

View File

@@ -114,6 +114,9 @@ func (s *sessionService) KnowledgeQA(
RewritePromptSystem: s.cfg.Conversation.RewritePromptSystem,
RewritePromptUser: s.cfg.Conversation.RewritePromptUser,
WebSearchEnabled: req.WebSearchEnabled,
WebSearchProviderID: s.resolveWebSearchProviderID(ctx, req, retrievalTenantID),
WebFetchEnabled: s.resolveWebFetchEnabled(req),
WebFetchTopN: s.resolveWebFetchTopN(req),
TenantID: retrievalTenantID,
Images: req.ImageURLs,
VLMModelID: vlmModelID,
@@ -162,6 +165,7 @@ func (s *sessionService) KnowledgeQA(
Add(types.QUERY_UNDERSTAND).
Add(types.CHUNK_SEARCH_PARALLEL).
Add(types.CHUNK_RERANK).
AddIf(req.WebSearchEnabled, types.WEB_FETCH).
Add(types.CHUNK_MERGE).
Add(types.FILTER_TOP_K).
Add(types.DATA_ANALYSIS).
@@ -798,3 +802,35 @@ func (s *sessionService) emitFallbackAnswer(ctx context.Context, chatManage *typ
logger.Infof(ctx, "Fallback answer event emitted successfully")
}
}
// resolveWebSearchProviderID returns the web search provider ID to use for a pipeline request.
// Priority: agent config > tenant default (is_default=true)
func (s *sessionService) resolveWebSearchProviderID(ctx context.Context, req *types.QARequest, tenantID uint64) string {
// 1. Agent-level override
if req.CustomAgent != nil && req.CustomAgent.Config.WebSearchProviderID != "" {
return req.CustomAgent.Config.WebSearchProviderID
}
// 2. Tenant default
if s.webSearchProviderRepo != nil {
if defaultProvider, err := s.webSearchProviderRepo.GetDefault(ctx, tenantID); err == nil && defaultProvider != nil {
return defaultProvider.ID
}
}
return ""
}
// resolveWebFetchEnabled returns whether auto web fetch is enabled for this request.
func (s *sessionService) resolveWebFetchEnabled(req *types.QARequest) bool {
if req.CustomAgent != nil {
return req.CustomAgent.Config.WebFetchEnabled
}
return false
}
// resolveWebFetchTopN returns how many pages to fetch after rerank.
func (s *sessionService) resolveWebFetchTopN(req *types.QARequest) int {
if req.CustomAgent != nil && req.CustomAgent.Config.WebFetchTopN > 0 {
return req.CustomAgent.Config.WebFetchTopN
}
return 3
}

View File

@@ -7,7 +7,7 @@ import (
"strings"
"time"
"github.com/Tencent/WeKnora/internal/application/service/web_search"
infra_web_search "github.com/Tencent/WeKnora/internal/infrastructure/web_search"
"github.com/Tencent/WeKnora/internal/config"
"github.com/Tencent/WeKnora/internal/logger"
"github.com/Tencent/WeKnora/internal/searchutil"
@@ -15,10 +15,117 @@ import (
"github.com/Tencent/WeKnora/internal/types/interfaces"
)
// WebSearchService provides web search functionality
// WebSearchService provides web search functionality.
// It resolves provider configurations from the database and creates provider
// instances on-demand via the infrastructure registry.
type WebSearchService struct {
providers map[string]interfaces.WebSearchProvider
timeout int
registry *infra_web_search.Registry
providerRepo interfaces.WebSearchProviderRepository
timeout int
}
// NewWebSearchService creates a new web search service.
// The registry holds provider type factories; the providerRepo loads tenant-specific configurations.
func NewWebSearchService(
cfg *config.Config,
registry *infra_web_search.Registry,
providerRepo interfaces.WebSearchProviderRepository,
) (interfaces.WebSearchService, error) {
timeout := 10 // default timeout in seconds
if cfg.WebSearch != nil && cfg.WebSearch.Timeout > 0 {
timeout = cfg.WebSearch.Timeout
}
return &WebSearchService{
registry: registry,
providerRepo: providerRepo,
timeout: timeout,
}, nil
}
// Search performs web search using the provider entity identified by providerID.
// If providerID is empty, it falls back to the deprecated config.Provider field for backward compatibility.
func (s *WebSearchService) Search(
ctx context.Context,
providerID string,
config *types.WebSearchConfig,
query string,
) ([]*types.WebSearchResult, error) {
if config == nil {
return nil, fmt.Errorf("web search config is required")
}
// Resolve the provider
searchProvider, err := s.resolveProvider(ctx, providerID, config)
if err != nil {
return nil, err
}
// Set timeout
timeout := time.Duration(s.timeout) * time.Second
if timeout == 0 {
timeout = 10 * time.Second
}
ctx, cancel := context.WithTimeout(ctx, timeout)
defer cancel()
// Perform search
results, err := searchProvider.Search(ctx, query, config.MaxResults, config.IncludeDate)
if err != nil {
return nil, fmt.Errorf("web search failed: %w", err)
}
// Apply blacklist filtering
results = s.filterBlacklist(results, config.Blacklist)
return results, nil
}
// resolveProvider resolves a WebSearchProvider instance from either:
// 1. A provider entity ID (new path) — loads from DB, creates via registry
// 2. The deprecated config.Provider field (backward compatibility) — creates with empty params
func (s *WebSearchService) resolveProvider(
ctx context.Context,
providerID string,
cfg *types.WebSearchConfig,
) (interfaces.WebSearchProvider, error) {
// New path: load provider entity from DB
if providerID != "" {
tenantID, ok := types.TenantIDFromContext(ctx)
if !ok {
return nil, fmt.Errorf("tenant ID not found in context")
}
entity, err := s.providerRepo.GetByID(ctx, tenantID, providerID)
if err != nil {
return nil, fmt.Errorf("failed to load web search provider %s: %w", providerID, err)
}
if entity == nil {
return nil, fmt.Errorf("web search provider not found: %s", providerID)
}
provider, err := s.registry.CreateProvider(string(entity.Provider), entity.Parameters)
if err != nil {
return nil, fmt.Errorf("failed to create provider %s (%s): %w", entity.Name, entity.Provider, err)
}
return provider, nil
}
// Backward compatibility: use the deprecated config.Provider field
if cfg.Provider != "" {
logger.Warnf(ctx, "Using deprecated WebSearchConfig.Provider field: %s. Please migrate to WebSearchProviderEntity.", cfg.Provider)
params := types.WebSearchProviderParameters{
APIKey: cfg.APIKey,
}
provider, err := s.registry.CreateProvider(cfg.Provider, params)
if err != nil {
return nil, fmt.Errorf("web search provider %s is not available: %w", cfg.Provider, err)
}
return provider, nil
}
return nil, fmt.Errorf("no web search provider configured")
}
// CompressWithRAG performs RAG-based compression using a temporary, hidden knowledge base.
@@ -242,72 +349,6 @@ func stripMarker(content string) string {
return content
}
// Search performs web search using the specified provider
// This method implements the interface expected by PluginSearch
func (s *WebSearchService) Search(
ctx context.Context,
config *types.WebSearchConfig,
query string,
) ([]*types.WebSearchResult, error) {
if config == nil {
return nil, fmt.Errorf("web search config is required")
}
provider, ok := s.providers[config.Provider]
if !ok {
return nil, fmt.Errorf("web search provider %s is not available", config.Provider)
}
// Set timeout
timeout := time.Duration(s.timeout) * time.Second
if timeout == 0 {
timeout = 10 * time.Second
}
ctx, cancel := context.WithTimeout(ctx, timeout)
defer cancel()
// Perform search
results, err := provider.Search(ctx, query, config.MaxResults, config.IncludeDate)
if err != nil {
return nil, fmt.Errorf("web search failed: %w", err)
}
// Apply blacklist filtering
results = s.filterBlacklist(results, config.Blacklist)
// Apply compression if needed
if config.CompressionMethod != "none" && config.CompressionMethod != "" {
// Compression will be handled later in the integration layer
// For now, we just return the results
}
return results, nil
}
// NewWebSearchService creates a new web search service
func NewWebSearchService(cfg *config.Config, registry *web_search.Registry) (interfaces.WebSearchService, error) {
timeout := 10 // default timeout
if cfg.WebSearch != nil && cfg.WebSearch.Timeout > 0 {
timeout = cfg.WebSearch.Timeout
}
// Create all registered providers
providers, err := registry.CreateAllProviders()
if err != nil {
return nil, err
}
for id := range providers {
logger.Infof(context.Background(), "Initialized web search provider: %s", id)
}
return &WebSearchService{
providers: providers,
timeout: timeout,
}, nil
}
// filterBlacklist filters results based on blacklist rules
func (s *WebSearchService) filterBlacklist(
results []*types.WebSearchResult,

View File

@@ -1,149 +0,0 @@
package web_search
import (
"context"
"encoding/json"
"net/http"
"net/http/httptest"
"os"
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func setBingEnv(apiKey string) {
os.Setenv("BING_SEARCH_API_KEY", apiKey)
}
func unsetBingEnv() {
os.Unsetenv("BING_SEARCH_API_KEY")
}
func TestNewBingProvider(t *testing.T) {
setBingEnv("test-api-key")
defer unsetBingEnv()
provider, err := NewBingProvider()
require.NoError(t, err)
assert.NotNil(t, provider)
}
func TestBingProvider_Search(t *testing.T) {
mockResponse := map[string]interface{}{
"_type": "SearchResponse",
"webPages": map[string]interface{}{
"webSearchUrl": "https://www.bing.com/search?q=test",
"totalEstimatedMatches": 1000,
"value": []map[string]interface{}{
{
"id": "result-1",
"name": "Test Result 1",
"url": "https://example.com/1",
"isFamilyFriendly": true,
"displayUrl": "example.com/1",
"snippet": "This is a test snippet 1",
"dateLastCrawled": time.Now().Format(time.RFC3339),
},
{
"id": "result-2",
"name": "Test Result 2",
"url": "https://example.com/2",
"isFamilyFriendly": true,
"displayUrl": "example.com/2",
"snippet": "This is a test snippet 2",
"dateLastCrawled": time.Now().Format(time.RFC3339),
},
},
},
}
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.Method != "GET" {
w.WriteHeader(http.StatusMethodNotAllowed)
return
}
if r.Header.Get("Ocp-Apim-Subscription-Key") != "test-api-key" {
w.WriteHeader(http.StatusUnauthorized)
return
}
query := r.URL.Query().Get("q")
if query == "" {
w.WriteHeader(http.StatusBadRequest)
return
}
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(mockResponse)
}))
defer server.Close()
provider := &BingProvider{
client: server.Client(),
baseURL: server.URL,
apiKey: "test-api-key",
}
t.Run("Successful search", func(t *testing.T) {
ctx := context.Background()
results, err := provider.Search(ctx, "test query", 10, true)
require.NoError(t, err)
assert.Len(t, results, 2)
assert.Equal(t, "Test Result 1", results[0].Title)
assert.Equal(t, "https://example.com/1", results[0].URL)
assert.Equal(t, "bing", results[0].Source)
})
t.Run("Empty query", func(t *testing.T) {
ctx := context.Background()
results, err := provider.Search(ctx, "", 10, true)
assert.Error(t, err)
assert.Nil(t, results)
assert.Contains(t, err.Error(), "query is empty")
})
}
func TestBingProvider_Search_Error(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusInternalServerError)
}))
defer server.Close()
provider := &BingProvider{
client: server.Client(),
baseURL: server.URL,
apiKey: "test-api-key",
}
t.Run("Server error", func(t *testing.T) {
ctx := context.Background()
results, err := provider.Search(ctx, "test query", 10, true)
assert.Error(t, err)
assert.Nil(t, results)
})
}
func TestBingProvider_Search_InvalidJSON(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
w.Write([]byte("invalid json"))
}))
defer server.Close()
provider := &BingProvider{
client: server.Client(),
baseURL: server.URL,
apiKey: "test-api-key",
}
t.Run("Invalid JSON response", func(t *testing.T) {
ctx := context.Background()
results, err := provider.Search(ctx, "test query", 10, true)
assert.Error(t, err)
assert.Nil(t, results)
assert.Contains(t, err.Error(), "failed to unmarshal response")
})
}

View File

@@ -1,280 +0,0 @@
package web_search
import (
"context"
"encoding/json"
"net/http"
"net/http/httptest"
"net/url"
"strings"
"testing"
"time"
)
// testRoundTripper rewrites outgoing requests that target DuckDuckGo hosts
// to the provided test server, preserving path and query.
type testRoundTripper struct {
base *url.URL
next http.RoundTripper
}
func (t *testRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) {
// Only rewrite requests to duckduckgo hosts used by the provider
if req.URL.Host == "html.duckduckgo.com" || req.URL.Host == "api.duckduckgo.com" {
cloned := *req
u := *req.URL
u.Scheme = t.base.Scheme
u.Host = t.base.Host
// Keep original path; our test server handlers should register for the same paths.
cloned.URL = &u
req = &cloned
}
return t.next.RoundTrip(req)
}
func newTestClient(ts *httptest.Server) *http.Client {
baseURL, _ := url.Parse(ts.URL)
return &http.Client{
Timeout: 5 * time.Second,
Transport: &testRoundTripper{
base: baseURL,
next: http.DefaultTransport,
},
}
}
func TestDuckDuckGoProvider_Name(t *testing.T) {
p, _ := NewDuckDuckGoProvider()
if p.Name() != "duckduckgo" {
t.Fatalf("expected provider name duckduckgo, got %s", p.Name())
}
}
func TestDuckDuckGoProvider(t *testing.T) {
// Minimal HTML page with two results, matching selectors used in searchHTML
html := `
<html>
<body>
<div class="web-result">
<a class="result__a" href="https://duckduckgo.com/l/?uddg=https%3A%2F%2Fexample.com%2Fpage1&rut=">Example One</a>
<div class="result__snippet">Snippet one</div>
</div>
<div class="web-result">
<a class="result__a" href="//duckduckgo.com/l/?uddg=https%3A%2F%2Fexample.org%2Fpage2&rut=">Example Two</a>
<div class="result__snippet">Snippet two</div>
</div>
</body>
</html>`
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// Provider requests GET https://html.duckduckgo.com/html/?q=...&kl=...
if r.URL.Path == "/html/" {
w.WriteHeader(http.StatusOK)
_, _ = w.Write([]byte(html))
return
}
t.Fatalf("unexpected request path: %s", r.URL.Path)
}))
defer ts.Close()
// Build provider and inject our test client
prov, _ := NewDuckDuckGoProvider()
dp := prov.(*DuckDuckGoProvider)
if dp == nil {
t.Fatalf("failed to build provider")
}
dp.client = newTestClient(ts)
ctx := context.Background()
results, err := dp.Search(ctx, "weknora", 5, false)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if len(results) != 2 {
t.Fatalf("expected 2 results, got %d", len(results))
}
if results[0].Title != "Example One" || !strings.HasPrefix(results[0].URL, "https://example.com/") ||
results[0].Snippet != "Snippet one" {
t.Fatalf("unexpected first result: %+v", results[0])
}
if results[1].Title != "Example Two" || !strings.HasPrefix(results[1].URL, "https://example.org/") ||
results[1].Snippet != "Snippet two" {
t.Fatalf("unexpected second result: %+v", results[1])
}
}
func TestDuckDuckGoProvider_Fallback(t *testing.T) {
// Simulate HTML returning non-OK to force API fallback, then a minimal API JSON
apiResp := struct {
AbstractText string `json:"AbstractText"`
AbstractURL string `json:"AbstractURL"`
Heading string `json:"Heading"`
Results []struct {
FirstURL string `json:"FirstURL"`
Text string `json:"Text"`
} `json:"Results"`
}{
AbstractText: "Abstract snippet",
AbstractURL: "https://example.com/abstract",
Heading: "Abstract Heading",
Results: []struct {
FirstURL string `json:"FirstURL"`
Text string `json:"Text"`
}{
{FirstURL: "https://example.net/x", Text: "Title X - Detail X"},
},
}
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
switch r.URL.Path {
case "/html/":
// Force fallback by returning 500
w.WriteHeader(http.StatusInternalServerError)
default:
// API endpoint path "/"
w.Header().Set("Content-Type", "application/json")
enc := json.NewEncoder(w)
_ = enc.Encode(apiResp)
}
}))
defer ts.Close()
prov, _ := NewDuckDuckGoProvider()
dp := prov.(*DuckDuckGoProvider)
if dp == nil {
t.Fatalf("failed to build provider")
}
dp.client = newTestClient(ts)
ctx := context.Background()
results, err := dp.Search(ctx, "weknora", 3, false)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if len(results) == 0 {
t.Fatalf("expected some results from API fallback")
}
if results[0].URL != "https://example.com/abstract" || results[0].Title != "Abstract Heading" {
t.Fatalf("unexpected first API result: %+v", results[0])
}
}
// TestDuckDuckGoProvider_Search_Real tests the DuckDuckGo provider against the real DuckDuckGo service.
// This is an integration test that requires network connectivity.
// Run with: go test -v -run TestDuckDuckGoProvider_Search_Real ./internal/application/service/web_search
func TestDuckDuckGoProvider_Search_Real(t *testing.T) {
// Skip if running in CI without network access (optional check)
if testing.Short() {
t.Skip("Skipping real DuckDuckGo integration test in short mode")
}
ctx := context.Background()
provider, err := NewDuckDuckGoProvider()
if err != nil {
t.Fatalf("Failed to create DuckDuckGo provider: %v", err)
}
if provider == nil {
t.Fatalf("failed to build provider")
}
// Test with a simple, general query that should return results
query := "Go programming language"
maxResults := 5
results, err := provider.Search(ctx, query, maxResults, false)
if err != nil {
t.Fatalf("Search failed: %v", err)
}
// Verify we got results
if len(results) == 0 {
t.Fatal("Expected at least one search result, got 0")
}
t.Logf("Received %d results for query: %s", len(results), query)
// Verify result structure
for i, result := range results {
if result == nil {
t.Fatalf("Result[%d]: is nil", i)
}
if result.Title == "" {
t.Errorf("Result[%d]: Title is empty", i)
}
if result.URL == "" {
t.Errorf("Result[%d]: URL is empty", i)
}
if !strings.HasPrefix(result.URL, "http://") && !strings.HasPrefix(result.URL, "https://") {
t.Errorf("Result[%d]: URL is not valid (should start with http:// or https://): %s", i, result.URL)
}
if result.Source != "duckduckgo" {
t.Errorf("Result[%d]: Source should be 'duckduckgo', got '%s'", i, result.Source)
}
t.Logf("Result[%d]: Title=%s, URL=%s, Snippet=%s", i, result.Title, result.URL, result.Snippet)
}
// Verify we don't exceed maxResults
if len(results) > maxResults {
t.Errorf("Got %d results, expected at most %d", len(results), maxResults)
}
// Test with maxResults limit
limitedResults, err := provider.Search(ctx, query, 2, false)
if err != nil {
t.Fatalf("Search with limit failed: %v", err)
}
if len(limitedResults) > 2 {
t.Errorf("Got %d results with maxResults=2, expected at most 2", len(limitedResults))
}
}
// TestDuckDuckGo_SearchChinese tests the DuckDuckGo provider with Chinese query.
// This verifies the Chinese language parameter (kl=cn-zh) works correctly.
func TestDuckDuckGo_SearchChinese(t *testing.T) {
if testing.Short() {
t.Skip("Skipping real DuckDuckGo integration test in short mode")
}
ctx := context.Background()
provider, err := NewDuckDuckGoProvider()
if err != nil {
t.Fatalf("Failed to create DuckDuckGo provider: %v", err)
}
if provider == nil {
t.Fatalf("failed to build provider")
}
// Test with a Chinese query
query := "WeKnora 企业级RAG框架 介绍 文档"
maxResults := 3
results, err := provider.Search(ctx, query, maxResults, false)
if err != nil {
t.Fatalf("Search failed: %v", err)
}
if len(results) == 0 {
t.Log("Warning: No results returned for Chinese query, but this might be expected")
return
}
t.Logf("Received %d results for Chinese query: %s", len(results), query)
// Verify result structure
for i, result := range results {
if result == nil {
t.Fatalf("Result[%d]: is nil", i)
}
if result.Title == "" {
t.Errorf("Result[%d]: Title is empty", i)
}
if result.URL == "" {
t.Errorf("Result[%d]: URL is empty", i)
}
if result.Source != "duckduckgo" {
t.Errorf("Result[%d]: Source should be 'duckduckgo', got '%s'", i, result.Source)
}
t.Logf("Result[%d]: Title=%s, URL=%s", i, result.Title, result.URL)
}
}

View File

@@ -1,276 +0,0 @@
package web_search
import (
"context"
"encoding/json"
"fmt"
"net/http"
"net/http/httptest"
"os"
"strings"
"testing"
)
func setGoogleEnv(apiURL string) {
os.Setenv("GOOGLE_SEARCH_API_URL", apiURL)
}
func unsetGoogleEnv() {
os.Unsetenv("GOOGLE_SEARCH_API_URL")
}
func TestNewGoogleProvider(t *testing.T) {
testCases := []struct {
name string
apiURL string
expected error
}{
{
name: "valid config",
apiURL: "https://customsearch.googleapis.com/customsearch/v1?api_key=test&engine_id=test",
expected: nil,
},
{
name: "missing engine id",
apiURL: "https://customsearch.googleapis.com/customsearch/v1?api_key=test",
expected: fmt.Errorf("engine_id is empty"),
},
{
name: "missing api key",
apiURL: "https://customsearch.googleapis.com/customsearch/v1?engine_id=test",
expected: fmt.Errorf("api_key is empty"),
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
setGoogleEnv(tc.apiURL)
defer unsetGoogleEnv()
_, err := NewGoogleProvider()
if tc.expected == nil {
if err != nil {
t.Fatalf("expected no error, got %v", err)
}
} else {
if err == nil {
t.Fatalf("expected error %v, got nil", tc.expected)
}
if !strings.Contains(err.Error(), tc.expected.Error()) {
t.Fatalf("expected error %v, got %v", tc.expected, err)
}
}
})
}
}
func TestGoogleProvider_Name(t *testing.T) {
setGoogleEnv("https://customsearch.googleapis.com/customsearch/v1?api_key=test&engine_id=test")
defer unsetGoogleEnv()
p, err := NewGoogleProvider()
if err != nil {
t.Fatalf("failed to create Google provider: %v", err)
}
if p.Name() != "google" {
t.Fatalf("expected provider name google, got %s", p.Name())
}
}
func TestGoogleProvider_Search(t *testing.T) {
mockResponse := map[string]interface{}{
"items": []map[string]interface{}{
{
"title": "Example Search Result One",
"link": "https://example.com/page1",
"snippet": "This is the first search result snippet describing the content.",
},
{
"title": "Example Search Result Two",
"link": "https://example.org/page2",
"snippet": "This is the second search result snippet with more details.",
},
},
}
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.URL.Path != "/customsearch/v1" {
t.Fatalf("unexpected request path: %s", r.URL.Path)
}
query := r.URL.Query().Get("q")
if query != "weknora" {
t.Fatalf("unexpected query: %s", query)
}
cx := r.URL.Query().Get("cx")
if cx != "test-engine-id" {
t.Fatalf("unexpected engine ID: %s", cx)
}
num := r.URL.Query().Get("num")
if num != "5" {
t.Fatalf("unexpected num parameter: %s", num)
}
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusOK)
enc := json.NewEncoder(w)
_ = enc.Encode(mockResponse)
}))
defer ts.Close()
setGoogleEnv(fmt.Sprintf("%s/customsearch/v1?api_key=test-key&engine_id=test-engine-id", ts.URL))
defer unsetGoogleEnv()
prov, err := NewGoogleProvider()
if err != nil {
t.Fatalf("failed to create Google provider: %v", err)
}
gp := prov.(*GoogleProvider)
if gp == nil {
t.Fatalf("failed to cast to GoogleProvider")
}
ctx := context.Background()
results, err := prov.Search(ctx, "weknora", 5, false)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if len(results) != 2 {
t.Fatalf("expected 2 results, got %d", len(results))
}
if results[0].Title != "Example Search Result One" ||
results[0].URL != "https://example.com/page1" ||
results[0].Snippet != "This is the first search result snippet describing the content." ||
results[0].Source != "google" {
t.Fatalf("unexpected first result: %+v", results[0])
}
if results[1].Title != "Example Search Result Two" ||
results[1].URL != "https://example.org/page2" ||
results[1].Snippet != "This is the second search result snippet with more details." ||
results[1].Source != "google" {
t.Fatalf("unexpected second result: %+v", results[1])
}
}
func TestGoogleProvider_Search_EmptyQuery(t *testing.T) {
setGoogleEnv("https://customsearch.googleapis.com/customsearch/v1?api_key=test&engine_id=test")
defer unsetGoogleEnv()
prov, err := NewGoogleProvider()
if err != nil {
t.Fatalf("failed to create Google provider: %v", err)
}
ctx := context.Background()
results, err := prov.Search(ctx, "", 5, false)
if err == nil {
t.Fatal("expected error for empty query, got nil")
}
if !strings.Contains(err.Error(), "query is empty") {
t.Fatalf("expected 'query is empty' error, got: %v", err)
}
if results != nil {
t.Fatalf("expected nil results for empty query, got: %v", results)
}
}
func TestGoogleProvider_Search_NoResults(t *testing.T) {
mockResponse := map[string]interface{}{
"items": []map[string]interface{}{},
}
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusOK)
enc := json.NewEncoder(w)
_ = enc.Encode(mockResponse)
}))
defer ts.Close()
setGoogleEnv(fmt.Sprintf("%s/customsearch/v1?api_key=test-key&engine_id=test-engine-id", ts.URL))
defer unsetGoogleEnv()
prov, err := NewGoogleProvider()
if err != nil {
t.Fatalf("failed to create Google provider: %v", err)
}
ctx := context.Background()
results, err := prov.Search(ctx, "nonexistent", 5, false)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if len(results) != 0 {
t.Fatalf("expected 0 results, got %d", len(results))
}
}
func TestGoogleProvider_Search_ErrorResponse(t *testing.T) {
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusInternalServerError)
w.Write([]byte("Internal Server Error"))
}))
defer ts.Close()
setGoogleEnv(fmt.Sprintf("%s/customsearch/v1?api_key=test-key&engine_id=test-engine-id", ts.URL))
defer unsetGoogleEnv()
prov, err := NewGoogleProvider()
if err != nil {
t.Fatalf("failed to create Google provider: %v", err)
}
ctx := context.Background()
results, err := prov.Search(ctx, "test", 5, false)
if err == nil {
t.Fatal("expected error for server error response, got nil")
}
if results != nil {
t.Fatalf("expected nil results for error response, got: %v", results)
}
}
func TestGoogleProvider_Search_MaxResults(t *testing.T) {
mockResponse := map[string]interface{}{
"items": []map[string]interface{}{
{"title": "Result 1", "link": "https://example.com/1", "snippet": "Snippet 1"},
{"title": "Result 2", "link": "https://example.com/2", "snippet": "Snippet 2"},
{"title": "Result 3", "link": "https://example.com/3", "snippet": "Snippet 3"},
},
}
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
num := r.URL.Query().Get("num")
if num != "2" {
t.Fatalf("expected num=2, got %s", num)
}
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusOK)
enc := json.NewEncoder(w)
_ = enc.Encode(mockResponse)
}))
defer ts.Close()
setGoogleEnv(fmt.Sprintf("%s/customsearch/v1?api_key=test-key&engine_id=test-engine-id", ts.URL))
defer unsetGoogleEnv()
prov, err := NewGoogleProvider()
if err != nil {
t.Fatalf("failed to create Google provider: %v", err)
}
ctx := context.Background()
results, err := prov.Search(ctx, "test", 2, false)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if len(results) != 3 {
t.Fatalf("expected 3 results, got %d", len(results))
}
if results[0].Title != "Result 1" || results[1].Title != "Result 2" || results[2].Title != "Result 3" {
t.Fatalf("unexpected results order or content")
}
}

View File

@@ -1,88 +0,0 @@
package web_search
import (
"fmt"
"sync"
"github.com/Tencent/WeKnora/internal/types"
"github.com/Tencent/WeKnora/internal/types/interfaces"
)
// ProviderFactory creates a new web search provider instance
type ProviderFactory func() (interfaces.WebSearchProvider, error)
// ProviderRegistration holds provider metadata and factory
type ProviderRegistration struct {
Info types.WebSearchProviderInfo
Factory ProviderFactory
}
// Registry manages web search provider registrations
type Registry struct {
providers map[string]*ProviderRegistration
mu sync.RWMutex
}
// NewRegistry creates a new web search provider registry
func NewRegistry() *Registry {
return &Registry{
providers: make(map[string]*ProviderRegistration),
}
}
// Register registers a web search provider
func (r *Registry) Register(info types.WebSearchProviderInfo, factory ProviderFactory) {
r.mu.Lock()
defer r.mu.Unlock()
r.providers[info.ID] = &ProviderRegistration{
Info: info,
Factory: factory,
}
}
// GetRegistration returns the registration for a provider
func (r *Registry) GetRegistration(id string) (*ProviderRegistration, bool) {
r.mu.RLock()
defer r.mu.RUnlock()
reg, ok := r.providers[id]
return reg, ok
}
// GetAllProviderInfos returns info for all registered providers
func (r *Registry) GetAllProviderInfos() []types.WebSearchProviderInfo {
r.mu.RLock()
defer r.mu.RUnlock()
infos := make([]types.WebSearchProviderInfo, 0, len(r.providers))
for _, reg := range r.providers {
infos = append(infos, reg.Info)
}
return infos
}
// CreateProvider creates a provider instance by ID
func (r *Registry) CreateProvider(id string) (interfaces.WebSearchProvider, error) {
r.mu.RLock()
reg, ok := r.providers[id]
r.mu.RUnlock()
if !ok {
return nil, fmt.Errorf("web search provider %s not registered", id)
}
return reg.Factory()
}
// CreateAllProviders creates instances of all registered providers
func (r *Registry) CreateAllProviders() (map[string]interfaces.WebSearchProvider, error) {
r.mu.RLock()
defer r.mu.RUnlock()
providers := make(map[string]interfaces.WebSearchProvider)
for id, reg := range r.providers {
provider, err := reg.Factory()
if err != nil {
// Skip providers that fail to initialize (e.g., missing API keys)
continue
}
providers[id] = provider
}
return providers, nil
}

View File

@@ -0,0 +1,108 @@
package service
import (
"context"
"fmt"
"github.com/Tencent/WeKnora/internal/logger"
"github.com/Tencent/WeKnora/internal/types"
"github.com/Tencent/WeKnora/internal/types/interfaces"
)
// webSearchProviderService implements interfaces.WebSearchProviderService
type webSearchProviderService struct {
repo interfaces.WebSearchProviderRepository
}
// NewWebSearchProviderService creates a new web search provider service
func NewWebSearchProviderService(repo interfaces.WebSearchProviderRepository) interfaces.WebSearchProviderService {
return &webSearchProviderService{repo: repo}
}
// CreateProvider creates a new web search provider configuration.
func (s *webSearchProviderService) CreateProvider(ctx context.Context, provider *types.WebSearchProviderEntity) error {
if provider.TenantID == 0 {
return fmt.Errorf("tenant ID is required")
}
if !isValidProviderType(provider.Provider) {
return fmt.Errorf("invalid provider type: %s", provider.Provider)
}
if err := validateProviderParameters(provider.Provider, provider.Parameters); err != nil {
return err
}
if provider.IsDefault {
if err := s.repo.ClearDefault(ctx, provider.TenantID, ""); err != nil {
logger.Warnf(ctx, "Failed to clear default providers: %v", err)
}
}
logger.Infof(ctx, "Creating web search provider: tenant=%d, name=%s, type=%s", provider.TenantID, provider.Name, provider.Provider)
return s.repo.Create(ctx, provider)
}
// UpdateProvider updates an existing provider.
func (s *webSearchProviderService) UpdateProvider(ctx context.Context, provider *types.WebSearchProviderEntity) error {
if provider.TenantID == 0 {
return fmt.Errorf("tenant ID is required")
}
// Validate provider type if set
if provider.Provider != "" && !isValidProviderType(provider.Provider) {
return fmt.Errorf("invalid provider type: %s", provider.Provider)
}
if provider.IsDefault {
if err := s.repo.ClearDefault(ctx, provider.TenantID, provider.ID); err != nil {
logger.Warnf(ctx, "Failed to clear default providers: %v", err)
}
}
logger.Infof(ctx, "Updating web search provider: tenant=%d, id=%s", provider.TenantID, provider.ID)
return s.repo.Update(ctx, provider)
}
// DeleteProvider deletes a provider by tenant + id.
func (s *webSearchProviderService) DeleteProvider(ctx context.Context, tenantID uint64, id string) error {
logger.Infof(ctx, "Deleting web search provider: tenant=%d, id=%s", tenantID, id)
return s.repo.Delete(ctx, tenantID, id)
}
// isValidProviderType checks if the given provider type is supported
func isValidProviderType(provider types.WebSearchProviderType) bool {
switch provider {
case types.WebSearchProviderTypeBing,
types.WebSearchProviderTypeGoogle,
types.WebSearchProviderTypeDuckDuckGo,
types.WebSearchProviderTypeTavily:
return true
default:
return false
}
}
// validateProviderParameters validates required parameters for each provider type
func validateProviderParameters(provider types.WebSearchProviderType, params types.WebSearchProviderParameters) error {
switch provider {
case types.WebSearchProviderTypeBing:
if params.APIKey == "" {
return fmt.Errorf("API key is required for Bing provider")
}
case types.WebSearchProviderTypeGoogle:
if params.APIKey == "" {
return fmt.Errorf("API key is required for Google provider")
}
if params.EngineID == "" {
return fmt.Errorf("engine ID is required for Google provider")
}
case types.WebSearchProviderTypeTavily:
if params.APIKey == "" {
return fmt.Errorf("API key is required for Tavily provider")
}
case types.WebSearchProviderTypeDuckDuckGo:
// No API key required
}
return nil
}

View File

@@ -47,7 +47,7 @@ import (
"github.com/Tencent/WeKnora/internal/application/service/llmcontext"
memoryService "github.com/Tencent/WeKnora/internal/application/service/memory"
"github.com/Tencent/WeKnora/internal/application/service/retriever"
"github.com/Tencent/WeKnora/internal/application/service/web_search"
infra_web_search "github.com/Tencent/WeKnora/internal/infrastructure/web_search"
"github.com/Tencent/WeKnora/internal/config"
"github.com/Tencent/WeKnora/internal/database"
"github.com/Tencent/WeKnora/internal/datasource"
@@ -180,9 +180,11 @@ func BuildContainer(container *dig.Container) *dig.Container {
// Web search service (needed by AgentService)
logger.Debugf(ctx, "[Container] Registering web search registry and providers...")
must(container.Provide(web_search.NewRegistry))
must(container.Provide(infra_web_search.NewRegistry))
must(container.Invoke(registerWebSearchProviders))
must(container.Provide(repository.NewWebSearchProviderRepository))
must(container.Provide(service.NewWebSearchService))
must(container.Provide(service.NewWebSearchProviderService))
// Agent service layer (requires event bus, web search service)
// SessionService is passed as parameter to CreateAgentEngine method when creating AgentService
@@ -219,6 +221,7 @@ func BuildContainer(container *dig.Container) *dig.Container {
must(container.Provide(chatpipeline.NewEventManager))
must(container.Invoke(chatpipeline.NewPluginSearch))
must(container.Invoke(chatpipeline.NewPluginRerank))
must(container.Invoke(chatpipeline.NewPluginWebFetch))
must(container.Invoke(chatpipeline.NewPluginMerge))
must(container.Invoke(chatpipeline.NewPluginDataAnalysis))
must(container.Invoke(chatpipeline.NewPluginIntoChatMessage))
@@ -250,6 +253,7 @@ func BuildContainer(container *dig.Container) *dig.Container {
must(container.Provide(handler.NewSystemHandler))
must(container.Provide(handler.NewMCPServiceHandler))
must(container.Provide(handler.NewWebSearchHandler))
must(container.Provide(handler.NewWebSearchProviderHandler))
must(container.Provide(handler.NewCustomAgentHandler))
must(container.Provide(service.NewSkillService))
must(container.Provide(handler.NewSkillHandler))
@@ -951,27 +955,14 @@ func NewDuckDB() (*sql.DB, error) {
return sqlDB, nil
}
// registerWebSearchProviders registers all web search providers to the registry
func registerWebSearchProviders(registry *web_search.Registry) {
// Register DuckDuckGo provider
registry.Register(web_search.DuckDuckGoProviderInfo(), func() (interfaces.WebSearchProvider, error) {
return web_search.NewDuckDuckGoProvider()
})
// Register Google provider
registry.Register(web_search.GoogleProviderInfo(), func() (interfaces.WebSearchProvider, error) {
return web_search.NewGoogleProvider()
})
// Register Bing provider
registry.Register(web_search.BingProviderInfo(), func() (interfaces.WebSearchProvider, error) {
return web_search.NewBingProvider()
})
// Register Tavily provider
registry.Register(web_search.TavilyProviderInfo(), func() (interfaces.WebSearchProvider, error) {
return web_search.NewTavilyProvider()
})
// registerWebSearchProviders registers all web search provider types to the registry.
// Each provider type is registered with its factory function that accepts parameters.
// Provider instances are created on-demand when tenants configure them.
func registerWebSearchProviders(registry *infra_web_search.Registry) {
registry.Register("duckduckgo", infra_web_search.NewDuckDuckGoProvider)
registry.Register("google", infra_web_search.NewGoogleProvider)
registry.Register("bing", infra_web_search.NewBingProvider)
registry.Register("tavily", infra_web_search.NewTavilyProvider)
}
// registerIMAdapterFactories registers adapter factories for each IM platform

View File

@@ -3,42 +3,22 @@ package handler
import (
"net/http"
"github.com/Tencent/WeKnora/internal/application/service/web_search"
"github.com/Tencent/WeKnora/internal/logger"
"github.com/Tencent/WeKnora/internal/types"
"github.com/gin-gonic/gin"
)
// WebSearchHandler handles web search related requests
type WebSearchHandler struct {
registry *web_search.Registry
}
// WebSearchHandler handles legacy web search related requests
type WebSearchHandler struct{}
// NewWebSearchHandler creates a new web search handler
func NewWebSearchHandler(registry *web_search.Registry) *WebSearchHandler {
return &WebSearchHandler{
registry: registry,
}
func NewWebSearchHandler() *WebSearchHandler {
return &WebSearchHandler{}
}
// GetProviders returns the list of available web search providers
// @Summary Get available web search providers
// @Description Returns the list of available web search providers from configuration
// @Tags web-search
// @Accept json
// @Produce json
// @Success 200 {object} map[string]interface{} "List of providers"
// @Security Bearer
// @Security ApiKeyAuth
// @Router /web-search/providers [get]
// GetProviders returns the list of available web search provider types
func (h *WebSearchHandler) GetProviders(c *gin.Context) {
ctx := c.Request.Context()
logger.Info(ctx, "Getting web search providers")
providers := h.registry.GetAllProviderInfos()
logger.Infof(ctx, "Returning %d web search providers", len(providers))
c.JSON(http.StatusOK, gin.H{
"success": true,
"data": providers,
"data": types.GetWebSearchProviderTypes(),
})
}

View File

@@ -0,0 +1,319 @@
package handler
import (
"context"
"fmt"
"net/http"
"github.com/Tencent/WeKnora/internal/errors"
infra_web_search "github.com/Tencent/WeKnora/internal/infrastructure/web_search"
"github.com/Tencent/WeKnora/internal/logger"
"github.com/Tencent/WeKnora/internal/types"
"github.com/Tencent/WeKnora/internal/types/interfaces"
secutils "github.com/Tencent/WeKnora/internal/utils"
"github.com/gin-gonic/gin"
)
// WebSearchProviderHandler handles HTTP requests for web search provider CRUD
type WebSearchProviderHandler struct {
repo interfaces.WebSearchProviderRepository
service interfaces.WebSearchProviderService
registry *infra_web_search.Registry
}
// NewWebSearchProviderHandler creates a new handler
func NewWebSearchProviderHandler(
repo interfaces.WebSearchProviderRepository,
service interfaces.WebSearchProviderService,
registry *infra_web_search.Registry,
) *WebSearchProviderHandler {
return &WebSearchProviderHandler{repo: repo, service: service, registry: registry}
}
// --- request DTOs ---
// CreateProviderRequest defines the request body for creating a provider
type CreateProviderRequest struct {
Name string `json:"name" binding:"required"`
Provider types.WebSearchProviderType `json:"provider" binding:"required"`
Description string `json:"description"`
Parameters types.WebSearchProviderParameters `json:"parameters"`
IsDefault bool `json:"is_default"`
}
// UpdateProviderRequest defines the request body for updating a provider
type UpdateProviderRequest struct {
Name string `json:"name"`
Description string `json:"description"`
Parameters types.WebSearchProviderParameters `json:"parameters"`
IsDefault bool `json:"is_default"`
}
// --- helpers ---
// getTenantID extracts tenant ID from gin context (set by auth middleware).
func (h *WebSearchProviderHandler) getTenantID(c *gin.Context) uint64 {
return c.GetUint64(types.TenantIDContextKey.String())
}
// getOwnedProvider loads a provider and verifies it belongs to the given tenant.
// Returns (nil, status, msg) on failure so callers can respond immediately.
func (h *WebSearchProviderHandler) getOwnedProvider(
ctx context.Context, tenantID uint64, id string,
) (*types.WebSearchProviderEntity, int, string) {
provider, err := h.repo.GetByID(ctx, tenantID, id)
if err != nil {
return nil, http.StatusInternalServerError, "failed to query provider"
}
if provider == nil {
return nil, http.StatusNotFound, "web search provider not found"
}
return provider, http.StatusOK, ""
}
// --- endpoints ---
// CreateProvider creates a new web search provider
func (h *WebSearchProviderHandler) CreateProvider(c *gin.Context) {
ctx := c.Request.Context()
tenantID := h.getTenantID(c)
if tenantID == 0 {
c.JSON(http.StatusUnauthorized, gin.H{"success": false, "error": "unauthorized: tenant context missing"})
return
}
var req CreateProviderRequest
if err := c.ShouldBindJSON(&req); err != nil {
logger.Warnf(ctx, "Invalid create provider request: %v", err)
c.Error(errors.NewBadRequestError(err.Error()))
return
}
logger.Infof(ctx, "Creating web search provider: tenant=%d, name=%s, type=%s",
tenantID, secutils.SanitizeForLog(req.Name), secutils.SanitizeForLog(string(req.Provider)))
provider := &types.WebSearchProviderEntity{
TenantID: tenantID,
Name: secutils.SanitizeForLog(req.Name),
Provider: req.Provider,
Description: secutils.SanitizeForLog(req.Description),
Parameters: req.Parameters,
IsDefault: req.IsDefault,
}
if err := h.service.CreateProvider(ctx, provider); err != nil {
logger.Warnf(ctx, "Failed to create web search provider: %v", err)
c.Error(errors.NewInternalServerError(err.Error()))
return
}
c.JSON(http.StatusCreated, gin.H{
"success": true,
"data": provider,
})
}
// ListProviders lists all web search providers for the current tenant
func (h *WebSearchProviderHandler) ListProviders(c *gin.Context) {
ctx := c.Request.Context()
tenantID := h.getTenantID(c)
if tenantID == 0 {
c.JSON(http.StatusUnauthorized, gin.H{"success": false, "error": "unauthorized: tenant context missing"})
return
}
providers, err := h.repo.List(ctx, tenantID)
if err != nil {
logger.Warnf(ctx, "Failed to list web search providers: %v", err)
c.Error(errors.NewInternalServerError(err.Error()))
return
}
c.JSON(http.StatusOK, gin.H{
"success": true,
"data": providers,
})
}
// GetProvider retrieves a single web search provider by ID
func (h *WebSearchProviderHandler) GetProvider(c *gin.Context) {
ctx := c.Request.Context()
tenantID := h.getTenantID(c)
if tenantID == 0 {
c.JSON(http.StatusUnauthorized, gin.H{"success": false, "error": "unauthorized: tenant context missing"})
return
}
id := c.Param("id")
provider, status, msg := h.getOwnedProvider(ctx, tenantID, id)
if status != http.StatusOK {
c.JSON(status, gin.H{"success": false, "error": msg})
return
}
c.JSON(http.StatusOK, gin.H{
"success": true,
"data": provider,
})
}
// UpdateProvider updates a web search provider
func (h *WebSearchProviderHandler) UpdateProvider(c *gin.Context) {
ctx := c.Request.Context()
tenantID := h.getTenantID(c)
if tenantID == 0 {
c.JSON(http.StatusUnauthorized, gin.H{"success": false, "error": "unauthorized: tenant context missing"})
return
}
id := c.Param("id")
// Ownership check
existing, status, msg := h.getOwnedProvider(ctx, tenantID, id)
if status != http.StatusOK {
c.JSON(status, gin.H{"success": false, "error": msg})
return
}
var req UpdateProviderRequest
if err := c.ShouldBindJSON(&req); err != nil {
c.Error(errors.NewBadRequestError(err.Error()))
return
}
// Build updated entity, keeping immutable fields from existing
provider := &types.WebSearchProviderEntity{
ID: id,
TenantID: tenantID,
Name: secutils.SanitizeForLog(req.Name),
Provider: existing.Provider, // Provider type is immutable after creation
Description: secutils.SanitizeForLog(req.Description),
Parameters: req.Parameters,
IsDefault: req.IsDefault,
}
if err := h.service.UpdateProvider(ctx, provider); err != nil {
logger.Warnf(ctx, "Failed to update web search provider %s: %v", id, err)
c.Error(errors.NewInternalServerError(err.Error()))
return
}
// Re-fetch to get the full stored state
updated, _ := h.repo.GetByID(ctx, tenantID, id)
if updated != nil {
c.JSON(http.StatusOK, gin.H{"success": true, "data": updated})
} else {
c.JSON(http.StatusOK, gin.H{"success": true})
}
}
// DeleteProvider deletes a web search provider
func (h *WebSearchProviderHandler) DeleteProvider(c *gin.Context) {
ctx := c.Request.Context()
tenantID := h.getTenantID(c)
if tenantID == 0 {
c.JSON(http.StatusUnauthorized, gin.H{"success": false, "error": "unauthorized: tenant context missing"})
return
}
id := c.Param("id")
// Ownership check
if _, status, msg := h.getOwnedProvider(ctx, tenantID, id); status != http.StatusOK {
c.JSON(status, gin.H{"success": false, "error": msg})
return
}
if err := h.service.DeleteProvider(ctx, tenantID, id); err != nil {
logger.Warnf(ctx, "Failed to delete web search provider %s: %v", id, err)
c.Error(errors.NewInternalServerError(err.Error()))
return
}
c.JSON(http.StatusOK, gin.H{"success": true})
}
// ListProviderTypes returns available provider types and their parameter requirements
func (h *WebSearchProviderHandler) ListProviderTypes(c *gin.Context) {
c.JSON(http.StatusOK, gin.H{
"success": true,
"data": types.GetWebSearchProviderTypes(),
})
}
// TestProviderByID tests an existing saved provider by performing a sample search
func (h *WebSearchProviderHandler) TestProviderByID(c *gin.Context) {
ctx := c.Request.Context()
tenantID := h.getTenantID(c)
if tenantID == 0 {
c.JSON(http.StatusUnauthorized, gin.H{"success": false, "error": "unauthorized: tenant context missing"})
return
}
id := c.Param("id")
provider, status, msg := h.getOwnedProvider(ctx, tenantID, id)
if status != http.StatusOK {
c.JSON(status, gin.H{"success": false, "error": msg})
return
}
if err := h.doTestSearch(ctx, string(provider.Provider), provider.Parameters); err != nil {
logger.Warnf(ctx, "Web search provider test failed: %v", err)
c.JSON(http.StatusOK, gin.H{"success": false, "error": err.Error()})
return
}
c.JSON(http.StatusOK, gin.H{"success": true})
}
// TestProviderRequest defines the body for testing raw credentials
type TestProviderRequest struct {
Provider string `json:"provider" binding:"required"`
Parameters types.WebSearchProviderParameters `json:"parameters"`
}
// TestProviderRaw tests a provider with raw credentials (no persistence)
func (h *WebSearchProviderHandler) TestProviderRaw(c *gin.Context) {
ctx := c.Request.Context()
var req TestProviderRequest
if err := c.ShouldBindJSON(&req); err != nil {
c.Error(errors.NewBadRequestError(err.Error()))
return
}
if err := h.doTestSearch(ctx, req.Provider, req.Parameters); err != nil {
logger.Warnf(ctx, "Web search provider test failed: %v", err)
c.JSON(http.StatusOK, gin.H{"success": false, "error": err.Error()})
return
}
c.JSON(http.StatusOK, gin.H{"success": true})
}
// doTestSearch creates a temporary provider and runs a simple test query
func (h *WebSearchProviderHandler) doTestSearch(ctx context.Context, providerType string, params types.WebSearchProviderParameters) error {
logger.Infof(ctx, "[WebSearch][Test] testing provider type=%s", providerType)
searchProvider, err := h.registry.CreateProvider(providerType, params)
if err != nil {
logger.Warnf(ctx, "[WebSearch][Test] failed to create provider: %v", err)
return fmt.Errorf("failed to create provider: %w", err)
}
results, err := searchProvider.Search(ctx, "test", 1, false)
if err != nil {
logger.Warnf(ctx, "[WebSearch][Test] search failed: %v", err)
return err
}
if len(results) == 0 {
logger.Warnf(ctx, "[WebSearch][Test] search returned 0 results — API key or configuration may be invalid")
return fmt.Errorf("search returned 0 results, please verify your API key and configuration")
}
logger.Infof(ctx, "[WebSearch][Test] succeeded: type=%s, results=%d", providerType, len(results))
return nil
}

View File

@@ -0,0 +1,171 @@
// Package web_fetch provides a public URL content fetcher with SSRF protection.
// It extracts core logic from the agent WebFetchTool so it can be used by the chat pipeline.
package web_fetch
import (
"context"
"crypto/tls"
"fmt"
"io"
"net"
"net/http"
"net/url"
"strings"
"time"
"github.com/PuerkitoBio/goquery"
"github.com/Tencent/WeKnora/internal/logger"
"github.com/Tencent/WeKnora/internal/utils"
)
const (
fetchTimeout = 15 * time.Second
maxBodySize = 100 * 1024 // 100KB
)
// FetchURLContent fetches a URL and returns its text content (HTML converted to clean text).
// Includes SSRF validation, DNS pinning, browser-like headers, and content size limits.
func FetchURLContent(ctx context.Context, rawURL string) (string, error) {
if rawURL == "" {
return "", fmt.Errorf("url is empty")
}
// SSRF validation
if safe, reason := utils.IsSSRFSafeURL(rawURL); !safe {
return "", fmt.Errorf("URL rejected: %s", reason)
}
u, err := url.Parse(rawURL)
if err != nil {
return "", fmt.Errorf("invalid URL: %w", err)
}
hostname := u.Hostname()
port := u.Port()
if port == "" {
if u.Scheme == "https" {
port = "443"
} else {
port = "80"
}
}
// DNS pinning: resolve once, use pinned IP
ips, err := net.DefaultResolver.LookupIP(context.Background(), "ip", hostname)
if err != nil || len(ips) == 0 {
return "", fmt.Errorf("DNS lookup failed for %s: %w", hostname, err)
}
var pinnedIP net.IP
for _, ip := range ips {
if utils.IsPublicIP(ip) {
pinnedIP = ip
break
}
}
if pinnedIP == nil {
return "", fmt.Errorf("no public IP for host %s", hostname)
}
// Build request with pinned IP
hostPort := net.JoinHostPort(pinnedIP.String(), port)
fetchURL := *u
fetchURL.Host = hostPort
ctx, cancel := context.WithTimeout(ctx, fetchTimeout)
defer cancel()
req, err := http.NewRequestWithContext(ctx, http.MethodGet, fetchURL.String(), nil)
if err != nil {
return "", err
}
req.Host = hostname
// Browser-like headers to reduce 403 rejections.
// These match a real Chrome browser fingerprint.
req.Header.Set("User-Agent",
"Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/131.0.0.0 Safari/537.36")
req.Header.Set("Accept",
"text/html,application/xhtml+xml,application/xml;q=0.9,image/avif,image/webp,image/apng,*/*;q=0.8,application/signed-exchange;v=b3;q=0.7")
req.Header.Set("Accept-Language", "zh-CN,zh;q=0.9,en;q=0.8,en-GB;q=0.7,en-US;q=0.6")
req.Header.Set("Accept-Encoding", "identity") // no gzip to simplify reading
req.Header.Set("Cache-Control", "no-cache")
req.Header.Set("Pragma", "no-cache")
req.Header.Set("Sec-Ch-Ua", `"Chromium";v="131", "Not_A Brand";v="24"`)
req.Header.Set("Sec-Ch-Ua-Mobile", "?0")
req.Header.Set("Sec-Ch-Ua-Platform", `"macOS"`)
req.Header.Set("Sec-Fetch-Dest", "document")
req.Header.Set("Sec-Fetch-Mode", "navigate")
req.Header.Set("Sec-Fetch-Site", "none")
req.Header.Set("Sec-Fetch-User", "?1")
req.Header.Set("Upgrade-Insecure-Requests", "1")
req.Header.Set("Referer", u.Scheme+"://"+hostname+"/")
// Custom transport: TLS ServerName for certificate validation with pinned IP.
client := &http.Client{
Timeout: fetchTimeout,
Transport: &http.Transport{
TLSClientConfig: &tls.Config{
ServerName: hostname,
},
},
// Follow redirects (default behavior), up to 10 hops
}
resp, err := client.Do(req)
if err != nil {
return "", fmt.Errorf("fetch failed: %w", err)
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
return "", fmt.Errorf("HTTP %d %s", resp.StatusCode, resp.Status)
}
body, err := io.ReadAll(io.LimitReader(resp.Body, maxBodySize))
if err != nil {
return "", fmt.Errorf("read failed: %w", err)
}
text := htmlToText(string(body))
logger.Infof(ctx, "[WebFetch] fetched %s → %d chars", rawURL, len(text))
return text, nil
}
// htmlToText extracts clean text from HTML, removing scripts/styles/nav.
func htmlToText(html string) string {
doc, err := goquery.NewDocumentFromReader(strings.NewReader(html))
if err != nil {
return stripTags(html)
}
doc.Find("script, style, nav, footer, header, iframe, noscript, svg, img").Remove()
var sb strings.Builder
doc.Find("body").Each(func(i int, s *goquery.Selection) {
sb.WriteString(s.Text())
})
text := sb.String()
// Normalize whitespace: collapse blank lines
lines := strings.Split(text, "\n")
var cleaned []string
for _, line := range lines {
line = strings.TrimSpace(line)
if line != "" {
cleaned = append(cleaned, line)
}
}
return strings.Join(cleaned, "\n")
}
func stripTags(s string) string {
var sb strings.Builder
inTag := false
for _, r := range s {
if r == '<' {
inTag = true
} else if r == '>' {
inTag = false
} else if !inTag {
sb.WriteRune(r)
}
}
return sb.String()
}

View File

@@ -7,22 +7,21 @@ import (
"io"
"net/http"
"net/url"
"os"
"strconv"
"time"
"github.com/Tencent/WeKnora/internal/logger"
"github.com/Tencent/WeKnora/internal/types"
"github.com/Tencent/WeKnora/internal/types/interfaces"
)
const (
// defaultBingSearchURL is the default Bing search API URL.
// Reference: https://learn.microsoft.com/en-us/previous-versions/bing/search-apis/bing-web-search/reference/endpoints
// defaultBingSearchURL is the hardcoded Bing search API URL.
// Not configurable by tenants — prevents SSRF.
defaultBingSearchURL = "https://api.bing.microsoft.com/v7.0/search"
)
var (
// defaultUserAgentHeader for PC. https://learn.microsoft.com/en-us/previous-versions/bing/search-apis/bing-web-search/reference/headers
defaultUserAgentHeader = "Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/139.0.0.0 Safari/537.36"
defaultBingTimeout = 10 * time.Second
)
@@ -50,33 +49,21 @@ type BingProvider struct {
apiKey string
}
// NewBingProvider creates a new Bing provider
func NewBingProvider() (interfaces.WebSearchProvider, error) {
apiKey := os.Getenv("BING_SEARCH_API_KEY")
if len(apiKey) == 0 {
return nil, fmt.Errorf("BING_SEARCH_API_KEY is not set")
// NewBingProvider creates a new Bing provider from parameters (no environment variables).
func NewBingProvider(params types.WebSearchProviderParameters) (interfaces.WebSearchProvider, error) {
if params.APIKey == "" {
return nil, fmt.Errorf("API key is required for Bing provider")
}
client := &http.Client{
Timeout: defaultBingTimeout,
}
return &BingProvider{
client: client,
baseURL: defaultBingSearchURL,
apiKey: apiKey,
baseURL: defaultBingSearchURL, // Hardcoded — not tenant-configurable
apiKey: params.APIKey,
}, nil
}
// BingProviderInfo returns the provider info for registration
func BingProviderInfo() types.WebSearchProviderInfo {
return types.WebSearchProviderInfo{
ID: "bing",
Name: "Bing",
Free: false,
RequiresAPIKey: true,
Description: "Bing Search API",
}
}
// Name returns the provider name
func (p *BingProvider) Name() string {
return "bing"
@@ -92,11 +79,18 @@ func (p *BingProvider) Search(
if len(query) == 0 {
return nil, fmt.Errorf("query is empty")
}
logger.Infof(ctx, "[WebSearch][Bing] query=%q maxResults=%d url=%s", query, maxResults, p.baseURL)
req, err := p.buildParams(ctx, query, maxResults, includeDate)
if err != nil {
return nil, err
}
return p.doSearch(ctx, req)
results, err := p.doSearch(ctx, req)
if err != nil {
logger.Warnf(ctx, "[WebSearch][Bing] failed: %v", err)
return nil, err
}
logger.Infof(ctx, "[WebSearch][Bing] returned %d results", len(results))
return results, nil
}
func (p *BingProvider) doSearch(ctx context.Context, req *http.Request) ([]*types.WebSearchResult, error) {
@@ -110,6 +104,12 @@ func (p *BingProvider) doSearch(ctx context.Context, req *http.Request) ([]*type
if err != nil {
return nil, err
}
if resp.StatusCode != http.StatusOK {
logger.Warnf(ctx, "[WebSearch][Bing] API returned status %d: %s", resp.StatusCode, string(body))
return nil, fmt.Errorf("bing API returned status %d: %s", resp.StatusCode, string(body))
}
var respData bingSearchResponse
if err := json.Unmarshal(body, &respData); err != nil {
return nil, fmt.Errorf("failed to unmarshal response: %w", err)
@@ -128,7 +128,6 @@ func (p *BingProvider) doSearch(ctx context.Context, req *http.Request) ([]*type
}
// bingSearchResponse defines the response structure for Bing search API.
// ref: https://learn.microsoft.com/en-us/previous-versions/bing/search-apis/bing-web-search/quickstarts/rest/go
type bingSearchResponse struct {
Type string `json:"_type"`
QueryContext struct {
@@ -183,8 +182,6 @@ type bingSearchResponse struct {
} `json:"rankingResponse"`
}
// buildParams builds the request parameters for Bing search API.
// ref: https://learn.microsoft.com/en-us/previous-versions/bing/search-apis/bing-web-search/quickstarts/rest/go
func (p *BingProvider) buildParams(ctx context.Context, query string, maxResults int, includeDate bool) (*http.Request, error) {
params := url.Values{}
params.Set("q", query)

View File

@@ -22,8 +22,9 @@ type DuckDuckGoProvider struct {
client *http.Client
}
// NewDuckDuckGoProvider creates a new DuckDuckGo provider
func NewDuckDuckGoProvider() (interfaces.WebSearchProvider, error) {
// NewDuckDuckGoProvider creates a new DuckDuckGo provider.
// DuckDuckGo is free and requires no API key or configuration.
func NewDuckDuckGoProvider(params types.WebSearchProviderParameters) (interfaces.WebSearchProvider, error) {
return &DuckDuckGoProvider{
client: &http.Client{
Timeout: 30 * time.Second,
@@ -31,17 +32,6 @@ func NewDuckDuckGoProvider() (interfaces.WebSearchProvider, error) {
}, nil
}
// DuckDuckGoProviderInfo returns the provider info for registration
func DuckDuckGoProviderInfo() types.WebSearchProviderInfo {
return types.WebSearchProviderInfo{
ID: "duckduckgo",
Name: "DuckDuckGo",
Free: true,
RequiresAPIKey: false,
Description: "DuckDuckGo Search API",
}
}
// Name returns the provider name
func (p *DuckDuckGoProvider) Name() string {
return "duckduckgo"
@@ -82,7 +72,6 @@ func (p *DuckDuckGoProvider) searchHTML(
baseURL := "https://html.duckduckgo.com/html/"
params := url.Values{}
params.Set("q", query)
// Prefer Chinese results if applicable; otherwise DDG will auto-detect
params.Set("kl", "cn-zh")
reqURL := baseURL + "?" + params.Encode()
@@ -90,13 +79,11 @@ func (p *DuckDuckGoProvider) searchHTML(
if err != nil {
return nil, fmt.Errorf("failed to create request: %w", err)
}
// Use a realistic UA to avoid blocks
req.Header.Set(
"User-Agent",
"Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/120.0.0.0 Safari/537.36",
)
// print curl of request
curlCommand := fmt.Sprintf(
"curl -X GET '%s' -H 'User-Agent: Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/120.0.0.0 Safari/537.36'",
req.URL.String(),
@@ -119,7 +106,6 @@ func (p *DuckDuckGoProvider) searchHTML(
}
results := make([]*types.WebSearchResult, 0, maxResults)
// Structure based on DDG HTML page
doc.Find(".web-result").Each(func(i int, s *goquery.Selection) {
if len(results) >= maxResults {
return

View File

@@ -3,12 +3,11 @@ package web_search
import (
"context"
"fmt"
"net/url"
"os"
"google.golang.org/api/customsearch/v1"
"google.golang.org/api/option"
"github.com/Tencent/WeKnora/internal/logger"
"github.com/Tencent/WeKnora/internal/types"
"github.com/Tencent/WeKnora/internal/types/interfaces"
)
@@ -18,54 +17,32 @@ type GoogleProvider struct {
srv *customsearch.Service
apiKey string
engineID string
baseURL string
}
// NewGoogleProvider creates a new Google provider
func NewGoogleProvider() (interfaces.WebSearchProvider, error) {
apiURL := os.Getenv("GOOGLE_SEARCH_API_URL")
if apiURL == "" {
return nil, fmt.Errorf("GOOGLE_SEARCH_API_URL environment variable is not set")
// NewGoogleProvider creates a new Google provider from parameters (no environment variables).
// The API endpoint is the official Google Custom Search endpoint — not tenant-configurable.
func NewGoogleProvider(params types.WebSearchProviderParameters) (interfaces.WebSearchProvider, error) {
if params.APIKey == "" {
return nil, fmt.Errorf("API key is required for Google provider")
}
if params.EngineID == "" {
return nil, fmt.Errorf("engine ID is required for Google provider")
}
u, err := url.Parse(apiURL)
if err != nil {
return nil, err
clientOpts := []option.ClientOption{
option.WithAPIKey(params.APIKey),
}
engineID := u.Query().Get("engine_id")
if engineID == "" {
return nil, fmt.Errorf("engine_id is empty")
}
apiKey := u.Query().Get("api_key")
if apiKey == "" {
return nil, fmt.Errorf("api_key is empty")
}
clientOpts := make([]option.ClientOption, 0)
clientOpts = append(clientOpts, option.WithAPIKey(apiKey))
clientOpts = append(clientOpts, option.WithEndpoint(u.Scheme+"://"+u.Host))
srv, err := customsearch.NewService(context.Background(), clientOpts...)
if err != nil {
return nil, err
}
return &GoogleProvider{
srv: srv,
apiKey: apiKey,
engineID: engineID,
baseURL: apiURL,
apiKey: params.APIKey,
engineID: params.EngineID,
}, nil
}
// GoogleProviderInfo returns the provider info for registration
func GoogleProviderInfo() types.WebSearchProviderInfo {
return types.WebSearchProviderInfo{
ID: "google",
Name: "Google",
Free: false,
RequiresAPIKey: true,
Description: "Google Custom Search API",
}
}
// Name returns the provider name
func (p *GoogleProvider) Name() string {
return "google"
@@ -81,6 +58,7 @@ func (p *GoogleProvider) Search(
if len(query) == 0 {
return nil, fmt.Errorf("query is empty")
}
logger.Infof(ctx, "[WebSearch][Google] query=%q maxResults=%d engineID=%s", query, maxResults, p.engineID)
cseCall := p.srv.Cse.List().Context(ctx).Cx(p.engineID).Q(query)
if maxResults > 0 {
@@ -92,6 +70,7 @@ func (p *GoogleProvider) Search(
resp, err := cseCall.Do()
if err != nil {
logger.Warnf(ctx, "[WebSearch][Google] failed: %v", err)
return nil, err
}
results := make([]*types.WebSearchResult, 0)
@@ -104,5 +83,6 @@ func (p *GoogleProvider) Search(
}
results = append(results, result)
}
logger.Infof(ctx, "[WebSearch][Google] returned %d results", len(results))
return results, nil
}

View File

@@ -0,0 +1,45 @@
package web_search
import (
"fmt"
"sync"
"github.com/Tencent/WeKnora/internal/types"
"github.com/Tencent/WeKnora/internal/types/interfaces"
)
// ProviderFactory creates a new web search provider instance from parameters.
type ProviderFactory func(params types.WebSearchProviderParameters) (interfaces.WebSearchProvider, error)
// Registry manages web search provider type registrations.
// It maps provider type IDs (e.g., "bing", "google") to their factory functions.
// Instances are created on-demand with tenant-specific parameters.
type Registry struct {
factories map[string]ProviderFactory
mu sync.RWMutex
}
// NewRegistry creates a new web search provider registry
func NewRegistry() *Registry {
return &Registry{
factories: make(map[string]ProviderFactory),
}
}
// Register registers a provider type factory by ID
func (r *Registry) Register(id string, factory ProviderFactory) {
r.mu.Lock()
defer r.mu.Unlock()
r.factories[id] = factory
}
// CreateProvider creates a provider instance by type with the given parameters.
func (r *Registry) CreateProvider(providerType string, params types.WebSearchProviderParameters) (interfaces.WebSearchProvider, error) {
r.mu.RLock()
factory, ok := r.factories[providerType]
r.mu.RUnlock()
if !ok {
return nil, fmt.Errorf("web search provider type %s not registered", providerType)
}
return factory(params)
}

View File

@@ -7,14 +7,16 @@ import (
"fmt"
"io"
"net/http"
"os"
"time"
"github.com/Tencent/WeKnora/internal/logger"
"github.com/Tencent/WeKnora/internal/types"
"github.com/Tencent/WeKnora/internal/types/interfaces"
)
const (
// defaultTavilySearchURL is the hardcoded Tavily API URL.
// Not configurable by tenants — prevents SSRF.
defaultTavilySearchURL = "https://api.tavily.com/search"
)
@@ -29,11 +31,10 @@ type TavilyProvider struct {
apiKey string
}
// NewTavilyProvider creates a new Tavily provider
func NewTavilyProvider() (interfaces.WebSearchProvider, error) {
apiKey := os.Getenv("TAVILY_API_KEY")
if len(apiKey) == 0 {
return nil, fmt.Errorf("TAVILY_API_KEY is not set")
// NewTavilyProvider creates a new Tavily provider from parameters (no environment variables).
func NewTavilyProvider(params types.WebSearchProviderParameters) (interfaces.WebSearchProvider, error) {
if params.APIKey == "" {
return nil, fmt.Errorf("API key is required for Tavily provider")
}
client := &http.Client{
Timeout: defaultTavilyTimeout,
@@ -41,21 +42,10 @@ func NewTavilyProvider() (interfaces.WebSearchProvider, error) {
return &TavilyProvider{
client: client,
baseURL: defaultTavilySearchURL,
apiKey: apiKey,
apiKey: params.APIKey,
}, nil
}
// TavilyProviderInfo returns the provider info for registration
func TavilyProviderInfo() types.WebSearchProviderInfo {
return types.WebSearchProviderInfo{
ID: "tavily",
Name: "Tavily",
Free: false,
RequiresAPIKey: true,
Description: "Tavily Search API",
}
}
// Name returns the provider name
func (p *TavilyProvider) Name() string {
return "tavily"
@@ -71,6 +61,7 @@ func (p *TavilyProvider) Search(
if len(query) == 0 {
return nil, fmt.Errorf("query is empty")
}
logger.Infof(ctx, "[WebSearch][Tavily] query=%q maxResults=%d url=%s", query, maxResults, p.baseURL)
reqBody := tavilySearchRequest{
APIKey: p.apiKey,
@@ -97,6 +88,7 @@ func (p *TavilyProvider) Search(
if resp.StatusCode != http.StatusOK {
respBody, _ := io.ReadAll(resp.Body)
logger.Warnf(ctx, "[WebSearch][Tavily] API returned status %d: %s", resp.StatusCode, string(respBody))
return nil, fmt.Errorf("tavily API returned status %d: %s", resp.StatusCode, string(respBody))
}
@@ -125,6 +117,7 @@ func (p *TavilyProvider) Search(
}
results = append(results, result)
}
logger.Infof(ctx, "[WebSearch][Tavily] returned %d results", len(results))
return results, nil
}

View File

@@ -53,7 +53,8 @@ type RouterParams struct {
InitializationHandler *handler.InitializationHandler
SystemHandler *handler.SystemHandler
MCPServiceHandler *handler.MCPServiceHandler
WebSearchHandler *handler.WebSearchHandler
WebSearchHandler *handler.WebSearchHandler
WebSearchProviderHandler *handler.WebSearchProviderHandler
FAQHandler *handler.FAQHandler
TagHandler *handler.TagHandler
CustomAgentHandler *handler.CustomAgentHandler
@@ -137,6 +138,7 @@ func NewRouter(params RouterParams) *gin.Engine {
RegisterSystemRoutes(v1, params.SystemHandler)
RegisterMCPServiceRoutes(v1, params.MCPServiceHandler)
RegisterWebSearchRoutes(v1, params.WebSearchHandler)
RegisterWebSearchProviderRoutes(v1, params.WebSearchProviderHandler)
RegisterCustomAgentRoutes(v1, params.CustomAgentHandler)
RegisterSkillRoutes(v1, params.SkillHandler)
RegisterOrganizationRoutes(v1, params.OrganizationHandler)
@@ -477,6 +479,25 @@ func RegisterWebSearchRoutes(r *gin.RouterGroup, webSearchHandler *handler.WebSe
}
}
// RegisterWebSearchProviderRoutes registers CRUD routes for web search provider configurations
func RegisterWebSearchProviderRoutes(r *gin.RouterGroup, h *handler.WebSearchProviderHandler) {
providers := r.Group("/web-search-providers")
{
// List available provider types (metadata for UI forms)
providers.GET("/types", h.ListProviderTypes)
// Test with raw credentials (no persistence)
providers.POST("/test", h.TestProviderRaw)
// CRUD
providers.POST("", h.CreateProvider)
providers.GET("", h.ListProviders)
providers.GET("/:id", h.GetProvider)
providers.PUT("/:id", h.UpdateProvider)
providers.DELETE("/:id", h.DeleteProvider)
// Test existing saved provider
providers.POST("/:id/test", h.TestProviderByID)
}
}
// RegisterCustomAgentRoutes registers custom agent routes
func RegisterCustomAgentRoutes(r *gin.RouterGroup, agentHandler *handler.CustomAgentHandler) {
agents := r.Group("/agents")

View File

@@ -64,7 +64,7 @@ func ConvertWebSearchResults(
result := &types.SearchResult{
ID: chunkID,
Content: content,
KnowledgeID: "",
KnowledgeID: chunkID, // Use URL as KnowledgeID so each web result stays independent during merge
ChunkIndex: 0,
KnowledgeTitle: webResult.Title,
StartAt: 0,

View File

@@ -25,6 +25,7 @@ type AgentConfig struct {
UseCustomSystemPrompt bool `json:"use_custom_system_prompt"` // Whether to use custom system prompt instead of default
WebSearchEnabled bool `json:"web_search_enabled"` // Whether web search tool is enabled
WebSearchMaxResults int `json:"web_search_max_results"` // Maximum number of web search results (default: 5)
WebSearchProviderID string `json:"web_search_provider_id,omitempty"` // WebSearchProviderEntity ID (resolved from agent config)
MultiTurnEnabled bool `json:"multi_turn_enabled"` // Whether multi-turn conversation is enabled
HistoryTurns int `json:"history_turns"` // Number of history turns to keep in context
SearchTargets SearchTargets `json:"-"` // Pre-computed unified search targets (runtime only)

View File

@@ -46,9 +46,12 @@ type PipelineRequest struct {
ChatModelSupportsVision bool `json:"-"`
// Misc request-scoped config
TenantID uint64 `json:"-"`
WebSearchEnabled bool `json:"-"`
Language string `json:"-"`
TenantID uint64 `json:"-"`
WebSearchEnabled bool `json:"-"`
WebSearchProviderID string `json:"-"` // Resolved from agent config or tenant default
WebFetchEnabled bool `json:"-"` // Auto-fetch full page content for web search results after rerank
WebFetchTopN int `json:"-"` // Max pages to fetch (default 3)
Language string `json:"-"`
}
// QueryIntent represents the classified intent of a user query.
@@ -183,6 +186,9 @@ func (c *ChatManage) Clone() *ChatManage {
ChatModelSupportsVision: c.ChatModelSupportsVision,
TenantID: c.TenantID,
WebSearchEnabled: c.WebSearchEnabled,
WebSearchProviderID: c.WebSearchProviderID,
WebFetchEnabled: c.WebFetchEnabled,
WebFetchTopN: c.WebFetchTopN,
Language: c.Language,
},
PipelineState: PipelineState{
@@ -205,6 +211,7 @@ const (
CHUNK_SEARCH_PARALLEL EventType = "chunk_search_parallel"
ENTITY_SEARCH EventType = "entity_search"
CHUNK_RERANK EventType = "chunk_rerank"
WEB_FETCH EventType = "web_fetch"
CHUNK_MERGE EventType = "chunk_merge"
DATA_ANALYSIS EventType = "data_analysis"
INTO_CHAT_MESSAGE EventType = "into_chat_message"

View File

@@ -141,6 +141,13 @@ type CustomAgentConfig struct {
WebSearchEnabled bool `yaml:"web_search_enabled" json:"web_search_enabled"`
// Maximum web search results
WebSearchMaxResults int `yaml:"web_search_max_results" json:"web_search_max_results"`
// WebSearchProviderID references a specific WebSearchProviderEntity.
// If empty, the tenant's default provider (is_default=true) is used.
WebSearchProviderID string `yaml:"web_search_provider_id" json:"web_search_provider_id,omitempty"`
// Whether to auto-fetch full page content for reranked web search results
WebFetchEnabled bool `yaml:"web_fetch_enabled" json:"web_fetch_enabled"`
// Max number of pages to fetch after rerank (default: 3)
WebFetchTopN int `yaml:"web_fetch_top_n" json:"web_fetch_top_n,omitempty"`
// ===== Multi-turn Conversation Settings =====
// Whether multi-turn conversation is enabled

View File

@@ -16,8 +16,9 @@ type WebSearchProvider interface {
// WebSearchService defines the interface for web search services
type WebSearchService interface {
// Search performs a web search
Search(ctx context.Context, config *types.WebSearchConfig, query string) ([]*types.WebSearchResult, error)
// Search performs a web search using the provider entity identified by providerID.
// If providerID is empty, it falls back to the deprecated config.Provider field for backward compatibility.
Search(ctx context.Context, providerID string, config *types.WebSearchConfig, query string) ([]*types.WebSearchResult, error)
// CompressWithRAG performs RAG-based compression using a temporary, hidden knowledge base
// The temporary knowledge base is deleted after use. The UI will not list it due to repo filtering.
CompressWithRAG(ctx context.Context, sessionID string, tempKBID string, questions []string,

View File

@@ -0,0 +1,39 @@
package interfaces
import (
"context"
"github.com/Tencent/WeKnora/internal/types"
)
// WebSearchProviderRepository defines the repository interface for web search provider CRUD
type WebSearchProviderRepository interface {
// Create creates a new web search provider
Create(ctx context.Context, provider *types.WebSearchProviderEntity) error
// GetByID retrieves a web search provider by ID within a tenant scope
GetByID(ctx context.Context, tenantID uint64, id string) (*types.WebSearchProviderEntity, error)
// GetDefault retrieves the default provider (is_default=true) for a tenant, or nil if none.
GetDefault(ctx context.Context, tenantID uint64) (*types.WebSearchProviderEntity, error)
// List lists all web search providers for a tenant
List(ctx context.Context, tenantID uint64) ([]*types.WebSearchProviderEntity, error)
// Update updates a web search provider
Update(ctx context.Context, provider *types.WebSearchProviderEntity) error
// Delete deletes a web search provider (soft delete)
Delete(ctx context.Context, tenantID uint64, id string) error
// ClearDefault clears the default flag for all providers of a tenant, optionally excluding one
ClearDefault(ctx context.Context, tenantID uint64, excludeID string) error
}
// WebSearchProviderService defines the service interface for web search provider management.
// Tenant isolation is enforced by the handler layer (getOwned pattern).
// Service methods operate on entities whose TenantID is already verified.
type WebSearchProviderService interface {
// CreateProvider creates a new web search provider.
// provider.TenantID must be set by the caller (handler).
CreateProvider(ctx context.Context, provider *types.WebSearchProviderEntity) error
// UpdateProvider updates an existing provider.
// provider.TenantID must be set by the caller (handler) for the repository WHERE clause.
UpdateProvider(ctx context.Context, provider *types.WebSearchProviderEntity) error
// DeleteProvider deletes a provider by tenant + id.
DeleteProvider(ctx context.Context, tenantID uint64, id string) error
}

View File

@@ -8,8 +8,11 @@ import (
// WebSearchConfig represents the web search configuration for a tenant
type WebSearchConfig struct {
Provider string `json:"provider"` // 搜索引擎提供商ID
APIKey string `json:"api_key"` // API密钥如果需要
// Deprecated: Use WebSearchProviderEntity.Parameters.APIKey instead.
Provider string `json:"provider,omitempty"`
// Deprecated: Use WebSearchProviderEntity.Parameters.APIKey instead.
APIKey string `json:"api_key,omitempty"`
MaxResults int `json:"max_results"` // 最大搜索结果数
IncludeDate bool `json:"include_date"` // 是否包含日期
CompressionMethod string `json:"compression_method"` // 压缩方法none, summary, extract, rag

View File

@@ -0,0 +1,157 @@
package types
import (
"database/sql/driver"
"encoding/json"
"time"
"github.com/Tencent/WeKnora/internal/utils"
"github.com/google/uuid"
"gorm.io/gorm"
)
// WebSearchProviderType represents the type of web search provider
type WebSearchProviderType string
const (
WebSearchProviderTypeBing WebSearchProviderType = "bing"
WebSearchProviderTypeGoogle WebSearchProviderType = "google"
WebSearchProviderTypeDuckDuckGo WebSearchProviderType = "duckduckgo"
WebSearchProviderTypeTavily WebSearchProviderType = "tavily"
)
// WebSearchProviderEntity represents a configured web search provider instance for a tenant.
// This is a CRUD entity stored in the database, similar to the Model entity.
// Each tenant can create multiple provider configurations (e.g., "Production Bing", "Test Google").
// Agents reference these by ID.
type WebSearchProviderEntity struct {
// Unique identifier (UUID, auto-generated)
ID string `yaml:"id" json:"id" gorm:"type:varchar(36);primaryKey"`
// Tenant ID for scoping
TenantID uint64 `yaml:"tenant_id" json:"tenant_id"`
// User-friendly name, e.g., "Production Bing Search"
Name string `yaml:"name" json:"name" gorm:"type:varchar(255);not null"`
// Provider type: bing, google, duckduckgo, tavily
Provider WebSearchProviderType `yaml:"provider" json:"provider" gorm:"type:varchar(50);not null"`
// Description
Description string `yaml:"description" json:"description" gorm:"type:text"`
// Provider-specific parameters (API key, engine ID, etc.) stored as encrypted JSON
Parameters WebSearchProviderParameters `yaml:"parameters" json:"parameters" gorm:"type:json"`
// Whether this is the default provider for the tenant
IsDefault bool `yaml:"is_default" json:"is_default" gorm:"default:false"`
// Timestamps
CreatedAt time.Time `yaml:"created_at" json:"created_at"`
UpdatedAt time.Time `yaml:"updated_at" json:"updated_at"`
DeletedAt gorm.DeletedAt `yaml:"deleted_at" json:"deleted_at" gorm:"index"`
}
// TableName returns the table name for WebSearchProviderEntity
func (WebSearchProviderEntity) TableName() string {
return "web_search_providers"
}
// BeforeCreate is a GORM hook that runs before creating a new record.
// Automatically generates a UUID for new providers.
func (e *WebSearchProviderEntity) BeforeCreate(tx *gorm.DB) (err error) {
if e.ID == "" {
e.ID = uuid.New().String()
}
return nil
}
// WebSearchProviderParameters holds provider-specific configuration.
// API keys are encrypted at rest using AES-GCM.
// BaseURL is intentionally NOT included — each provider type uses a hardcoded
// official API endpoint to prevent SSRF attacks.
type WebSearchProviderParameters struct {
// API key for the search provider (encrypted in DB)
APIKey string `yaml:"api_key" json:"api_key,omitempty"`
// Google Custom Search Engine ID (only for Google provider)
EngineID string `yaml:"engine_id" json:"engine_id,omitempty"`
// Provider-specific extra configuration for future extensibility
ExtraConfig map[string]string `yaml:"extra_config" json:"extra_config,omitempty"`
}
// Value implements the driver.Valuer interface.
// Encrypts APIKey before persisting to database.
func (p WebSearchProviderParameters) Value() (driver.Value, error) {
if key := utils.GetAESKey(); key != nil && p.APIKey != "" {
if encrypted, err := utils.EncryptAESGCM(p.APIKey, key); err == nil {
p.APIKey = encrypted
}
}
return json.Marshal(p)
}
// Scan implements the sql.Scanner interface.
// Decrypts APIKey after loading from database.
func (p *WebSearchProviderParameters) Scan(value interface{}) error {
if value == nil {
return nil
}
b, ok := value.([]byte)
if !ok {
return nil
}
if err := json.Unmarshal(b, p); err != nil {
return err
}
if key := utils.GetAESKey(); key != nil && p.APIKey != "" {
if decrypted, err := utils.DecryptAESGCM(p.APIKey, key); err == nil {
p.APIKey = decrypted
}
}
return nil
}
// WebSearchProviderTypeInfo describes the metadata of a provider type.
// Used by the GET /types endpoint so the frontend can dynamically render forms.
type WebSearchProviderTypeInfo struct {
// Provider type identifier
ID string `json:"id"`
// Human-readable name
Name string `json:"name"`
// Whether the provider requires an API key
RequiresAPIKey bool `json:"requires_api_key"`
// Whether the provider requires an engine ID (e.g., Google CSE)
RequiresEngineID bool `json:"requires_engine_id"`
// Description
Description string `json:"description"`
// URL to the provider's official website or documentation for obtaining credentials
DocsURL string `json:"docs_url,omitempty"`
}
// GetWebSearchProviderTypes returns metadata for all supported provider types.
func GetWebSearchProviderTypes() []WebSearchProviderTypeInfo {
return []WebSearchProviderTypeInfo{
{
ID: "duckduckgo",
Name: "DuckDuckGo",
RequiresAPIKey: false,
Description: "DuckDuckGo Search (free, no API key required)",
DocsURL: "https://duckduckgo.com/",
},
{
ID: "bing",
Name: "Bing",
RequiresAPIKey: true,
Description: "Bing Search API (requires API key from Azure)",
DocsURL: "https://learn.microsoft.com/en-us/bing/search-apis/bing-web-search/overview",
},
{
ID: "google",
Name: "Google",
RequiresAPIKey: true,
RequiresEngineID: true,
Description: "Google Custom Search API (requires API key and engine ID)",
DocsURL: "https://developers.google.com/custom-search/v1/overview",
},
{
ID: "tavily",
Name: "Tavily",
RequiresAPIKey: true,
Description: "Tavily Search API (requires API key)",
DocsURL: "https://tavily.com/",
},
}
}

View File

@@ -56,23 +56,5 @@ CREATE INDEX IF NOT EXISTS idx_sync_logs_tenant_id ON sync_logs (tenant_id);
CREATE INDEX IF NOT EXISTS idx_sync_logs_status ON sync_logs (status);
CREATE INDEX IF NOT EXISTS idx_sync_logs_started_at ON sync_logs (started_at);
-- Trigger function to auto-update updated_at column
CREATE OR REPLACE FUNCTION update_updated_at_column()
RETURNS TRIGGER AS $$
BEGIN
NEW.updated_at = CURRENT_TIMESTAMP;
RETURN NEW;
END;
$$ LANGUAGE plpgsql;
CREATE TRIGGER trg_data_sources_updated_at
BEFORE UPDATE ON data_sources
FOR EACH ROW
EXECUTE FUNCTION update_updated_at_column();
CREATE TRIGGER trg_sync_logs_updated_at
BEFORE UPDATE ON sync_logs
FOR EACH ROW
EXECUTE FUNCTION update_updated_at_column();
DO $$ BEGIN RAISE NOTICE '[Migration 000028] data_sources and sync_logs tables created successfully'; END $$;

View File

@@ -0,0 +1,3 @@
-- Rollback migration: 000029_web_search_providers
DROP TRIGGER IF EXISTS trg_web_search_providers_updated_at ON web_search_providers;
DROP TABLE IF EXISTS web_search_providers;

View File

@@ -0,0 +1,31 @@
-- Migration: 000029_web_search_providers
-- Description: Create web_search_providers table for tenant-specific search engine configurations
DO $$ BEGIN RAISE NOTICE '[Migration 000029] Creating web_search_providers table'; END $$;
-- Create web_search_providers table for managing tenant search engine configurations
-- Each row represents a configured search provider instance (e.g., "Production Bing", "Test Google")
-- Agents reference these by ID via custom_agents.config.web_search_provider_id
CREATE TABLE IF NOT EXISTS web_search_providers (
id VARCHAR(36) NOT NULL PRIMARY KEY,
tenant_id BIGINT NOT NULL,
name VARCHAR(255) NOT NULL,
provider VARCHAR(50) NOT NULL,
description TEXT,
parameters JSONB,
is_default BOOLEAN DEFAULT false,
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
deleted_at TIMESTAMP NULL
);
CREATE INDEX IF NOT EXISTS idx_web_search_providers_tenant_id ON web_search_providers (tenant_id);
CREATE INDEX IF NOT EXISTS idx_web_search_providers_provider ON web_search_providers (provider);
CREATE INDEX IF NOT EXISTS idx_web_search_providers_deleted_at ON web_search_providers (deleted_at);
-- Auto-update updated_at column (reuses the function from migration 000028)
CREATE TRIGGER trg_web_search_providers_updated_at
BEFORE UPDATE ON web_search_providers
FOR EACH ROW
EXECUTE FUNCTION update_updated_at_column();
DO $$ BEGIN RAISE NOTICE '[Migration 000029] web_search_providers table created successfully'; END $$;