diff --git a/internal/agent/engine.go b/internal/agent/engine.go index b773db3c..7157719b 100644 --- a/internal/agent/engine.go +++ b/internal/agent/engine.go @@ -21,11 +21,45 @@ import ( ) const ( - // llmPerCallTimeout is the maximum time allowed for a single LLM call (stream initiation + full response). + // defaultLLMCallTimeout is the default maximum time allowed for a single LLM call. // This prevents a single slow call from consuming the entire pipeline deadline. - llmPerCallTimeout = 120 * time.Second + // Can be overridden via AgentConfig.LLMCallTimeout. + defaultLLMCallTimeout = 120 * time.Second + + // maxLLMRetries is the maximum number of retries for transient LLM errors. + maxLLMRetries = 2 ) +// transientErrorMarkers are substrings that indicate a transient (retryable) error. +var transientErrorMarkers = []string{ + "429", "rate limit", + "500", "502", "503", "504", + "overloaded", "timeout", "timed out", + "connection", "server error", "temporarily unavailable", +} + +// isTransientError checks whether an error is likely transient and worth retrying. +func isTransientError(err error) bool { + if err == nil { + return false + } + errStr := strings.ToLower(err.Error()) + for _, marker := range transientErrorMarkers { + if strings.Contains(errStr, marker) { + return true + } + } + return false +} + +// getLLMCallTimeout returns the configured LLM call timeout, falling back to default. +func (e *AgentEngine) getLLMCallTimeout() time.Duration { + if e.config.LLMCallTimeout > 0 { + return time.Duration(e.config.LLMCallTimeout) * time.Second + } + return defaultLLMCallTimeout +} + // generateEventID generates a unique event ID with type suffix for better traceability func generateEventID(suffix string) string { return fmt.Sprintf("%s-%s", uuid.New().String()[:8], suffix) @@ -305,6 +339,20 @@ func (e *AgentEngine) executeLoop( "tool_cnt": len(tools), }) response, err := e.streamThinkingToEventBus(ctx, messages, tools, state.CurrentRound, sessionID) + if err != nil && isTransientError(err) { + // Retry transient errors (timeout, rate limit, server errors) up to maxLLMRetries times + for retry := 1; retry <= maxLLMRetries; retry++ { + retryDelay := time.Duration(retry) * time.Second + logger.Warnf(ctx, "[Agent][Round-%d] LLM transient error (attempt %d/%d), retrying in %v: %v", + state.CurrentRound+1, retry, maxLLMRetries, retryDelay, err) + time.Sleep(retryDelay) + + response, err = e.streamThinkingToEventBus(ctx, messages, tools, state.CurrentRound, sessionID) + if err == nil || !isTransientError(err) { + break + } + } + } if err != nil { logger.Errorf(ctx, "[Agent][Round-%d] LLM call failed: %v", state.CurrentRound+1, err) common.PipelineError(ctx, "Agent", "think_failed", map[string]interface{}{ @@ -450,6 +498,18 @@ func (e *AgentEngine) executeLoop( if err := json.Unmarshal([]byte(tc.Function.Arguments), &args); err != nil { logger.Errorf(ctx, "[Agent][Round-%d][Tool-%d/%d] Failed to parse tool arguments: %v", state.CurrentRound+1, i+1, len(response.ToolCalls), err) + // Record parse failure as a tool result so the LLM can see the error and adapt + parseFailCall := types.ToolCall{ + ID: tc.ID, + Name: tc.Function.Name, + Args: map[string]any{"_raw": tc.Function.Arguments}, + Result: &types.ToolResult{ + Success: false, + Error: fmt.Sprintf("Failed to parse tool arguments: %v", err) + + "\n\n[Analyze the error above and try a different approach.]", + }, + } + step.ToolCalls = append(step.ToolCalls, parseFailCall) continue } @@ -790,7 +850,7 @@ func (e *AgentEngine) streamLLMToEventBus( // guaranteed time window even when the parent context's deadline is almost // exhausted after previous iterations. The shorter of the two deadlines wins, // so the parent pipeline timeout is still respected. - llmCtx, llmCancel := context.WithTimeout(ctx, llmPerCallTimeout) + llmCtx, llmCancel := context.WithTimeout(ctx, e.getLLMCallTimeout()) defer llmCancel() stream, err := e.chatModel.ChatStream(llmCtx, messages, opts) diff --git a/internal/agent/tools/param_cast.go b/internal/agent/tools/param_cast.go new file mode 100644 index 00000000..f3da8afd --- /dev/null +++ b/internal/agent/tools/param_cast.go @@ -0,0 +1,127 @@ +package tools + +import ( + "encoding/json" + "strconv" + "strings" +) + +// CastParams performs schema-driven type casting on tool arguments. +// LLMs sometimes return incorrect types (e.g., "true" instead of true, "123" instead of 123). +// This function attempts safe conversions based on the JSON Schema definition of the tool's parameters. +// +// If the schema is nil or cannot be parsed, the original args are returned unchanged. +func CastParams(args json.RawMessage, schema json.RawMessage) json.RawMessage { + if len(schema) == 0 || len(args) == 0 { + return args + } + + var schemaDef map[string]interface{} + if err := json.Unmarshal(schema, &schemaDef); err != nil { + return args + } + + properties, ok := schemaDef["properties"].(map[string]interface{}) + if !ok || len(properties) == 0 { + return args + } + + var argsMap map[string]interface{} + if err := json.Unmarshal(args, &argsMap); err != nil { + return args + } + + changed := false + for key, val := range argsMap { + propDef, exists := properties[key] + if !exists { + continue + } + prop, ok := propDef.(map[string]interface{}) + if !ok { + continue + } + targetType, _ := prop["type"].(string) + if targetType == "" { + continue + } + + newVal, didCast := castValue(val, targetType) + if didCast { + argsMap[key] = newVal + changed = true + } + } + + if !changed { + return args + } + + result, err := json.Marshal(argsMap) + if err != nil { + return args + } + return result +} + +// castValue attempts to convert val to the expected targetType. +// Returns (newValue, true) if a conversion was made, (val, false) otherwise. +func castValue(val interface{}, targetType string) (interface{}, bool) { + switch targetType { + case "boolean": + if s, ok := val.(string); ok { + lower := strings.ToLower(s) + switch lower { + case "true", "1", "yes": + return true, true + case "false", "0", "no": + return false, true + } + } + // JSON number 0/1 -> bool + if n, ok := val.(float64); ok { + if n == 0 { + return false, true + } + if n == 1 { + return true, true + } + } + + case "integer": + if s, ok := val.(string); ok { + if i, err := strconv.ParseInt(s, 10, 64); err == nil { + return i, true + } + } + // JSON numbers are float64 in Go; convert to int if it's a whole number + if f, ok := val.(float64); ok { + if f == float64(int64(f)) { + return int64(f), true + } + } + + case "number": + if s, ok := val.(string); ok { + if f, err := strconv.ParseFloat(s, 64); err == nil { + return f, true + } + } + + case "string": + // Non-string values -> string (e.g., number or bool passed as non-string) + switch v := val.(type) { + case bool: + if v { + return "true", true + } + return "false", true + case float64: + return strconv.FormatFloat(v, 'f', -1, 64), true + case int64: + return strconv.FormatInt(v, 10), true + } + } + + return val, false +} diff --git a/internal/agent/tools/param_cast_test.go b/internal/agent/tools/param_cast_test.go new file mode 100644 index 00000000..36d28251 --- /dev/null +++ b/internal/agent/tools/param_cast_test.go @@ -0,0 +1,81 @@ +package tools + +import ( + "encoding/json" + "testing" +) + +func TestCastParams_StringToBool(t *testing.T) { + schema := json.RawMessage(`{"type":"object","properties":{"enabled":{"type":"boolean"}}}`) + args := json.RawMessage(`{"enabled":"true"}`) + result := CastParams(args, schema) + + var parsed map[string]interface{} + if err := json.Unmarshal(result, &parsed); err != nil { + t.Fatal(err) + } + if parsed["enabled"] != true { + t.Errorf("expected true, got %v (%T)", parsed["enabled"], parsed["enabled"]) + } +} + +func TestCastParams_StringToInt(t *testing.T) { + schema := json.RawMessage(`{"type":"object","properties":{"count":{"type":"integer"}}}`) + args := json.RawMessage(`{"count":"42"}`) + result := CastParams(args, schema) + + var parsed map[string]interface{} + if err := json.Unmarshal(result, &parsed); err != nil { + t.Fatal(err) + } + // JSON numbers are float64 in Go + if parsed["count"] != float64(42) { + t.Errorf("expected 42, got %v (%T)", parsed["count"], parsed["count"]) + } +} + +func TestCastParams_StringToFloat(t *testing.T) { + schema := json.RawMessage(`{"type":"object","properties":{"score":{"type":"number"}}}`) + args := json.RawMessage(`{"score":"3.14"}`) + result := CastParams(args, schema) + + var parsed map[string]interface{} + if err := json.Unmarshal(result, &parsed); err != nil { + t.Fatal(err) + } + if parsed["score"] != 3.14 { + t.Errorf("expected 3.14, got %v", parsed["score"]) + } +} + +func TestCastParams_NoChangeNeeded(t *testing.T) { + schema := json.RawMessage(`{"type":"object","properties":{"name":{"type":"string"}}}`) + args := json.RawMessage(`{"name":"hello"}`) + result := CastParams(args, schema) + + if string(result) != string(args) { + t.Errorf("expected no change, got %s", result) + } +} + +func TestCastParams_NilSchema(t *testing.T) { + args := json.RawMessage(`{"foo":"bar"}`) + result := CastParams(args, nil) + if string(result) != string(args) { + t.Errorf("expected no change with nil schema") + } +} + +func TestCastParams_BoolFalseString(t *testing.T) { + schema := json.RawMessage(`{"type":"object","properties":{"flag":{"type":"boolean"}}}`) + args := json.RawMessage(`{"flag":"false"}`) + result := CastParams(args, schema) + + var parsed map[string]interface{} + if err := json.Unmarshal(result, &parsed); err != nil { + t.Fatal(err) + } + if parsed["flag"] != false { + t.Errorf("expected false, got %v (%T)", parsed["flag"], parsed["flag"]) + } +} diff --git a/internal/agent/tools/registry.go b/internal/agent/tools/registry.go index 7dcb9ca0..3cdfad0b 100644 --- a/internal/agent/tools/registry.go +++ b/internal/agent/tools/registry.go @@ -6,9 +6,13 @@ import ( "fmt" "github.com/Tencent/WeKnora/internal/common" + "github.com/Tencent/WeKnora/internal/logger" "github.com/Tencent/WeKnora/internal/types" ) +// toolErrorHint is appended to tool error messages to guide the LLM to retry with a different approach. +const toolErrorHint = "\n\n[Analyze the error above and try a different approach.]" + // ToolRegistry manages the registration and retrieval of tools type ToolRegistry struct { tools map[string]types.Tool @@ -27,6 +31,8 @@ func NewToolRegistry() *ToolRegistry { func (r *ToolRegistry) RegisterTool(tool types.Tool) { name := tool.Name() if _, exists := r.tools[name]; exists { + logger.Warnf(context.Background(), + "[ToolRegistry] Duplicate tool registration rejected: %s (first-wins policy)", name) return } r.tools[name] = tool @@ -81,10 +87,14 @@ func (r *ToolRegistry) ExecuteTool( }) return &types.ToolResult{ Success: false, - Error: err.Error(), + Error: err.Error() + toolErrorHint, }, err } + // Cast parameters to match expected schema types before execution. + // This handles common LLM quirks like returning "true" instead of true. + args = CastParams(args, tool.Parameters()) + result, execErr := tool.Execute(ctx, args) fields := map[string]interface{}{ "tool": name, @@ -100,6 +110,10 @@ func (r *ToolRegistry) ExecuteTool( fields["error"] = execErr.Error() common.PipelineError(ctx, "AgentTool", "execute_done", fields) } else if result != nil && !result.Success { + // Append error hint to guide LLM to retry with a different approach + if result.Error != "" { + result.Error = result.Error + toolErrorHint + } common.PipelineWarn(ctx, "AgentTool", "execute_done", fields) } else { common.PipelineInfo(ctx, "AgentTool", "execute_done", fields) @@ -108,12 +122,13 @@ func (r *ToolRegistry) ExecuteTool( return result, execErr } -// Cleanup cleans up all registered tools that implement the Cleanup method +// Cleanup cleans up all registered tools that implement the types.Cleanable interface. +// This is called at the end of agent sessions to release tool-specific resources. func (r *ToolRegistry) Cleanup(ctx context.Context) { - // Check specifically for DataAnalysisTool - if tool, exists := r.tools[ToolDataAnalysis]; exists { - if dataAnalysisTool, ok := tool.(*DataAnalysisTool); ok { - dataAnalysisTool.Cleanup(ctx) + for name, tool := range r.tools { + if cleanable, ok := tool.(types.Cleanable); ok { + logger.Infof(ctx, "[ToolRegistry] Cleaning up tool: %s", name) + cleanable.Cleanup(ctx) } } } diff --git a/internal/application/service/chat_pipline/rerank.go b/internal/application/service/chat_pipline/rerank.go index babf372a..c29f2170 100644 --- a/internal/application/service/chat_pipline/rerank.go +++ b/internal/application/service/chat_pipline/rerank.go @@ -113,9 +113,11 @@ func (p *PluginRerank) OnEvent(ctx context.Context, if degradedThreshold < 0.3 { degradedThreshold = 0.3 } - pipelineInfo(ctx, "Rerank", "threshold_degrade", map[string]interface{}{ - "original": originalThreshold, - "degraded": degradedThreshold, + pipelineWarn(ctx, "Rerank", "threshold_degrade", map[string]interface{}{ + "original": originalThreshold, + "degraded": degradedThreshold, + "candidate_cnt": len(candidatesToRerank), + "reason": "no results above original threshold, retrying with lower threshold", }) chatManage.RerankThreshold = degradedThreshold rerankResp = p.rerank(ctx, chatManage, rerankModel, chatManage.RewriteQuery, passages, candidatesToRerank) diff --git a/internal/application/service/llmcontext/compression_strategies.go b/internal/application/service/llmcontext/compression_strategies.go index 9372724d..f888a953 100644 --- a/internal/application/service/llmcontext/compression_strategies.go +++ b/internal/application/service/llmcontext/compression_strategies.go @@ -4,12 +4,50 @@ import ( "context" "fmt" "strings" + "unicode" "github.com/Tencent/WeKnora/internal/logger" "github.com/Tencent/WeKnora/internal/models/chat" "github.com/Tencent/WeKnora/internal/types/interfaces" ) +// estimateStringTokens provides a more accurate token estimation by distinguishing +// CJK characters (≈1.5 tokens each) from Latin characters (≈4 chars per token). +// This is significantly more accurate than the naive totalChars/4 approach, +// especially for mixed Chinese-English content common in WeKnora. +func estimateStringTokens(s string) int { + cjkChars := 0 + otherChars := 0 + for _, r := range s { + if unicode.Is(unicode.Han, r) || unicode.Is(unicode.Hangul, r) || unicode.Is(unicode.Katakana, r) || unicode.Is(unicode.Hiragana, r) { + cjkChars++ + } else { + otherChars++ + } + } + // CJK characters average ~1.5 tokens each; Latin ~0.25 tokens per char + return (cjkChars*3 + otherChars) / 2 +} + +// estimateMessageTokens estimates the token count for a list of messages. +// Accounts for per-message overhead (role markers, special tokens) and tool call metadata. +func estimateMessageTokens(messages []chat.Message) int { + totalTokens := 0 + for _, msg := range messages { + totalTokens += estimateStringTokens(msg.Role) + estimateStringTokens(msg.Content) + // Per-message overhead: role markers, delimiters, special tokens + totalTokens += 4 + if len(msg.ToolCalls) > 0 { + for _, tc := range msg.ToolCalls { + totalTokens += estimateStringTokens(tc.Function.Name) + estimateStringTokens(tc.Function.Arguments) + // Tool call overhead: function call structure tokens + totalTokens += 8 + } + } + } + return totalTokens +} + // slidingWindowStrategy implements CompressionStrategy using sliding window type slidingWindowStrategy struct { recentMessageCount int @@ -64,19 +102,9 @@ func (s *slidingWindowStrategy) Compress( return result, nil } -// EstimateTokens estimates token count (rough approximation: 4 characters ≈ 1 token) +// EstimateTokens estimates token count using CJK-aware heuristics. func (s *slidingWindowStrategy) EstimateTokens(messages []chat.Message) int { - totalChars := 0 - for _, msg := range messages { - totalChars += len(msg.Role) + len(msg.Content) - // Account for tool calls if present - if len(msg.ToolCalls) > 0 { - for _, tc := range msg.ToolCalls { - totalChars += len(tc.Function.Name) + len(tc.Function.Arguments) - } - } - } - return totalChars / 4 // Rough approximation + return estimateMessageTokens(messages) } // smartCompressionStrategy implements CompressionStrategy using LLM summarization @@ -151,12 +179,10 @@ func (s *smartCompressionStrategy) Compress( // Summarize old messages using LLM summary, err := s.summarizeMessages(ctx, oldMessages) if err != nil { - logger.Warnf(ctx, "[SmartCompression] Failed to summarize messages: %v, falling back to old messages", err) - // Fallback: return all messages if summarization fails - result := make([]chat.Message, 0, len(systemMessages)+len(regularMessages)) - result = append(result, systemMessages...) - result = append(result, regularMessages...) - return result, nil + logger.Warnf(ctx, "[SmartCompression] Failed to summarize messages: %v, falling back to sliding window", err) + // Fallback: use sliding window strategy to at least reduce message count + fallback := &slidingWindowStrategy{recentMessageCount: s.recentMessageCount} + return fallback.Compress(ctx, messages, maxTokens) } // Construct final message list: system + summary + recent @@ -186,7 +212,7 @@ func (s *smartCompressionStrategy) summarizeMessages(ctx context.Context, messag // Build conversation text var sb strings.Builder for i, msg := range messages { - sb.WriteString(fmt.Sprintf("[%s] %s\n", msg.Role, msg.Content)) + fmt.Fprintf(&sb, "[%s] %s\n", msg.Role, msg.Content) if i < len(messages)-1 { sb.WriteString("\n") } @@ -226,17 +252,7 @@ func (s *smartCompressionStrategy) summarizeMessages(ctx context.Context, messag return summary, nil } -// EstimateTokens estimates token count (rough approximation: 4 characters ≈ 1 token) +// EstimateTokens estimates token count using CJK-aware heuristics. func (s *smartCompressionStrategy) EstimateTokens(messages []chat.Message) int { - totalChars := 0 - for _, msg := range messages { - totalChars += len(msg.Role) + len(msg.Content) - // Account for tool calls if present - if len(msg.ToolCalls) > 0 { - for _, tc := range msg.ToolCalls { - totalChars += len(tc.Function.Name) + len(tc.Function.Arguments) - } - } - } - return totalChars / 4 // Rough approximation + return estimateMessageTokens(messages) } diff --git a/internal/config/config.go b/internal/config/config.go index 8a161d77..625be8af 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -379,9 +379,58 @@ func LoadConfig() (*Config, error) { resolveBuiltinAgentPromptIDs(cfg.PromptTemplates) } + // Validate configuration values + if err := ValidateConfig(&cfg); err != nil { + return nil, err + } + return &cfg, nil } +// ValidateConfig performs basic validation of the loaded configuration. +// It checks for obviously invalid or missing values that would cause runtime failures. +func ValidateConfig(cfg *Config) error { + var errs []string + + if cfg.Conversation != nil { + if cfg.Conversation.EmbeddingTopK < 0 { + errs = append(errs, "conversation.embedding_top_k must be >= 0") + } + if cfg.Conversation.RerankTopK < 0 { + errs = append(errs, "conversation.rerank_top_k must be >= 0") + } + if cfg.Conversation.VectorThreshold < 0 || cfg.Conversation.VectorThreshold > 1 { + errs = append(errs, "conversation.vector_threshold must be between 0 and 1") + } + if cfg.Conversation.RerankThreshold < 0 || cfg.Conversation.RerankThreshold > 1 { + errs = append(errs, "conversation.rerank_threshold must be between 0 and 1") + } + } + + if cfg.KnowledgeBase != nil { + if cfg.KnowledgeBase.ChunkSize <= 0 { + errs = append(errs, "knowledge_base.chunk_size must be > 0") + } + if cfg.KnowledgeBase.ChunkOverlap < 0 { + errs = append(errs, "knowledge_base.chunk_overlap must be >= 0") + } + if cfg.KnowledgeBase.ChunkOverlap >= cfg.KnowledgeBase.ChunkSize { + errs = append(errs, "knowledge_base.chunk_overlap must be less than chunk_size") + } + } + + if cfg.Server != nil { + if cfg.Server.Port <= 0 || cfg.Server.Port > 65535 { + errs = append(errs, "server.port must be between 1 and 65535") + } + } + + if len(errs) > 0 { + return fmt.Errorf("config validation errors: %s", strings.Join(errs, "; ")) + } + return nil +} + // backfillConversationDefaults resolves prompt template ID references // into actual prompt text content. Only xxx_id fields are used; // no fallback to default templates. diff --git a/internal/models/chat/remote_api.go b/internal/models/chat/remote_api.go index 7f076d01..5641d3b1 100644 --- a/internal/models/chat/remote_api.go +++ b/internal/models/chat/remote_api.go @@ -459,7 +459,7 @@ func (c *RemoteAPIChat) processStream(ctx context.Context, stream *openai.ChatCo for { response, err := stream.Recv() if err != nil { - if err.Error() == "EOF" { + if err == io.EOF { streamChan <- types.StreamResponse{ ResponseType: types.ResponseTypeAnswer, Content: "", @@ -493,7 +493,7 @@ func (c *RemoteAPIChat) processRawHTTPStream(ctx context.Context, resp *http.Res for { event, err := reader.ReadEvent() if err != nil { - if err.Error() != "EOF" { + if err != io.EOF { logger.Errorf(ctx, "Stream read error: %v", err) streamChan <- types.StreamResponse{ ResponseType: types.ResponseTypeError, diff --git a/internal/types/agent.go b/internal/types/agent.go index 0cc29d31..3d60fbbc 100644 --- a/internal/types/agent.go +++ b/internal/types/agent.go @@ -41,6 +41,8 @@ type AgentConfig struct { // Runtime-only fields (not persisted) VLMModelID string `json:"-"` // VLM model ID for tool result image analysis (set from CustomAgent config) + // LLM call timeout in seconds (default: 120). Controls the maximum time for a single LLM call. + LLMCallTimeout int `json:"llm_call_timeout,omitempty"` } // SessionAgentConfig represents session-level agent configuration @@ -137,12 +139,19 @@ type Tool interface { Execute(ctx context.Context, args json.RawMessage) (*ToolResult, error) } +// Cleanable is an optional interface that tools can implement to release resources. +// Tools implementing this interface will have their Cleanup method called during +// registry cleanup (e.g., at the end of an agent session). +type Cleanable interface { + Cleanup(ctx context.Context) +} + // ToolResult represents the result of a tool execution type ToolResult struct { - Success bool `json:"success"` // Whether the tool executed successfully - Output string `json:"output"` // Human-readable output - Data map[string]interface{} `json:"data,omitempty"` // Structured data for programmatic use - Error string `json:"error,omitempty"` // Error message if execution failed + Success bool `json:"success"` // Whether the tool executed successfully + Output string `json:"output"` // Human-readable output + Data map[string]interface{} `json:"data,omitempty"` // Structured data for programmatic use + Error string `json:"error,omitempty"` // Error message if execution failed Images []string `json:"images,omitempty"` // Base64 data URIs from tool (e.g. MCP image content) }