mirror of
https://github.com/Tencent/WeKnora.git
synced 2026-06-04 13:30:32 +08:00
feat(agent): implement LLM call timeout and transient error handling
- Introduced a configurable LLM call timeout with a default value, allowing for better control over LLM call durations. - Added logic to retry transient errors (e.g., timeouts, rate limits) up to a specified maximum number of retries, improving robustness in error handling. - Implemented parameter casting for tool arguments to ensure correct types are used, addressing common LLM quirks. - Enhanced tool execution error messages with guidance for retrying with different approaches. - Added validation for configuration values to prevent runtime errors.
This commit is contained in:
@@ -21,11 +21,45 @@ import (
|
||||
)
|
||||
|
||||
const (
|
||||
// llmPerCallTimeout is the maximum time allowed for a single LLM call (stream initiation + full response).
|
||||
// defaultLLMCallTimeout is the default maximum time allowed for a single LLM call.
|
||||
// This prevents a single slow call from consuming the entire pipeline deadline.
|
||||
llmPerCallTimeout = 120 * time.Second
|
||||
// Can be overridden via AgentConfig.LLMCallTimeout.
|
||||
defaultLLMCallTimeout = 120 * time.Second
|
||||
|
||||
// maxLLMRetries is the maximum number of retries for transient LLM errors.
|
||||
maxLLMRetries = 2
|
||||
)
|
||||
|
||||
// transientErrorMarkers are substrings that indicate a transient (retryable) error.
|
||||
var transientErrorMarkers = []string{
|
||||
"429", "rate limit",
|
||||
"500", "502", "503", "504",
|
||||
"overloaded", "timeout", "timed out",
|
||||
"connection", "server error", "temporarily unavailable",
|
||||
}
|
||||
|
||||
// isTransientError checks whether an error is likely transient and worth retrying.
|
||||
func isTransientError(err error) bool {
|
||||
if err == nil {
|
||||
return false
|
||||
}
|
||||
errStr := strings.ToLower(err.Error())
|
||||
for _, marker := range transientErrorMarkers {
|
||||
if strings.Contains(errStr, marker) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// getLLMCallTimeout returns the configured LLM call timeout, falling back to default.
|
||||
func (e *AgentEngine) getLLMCallTimeout() time.Duration {
|
||||
if e.config.LLMCallTimeout > 0 {
|
||||
return time.Duration(e.config.LLMCallTimeout) * time.Second
|
||||
}
|
||||
return defaultLLMCallTimeout
|
||||
}
|
||||
|
||||
// generateEventID generates a unique event ID with type suffix for better traceability
|
||||
func generateEventID(suffix string) string {
|
||||
return fmt.Sprintf("%s-%s", uuid.New().String()[:8], suffix)
|
||||
@@ -305,6 +339,20 @@ func (e *AgentEngine) executeLoop(
|
||||
"tool_cnt": len(tools),
|
||||
})
|
||||
response, err := e.streamThinkingToEventBus(ctx, messages, tools, state.CurrentRound, sessionID)
|
||||
if err != nil && isTransientError(err) {
|
||||
// Retry transient errors (timeout, rate limit, server errors) up to maxLLMRetries times
|
||||
for retry := 1; retry <= maxLLMRetries; retry++ {
|
||||
retryDelay := time.Duration(retry) * time.Second
|
||||
logger.Warnf(ctx, "[Agent][Round-%d] LLM transient error (attempt %d/%d), retrying in %v: %v",
|
||||
state.CurrentRound+1, retry, maxLLMRetries, retryDelay, err)
|
||||
time.Sleep(retryDelay)
|
||||
|
||||
response, err = e.streamThinkingToEventBus(ctx, messages, tools, state.CurrentRound, sessionID)
|
||||
if err == nil || !isTransientError(err) {
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
if err != nil {
|
||||
logger.Errorf(ctx, "[Agent][Round-%d] LLM call failed: %v", state.CurrentRound+1, err)
|
||||
common.PipelineError(ctx, "Agent", "think_failed", map[string]interface{}{
|
||||
@@ -450,6 +498,18 @@ func (e *AgentEngine) executeLoop(
|
||||
if err := json.Unmarshal([]byte(tc.Function.Arguments), &args); err != nil {
|
||||
logger.Errorf(ctx, "[Agent][Round-%d][Tool-%d/%d] Failed to parse tool arguments: %v",
|
||||
state.CurrentRound+1, i+1, len(response.ToolCalls), err)
|
||||
// Record parse failure as a tool result so the LLM can see the error and adapt
|
||||
parseFailCall := types.ToolCall{
|
||||
ID: tc.ID,
|
||||
Name: tc.Function.Name,
|
||||
Args: map[string]any{"_raw": tc.Function.Arguments},
|
||||
Result: &types.ToolResult{
|
||||
Success: false,
|
||||
Error: fmt.Sprintf("Failed to parse tool arguments: %v", err) +
|
||||
"\n\n[Analyze the error above and try a different approach.]",
|
||||
},
|
||||
}
|
||||
step.ToolCalls = append(step.ToolCalls, parseFailCall)
|
||||
continue
|
||||
}
|
||||
|
||||
@@ -790,7 +850,7 @@ func (e *AgentEngine) streamLLMToEventBus(
|
||||
// guaranteed time window even when the parent context's deadline is almost
|
||||
// exhausted after previous iterations. The shorter of the two deadlines wins,
|
||||
// so the parent pipeline timeout is still respected.
|
||||
llmCtx, llmCancel := context.WithTimeout(ctx, llmPerCallTimeout)
|
||||
llmCtx, llmCancel := context.WithTimeout(ctx, e.getLLMCallTimeout())
|
||||
defer llmCancel()
|
||||
|
||||
stream, err := e.chatModel.ChatStream(llmCtx, messages, opts)
|
||||
|
||||
127
internal/agent/tools/param_cast.go
Normal file
127
internal/agent/tools/param_cast.go
Normal file
@@ -0,0 +1,127 @@
|
||||
package tools
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"strconv"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// CastParams performs schema-driven type casting on tool arguments.
|
||||
// LLMs sometimes return incorrect types (e.g., "true" instead of true, "123" instead of 123).
|
||||
// This function attempts safe conversions based on the JSON Schema definition of the tool's parameters.
|
||||
//
|
||||
// If the schema is nil or cannot be parsed, the original args are returned unchanged.
|
||||
func CastParams(args json.RawMessage, schema json.RawMessage) json.RawMessage {
|
||||
if len(schema) == 0 || len(args) == 0 {
|
||||
return args
|
||||
}
|
||||
|
||||
var schemaDef map[string]interface{}
|
||||
if err := json.Unmarshal(schema, &schemaDef); err != nil {
|
||||
return args
|
||||
}
|
||||
|
||||
properties, ok := schemaDef["properties"].(map[string]interface{})
|
||||
if !ok || len(properties) == 0 {
|
||||
return args
|
||||
}
|
||||
|
||||
var argsMap map[string]interface{}
|
||||
if err := json.Unmarshal(args, &argsMap); err != nil {
|
||||
return args
|
||||
}
|
||||
|
||||
changed := false
|
||||
for key, val := range argsMap {
|
||||
propDef, exists := properties[key]
|
||||
if !exists {
|
||||
continue
|
||||
}
|
||||
prop, ok := propDef.(map[string]interface{})
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
targetType, _ := prop["type"].(string)
|
||||
if targetType == "" {
|
||||
continue
|
||||
}
|
||||
|
||||
newVal, didCast := castValue(val, targetType)
|
||||
if didCast {
|
||||
argsMap[key] = newVal
|
||||
changed = true
|
||||
}
|
||||
}
|
||||
|
||||
if !changed {
|
||||
return args
|
||||
}
|
||||
|
||||
result, err := json.Marshal(argsMap)
|
||||
if err != nil {
|
||||
return args
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
// castValue attempts to convert val to the expected targetType.
|
||||
// Returns (newValue, true) if a conversion was made, (val, false) otherwise.
|
||||
func castValue(val interface{}, targetType string) (interface{}, bool) {
|
||||
switch targetType {
|
||||
case "boolean":
|
||||
if s, ok := val.(string); ok {
|
||||
lower := strings.ToLower(s)
|
||||
switch lower {
|
||||
case "true", "1", "yes":
|
||||
return true, true
|
||||
case "false", "0", "no":
|
||||
return false, true
|
||||
}
|
||||
}
|
||||
// JSON number 0/1 -> bool
|
||||
if n, ok := val.(float64); ok {
|
||||
if n == 0 {
|
||||
return false, true
|
||||
}
|
||||
if n == 1 {
|
||||
return true, true
|
||||
}
|
||||
}
|
||||
|
||||
case "integer":
|
||||
if s, ok := val.(string); ok {
|
||||
if i, err := strconv.ParseInt(s, 10, 64); err == nil {
|
||||
return i, true
|
||||
}
|
||||
}
|
||||
// JSON numbers are float64 in Go; convert to int if it's a whole number
|
||||
if f, ok := val.(float64); ok {
|
||||
if f == float64(int64(f)) {
|
||||
return int64(f), true
|
||||
}
|
||||
}
|
||||
|
||||
case "number":
|
||||
if s, ok := val.(string); ok {
|
||||
if f, err := strconv.ParseFloat(s, 64); err == nil {
|
||||
return f, true
|
||||
}
|
||||
}
|
||||
|
||||
case "string":
|
||||
// Non-string values -> string (e.g., number or bool passed as non-string)
|
||||
switch v := val.(type) {
|
||||
case bool:
|
||||
if v {
|
||||
return "true", true
|
||||
}
|
||||
return "false", true
|
||||
case float64:
|
||||
return strconv.FormatFloat(v, 'f', -1, 64), true
|
||||
case int64:
|
||||
return strconv.FormatInt(v, 10), true
|
||||
}
|
||||
}
|
||||
|
||||
return val, false
|
||||
}
|
||||
81
internal/agent/tools/param_cast_test.go
Normal file
81
internal/agent/tools/param_cast_test.go
Normal file
@@ -0,0 +1,81 @@
|
||||
package tools
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestCastParams_StringToBool(t *testing.T) {
|
||||
schema := json.RawMessage(`{"type":"object","properties":{"enabled":{"type":"boolean"}}}`)
|
||||
args := json.RawMessage(`{"enabled":"true"}`)
|
||||
result := CastParams(args, schema)
|
||||
|
||||
var parsed map[string]interface{}
|
||||
if err := json.Unmarshal(result, &parsed); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if parsed["enabled"] != true {
|
||||
t.Errorf("expected true, got %v (%T)", parsed["enabled"], parsed["enabled"])
|
||||
}
|
||||
}
|
||||
|
||||
func TestCastParams_StringToInt(t *testing.T) {
|
||||
schema := json.RawMessage(`{"type":"object","properties":{"count":{"type":"integer"}}}`)
|
||||
args := json.RawMessage(`{"count":"42"}`)
|
||||
result := CastParams(args, schema)
|
||||
|
||||
var parsed map[string]interface{}
|
||||
if err := json.Unmarshal(result, &parsed); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
// JSON numbers are float64 in Go
|
||||
if parsed["count"] != float64(42) {
|
||||
t.Errorf("expected 42, got %v (%T)", parsed["count"], parsed["count"])
|
||||
}
|
||||
}
|
||||
|
||||
func TestCastParams_StringToFloat(t *testing.T) {
|
||||
schema := json.RawMessage(`{"type":"object","properties":{"score":{"type":"number"}}}`)
|
||||
args := json.RawMessage(`{"score":"3.14"}`)
|
||||
result := CastParams(args, schema)
|
||||
|
||||
var parsed map[string]interface{}
|
||||
if err := json.Unmarshal(result, &parsed); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if parsed["score"] != 3.14 {
|
||||
t.Errorf("expected 3.14, got %v", parsed["score"])
|
||||
}
|
||||
}
|
||||
|
||||
func TestCastParams_NoChangeNeeded(t *testing.T) {
|
||||
schema := json.RawMessage(`{"type":"object","properties":{"name":{"type":"string"}}}`)
|
||||
args := json.RawMessage(`{"name":"hello"}`)
|
||||
result := CastParams(args, schema)
|
||||
|
||||
if string(result) != string(args) {
|
||||
t.Errorf("expected no change, got %s", result)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCastParams_NilSchema(t *testing.T) {
|
||||
args := json.RawMessage(`{"foo":"bar"}`)
|
||||
result := CastParams(args, nil)
|
||||
if string(result) != string(args) {
|
||||
t.Errorf("expected no change with nil schema")
|
||||
}
|
||||
}
|
||||
|
||||
func TestCastParams_BoolFalseString(t *testing.T) {
|
||||
schema := json.RawMessage(`{"type":"object","properties":{"flag":{"type":"boolean"}}}`)
|
||||
args := json.RawMessage(`{"flag":"false"}`)
|
||||
result := CastParams(args, schema)
|
||||
|
||||
var parsed map[string]interface{}
|
||||
if err := json.Unmarshal(result, &parsed); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if parsed["flag"] != false {
|
||||
t.Errorf("expected false, got %v (%T)", parsed["flag"], parsed["flag"])
|
||||
}
|
||||
}
|
||||
@@ -6,9 +6,13 @@ import (
|
||||
"fmt"
|
||||
|
||||
"github.com/Tencent/WeKnora/internal/common"
|
||||
"github.com/Tencent/WeKnora/internal/logger"
|
||||
"github.com/Tencent/WeKnora/internal/types"
|
||||
)
|
||||
|
||||
// toolErrorHint is appended to tool error messages to guide the LLM to retry with a different approach.
|
||||
const toolErrorHint = "\n\n[Analyze the error above and try a different approach.]"
|
||||
|
||||
// ToolRegistry manages the registration and retrieval of tools
|
||||
type ToolRegistry struct {
|
||||
tools map[string]types.Tool
|
||||
@@ -27,6 +31,8 @@ func NewToolRegistry() *ToolRegistry {
|
||||
func (r *ToolRegistry) RegisterTool(tool types.Tool) {
|
||||
name := tool.Name()
|
||||
if _, exists := r.tools[name]; exists {
|
||||
logger.Warnf(context.Background(),
|
||||
"[ToolRegistry] Duplicate tool registration rejected: %s (first-wins policy)", name)
|
||||
return
|
||||
}
|
||||
r.tools[name] = tool
|
||||
@@ -81,10 +87,14 @@ func (r *ToolRegistry) ExecuteTool(
|
||||
})
|
||||
return &types.ToolResult{
|
||||
Success: false,
|
||||
Error: err.Error(),
|
||||
Error: err.Error() + toolErrorHint,
|
||||
}, err
|
||||
}
|
||||
|
||||
// Cast parameters to match expected schema types before execution.
|
||||
// This handles common LLM quirks like returning "true" instead of true.
|
||||
args = CastParams(args, tool.Parameters())
|
||||
|
||||
result, execErr := tool.Execute(ctx, args)
|
||||
fields := map[string]interface{}{
|
||||
"tool": name,
|
||||
@@ -100,6 +110,10 @@ func (r *ToolRegistry) ExecuteTool(
|
||||
fields["error"] = execErr.Error()
|
||||
common.PipelineError(ctx, "AgentTool", "execute_done", fields)
|
||||
} else if result != nil && !result.Success {
|
||||
// Append error hint to guide LLM to retry with a different approach
|
||||
if result.Error != "" {
|
||||
result.Error = result.Error + toolErrorHint
|
||||
}
|
||||
common.PipelineWarn(ctx, "AgentTool", "execute_done", fields)
|
||||
} else {
|
||||
common.PipelineInfo(ctx, "AgentTool", "execute_done", fields)
|
||||
@@ -108,12 +122,13 @@ func (r *ToolRegistry) ExecuteTool(
|
||||
return result, execErr
|
||||
}
|
||||
|
||||
// Cleanup cleans up all registered tools that implement the Cleanup method
|
||||
// Cleanup cleans up all registered tools that implement the types.Cleanable interface.
|
||||
// This is called at the end of agent sessions to release tool-specific resources.
|
||||
func (r *ToolRegistry) Cleanup(ctx context.Context) {
|
||||
// Check specifically for DataAnalysisTool
|
||||
if tool, exists := r.tools[ToolDataAnalysis]; exists {
|
||||
if dataAnalysisTool, ok := tool.(*DataAnalysisTool); ok {
|
||||
dataAnalysisTool.Cleanup(ctx)
|
||||
for name, tool := range r.tools {
|
||||
if cleanable, ok := tool.(types.Cleanable); ok {
|
||||
logger.Infof(ctx, "[ToolRegistry] Cleaning up tool: %s", name)
|
||||
cleanable.Cleanup(ctx)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -113,9 +113,11 @@ func (p *PluginRerank) OnEvent(ctx context.Context,
|
||||
if degradedThreshold < 0.3 {
|
||||
degradedThreshold = 0.3
|
||||
}
|
||||
pipelineInfo(ctx, "Rerank", "threshold_degrade", map[string]interface{}{
|
||||
"original": originalThreshold,
|
||||
"degraded": degradedThreshold,
|
||||
pipelineWarn(ctx, "Rerank", "threshold_degrade", map[string]interface{}{
|
||||
"original": originalThreshold,
|
||||
"degraded": degradedThreshold,
|
||||
"candidate_cnt": len(candidatesToRerank),
|
||||
"reason": "no results above original threshold, retrying with lower threshold",
|
||||
})
|
||||
chatManage.RerankThreshold = degradedThreshold
|
||||
rerankResp = p.rerank(ctx, chatManage, rerankModel, chatManage.RewriteQuery, passages, candidatesToRerank)
|
||||
|
||||
@@ -4,12 +4,50 @@ import (
|
||||
"context"
|
||||
"fmt"
|
||||
"strings"
|
||||
"unicode"
|
||||
|
||||
"github.com/Tencent/WeKnora/internal/logger"
|
||||
"github.com/Tencent/WeKnora/internal/models/chat"
|
||||
"github.com/Tencent/WeKnora/internal/types/interfaces"
|
||||
)
|
||||
|
||||
// estimateStringTokens provides a more accurate token estimation by distinguishing
|
||||
// CJK characters (≈1.5 tokens each) from Latin characters (≈4 chars per token).
|
||||
// This is significantly more accurate than the naive totalChars/4 approach,
|
||||
// especially for mixed Chinese-English content common in WeKnora.
|
||||
func estimateStringTokens(s string) int {
|
||||
cjkChars := 0
|
||||
otherChars := 0
|
||||
for _, r := range s {
|
||||
if unicode.Is(unicode.Han, r) || unicode.Is(unicode.Hangul, r) || unicode.Is(unicode.Katakana, r) || unicode.Is(unicode.Hiragana, r) {
|
||||
cjkChars++
|
||||
} else {
|
||||
otherChars++
|
||||
}
|
||||
}
|
||||
// CJK characters average ~1.5 tokens each; Latin ~0.25 tokens per char
|
||||
return (cjkChars*3 + otherChars) / 2
|
||||
}
|
||||
|
||||
// estimateMessageTokens estimates the token count for a list of messages.
|
||||
// Accounts for per-message overhead (role markers, special tokens) and tool call metadata.
|
||||
func estimateMessageTokens(messages []chat.Message) int {
|
||||
totalTokens := 0
|
||||
for _, msg := range messages {
|
||||
totalTokens += estimateStringTokens(msg.Role) + estimateStringTokens(msg.Content)
|
||||
// Per-message overhead: role markers, delimiters, special tokens
|
||||
totalTokens += 4
|
||||
if len(msg.ToolCalls) > 0 {
|
||||
for _, tc := range msg.ToolCalls {
|
||||
totalTokens += estimateStringTokens(tc.Function.Name) + estimateStringTokens(tc.Function.Arguments)
|
||||
// Tool call overhead: function call structure tokens
|
||||
totalTokens += 8
|
||||
}
|
||||
}
|
||||
}
|
||||
return totalTokens
|
||||
}
|
||||
|
||||
// slidingWindowStrategy implements CompressionStrategy using sliding window
|
||||
type slidingWindowStrategy struct {
|
||||
recentMessageCount int
|
||||
@@ -64,19 +102,9 @@ func (s *slidingWindowStrategy) Compress(
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// EstimateTokens estimates token count (rough approximation: 4 characters ≈ 1 token)
|
||||
// EstimateTokens estimates token count using CJK-aware heuristics.
|
||||
func (s *slidingWindowStrategy) EstimateTokens(messages []chat.Message) int {
|
||||
totalChars := 0
|
||||
for _, msg := range messages {
|
||||
totalChars += len(msg.Role) + len(msg.Content)
|
||||
// Account for tool calls if present
|
||||
if len(msg.ToolCalls) > 0 {
|
||||
for _, tc := range msg.ToolCalls {
|
||||
totalChars += len(tc.Function.Name) + len(tc.Function.Arguments)
|
||||
}
|
||||
}
|
||||
}
|
||||
return totalChars / 4 // Rough approximation
|
||||
return estimateMessageTokens(messages)
|
||||
}
|
||||
|
||||
// smartCompressionStrategy implements CompressionStrategy using LLM summarization
|
||||
@@ -151,12 +179,10 @@ func (s *smartCompressionStrategy) Compress(
|
||||
// Summarize old messages using LLM
|
||||
summary, err := s.summarizeMessages(ctx, oldMessages)
|
||||
if err != nil {
|
||||
logger.Warnf(ctx, "[SmartCompression] Failed to summarize messages: %v, falling back to old messages", err)
|
||||
// Fallback: return all messages if summarization fails
|
||||
result := make([]chat.Message, 0, len(systemMessages)+len(regularMessages))
|
||||
result = append(result, systemMessages...)
|
||||
result = append(result, regularMessages...)
|
||||
return result, nil
|
||||
logger.Warnf(ctx, "[SmartCompression] Failed to summarize messages: %v, falling back to sliding window", err)
|
||||
// Fallback: use sliding window strategy to at least reduce message count
|
||||
fallback := &slidingWindowStrategy{recentMessageCount: s.recentMessageCount}
|
||||
return fallback.Compress(ctx, messages, maxTokens)
|
||||
}
|
||||
|
||||
// Construct final message list: system + summary + recent
|
||||
@@ -186,7 +212,7 @@ func (s *smartCompressionStrategy) summarizeMessages(ctx context.Context, messag
|
||||
// Build conversation text
|
||||
var sb strings.Builder
|
||||
for i, msg := range messages {
|
||||
sb.WriteString(fmt.Sprintf("[%s] %s\n", msg.Role, msg.Content))
|
||||
fmt.Fprintf(&sb, "[%s] %s\n", msg.Role, msg.Content)
|
||||
if i < len(messages)-1 {
|
||||
sb.WriteString("\n")
|
||||
}
|
||||
@@ -226,17 +252,7 @@ func (s *smartCompressionStrategy) summarizeMessages(ctx context.Context, messag
|
||||
return summary, nil
|
||||
}
|
||||
|
||||
// EstimateTokens estimates token count (rough approximation: 4 characters ≈ 1 token)
|
||||
// EstimateTokens estimates token count using CJK-aware heuristics.
|
||||
func (s *smartCompressionStrategy) EstimateTokens(messages []chat.Message) int {
|
||||
totalChars := 0
|
||||
for _, msg := range messages {
|
||||
totalChars += len(msg.Role) + len(msg.Content)
|
||||
// Account for tool calls if present
|
||||
if len(msg.ToolCalls) > 0 {
|
||||
for _, tc := range msg.ToolCalls {
|
||||
totalChars += len(tc.Function.Name) + len(tc.Function.Arguments)
|
||||
}
|
||||
}
|
||||
}
|
||||
return totalChars / 4 // Rough approximation
|
||||
return estimateMessageTokens(messages)
|
||||
}
|
||||
|
||||
@@ -379,9 +379,58 @@ func LoadConfig() (*Config, error) {
|
||||
resolveBuiltinAgentPromptIDs(cfg.PromptTemplates)
|
||||
}
|
||||
|
||||
// Validate configuration values
|
||||
if err := ValidateConfig(&cfg); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &cfg, nil
|
||||
}
|
||||
|
||||
// ValidateConfig performs basic validation of the loaded configuration.
|
||||
// It checks for obviously invalid or missing values that would cause runtime failures.
|
||||
func ValidateConfig(cfg *Config) error {
|
||||
var errs []string
|
||||
|
||||
if cfg.Conversation != nil {
|
||||
if cfg.Conversation.EmbeddingTopK < 0 {
|
||||
errs = append(errs, "conversation.embedding_top_k must be >= 0")
|
||||
}
|
||||
if cfg.Conversation.RerankTopK < 0 {
|
||||
errs = append(errs, "conversation.rerank_top_k must be >= 0")
|
||||
}
|
||||
if cfg.Conversation.VectorThreshold < 0 || cfg.Conversation.VectorThreshold > 1 {
|
||||
errs = append(errs, "conversation.vector_threshold must be between 0 and 1")
|
||||
}
|
||||
if cfg.Conversation.RerankThreshold < 0 || cfg.Conversation.RerankThreshold > 1 {
|
||||
errs = append(errs, "conversation.rerank_threshold must be between 0 and 1")
|
||||
}
|
||||
}
|
||||
|
||||
if cfg.KnowledgeBase != nil {
|
||||
if cfg.KnowledgeBase.ChunkSize <= 0 {
|
||||
errs = append(errs, "knowledge_base.chunk_size must be > 0")
|
||||
}
|
||||
if cfg.KnowledgeBase.ChunkOverlap < 0 {
|
||||
errs = append(errs, "knowledge_base.chunk_overlap must be >= 0")
|
||||
}
|
||||
if cfg.KnowledgeBase.ChunkOverlap >= cfg.KnowledgeBase.ChunkSize {
|
||||
errs = append(errs, "knowledge_base.chunk_overlap must be less than chunk_size")
|
||||
}
|
||||
}
|
||||
|
||||
if cfg.Server != nil {
|
||||
if cfg.Server.Port <= 0 || cfg.Server.Port > 65535 {
|
||||
errs = append(errs, "server.port must be between 1 and 65535")
|
||||
}
|
||||
}
|
||||
|
||||
if len(errs) > 0 {
|
||||
return fmt.Errorf("config validation errors: %s", strings.Join(errs, "; "))
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// backfillConversationDefaults resolves prompt template ID references
|
||||
// into actual prompt text content. Only xxx_id fields are used;
|
||||
// no fallback to default templates.
|
||||
|
||||
@@ -459,7 +459,7 @@ func (c *RemoteAPIChat) processStream(ctx context.Context, stream *openai.ChatCo
|
||||
for {
|
||||
response, err := stream.Recv()
|
||||
if err != nil {
|
||||
if err.Error() == "EOF" {
|
||||
if err == io.EOF {
|
||||
streamChan <- types.StreamResponse{
|
||||
ResponseType: types.ResponseTypeAnswer,
|
||||
Content: "",
|
||||
@@ -493,7 +493,7 @@ func (c *RemoteAPIChat) processRawHTTPStream(ctx context.Context, resp *http.Res
|
||||
for {
|
||||
event, err := reader.ReadEvent()
|
||||
if err != nil {
|
||||
if err.Error() != "EOF" {
|
||||
if err != io.EOF {
|
||||
logger.Errorf(ctx, "Stream read error: %v", err)
|
||||
streamChan <- types.StreamResponse{
|
||||
ResponseType: types.ResponseTypeError,
|
||||
|
||||
@@ -41,6 +41,8 @@ type AgentConfig struct {
|
||||
|
||||
// Runtime-only fields (not persisted)
|
||||
VLMModelID string `json:"-"` // VLM model ID for tool result image analysis (set from CustomAgent config)
|
||||
// LLM call timeout in seconds (default: 120). Controls the maximum time for a single LLM call.
|
||||
LLMCallTimeout int `json:"llm_call_timeout,omitempty"`
|
||||
}
|
||||
|
||||
// SessionAgentConfig represents session-level agent configuration
|
||||
@@ -137,12 +139,19 @@ type Tool interface {
|
||||
Execute(ctx context.Context, args json.RawMessage) (*ToolResult, error)
|
||||
}
|
||||
|
||||
// Cleanable is an optional interface that tools can implement to release resources.
|
||||
// Tools implementing this interface will have their Cleanup method called during
|
||||
// registry cleanup (e.g., at the end of an agent session).
|
||||
type Cleanable interface {
|
||||
Cleanup(ctx context.Context)
|
||||
}
|
||||
|
||||
// ToolResult represents the result of a tool execution
|
||||
type ToolResult struct {
|
||||
Success bool `json:"success"` // Whether the tool executed successfully
|
||||
Output string `json:"output"` // Human-readable output
|
||||
Data map[string]interface{} `json:"data,omitempty"` // Structured data for programmatic use
|
||||
Error string `json:"error,omitempty"` // Error message if execution failed
|
||||
Success bool `json:"success"` // Whether the tool executed successfully
|
||||
Output string `json:"output"` // Human-readable output
|
||||
Data map[string]interface{} `json:"data,omitempty"` // Structured data for programmatic use
|
||||
Error string `json:"error,omitempty"` // Error message if execution failed
|
||||
Images []string `json:"images,omitempty"` // Base64 data URIs from tool (e.g. MCP image content)
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user