feat(agent): implement LLM call timeout and transient error handling

- Introduced a configurable LLM call timeout with a default value, allowing for better control over LLM call durations.
- Added logic to retry transient errors (e.g., timeouts, rate limits) up to a specified maximum number of retries, improving robustness in error handling.
- Implemented parameter casting for tool arguments to ensure correct types are used, addressing common LLM quirks.
- Enhanced tool execution error messages with guidance for retrying with different approaches.
- Added validation for configuration values to prevent runtime errors.
This commit is contained in:
wizardchen
2026-03-19 20:39:38 +08:00
committed by lyingbug
parent 00ac14aabb
commit e936e0b347
9 changed files with 408 additions and 49 deletions

View File

@@ -21,11 +21,45 @@ import (
)
const (
// llmPerCallTimeout is the maximum time allowed for a single LLM call (stream initiation + full response).
// defaultLLMCallTimeout is the default maximum time allowed for a single LLM call.
// This prevents a single slow call from consuming the entire pipeline deadline.
llmPerCallTimeout = 120 * time.Second
// Can be overridden via AgentConfig.LLMCallTimeout.
defaultLLMCallTimeout = 120 * time.Second
// maxLLMRetries is the maximum number of retries for transient LLM errors.
maxLLMRetries = 2
)
// transientErrorMarkers are substrings that indicate a transient (retryable) error.
var transientErrorMarkers = []string{
"429", "rate limit",
"500", "502", "503", "504",
"overloaded", "timeout", "timed out",
"connection", "server error", "temporarily unavailable",
}
// isTransientError checks whether an error is likely transient and worth retrying.
func isTransientError(err error) bool {
if err == nil {
return false
}
errStr := strings.ToLower(err.Error())
for _, marker := range transientErrorMarkers {
if strings.Contains(errStr, marker) {
return true
}
}
return false
}
// getLLMCallTimeout returns the configured LLM call timeout, falling back to default.
func (e *AgentEngine) getLLMCallTimeout() time.Duration {
if e.config.LLMCallTimeout > 0 {
return time.Duration(e.config.LLMCallTimeout) * time.Second
}
return defaultLLMCallTimeout
}
// generateEventID generates a unique event ID with type suffix for better traceability
func generateEventID(suffix string) string {
return fmt.Sprintf("%s-%s", uuid.New().String()[:8], suffix)
@@ -305,6 +339,20 @@ func (e *AgentEngine) executeLoop(
"tool_cnt": len(tools),
})
response, err := e.streamThinkingToEventBus(ctx, messages, tools, state.CurrentRound, sessionID)
if err != nil && isTransientError(err) {
// Retry transient errors (timeout, rate limit, server errors) up to maxLLMRetries times
for retry := 1; retry <= maxLLMRetries; retry++ {
retryDelay := time.Duration(retry) * time.Second
logger.Warnf(ctx, "[Agent][Round-%d] LLM transient error (attempt %d/%d), retrying in %v: %v",
state.CurrentRound+1, retry, maxLLMRetries, retryDelay, err)
time.Sleep(retryDelay)
response, err = e.streamThinkingToEventBus(ctx, messages, tools, state.CurrentRound, sessionID)
if err == nil || !isTransientError(err) {
break
}
}
}
if err != nil {
logger.Errorf(ctx, "[Agent][Round-%d] LLM call failed: %v", state.CurrentRound+1, err)
common.PipelineError(ctx, "Agent", "think_failed", map[string]interface{}{
@@ -450,6 +498,18 @@ func (e *AgentEngine) executeLoop(
if err := json.Unmarshal([]byte(tc.Function.Arguments), &args); err != nil {
logger.Errorf(ctx, "[Agent][Round-%d][Tool-%d/%d] Failed to parse tool arguments: %v",
state.CurrentRound+1, i+1, len(response.ToolCalls), err)
// Record parse failure as a tool result so the LLM can see the error and adapt
parseFailCall := types.ToolCall{
ID: tc.ID,
Name: tc.Function.Name,
Args: map[string]any{"_raw": tc.Function.Arguments},
Result: &types.ToolResult{
Success: false,
Error: fmt.Sprintf("Failed to parse tool arguments: %v", err) +
"\n\n[Analyze the error above and try a different approach.]",
},
}
step.ToolCalls = append(step.ToolCalls, parseFailCall)
continue
}
@@ -790,7 +850,7 @@ func (e *AgentEngine) streamLLMToEventBus(
// guaranteed time window even when the parent context's deadline is almost
// exhausted after previous iterations. The shorter of the two deadlines wins,
// so the parent pipeline timeout is still respected.
llmCtx, llmCancel := context.WithTimeout(ctx, llmPerCallTimeout)
llmCtx, llmCancel := context.WithTimeout(ctx, e.getLLMCallTimeout())
defer llmCancel()
stream, err := e.chatModel.ChatStream(llmCtx, messages, opts)

View File

@@ -0,0 +1,127 @@
package tools
import (
"encoding/json"
"strconv"
"strings"
)
// CastParams performs schema-driven type casting on tool arguments.
// LLMs sometimes return incorrect types (e.g., "true" instead of true, "123" instead of 123).
// This function attempts safe conversions based on the JSON Schema definition of the tool's parameters.
//
// If the schema is nil or cannot be parsed, the original args are returned unchanged.
func CastParams(args json.RawMessage, schema json.RawMessage) json.RawMessage {
if len(schema) == 0 || len(args) == 0 {
return args
}
var schemaDef map[string]interface{}
if err := json.Unmarshal(schema, &schemaDef); err != nil {
return args
}
properties, ok := schemaDef["properties"].(map[string]interface{})
if !ok || len(properties) == 0 {
return args
}
var argsMap map[string]interface{}
if err := json.Unmarshal(args, &argsMap); err != nil {
return args
}
changed := false
for key, val := range argsMap {
propDef, exists := properties[key]
if !exists {
continue
}
prop, ok := propDef.(map[string]interface{})
if !ok {
continue
}
targetType, _ := prop["type"].(string)
if targetType == "" {
continue
}
newVal, didCast := castValue(val, targetType)
if didCast {
argsMap[key] = newVal
changed = true
}
}
if !changed {
return args
}
result, err := json.Marshal(argsMap)
if err != nil {
return args
}
return result
}
// castValue attempts to convert val to the expected targetType.
// Returns (newValue, true) if a conversion was made, (val, false) otherwise.
func castValue(val interface{}, targetType string) (interface{}, bool) {
switch targetType {
case "boolean":
if s, ok := val.(string); ok {
lower := strings.ToLower(s)
switch lower {
case "true", "1", "yes":
return true, true
case "false", "0", "no":
return false, true
}
}
// JSON number 0/1 -> bool
if n, ok := val.(float64); ok {
if n == 0 {
return false, true
}
if n == 1 {
return true, true
}
}
case "integer":
if s, ok := val.(string); ok {
if i, err := strconv.ParseInt(s, 10, 64); err == nil {
return i, true
}
}
// JSON numbers are float64 in Go; convert to int if it's a whole number
if f, ok := val.(float64); ok {
if f == float64(int64(f)) {
return int64(f), true
}
}
case "number":
if s, ok := val.(string); ok {
if f, err := strconv.ParseFloat(s, 64); err == nil {
return f, true
}
}
case "string":
// Non-string values -> string (e.g., number or bool passed as non-string)
switch v := val.(type) {
case bool:
if v {
return "true", true
}
return "false", true
case float64:
return strconv.FormatFloat(v, 'f', -1, 64), true
case int64:
return strconv.FormatInt(v, 10), true
}
}
return val, false
}

View File

@@ -0,0 +1,81 @@
package tools
import (
"encoding/json"
"testing"
)
func TestCastParams_StringToBool(t *testing.T) {
schema := json.RawMessage(`{"type":"object","properties":{"enabled":{"type":"boolean"}}}`)
args := json.RawMessage(`{"enabled":"true"}`)
result := CastParams(args, schema)
var parsed map[string]interface{}
if err := json.Unmarshal(result, &parsed); err != nil {
t.Fatal(err)
}
if parsed["enabled"] != true {
t.Errorf("expected true, got %v (%T)", parsed["enabled"], parsed["enabled"])
}
}
func TestCastParams_StringToInt(t *testing.T) {
schema := json.RawMessage(`{"type":"object","properties":{"count":{"type":"integer"}}}`)
args := json.RawMessage(`{"count":"42"}`)
result := CastParams(args, schema)
var parsed map[string]interface{}
if err := json.Unmarshal(result, &parsed); err != nil {
t.Fatal(err)
}
// JSON numbers are float64 in Go
if parsed["count"] != float64(42) {
t.Errorf("expected 42, got %v (%T)", parsed["count"], parsed["count"])
}
}
func TestCastParams_StringToFloat(t *testing.T) {
schema := json.RawMessage(`{"type":"object","properties":{"score":{"type":"number"}}}`)
args := json.RawMessage(`{"score":"3.14"}`)
result := CastParams(args, schema)
var parsed map[string]interface{}
if err := json.Unmarshal(result, &parsed); err != nil {
t.Fatal(err)
}
if parsed["score"] != 3.14 {
t.Errorf("expected 3.14, got %v", parsed["score"])
}
}
func TestCastParams_NoChangeNeeded(t *testing.T) {
schema := json.RawMessage(`{"type":"object","properties":{"name":{"type":"string"}}}`)
args := json.RawMessage(`{"name":"hello"}`)
result := CastParams(args, schema)
if string(result) != string(args) {
t.Errorf("expected no change, got %s", result)
}
}
func TestCastParams_NilSchema(t *testing.T) {
args := json.RawMessage(`{"foo":"bar"}`)
result := CastParams(args, nil)
if string(result) != string(args) {
t.Errorf("expected no change with nil schema")
}
}
func TestCastParams_BoolFalseString(t *testing.T) {
schema := json.RawMessage(`{"type":"object","properties":{"flag":{"type":"boolean"}}}`)
args := json.RawMessage(`{"flag":"false"}`)
result := CastParams(args, schema)
var parsed map[string]interface{}
if err := json.Unmarshal(result, &parsed); err != nil {
t.Fatal(err)
}
if parsed["flag"] != false {
t.Errorf("expected false, got %v (%T)", parsed["flag"], parsed["flag"])
}
}

View File

@@ -6,9 +6,13 @@ import (
"fmt"
"github.com/Tencent/WeKnora/internal/common"
"github.com/Tencent/WeKnora/internal/logger"
"github.com/Tencent/WeKnora/internal/types"
)
// toolErrorHint is appended to tool error messages to guide the LLM to retry with a different approach.
const toolErrorHint = "\n\n[Analyze the error above and try a different approach.]"
// ToolRegistry manages the registration and retrieval of tools
type ToolRegistry struct {
tools map[string]types.Tool
@@ -27,6 +31,8 @@ func NewToolRegistry() *ToolRegistry {
func (r *ToolRegistry) RegisterTool(tool types.Tool) {
name := tool.Name()
if _, exists := r.tools[name]; exists {
logger.Warnf(context.Background(),
"[ToolRegistry] Duplicate tool registration rejected: %s (first-wins policy)", name)
return
}
r.tools[name] = tool
@@ -81,10 +87,14 @@ func (r *ToolRegistry) ExecuteTool(
})
return &types.ToolResult{
Success: false,
Error: err.Error(),
Error: err.Error() + toolErrorHint,
}, err
}
// Cast parameters to match expected schema types before execution.
// This handles common LLM quirks like returning "true" instead of true.
args = CastParams(args, tool.Parameters())
result, execErr := tool.Execute(ctx, args)
fields := map[string]interface{}{
"tool": name,
@@ -100,6 +110,10 @@ func (r *ToolRegistry) ExecuteTool(
fields["error"] = execErr.Error()
common.PipelineError(ctx, "AgentTool", "execute_done", fields)
} else if result != nil && !result.Success {
// Append error hint to guide LLM to retry with a different approach
if result.Error != "" {
result.Error = result.Error + toolErrorHint
}
common.PipelineWarn(ctx, "AgentTool", "execute_done", fields)
} else {
common.PipelineInfo(ctx, "AgentTool", "execute_done", fields)
@@ -108,12 +122,13 @@ func (r *ToolRegistry) ExecuteTool(
return result, execErr
}
// Cleanup cleans up all registered tools that implement the Cleanup method
// Cleanup cleans up all registered tools that implement the types.Cleanable interface.
// This is called at the end of agent sessions to release tool-specific resources.
func (r *ToolRegistry) Cleanup(ctx context.Context) {
// Check specifically for DataAnalysisTool
if tool, exists := r.tools[ToolDataAnalysis]; exists {
if dataAnalysisTool, ok := tool.(*DataAnalysisTool); ok {
dataAnalysisTool.Cleanup(ctx)
for name, tool := range r.tools {
if cleanable, ok := tool.(types.Cleanable); ok {
logger.Infof(ctx, "[ToolRegistry] Cleaning up tool: %s", name)
cleanable.Cleanup(ctx)
}
}
}

View File

@@ -113,9 +113,11 @@ func (p *PluginRerank) OnEvent(ctx context.Context,
if degradedThreshold < 0.3 {
degradedThreshold = 0.3
}
pipelineInfo(ctx, "Rerank", "threshold_degrade", map[string]interface{}{
"original": originalThreshold,
"degraded": degradedThreshold,
pipelineWarn(ctx, "Rerank", "threshold_degrade", map[string]interface{}{
"original": originalThreshold,
"degraded": degradedThreshold,
"candidate_cnt": len(candidatesToRerank),
"reason": "no results above original threshold, retrying with lower threshold",
})
chatManage.RerankThreshold = degradedThreshold
rerankResp = p.rerank(ctx, chatManage, rerankModel, chatManage.RewriteQuery, passages, candidatesToRerank)

View File

@@ -4,12 +4,50 @@ import (
"context"
"fmt"
"strings"
"unicode"
"github.com/Tencent/WeKnora/internal/logger"
"github.com/Tencent/WeKnora/internal/models/chat"
"github.com/Tencent/WeKnora/internal/types/interfaces"
)
// estimateStringTokens provides a more accurate token estimation by distinguishing
// CJK characters (≈1.5 tokens each) from Latin characters (≈4 chars per token).
// This is significantly more accurate than the naive totalChars/4 approach,
// especially for mixed Chinese-English content common in WeKnora.
func estimateStringTokens(s string) int {
cjkChars := 0
otherChars := 0
for _, r := range s {
if unicode.Is(unicode.Han, r) || unicode.Is(unicode.Hangul, r) || unicode.Is(unicode.Katakana, r) || unicode.Is(unicode.Hiragana, r) {
cjkChars++
} else {
otherChars++
}
}
// CJK characters average ~1.5 tokens each; Latin ~0.25 tokens per char
return (cjkChars*3 + otherChars) / 2
}
// estimateMessageTokens estimates the token count for a list of messages.
// Accounts for per-message overhead (role markers, special tokens) and tool call metadata.
func estimateMessageTokens(messages []chat.Message) int {
totalTokens := 0
for _, msg := range messages {
totalTokens += estimateStringTokens(msg.Role) + estimateStringTokens(msg.Content)
// Per-message overhead: role markers, delimiters, special tokens
totalTokens += 4
if len(msg.ToolCalls) > 0 {
for _, tc := range msg.ToolCalls {
totalTokens += estimateStringTokens(tc.Function.Name) + estimateStringTokens(tc.Function.Arguments)
// Tool call overhead: function call structure tokens
totalTokens += 8
}
}
}
return totalTokens
}
// slidingWindowStrategy implements CompressionStrategy using sliding window
type slidingWindowStrategy struct {
recentMessageCount int
@@ -64,19 +102,9 @@ func (s *slidingWindowStrategy) Compress(
return result, nil
}
// EstimateTokens estimates token count (rough approximation: 4 characters ≈ 1 token)
// EstimateTokens estimates token count using CJK-aware heuristics.
func (s *slidingWindowStrategy) EstimateTokens(messages []chat.Message) int {
totalChars := 0
for _, msg := range messages {
totalChars += len(msg.Role) + len(msg.Content)
// Account for tool calls if present
if len(msg.ToolCalls) > 0 {
for _, tc := range msg.ToolCalls {
totalChars += len(tc.Function.Name) + len(tc.Function.Arguments)
}
}
}
return totalChars / 4 // Rough approximation
return estimateMessageTokens(messages)
}
// smartCompressionStrategy implements CompressionStrategy using LLM summarization
@@ -151,12 +179,10 @@ func (s *smartCompressionStrategy) Compress(
// Summarize old messages using LLM
summary, err := s.summarizeMessages(ctx, oldMessages)
if err != nil {
logger.Warnf(ctx, "[SmartCompression] Failed to summarize messages: %v, falling back to old messages", err)
// Fallback: return all messages if summarization fails
result := make([]chat.Message, 0, len(systemMessages)+len(regularMessages))
result = append(result, systemMessages...)
result = append(result, regularMessages...)
return result, nil
logger.Warnf(ctx, "[SmartCompression] Failed to summarize messages: %v, falling back to sliding window", err)
// Fallback: use sliding window strategy to at least reduce message count
fallback := &slidingWindowStrategy{recentMessageCount: s.recentMessageCount}
return fallback.Compress(ctx, messages, maxTokens)
}
// Construct final message list: system + summary + recent
@@ -186,7 +212,7 @@ func (s *smartCompressionStrategy) summarizeMessages(ctx context.Context, messag
// Build conversation text
var sb strings.Builder
for i, msg := range messages {
sb.WriteString(fmt.Sprintf("[%s] %s\n", msg.Role, msg.Content))
fmt.Fprintf(&sb, "[%s] %s\n", msg.Role, msg.Content)
if i < len(messages)-1 {
sb.WriteString("\n")
}
@@ -226,17 +252,7 @@ func (s *smartCompressionStrategy) summarizeMessages(ctx context.Context, messag
return summary, nil
}
// EstimateTokens estimates token count (rough approximation: 4 characters ≈ 1 token)
// EstimateTokens estimates token count using CJK-aware heuristics.
func (s *smartCompressionStrategy) EstimateTokens(messages []chat.Message) int {
totalChars := 0
for _, msg := range messages {
totalChars += len(msg.Role) + len(msg.Content)
// Account for tool calls if present
if len(msg.ToolCalls) > 0 {
for _, tc := range msg.ToolCalls {
totalChars += len(tc.Function.Name) + len(tc.Function.Arguments)
}
}
}
return totalChars / 4 // Rough approximation
return estimateMessageTokens(messages)
}

View File

@@ -379,9 +379,58 @@ func LoadConfig() (*Config, error) {
resolveBuiltinAgentPromptIDs(cfg.PromptTemplates)
}
// Validate configuration values
if err := ValidateConfig(&cfg); err != nil {
return nil, err
}
return &cfg, nil
}
// ValidateConfig performs basic validation of the loaded configuration.
// It checks for obviously invalid or missing values that would cause runtime failures.
func ValidateConfig(cfg *Config) error {
var errs []string
if cfg.Conversation != nil {
if cfg.Conversation.EmbeddingTopK < 0 {
errs = append(errs, "conversation.embedding_top_k must be >= 0")
}
if cfg.Conversation.RerankTopK < 0 {
errs = append(errs, "conversation.rerank_top_k must be >= 0")
}
if cfg.Conversation.VectorThreshold < 0 || cfg.Conversation.VectorThreshold > 1 {
errs = append(errs, "conversation.vector_threshold must be between 0 and 1")
}
if cfg.Conversation.RerankThreshold < 0 || cfg.Conversation.RerankThreshold > 1 {
errs = append(errs, "conversation.rerank_threshold must be between 0 and 1")
}
}
if cfg.KnowledgeBase != nil {
if cfg.KnowledgeBase.ChunkSize <= 0 {
errs = append(errs, "knowledge_base.chunk_size must be > 0")
}
if cfg.KnowledgeBase.ChunkOverlap < 0 {
errs = append(errs, "knowledge_base.chunk_overlap must be >= 0")
}
if cfg.KnowledgeBase.ChunkOverlap >= cfg.KnowledgeBase.ChunkSize {
errs = append(errs, "knowledge_base.chunk_overlap must be less than chunk_size")
}
}
if cfg.Server != nil {
if cfg.Server.Port <= 0 || cfg.Server.Port > 65535 {
errs = append(errs, "server.port must be between 1 and 65535")
}
}
if len(errs) > 0 {
return fmt.Errorf("config validation errors: %s", strings.Join(errs, "; "))
}
return nil
}
// backfillConversationDefaults resolves prompt template ID references
// into actual prompt text content. Only xxx_id fields are used;
// no fallback to default templates.

View File

@@ -459,7 +459,7 @@ func (c *RemoteAPIChat) processStream(ctx context.Context, stream *openai.ChatCo
for {
response, err := stream.Recv()
if err != nil {
if err.Error() == "EOF" {
if err == io.EOF {
streamChan <- types.StreamResponse{
ResponseType: types.ResponseTypeAnswer,
Content: "",
@@ -493,7 +493,7 @@ func (c *RemoteAPIChat) processRawHTTPStream(ctx context.Context, resp *http.Res
for {
event, err := reader.ReadEvent()
if err != nil {
if err.Error() != "EOF" {
if err != io.EOF {
logger.Errorf(ctx, "Stream read error: %v", err)
streamChan <- types.StreamResponse{
ResponseType: types.ResponseTypeError,

View File

@@ -41,6 +41,8 @@ type AgentConfig struct {
// Runtime-only fields (not persisted)
VLMModelID string `json:"-"` // VLM model ID for tool result image analysis (set from CustomAgent config)
// LLM call timeout in seconds (default: 120). Controls the maximum time for a single LLM call.
LLMCallTimeout int `json:"llm_call_timeout,omitempty"`
}
// SessionAgentConfig represents session-level agent configuration
@@ -137,12 +139,19 @@ type Tool interface {
Execute(ctx context.Context, args json.RawMessage) (*ToolResult, error)
}
// Cleanable is an optional interface that tools can implement to release resources.
// Tools implementing this interface will have their Cleanup method called during
// registry cleanup (e.g., at the end of an agent session).
type Cleanable interface {
Cleanup(ctx context.Context)
}
// ToolResult represents the result of a tool execution
type ToolResult struct {
Success bool `json:"success"` // Whether the tool executed successfully
Output string `json:"output"` // Human-readable output
Data map[string]interface{} `json:"data,omitempty"` // Structured data for programmatic use
Error string `json:"error,omitempty"` // Error message if execution failed
Success bool `json:"success"` // Whether the tool executed successfully
Output string `json:"output"` // Human-readable output
Data map[string]interface{} `json:"data,omitempty"` // Structured data for programmatic use
Error string `json:"error,omitempty"` // Error message if execution failed
Images []string `json:"images,omitempty"` // Base64 data URIs from tool (e.g. MCP image content)
}