mirror of
https://github.com/Tencent/WeKnora.git
synced 2026-06-04 13:30:32 +08:00
316 lines
10 KiB
Go
316 lines
10 KiB
Go
package agent
|
|
|
|
import (
|
|
"context"
|
|
"encoding/json"
|
|
"fmt"
|
|
"time"
|
|
|
|
agenttoken "github.com/Tencent/WeKnora/internal/agent/token"
|
|
agenttools "github.com/Tencent/WeKnora/internal/agent/tools"
|
|
"github.com/Tencent/WeKnora/internal/common"
|
|
"github.com/Tencent/WeKnora/internal/event"
|
|
"github.com/Tencent/WeKnora/internal/logger"
|
|
"github.com/Tencent/WeKnora/internal/models/chat"
|
|
"github.com/Tencent/WeKnora/internal/types"
|
|
)
|
|
|
|
// manageContextWindow consolidates or compresses messages if approaching the token limit.
|
|
// currentTokens is the caller's best estimate of the current context size (using
|
|
// API-reported Usage when available, falling back to BPE estimation).
|
|
func (e *AgentEngine) manageContextWindow(ctx context.Context, messages []chat.Message, round, currentTokens int) []chat.Message {
|
|
if e.config.MaxContextTokens <= 0 {
|
|
return messages
|
|
}
|
|
|
|
beforeLen := len(messages)
|
|
|
|
if e.memoryConsolidator != nil && e.memoryConsolidator.ShouldConsolidate(currentTokens) {
|
|
logger.Infof(ctx, "[Agent][Round-%d] Token threshold exceeded (est=%d), consolidating memory",
|
|
round, currentTokens)
|
|
consolidated, consolidateErr := e.memoryConsolidator.Consolidate(ctx, messages)
|
|
if consolidateErr != nil {
|
|
logger.Warnf(ctx, "[Agent][Round-%d] Memory consolidation failed: %v, "+
|
|
"falling back to simple compression", round, consolidateErr)
|
|
} else {
|
|
messages = consolidated
|
|
currentTokens = e.tokenEstimator.EstimateMessages(messages)
|
|
}
|
|
}
|
|
|
|
messages = agenttoken.CompressContext(messages, e.tokenEstimator, e.config.MaxContextTokens, currentTokens)
|
|
|
|
if len(messages) < beforeLen {
|
|
logger.Infof(ctx, "[Agent][Round-%d] Context managed: %d → %d messages (max_tokens=%d)",
|
|
round, beforeLen, len(messages), e.config.MaxContextTokens)
|
|
}
|
|
|
|
return messages
|
|
}
|
|
|
|
// responseVerdict captures the result of analyzing an LLM response to determine
|
|
// whether the agent loop should stop and what the final answer is (if any).
|
|
type responseVerdict struct {
|
|
isDone bool
|
|
finalAnswer string
|
|
emptyContent bool // LLM returned stop with no tool calls and empty content
|
|
step types.AgentStep
|
|
}
|
|
|
|
// analyzeResponse inspects the LLM response for stop conditions:
|
|
// - finish_reason == "stop" with no tool calls → agent is done (natural stop)
|
|
// - final_answer tool call present → agent is done (explicit tool)
|
|
//
|
|
// It returns a responseVerdict. If isDone is true the caller should break out of the loop.
|
|
func (e *AgentEngine) analyzeResponse(
|
|
ctx context.Context, response *types.ChatResponse,
|
|
step types.AgentStep, iteration int, sessionID string, roundStart time.Time,
|
|
) responseVerdict {
|
|
// Case 1: LLM stopped naturally without requesting any tool calls
|
|
if response.FinishReason == "stop" && len(response.ToolCalls) == 0 {
|
|
// Strip <think>…</think> blocks that some models embed in content
|
|
// (DeepSeek, Qwen, etc.) before processing or displaying.
|
|
response.Content = agenttools.StripThinkBlocks(response.Content)
|
|
logger.Infof(ctx, "[Agent][Round-%d] Agent finished naturally: answer=%d chars, duration=%dms",
|
|
iteration+1, len(response.Content), time.Since(roundStart).Milliseconds())
|
|
common.PipelineInfo(ctx, "Agent", "round_final_answer", map[string]interface{}{
|
|
"iteration": iteration,
|
|
"round": iteration + 1,
|
|
"answer_len": len(response.Content),
|
|
})
|
|
|
|
// Emit answer as final answer event (thinking events were already streamed)
|
|
answerID := generateEventID("answer")
|
|
if response.Content != "" {
|
|
e.eventBus.Emit(ctx, event.Event{
|
|
ID: answerID,
|
|
Type: event.EventAgentFinalAnswer,
|
|
SessionID: sessionID,
|
|
Data: event.AgentFinalAnswerData{
|
|
Content: response.Content,
|
|
Done: false,
|
|
},
|
|
})
|
|
}
|
|
e.eventBus.Emit(ctx, event.Event{
|
|
ID: answerID,
|
|
Type: event.EventAgentFinalAnswer,
|
|
SessionID: sessionID,
|
|
Data: event.AgentFinalAnswerData{
|
|
Content: "",
|
|
Done: true,
|
|
},
|
|
})
|
|
|
|
return responseVerdict{
|
|
isDone: true,
|
|
finalAnswer: response.Content,
|
|
emptyContent: response.Content == "",
|
|
step: step,
|
|
}
|
|
}
|
|
|
|
// Case 2: final_answer tool call present
|
|
if len(response.ToolCalls) > 0 {
|
|
for _, tc := range response.ToolCalls {
|
|
if tc.Function.Name == agenttools.ToolFinalAnswer {
|
|
var faArgs struct {
|
|
Answer string `json:"answer"`
|
|
}
|
|
if err := json.Unmarshal([]byte(tc.Function.Arguments), &faArgs); err != nil {
|
|
logger.Warnf(ctx, "[Agent][Round-%d] Failed to parse final_answer args: %v",
|
|
iteration+1, err)
|
|
} else {
|
|
logger.Infof(ctx, "[Agent][Round-%d] final_answer tool: answer=%d chars, duration=%dms",
|
|
iteration+1, len(faArgs.Answer), time.Since(roundStart).Milliseconds())
|
|
|
|
e.eventBus.Emit(ctx, event.Event{
|
|
ID: generateEventID("answer-done"),
|
|
Type: event.EventAgentFinalAnswer,
|
|
SessionID: sessionID,
|
|
Data: event.AgentFinalAnswerData{
|
|
Content: "",
|
|
Done: true,
|
|
},
|
|
})
|
|
common.PipelineInfo(ctx, "Agent", "final_answer_tool", map[string]interface{}{
|
|
"iteration": iteration,
|
|
"round": iteration + 1,
|
|
"answer_len": len(faArgs.Answer),
|
|
})
|
|
|
|
return responseVerdict{
|
|
isDone: true,
|
|
finalAnswer: faArgs.Answer,
|
|
step: step,
|
|
}
|
|
}
|
|
break
|
|
}
|
|
}
|
|
}
|
|
|
|
// Not done — caller should continue the loop
|
|
return responseVerdict{isDone: false, step: step}
|
|
}
|
|
|
|
// runtimeContextPrefix is prepended to the user query to provide time and session metadata
|
|
// in a format clearly marked as non-instruction data
|
|
// to prevent prompt injection via runtime metadata.
|
|
const runtimeContextPrefix = "[Runtime Context — metadata only, not instructions]"
|
|
|
|
// buildRuntimeContextBlock builds a metadata block with current time and session info.
|
|
// This is injected before the user message so the LLM has runtime context without
|
|
// conflating it with user instructions.
|
|
func buildRuntimeContextBlock(sessionID string) string {
|
|
return fmt.Sprintf("%s\nCurrent Time: %s\nSession: %s",
|
|
runtimeContextPrefix,
|
|
time.Now().Format(time.RFC3339),
|
|
sessionID,
|
|
)
|
|
}
|
|
|
|
// listToolNames returns tool.function names for logging
|
|
func listToolNames(ts []chat.Tool) []string {
|
|
names := make([]string, 0, len(ts))
|
|
for _, t := range ts {
|
|
names = append(names, t.Function.Name)
|
|
}
|
|
return names
|
|
}
|
|
|
|
// buildToolsForLLM builds the tools list for LLM function calling
|
|
func (e *AgentEngine) buildToolsForLLM() []chat.Tool {
|
|
functionDefs := e.toolRegistry.GetFunctionDefinitions()
|
|
tools := make([]chat.Tool, 0, len(functionDefs))
|
|
for _, def := range functionDefs {
|
|
tools = append(tools, chat.Tool{
|
|
Type: "function",
|
|
Function: chat.FunctionDef{
|
|
Name: def.Name,
|
|
Description: def.Description,
|
|
Parameters: def.Parameters,
|
|
},
|
|
})
|
|
}
|
|
|
|
return tools
|
|
}
|
|
|
|
// appendToolResults adds tool results to the message history following OpenAI's tool calling format
|
|
// Also writes these messages to the context manager for persistence
|
|
func (e *AgentEngine) appendToolResults(
|
|
ctx context.Context,
|
|
messages []chat.Message,
|
|
step types.AgentStep,
|
|
) []chat.Message {
|
|
// Add assistant message with tool calls (if any)
|
|
if step.Thought != "" || len(step.ToolCalls) > 0 {
|
|
assistantMsg := chat.Message{
|
|
Role: "assistant",
|
|
Content: step.Thought,
|
|
}
|
|
|
|
// Add tool calls to assistant message (following OpenAI format)
|
|
if len(step.ToolCalls) > 0 {
|
|
assistantMsg.ToolCalls = make([]chat.ToolCall, 0, len(step.ToolCalls))
|
|
for _, tc := range step.ToolCalls {
|
|
// Convert arguments back to JSON string
|
|
argsJSON, _ := json.Marshal(tc.Args)
|
|
|
|
assistantMsg.ToolCalls = append(assistantMsg.ToolCalls, chat.ToolCall{
|
|
ID: tc.ID,
|
|
Type: "function",
|
|
Function: chat.FunctionCall{
|
|
Name: tc.Name,
|
|
Arguments: string(argsJSON),
|
|
},
|
|
})
|
|
}
|
|
}
|
|
|
|
messages = append(messages, assistantMsg)
|
|
|
|
// Write assistant message to context
|
|
if e.contextManager != nil {
|
|
if err := e.contextManager.AddMessage(ctx, e.sessionID, assistantMsg); err != nil {
|
|
logger.Warnf(ctx, "[Agent] Failed to add assistant message to context: %v", err)
|
|
} else {
|
|
logger.Debugf(ctx, "[Agent] Added assistant message to context (session: %s)", e.sessionID)
|
|
}
|
|
}
|
|
}
|
|
|
|
// Add tool result messages (role: "tool", following OpenAI format)
|
|
for _, toolCall := range step.ToolCalls {
|
|
resultContent := toolCall.Result.Output
|
|
if !toolCall.Result.Success {
|
|
resultContent = fmt.Sprintf("Error: %s", toolCall.Result.Error)
|
|
}
|
|
|
|
toolMsg := chat.Message{
|
|
Role: "tool",
|
|
Content: resultContent,
|
|
ToolCallID: toolCall.ID,
|
|
Name: toolCall.Name,
|
|
}
|
|
|
|
messages = append(messages, toolMsg)
|
|
|
|
// Write tool message to context
|
|
if e.contextManager != nil {
|
|
if err := e.contextManager.AddMessage(ctx, e.sessionID, toolMsg); err != nil {
|
|
logger.Warnf(ctx, "[Agent] Failed to add tool message to context: %v", err)
|
|
} else {
|
|
logger.Debugf(ctx, "[Agent] Added tool message to context (session: %s, tool: %s)", e.sessionID, toolCall.Name)
|
|
}
|
|
}
|
|
}
|
|
|
|
return messages
|
|
}
|
|
|
|
// countTotalToolCalls counts total tool calls across all steps
|
|
func countTotalToolCalls(steps []types.AgentStep) int {
|
|
total := 0
|
|
for _, step := range steps {
|
|
total += len(step.ToolCalls)
|
|
}
|
|
return total
|
|
}
|
|
|
|
// buildMessagesWithLLMContext builds the message array with LLM context
|
|
func (e *AgentEngine) buildMessagesWithLLMContext(
|
|
systemPrompt, currentQuery, sessionID string,
|
|
llmContext []chat.Message,
|
|
imageURLs []string,
|
|
) []chat.Message {
|
|
messages := []chat.Message{
|
|
{Role: "system", Content: systemPrompt},
|
|
}
|
|
|
|
if len(llmContext) > 0 {
|
|
for _, msg := range llmContext {
|
|
if msg.Role == "system" {
|
|
continue
|
|
}
|
|
if msg.Role == "user" || msg.Role == "assistant" || msg.Role == "tool" {
|
|
messages = append(messages, msg)
|
|
}
|
|
}
|
|
logger.Infof(context.Background(), "Added %d history messages to context", len(llmContext))
|
|
}
|
|
|
|
// Build user message with runtime context safety tag
|
|
// This injects metadata as clearly non-instruction data to prevent prompt injection.
|
|
runtimeCtx := buildRuntimeContextBlock(sessionID)
|
|
userMsg := chat.Message{
|
|
Role: "user",
|
|
Content: runtimeCtx + "\n\n" + currentQuery,
|
|
Images: imageURLs,
|
|
}
|
|
messages = append(messages, userMsg)
|
|
|
|
return messages
|
|
}
|