mirror of
https://github.com/Tencent/WeKnora.git
synced 2026-06-04 13:30:32 +08:00
722 lines
30 KiB
Go
722 lines
30 KiB
Go
package config
|
||
|
||
import (
|
||
"fmt"
|
||
"os"
|
||
"path/filepath"
|
||
"regexp"
|
||
"strings"
|
||
"time"
|
||
|
||
"github.com/Tencent/WeKnora/internal/types"
|
||
"github.com/go-viper/mapstructure/v2"
|
||
"github.com/spf13/viper"
|
||
"gopkg.in/yaml.v3"
|
||
)
|
||
|
||
// Config 应用程序总配置
|
||
type Config struct {
|
||
Conversation *ConversationConfig `yaml:"conversation" json:"conversation"`
|
||
Server *ServerConfig `yaml:"server" json:"server"`
|
||
KnowledgeBase *KnowledgeBaseConfig `yaml:"knowledge_base" json:"knowledge_base"`
|
||
Tenant *TenantConfig `yaml:"tenant" json:"tenant"`
|
||
OIDCAuth *OIDCAuthConfig `yaml:"oidc_auth" json:"oidc_auth"`
|
||
Models []ModelConfig `yaml:"models" json:"models"`
|
||
VectorDatabase *VectorDatabaseConfig `yaml:"vector_database" json:"vector_database"`
|
||
DocReader *DocReaderConfig `yaml:"docreader" json:"docreader"`
|
||
StreamManager *StreamManagerConfig `yaml:"stream_manager" json:"stream_manager"`
|
||
ExtractManager *ExtractManagerConfig `yaml:"extract" json:"extract"`
|
||
WebSearch *WebSearchConfig `yaml:"web_search" json:"web_search"`
|
||
PromptTemplates *PromptTemplatesConfig `yaml:"prompt_templates" json:"prompt_templates"`
|
||
IM *IMConfig `yaml:"im" json:"im"`
|
||
}
|
||
|
||
// IMConfig configures the IM integration service.
|
||
// All fields are optional — zero values fall back to built-in defaults so
|
||
// existing deployments need no config changes.
|
||
type IMConfig struct {
|
||
// Workers is the number of concurrent QA worker goroutines per instance.
|
||
// Default: 5.
|
||
Workers int `yaml:"workers" json:"workers"`
|
||
// GlobalMaxWorkers is the maximum number of QA requests that can execute
|
||
// concurrently across ALL instances. Enforced via a Redis counter; when the
|
||
// global limit is reached, local workers wait until a slot opens.
|
||
// Requires Redis — ignored in single-instance mode.
|
||
// 0 (default) means no global limit.
|
||
GlobalMaxWorkers int `yaml:"global_max_workers" json:"global_max_workers"`
|
||
// MaxQueueSize is the maximum number of pending QA requests per instance.
|
||
// Default: 50.
|
||
MaxQueueSize int `yaml:"max_queue_size" json:"max_queue_size"`
|
||
// MaxPerUser limits how many requests a single user can have queued globally.
|
||
// Default: 3.
|
||
MaxPerUser int `yaml:"max_per_user" json:"max_per_user"`
|
||
// RateLimitWindow is the sliding window duration for per-user rate limiting.
|
||
// Default: 60s.
|
||
RateLimitWindow time.Duration `yaml:"rate_limit_window" json:"rate_limit_window"`
|
||
// RateLimitMax is the maximum number of requests allowed per window per user.
|
||
// Default: 10.
|
||
RateLimitMax int `yaml:"rate_limit_max" json:"rate_limit_max"`
|
||
}
|
||
|
||
// DocReaderConfig configures the document parser client (gRPC or HTTP).
|
||
type DocReaderConfig struct {
|
||
// Addr: for gRPC it is the server address (e.g. "localhost:50051"); for HTTP it is the base URL (e.g. "http://localhost:8080").
|
||
Addr string `yaml:"addr" json:"addr"`
|
||
// Transport: "grpc" (default) or "http"
|
||
Transport string `yaml:"transport" json:"transport"`
|
||
}
|
||
|
||
type VectorDatabaseConfig struct {
|
||
Driver string `yaml:"driver" json:"driver"`
|
||
}
|
||
|
||
// ConversationConfig 对话服务配置
|
||
type ConversationConfig struct {
|
||
MaxRounds int `yaml:"max_rounds" json:"max_rounds"`
|
||
KeywordThreshold float64 `yaml:"keyword_threshold" json:"keyword_threshold"`
|
||
EmbeddingTopK int `yaml:"embedding_top_k" json:"embedding_top_k"`
|
||
VectorThreshold float64 `yaml:"vector_threshold" json:"vector_threshold"`
|
||
RerankTopK int `yaml:"rerank_top_k" json:"rerank_top_k"`
|
||
RerankThreshold float64 `yaml:"rerank_threshold" json:"rerank_threshold"`
|
||
FallbackStrategy string `yaml:"fallback_strategy" json:"fallback_strategy"`
|
||
FallbackResponse string `yaml:"fallback_response" json:"fallback_response"`
|
||
EnableRewrite bool `yaml:"enable_rewrite" json:"enable_rewrite"`
|
||
EnableQueryExpansion bool `yaml:"enable_query_expansion" json:"enable_query_expansion"`
|
||
EnableRerank bool `yaml:"enable_rerank" json:"enable_rerank"`
|
||
Summary *SummaryConfig `yaml:"summary" json:"summary"`
|
||
|
||
// Prompt template ID fields — resolved to text by backfillConversationDefaults
|
||
FallbackPromptID string `yaml:"fallback_prompt_id" json:"fallback_prompt_id"`
|
||
RewritePromptID string `yaml:"rewrite_prompt_id" json:"rewrite_prompt_id"`
|
||
GenerateSessionTitlePromptID string `yaml:"generate_session_title_prompt_id" json:"generate_session_title_prompt_id"`
|
||
GenerateSummaryPromptID string `yaml:"generate_summary_prompt_id" json:"generate_summary_prompt_id"`
|
||
ExtractEntitiesPromptID string `yaml:"extract_entities_prompt_id" json:"extract_entities_prompt_id"`
|
||
ExtractRelationshipsPromptID string `yaml:"extract_relationships_prompt_id" json:"extract_relationships_prompt_id"`
|
||
GenerateQuestionsPromptID string `yaml:"generate_questions_prompt_id" json:"generate_questions_prompt_id"`
|
||
|
||
// Resolved prompt text fields (populated by backfill, not from YAML)
|
||
FallbackPrompt string `yaml:"-" json:"fallback_prompt"`
|
||
RewritePromptSystem string `yaml:"-" json:"rewrite_prompt_system"`
|
||
RewritePromptUser string `yaml:"-" json:"rewrite_prompt_user"`
|
||
GenerateSessionTitlePrompt string `yaml:"-" json:"generate_session_title_prompt"`
|
||
GenerateSummaryPrompt string `yaml:"-" json:"generate_summary_prompt"`
|
||
ExtractEntitiesPrompt string `yaml:"-" json:"extract_entities_prompt"`
|
||
ExtractRelationshipsPrompt string `yaml:"-" json:"extract_relationships_prompt"`
|
||
GenerateQuestionsPrompt string `yaml:"-" json:"generate_questions_prompt"`
|
||
|
||
// IntentSystemPrompts maps intent values (e.g. "greeting", "chitchat") to
|
||
// system prompt text. Populated by backfill from IntentPrompts templates.
|
||
IntentSystemPrompts map[string]string `yaml:"-" json:"-"`
|
||
}
|
||
|
||
// SummaryConfig 摘要配置
|
||
type SummaryConfig struct {
|
||
MaxTokens int `yaml:"max_tokens" json:"max_tokens"`
|
||
RepeatPenalty float64 `yaml:"repeat_penalty" json:"repeat_penalty"`
|
||
TopK int `yaml:"top_k" json:"top_k"`
|
||
TopP float64 `yaml:"top_p" json:"top_p"`
|
||
FrequencyPenalty float64 `yaml:"frequency_penalty" json:"frequency_penalty"`
|
||
PresencePenalty float64 `yaml:"presence_penalty" json:"presence_penalty"`
|
||
Temperature float64 `yaml:"temperature" json:"temperature"`
|
||
Seed int `yaml:"seed" json:"seed"`
|
||
MaxCompletionTokens int `yaml:"max_completion_tokens" json:"max_completion_tokens"`
|
||
NoMatchPrefix string `yaml:"no_match_prefix" json:"no_match_prefix"`
|
||
Thinking *bool `yaml:"thinking" json:"thinking"`
|
||
|
||
// Prompt template ID fields — resolved to text by backfillConversationDefaults
|
||
PromptID string `yaml:"prompt_id" json:"prompt_id"`
|
||
ContextTemplateID string `yaml:"context_template_id" json:"context_template_id"`
|
||
|
||
// Resolved prompt text fields (populated by backfill, not from YAML)
|
||
Prompt string `yaml:"-" json:"prompt"`
|
||
ContextTemplate string `yaml:"-" json:"context_template"`
|
||
}
|
||
|
||
// ServerConfig 服务器配置
|
||
type ServerConfig struct {
|
||
Port int `yaml:"port" json:"port"`
|
||
Host string `yaml:"host" json:"host"`
|
||
LogPath string `yaml:"log_path" json:"log_path"`
|
||
ShutdownTimeout time.Duration `yaml:"shutdown_timeout" json:"shutdown_timeout" default:"30s"`
|
||
}
|
||
|
||
// KnowledgeBaseConfig 知识库配置
|
||
type KnowledgeBaseConfig struct {
|
||
ChunkSize int `yaml:"chunk_size" json:"chunk_size"`
|
||
ChunkOverlap int `yaml:"chunk_overlap" json:"chunk_overlap"`
|
||
SplitMarkers []string `yaml:"split_markers" json:"split_markers"`
|
||
KeepSeparator bool `yaml:"keep_separator" json:"keep_separator"`
|
||
ImageProcessing *ImageProcessingConfig `yaml:"image_processing" json:"image_processing"`
|
||
}
|
||
|
||
// ImageProcessingConfig 图像处理配置
|
||
type ImageProcessingConfig struct {
|
||
EnableMultimodal bool `yaml:"enable_multimodal" json:"enable_multimodal"`
|
||
}
|
||
|
||
// TenantConfig 租户配置
|
||
type TenantConfig struct {
|
||
DefaultSessionName string `yaml:"default_session_name" json:"default_session_name"`
|
||
DefaultSessionTitle string `yaml:"default_session_title" json:"default_session_title"`
|
||
DefaultSessionDescription string `yaml:"default_session_description" json:"default_session_description"`
|
||
// EnableCrossTenantAccess enables cross-tenant access for users with permission
|
||
EnableCrossTenantAccess bool `yaml:"enable_cross_tenant_access" json:"enable_cross_tenant_access"`
|
||
}
|
||
|
||
type OIDCUserInfoMapping struct {
|
||
Username string `yaml:"username" json:"username"`
|
||
Email string `yaml:"email" json:"email"`
|
||
}
|
||
|
||
type OIDCAuthConfig struct {
|
||
Enable bool `yaml:"enable" json:"enable"`
|
||
IssuerURL string `yaml:"issuer_url" json:"issuer_url"`
|
||
DiscoveryURL string `yaml:"discovery_url" json:"discovery_url"`
|
||
ProviderDisplayName string `yaml:"provider_display_name" json:"provider_display_name"`
|
||
ClientID string `yaml:"client_id" json:"client_id"`
|
||
ClientSecret string `yaml:"client_secret" json:"-"`
|
||
AuthorizationEndpoint string `yaml:"authorization_endpoint" json:"authorization_endpoint"`
|
||
TokenEndpoint string `yaml:"token_endpoint" json:"token_endpoint"`
|
||
UserInfoEndpoint string `yaml:"user_info_endpoint" json:"user_info_endpoint"`
|
||
Scopes []string `yaml:"scopes" json:"scopes"`
|
||
UserInfoMapping *OIDCUserInfoMapping `yaml:"user_info_mapping" json:"user_info_mapping"`
|
||
}
|
||
|
||
// PromptTemplateI18n holds localized name and description for a prompt template.
|
||
type PromptTemplateI18n struct {
|
||
Name string `yaml:"name" json:"name"`
|
||
Description string `yaml:"description" json:"description"`
|
||
}
|
||
|
||
// PromptTemplate 提示词模板
|
||
//
|
||
// 字段设计:每个模板最多由两部分组成 —— 系统侧 (content) 和用户侧 (user)。
|
||
// - content: 主要内容 / 系统 Prompt(所有模板都使用此字段)
|
||
// - user: 用户侧 Prompt(仅在需要 system+user 配对的模板中使用,如 rewrite、keywords_extraction)
|
||
// - i18n: 多语言 name/description,键为 locale(如 "zh-CN"、"en-US"、"ko-KR"),后端根据请求语言替换 Name/Description 再返回
|
||
type PromptTemplate struct {
|
||
ID string `yaml:"id" json:"id"`
|
||
Name string `yaml:"name" json:"name"`
|
||
Description string `yaml:"description" json:"description"`
|
||
Content string `yaml:"content" json:"content"`
|
||
User string `yaml:"user" json:"user,omitempty"`
|
||
HasKnowledgeBase bool `yaml:"has_knowledge_base" json:"has_knowledge_base,omitempty"`
|
||
HasWebSearch bool `yaml:"has_web_search" json:"has_web_search,omitempty"`
|
||
Default bool `yaml:"default" json:"default,omitempty"`
|
||
Mode string `yaml:"mode" json:"mode,omitempty"`
|
||
I18n map[string]PromptTemplateI18n `yaml:"i18n" json:"-"`
|
||
}
|
||
|
||
// PromptTemplatesConfig 提示词模板配置
|
||
//
|
||
// 每种 Prompt 类型对应一个 YAML 文件,所有模板都在同一个字段(文件)中管理。
|
||
// 每个模板使用 content (system prompt) + user (user prompt) 两个字段。
|
||
type PromptTemplatesConfig struct {
|
||
SystemPrompt []PromptTemplate `yaml:"system_prompt" json:"system_prompt"`
|
||
ContextTemplate []PromptTemplate `yaml:"context_template" json:"context_template"`
|
||
// Rewrite 合并了前端可选模板和运行时默认模板,每个模板同时包含 content + user
|
||
Rewrite []PromptTemplate `yaml:"rewrite" json:"rewrite"`
|
||
// Fallback 合并了固定回复模板和模型兜底 prompt(通过 mode:"model" 区分)
|
||
Fallback []PromptTemplate `yaml:"fallback" json:"fallback"`
|
||
|
||
GenerateSessionTitle []PromptTemplate `yaml:"generate_session_title" json:"generate_session_title,omitempty"`
|
||
GenerateSummary []PromptTemplate `yaml:"generate_summary" json:"generate_summary,omitempty"`
|
||
KeywordsExtraction []PromptTemplate `yaml:"keywords_extraction" json:"keywords_extraction,omitempty"`
|
||
AgentSystemPrompt []PromptTemplate `yaml:"agent_system_prompt" json:"agent_system_prompt,omitempty"`
|
||
GraphExtraction []PromptTemplate `yaml:"graph_extraction" json:"graph_extraction,omitempty"`
|
||
GenerateQuestions []PromptTemplate `yaml:"generate_questions" json:"generate_questions,omitempty"`
|
||
// IntentPrompts holds per-intent system prompt overrides (template ID = intent value).
|
||
IntentPrompts []PromptTemplate `yaml:"intent_prompts" json:"intent_prompts,omitempty"`
|
||
}
|
||
|
||
// DefaultTemplate returns the first template marked as default in the list,
|
||
// or the first template if none is marked, or nil if the list is empty.
|
||
func DefaultTemplate(templates []PromptTemplate) *PromptTemplate {
|
||
for i := range templates {
|
||
if templates[i].Default {
|
||
return &templates[i]
|
||
}
|
||
}
|
||
if len(templates) > 0 {
|
||
return &templates[0]
|
||
}
|
||
return nil
|
||
}
|
||
|
||
// DefaultTemplateByMode returns the default template filtered by mode.
|
||
func DefaultTemplateByMode(templates []PromptTemplate, mode string) *PromptTemplate {
|
||
for i := range templates {
|
||
if templates[i].Mode == mode && templates[i].Default {
|
||
return &templates[i]
|
||
}
|
||
}
|
||
for i := range templates {
|
||
if templates[i].Mode == mode {
|
||
return &templates[i]
|
||
}
|
||
}
|
||
return DefaultTemplate(templates)
|
||
}
|
||
|
||
// LocalizeTemplates returns a deep copy of the template list with Name and
|
||
// Description replaced according to the given locale. Fallback chain:
|
||
// locale → primary language (e.g. "zh" from "zh-CN") → original Name/Description.
|
||
// The returned slice is safe to serialise directly; it never mutates the original.
|
||
func LocalizeTemplates(templates []PromptTemplate, locale string) []PromptTemplate {
|
||
if len(templates) == 0 {
|
||
return templates
|
||
}
|
||
out := make([]PromptTemplate, len(templates))
|
||
copy(out, templates)
|
||
for i := range out {
|
||
if len(out[i].I18n) == 0 {
|
||
continue
|
||
}
|
||
// Try exact match first (e.g. "zh-CN"), then primary subtag (e.g. "zh")
|
||
l10n, ok := out[i].I18n[locale]
|
||
if !ok {
|
||
if idx := strings.IndexByte(locale, '-'); idx > 0 {
|
||
l10n, ok = out[i].I18n[locale[:idx]]
|
||
}
|
||
}
|
||
if !ok {
|
||
continue
|
||
}
|
||
if l10n.Name != "" {
|
||
out[i].Name = l10n.Name
|
||
}
|
||
if l10n.Description != "" {
|
||
out[i].Description = l10n.Description
|
||
}
|
||
}
|
||
return out
|
||
}
|
||
|
||
// ModelConfig 模型配置
|
||
type ModelConfig struct {
|
||
Type string `yaml:"type" json:"type"`
|
||
Source string `yaml:"source" json:"source"`
|
||
ModelName string `yaml:"model_name" json:"model_name"`
|
||
Parameters map[string]interface{} `yaml:"parameters" json:"parameters"`
|
||
}
|
||
|
||
// StreamManagerConfig 流管理器配置
|
||
type StreamManagerConfig struct {
|
||
Type string `yaml:"type" json:"type"` // 类型: "memory" 或 "redis"
|
||
Redis RedisConfig `yaml:"redis" json:"redis"` // Redis配置
|
||
CleanupTimeout time.Duration `yaml:"cleanup_timeout" json:"cleanup_timeout"` // 清理超时,单位秒
|
||
}
|
||
|
||
// RedisConfig Redis配置
|
||
type RedisConfig struct {
|
||
Address string `yaml:"address" json:"address"` // Redis地址
|
||
Username string `yaml:"username" json:"username"` // Redis用户名
|
||
Password string `yaml:"password" json:"password"` // Redis密码
|
||
DB int `yaml:"db" json:"db"` // Redis数据库
|
||
Prefix string `yaml:"prefix" json:"prefix"` // 键前缀
|
||
TTL time.Duration `yaml:"ttl" json:"ttl"` // 过期时间(小时)
|
||
}
|
||
|
||
// ExtractManagerConfig 抽取管理器配置
|
||
type ExtractManagerConfig struct {
|
||
ExtractGraph *types.PromptTemplateStructured `yaml:"extract_graph" json:"extract_graph"`
|
||
ExtractEntity *types.PromptTemplateStructured `yaml:"extract_entity" json:"extract_entity"`
|
||
FabriText *FebriText `yaml:"fabri_text" json:"fabri_text"`
|
||
}
|
||
|
||
type FebriText struct {
|
||
WithTag string `yaml:"with_tag" json:"with_tag"`
|
||
WithNoTag string `yaml:"with_no_tag" json:"with_no_tag"`
|
||
}
|
||
|
||
// LoadConfig 从配置文件加载配置
|
||
func LoadConfig() (*Config, error) {
|
||
// 设置配置文件名和路径
|
||
viper.SetConfigName("config") // 配置文件名称(不带扩展名)
|
||
viper.SetConfigType("yaml") // 配置文件类型
|
||
viper.AddConfigPath(".") // 当前目录
|
||
viper.AddConfigPath("./config") // config子目录
|
||
viper.AddConfigPath("$HOME/.appname") // 用户目录
|
||
viper.AddConfigPath("/etc/appname/") // etc目录
|
||
|
||
// 启用环境变量替换
|
||
viper.AutomaticEnv()
|
||
viper.SetEnvKeyReplacer(strings.NewReplacer(".", "_"))
|
||
|
||
// 读取配置文件
|
||
if err := viper.ReadInConfig(); err != nil {
|
||
return nil, fmt.Errorf("error reading config file: %w", err)
|
||
}
|
||
|
||
// 替换配置中的环境变量引用
|
||
configFileContent, err := os.ReadFile(viper.ConfigFileUsed())
|
||
if err != nil {
|
||
return nil, fmt.Errorf("error reading config file content: %w", err)
|
||
}
|
||
|
||
// 替换${ENV_VAR}格式的环境变量引用
|
||
re := regexp.MustCompile(`\${([^}]+)}`)
|
||
result := re.ReplaceAllStringFunc(string(configFileContent), func(match string) string {
|
||
// 提取环境变量名称(去掉${}部分)
|
||
envVar := match[2 : len(match)-1]
|
||
// 获取环境变量值,如果不存在则保持原样
|
||
if value := os.Getenv(envVar); value != "" {
|
||
return value
|
||
}
|
||
return match
|
||
})
|
||
|
||
// 使用处理后的配置内容
|
||
viper.ReadConfig(strings.NewReader(result))
|
||
|
||
// 解析配置到结构体
|
||
var cfg Config
|
||
if err := viper.Unmarshal(&cfg, func(dc *mapstructure.DecoderConfig) {
|
||
dc.TagName = "yaml"
|
||
}); err != nil {
|
||
return nil, fmt.Errorf("unable to decode config into struct: %w", err)
|
||
}
|
||
fmt.Printf("Using configuration file: %s\n", viper.ConfigFileUsed())
|
||
|
||
// 加载提示词模板(从目录或配置文件)
|
||
configDir := filepath.Dir(viper.ConfigFileUsed())
|
||
promptTemplates, err := loadPromptTemplates(configDir)
|
||
if err != nil {
|
||
fmt.Printf("Warning: failed to load prompt templates from directory: %v\n", err)
|
||
// 如果目录加载失败,使用配置文件中的模板(如果有)
|
||
} else if promptTemplates != nil {
|
||
cfg.PromptTemplates = promptTemplates
|
||
}
|
||
|
||
// Back-fill conversation config from prompt templates defaults
|
||
// (so config.yaml can omit large prompt blocks and rely on template files)
|
||
if cfg.PromptTemplates != nil && cfg.Conversation != nil {
|
||
backfillConversationDefaults(&cfg)
|
||
}
|
||
|
||
// Load built-in agent definitions (i18n-aware) from builtin_agents.yaml
|
||
if err := types.LoadBuiltinAgentsConfig(configDir); err != nil {
|
||
fmt.Printf("Warning: failed to load builtin agents config: %v\n", err)
|
||
}
|
||
|
||
// Resolve prompt template ID references in builtin agent configs
|
||
// (e.g. system_prompt_id -> actual content from agent_system_prompt.yaml)
|
||
if cfg.PromptTemplates != nil {
|
||
resolveBuiltinAgentPromptIDs(cfg.PromptTemplates)
|
||
}
|
||
|
||
// Validate configuration values
|
||
applyOIDCEnvOverrides(&cfg)
|
||
|
||
if err := ValidateConfig(&cfg); err != nil {
|
||
return nil, err
|
||
}
|
||
|
||
return &cfg, nil
|
||
}
|
||
|
||
// ValidateConfig performs basic validation of the loaded configuration.
|
||
// It checks for obviously invalid or missing values that would cause runtime failures.
|
||
func ValidateConfig(cfg *Config) error {
|
||
var errs []string
|
||
|
||
if cfg.OIDCAuth != nil && cfg.OIDCAuth.Enable {
|
||
if strings.TrimSpace(cfg.OIDCAuth.ClientID) == "" {
|
||
errs = append(errs, "oidc_auth.client_id is required when OIDC is enabled")
|
||
}
|
||
if strings.TrimSpace(cfg.OIDCAuth.ClientSecret) == "" {
|
||
errs = append(errs, "oidc_auth.client_secret is required when OIDC is enabled")
|
||
}
|
||
if strings.TrimSpace(cfg.OIDCAuth.DiscoveryURL) == "" &&
|
||
(strings.TrimSpace(cfg.OIDCAuth.AuthorizationEndpoint) == "" || strings.TrimSpace(cfg.OIDCAuth.TokenEndpoint) == "") {
|
||
errs = append(errs, "oidc_auth.discovery_url or both oidc_auth.authorization_endpoint and oidc_auth.token_endpoint are required when OIDC is enabled")
|
||
}
|
||
}
|
||
|
||
if cfg.Conversation != nil {
|
||
if cfg.Conversation.EmbeddingTopK < 0 {
|
||
errs = append(errs, "conversation.embedding_top_k must be >= 0")
|
||
}
|
||
if cfg.Conversation.RerankTopK < 0 {
|
||
errs = append(errs, "conversation.rerank_top_k must be >= 0")
|
||
}
|
||
if cfg.Conversation.VectorThreshold < 0 || cfg.Conversation.VectorThreshold > 1 {
|
||
errs = append(errs, "conversation.vector_threshold must be between 0 and 1")
|
||
}
|
||
if cfg.Conversation.RerankThreshold < -10 || cfg.Conversation.RerankThreshold > 10 {
|
||
errs = append(errs, "conversation.rerank_threshold must be between -10 and 10")
|
||
}
|
||
}
|
||
|
||
if cfg.KnowledgeBase != nil {
|
||
if cfg.KnowledgeBase.ChunkSize <= 0 {
|
||
errs = append(errs, "knowledge_base.chunk_size must be > 0")
|
||
}
|
||
if cfg.KnowledgeBase.ChunkOverlap < 0 {
|
||
errs = append(errs, "knowledge_base.chunk_overlap must be >= 0")
|
||
}
|
||
if cfg.KnowledgeBase.ChunkOverlap >= cfg.KnowledgeBase.ChunkSize {
|
||
errs = append(errs, "knowledge_base.chunk_overlap must be less than chunk_size")
|
||
}
|
||
}
|
||
|
||
if cfg.Server != nil {
|
||
if cfg.Server.Port <= 0 || cfg.Server.Port > 65535 {
|
||
errs = append(errs, "server.port must be between 1 and 65535")
|
||
}
|
||
}
|
||
|
||
if len(errs) > 0 {
|
||
return fmt.Errorf("config validation errors: %s", strings.Join(errs, "; "))
|
||
}
|
||
return nil
|
||
}
|
||
|
||
func applyOIDCEnvOverrides(cfg *Config) {
|
||
if cfg.OIDCAuth == nil {
|
||
cfg.OIDCAuth = &OIDCAuthConfig{}
|
||
}
|
||
if cfg.OIDCAuth.UserInfoMapping == nil {
|
||
cfg.OIDCAuth.UserInfoMapping = &OIDCUserInfoMapping{}
|
||
}
|
||
|
||
if value := strings.TrimSpace(os.Getenv("OIDC_AUTH_ENABLE")); value != "" {
|
||
cfg.OIDCAuth.Enable = strings.EqualFold(value, "true")
|
||
}
|
||
if value := strings.TrimSpace(os.Getenv("OIDC_AUTH_ISSUER_URL")); value != "" {
|
||
cfg.OIDCAuth.IssuerURL = value
|
||
}
|
||
if value := strings.TrimSpace(os.Getenv("OIDC_AUTH_DISCOVERY_URL")); value != "" {
|
||
cfg.OIDCAuth.DiscoveryURL = value
|
||
}
|
||
if value := strings.TrimSpace(os.Getenv("OIDC_AUTH_PROVIDER_DISPLAY_NAME")); value != "" {
|
||
cfg.OIDCAuth.ProviderDisplayName = value
|
||
}
|
||
if value := strings.TrimSpace(os.Getenv("OIDC_AUTH_CLIENT_ID")); value != "" {
|
||
cfg.OIDCAuth.ClientID = value
|
||
}
|
||
if value := strings.TrimSpace(os.Getenv("OIDC_AUTH_CLIENT_SECRET")); value != "" {
|
||
cfg.OIDCAuth.ClientSecret = value
|
||
}
|
||
if value := strings.TrimSpace(os.Getenv("OIDC_AUTH_AUTHORIZATION_ENDPOINT")); value != "" {
|
||
cfg.OIDCAuth.AuthorizationEndpoint = value
|
||
}
|
||
if value := strings.TrimSpace(os.Getenv("OIDC_AUTH_TOKEN_ENDPOINT")); value != "" {
|
||
cfg.OIDCAuth.TokenEndpoint = value
|
||
}
|
||
if value := strings.TrimSpace(os.Getenv("OIDC_AUTH_USER_INFO_ENDPOINT")); value != "" {
|
||
cfg.OIDCAuth.UserInfoEndpoint = value
|
||
}
|
||
if value := strings.TrimSpace(os.Getenv("OIDC_AUTH_SCOPES")); value != "" {
|
||
cfg.OIDCAuth.Scopes = strings.Fields(strings.ReplaceAll(value, ",", " "))
|
||
}
|
||
if value := strings.TrimSpace(os.Getenv("OIDC_USER_INFO_MAPPING_USER_NAME")); value != "" {
|
||
cfg.OIDCAuth.UserInfoMapping.Username = value
|
||
}
|
||
if value := strings.TrimSpace(os.Getenv("OIDC_USER_INFO_MAPPING_EMAIL")); value != "" {
|
||
cfg.OIDCAuth.UserInfoMapping.Email = value
|
||
}
|
||
|
||
if cfg.OIDCAuth.ProviderDisplayName == "" {
|
||
cfg.OIDCAuth.ProviderDisplayName = "OIDC"
|
||
}
|
||
if len(cfg.OIDCAuth.Scopes) == 0 {
|
||
cfg.OIDCAuth.Scopes = []string{"openid", "profile", "email"}
|
||
}
|
||
if cfg.OIDCAuth.UserInfoMapping.Username == "" {
|
||
cfg.OIDCAuth.UserInfoMapping.Username = "name"
|
||
}
|
||
if cfg.OIDCAuth.UserInfoMapping.Email == "" {
|
||
cfg.OIDCAuth.UserInfoMapping.Email = "email"
|
||
}
|
||
if cfg.OIDCAuth.DiscoveryURL == "" && cfg.OIDCAuth.IssuerURL != "" {
|
||
cfg.OIDCAuth.DiscoveryURL = strings.TrimRight(cfg.OIDCAuth.IssuerURL, "/") + "/.well-known/openid-configuration"
|
||
}
|
||
}
|
||
|
||
// backfillConversationDefaults resolves prompt template ID references
|
||
// into actual prompt text content. Only xxx_id fields are used;
|
||
// no fallback to default templates.
|
||
func backfillConversationDefaults(cfg *Config) {
|
||
pt := cfg.PromptTemplates
|
||
conv := cfg.Conversation
|
||
|
||
if conv.FallbackPromptID != "" {
|
||
if t := FindTemplateByID(pt, conv.FallbackPromptID); t != nil {
|
||
conv.FallbackPrompt = t.Content
|
||
} else {
|
||
fmt.Printf("Warning: fallback_prompt_id %q not found\n", conv.FallbackPromptID)
|
||
}
|
||
}
|
||
if conv.RewritePromptID != "" {
|
||
if t := FindTemplateByID(pt, conv.RewritePromptID); t != nil {
|
||
conv.RewritePromptSystem = t.Content
|
||
conv.RewritePromptUser = t.User
|
||
} else {
|
||
fmt.Printf("Warning: rewrite_prompt_id %q not found\n", conv.RewritePromptID)
|
||
}
|
||
}
|
||
if conv.GenerateSessionTitlePromptID != "" {
|
||
if t := FindTemplateByID(pt, conv.GenerateSessionTitlePromptID); t != nil {
|
||
conv.GenerateSessionTitlePrompt = t.Content
|
||
} else {
|
||
fmt.Printf("Warning: generate_session_title_prompt_id %q not found\n", conv.GenerateSessionTitlePromptID)
|
||
}
|
||
}
|
||
if conv.GenerateSummaryPromptID != "" {
|
||
if t := FindTemplateByID(pt, conv.GenerateSummaryPromptID); t != nil {
|
||
conv.GenerateSummaryPrompt = t.Content
|
||
} else {
|
||
fmt.Printf("Warning: generate_summary_prompt_id %q not found\n", conv.GenerateSummaryPromptID)
|
||
}
|
||
}
|
||
if conv.ExtractEntitiesPromptID != "" {
|
||
if t := FindTemplateByID(pt, conv.ExtractEntitiesPromptID); t != nil {
|
||
conv.ExtractEntitiesPrompt = t.Content
|
||
} else {
|
||
fmt.Printf("Warning: extract_entities_prompt_id %q not found\n", conv.ExtractEntitiesPromptID)
|
||
}
|
||
}
|
||
if conv.ExtractRelationshipsPromptID != "" {
|
||
if t := FindTemplateByID(pt, conv.ExtractRelationshipsPromptID); t != nil {
|
||
conv.ExtractRelationshipsPrompt = t.Content
|
||
} else {
|
||
fmt.Printf("Warning: extract_relationships_prompt_id %q not found\n", conv.ExtractRelationshipsPromptID)
|
||
}
|
||
}
|
||
if conv.GenerateQuestionsPromptID != "" {
|
||
if t := FindTemplateByID(pt, conv.GenerateQuestionsPromptID); t != nil {
|
||
conv.GenerateQuestionsPrompt = t.Content
|
||
} else {
|
||
fmt.Printf("Warning: generate_questions_prompt_id %q not found\n", conv.GenerateQuestionsPromptID)
|
||
}
|
||
}
|
||
if conv.Summary != nil {
|
||
if conv.Summary.PromptID != "" {
|
||
if t := FindTemplateByID(pt, conv.Summary.PromptID); t != nil {
|
||
conv.Summary.Prompt = t.Content
|
||
} else {
|
||
fmt.Printf("Warning: summary.prompt_id %q not found\n", conv.Summary.PromptID)
|
||
}
|
||
}
|
||
if conv.Summary.ContextTemplateID != "" {
|
||
if t := FindTemplateByID(pt, conv.Summary.ContextTemplateID); t != nil {
|
||
conv.Summary.ContextTemplate = t.Content
|
||
} else {
|
||
fmt.Printf("Warning: summary.context_template_id %q not found\n", conv.Summary.ContextTemplateID)
|
||
}
|
||
}
|
||
}
|
||
|
||
// Build intent→system-prompt map from IntentPrompts templates.
|
||
// Template ID must equal the QueryIntent string value (e.g. "greeting").
|
||
if len(pt.IntentPrompts) > 0 {
|
||
conv.IntentSystemPrompts = make(map[string]string, len(pt.IntentPrompts))
|
||
for _, t := range pt.IntentPrompts {
|
||
if t.ID != "" && t.Content != "" {
|
||
conv.IntentSystemPrompts[t.ID] = t.Content
|
||
}
|
||
}
|
||
}
|
||
}
|
||
|
||
// FindTemplateByID searches across all template lists for a template with the given ID.
|
||
// It returns the template if found, or nil otherwise.
|
||
func FindTemplateByID(pt *PromptTemplatesConfig, id string) *PromptTemplate {
|
||
if pt == nil || id == "" {
|
||
return nil
|
||
}
|
||
// Search all template collections
|
||
for _, list := range [][]PromptTemplate{
|
||
pt.SystemPrompt,
|
||
pt.ContextTemplate,
|
||
pt.Rewrite,
|
||
pt.Fallback,
|
||
pt.GenerateSessionTitle,
|
||
pt.GenerateSummary,
|
||
pt.KeywordsExtraction,
|
||
pt.AgentSystemPrompt,
|
||
pt.GraphExtraction,
|
||
pt.GenerateQuestions,
|
||
pt.IntentPrompts,
|
||
} {
|
||
for i := range list {
|
||
if list[i].ID == id {
|
||
return &list[i]
|
||
}
|
||
}
|
||
}
|
||
return nil
|
||
}
|
||
|
||
// resolveBuiltinAgentPromptIDs resolves system_prompt_id and context_template_id
|
||
// references in builtin agent configs by looking up the actual content from
|
||
// prompt template YAML files.
|
||
func resolveBuiltinAgentPromptIDs(pt *PromptTemplatesConfig) {
|
||
types.ResolveBuiltinAgentPromptRefs(func(id string) string {
|
||
if t := FindTemplateByID(pt, id); t != nil {
|
||
return t.Content
|
||
}
|
||
return ""
|
||
})
|
||
}
|
||
|
||
// promptTemplateFile 用于解析模板文件
|
||
type promptTemplateFile struct {
|
||
Templates []PromptTemplate `yaml:"templates"`
|
||
}
|
||
|
||
// loadPromptTemplates 从目录加载提示词模板
|
||
func loadPromptTemplates(configDir string) (*PromptTemplatesConfig, error) {
|
||
templatesDir := filepath.Join(configDir, "prompt_templates")
|
||
|
||
// 检查目录是否存在
|
||
if _, err := os.Stat(templatesDir); os.IsNotExist(err) {
|
||
return nil, nil // 目录不存在,返回nil让调用者使用配置文件中的模板
|
||
}
|
||
|
||
config := &PromptTemplatesConfig{}
|
||
|
||
// 定义模板文件映射
|
||
templateFiles := map[string]*[]PromptTemplate{
|
||
"system_prompt.yaml": &config.SystemPrompt,
|
||
"context_template.yaml": &config.ContextTemplate,
|
||
"rewrite.yaml": &config.Rewrite,
|
||
"fallback.yaml": &config.Fallback,
|
||
"generate_session_title.yaml": &config.GenerateSessionTitle,
|
||
"generate_summary.yaml": &config.GenerateSummary,
|
||
"keywords_extraction.yaml": &config.KeywordsExtraction,
|
||
"agent_system_prompt.yaml": &config.AgentSystemPrompt,
|
||
"graph_extraction.yaml": &config.GraphExtraction,
|
||
"generate_questions.yaml": &config.GenerateQuestions,
|
||
"intent_prompts.yaml": &config.IntentPrompts,
|
||
}
|
||
|
||
// 加载每个模板文件
|
||
for filename, target := range templateFiles {
|
||
filePath := filepath.Join(templatesDir, filename)
|
||
if _, err := os.Stat(filePath); os.IsNotExist(err) {
|
||
continue // 文件不存在,跳过
|
||
}
|
||
|
||
data, err := os.ReadFile(filePath)
|
||
if err != nil {
|
||
return nil, fmt.Errorf("failed to read %s: %w", filename, err)
|
||
}
|
||
|
||
var file promptTemplateFile
|
||
if err := yaml.Unmarshal(data, &file); err != nil {
|
||
return nil, fmt.Errorf("failed to parse %s: %w", filename, err)
|
||
}
|
||
|
||
*target = file.Templates
|
||
}
|
||
|
||
return config, nil
|
||
}
|
||
|
||
// WebSearchConfig represents the web search configuration
|
||
type WebSearchConfig struct {
|
||
Timeout int `yaml:"timeout" json:"timeout"` // 超时时间(秒)
|
||
}
|