mirror of
https://github.com/Tencent/WeKnora.git
synced 2026-06-04 13:30:32 +08:00
feat(agent): enhance message consolidation and context management
- Improved the message consolidation logic to preserve the current turn's user query and all subsequent assistant/tool messages, ensuring better context retention. - Updated the CompressContext function to reflect the new consolidation strategy, maintaining the system prompt and relevant recent messages. - Refactored the context manager to support optional message repository for improved context rebuilding from persistent storage. - Added comprehensive tests to validate the new consolidation behavior and ensure correct message handling across various scenarios.
This commit is contained in:
@@ -293,7 +293,11 @@ func (e *AgentEngine) executeLoop(
|
||||
// the API-reported usage from the previous round plus a BPE delta
|
||||
// for newly appended messages (assistant reply + tool results).
|
||||
currentTokens := e.estimateCurrentTokens(messages)
|
||||
beforeLen := len(messages)
|
||||
messages = e.manageContextWindow(ctx, messages, state.CurrentRound+1, currentTokens)
|
||||
if len(messages) < beforeLen {
|
||||
currentTokens = e.tokenEstimator.EstimateMessages(messages)
|
||||
}
|
||||
|
||||
logger.Infof(ctx, "[Agent][Round-%d/%d] Starting: %d messages, %d tools, est_tokens=%d",
|
||||
state.CurrentRound+1, e.config.MaxIterations, len(messages), len(tools), currentTokens)
|
||||
|
||||
@@ -70,10 +70,11 @@ func (c *Consolidator) ShouldConsolidate(currentTokens int) bool {
|
||||
// Consolidate summarizes older messages and returns a compressed message array.
|
||||
// It preserves:
|
||||
// - The system prompt (first message)
|
||||
// - The current user query (last message)
|
||||
// - Recent messages that fit within the token budget
|
||||
// - The current turn: user query (last user message) and all subsequent
|
||||
// assistant/tool messages belonging to the same turn
|
||||
// - Recent history messages that fit within the token budget
|
||||
//
|
||||
// Older messages are replaced with a summary system message.
|
||||
// Older history messages are replaced with a summary system message.
|
||||
// On LLM failure after maxConsolidationAttempts, falls back to raw text archiving.
|
||||
func (c *Consolidator) Consolidate(
|
||||
ctx context.Context,
|
||||
@@ -84,32 +85,51 @@ func (c *Consolidator) Consolidate(
|
||||
}
|
||||
|
||||
systemMsg := messages[0]
|
||||
lastMsg := messages[len(messages)-1]
|
||||
middle := messages[1 : len(messages)-1]
|
||||
|
||||
// Determine how many messages to keep vs. consolidate.
|
||||
// We want to consolidate enough to bring tokens below the threshold.
|
||||
targetTokens := int(float64(c.maxTokens) * c.threshold * 0.6) // aim for 60% of threshold
|
||||
keepFromEnd := c.findKeepBoundary(middle, targetTokens, &systemMsg, &lastMsg)
|
||||
|
||||
if keepFromEnd >= len(middle) {
|
||||
// Nothing to consolidate
|
||||
// Find the current user query — the last message with role "user".
|
||||
// Everything from this point onward (user query + assistant tool_calls +
|
||||
// tool results) is the active turn and must be preserved intact.
|
||||
lastUserIdx := 0
|
||||
for i := len(messages) - 1; i >= 1; i-- {
|
||||
if messages[i].Role == "user" {
|
||||
lastUserIdx = i
|
||||
break
|
||||
}
|
||||
}
|
||||
if lastUserIdx <= 1 {
|
||||
return messages, nil
|
||||
}
|
||||
|
||||
toConsolidate := middle[:len(middle)-keepFromEnd]
|
||||
toKeep := middle[len(middle)-keepFromEnd:]
|
||||
history := messages[1:lastUserIdx]
|
||||
tail := messages[lastUserIdx:]
|
||||
|
||||
if len(history) < 2 {
|
||||
return messages, nil
|
||||
}
|
||||
|
||||
targetTokens := int(float64(c.maxTokens) * c.threshold * 0.6) // aim for 60% of threshold
|
||||
|
||||
tailTokens := 0
|
||||
for i := range tail {
|
||||
tailTokens += c.estimator.EstimateMessage(&tail[i])
|
||||
}
|
||||
|
||||
keepFromEnd := c.findKeepBoundary(history, targetTokens, &systemMsg, tailTokens)
|
||||
|
||||
if keepFromEnd >= len(history) {
|
||||
return messages, nil
|
||||
}
|
||||
|
||||
toConsolidate := history[:len(history)-keepFromEnd]
|
||||
toKeep := history[len(history)-keepFromEnd:]
|
||||
|
||||
// Try LLM-powered summarization
|
||||
summary, err := c.summarizeWithRetry(ctx, toConsolidate)
|
||||
if err != nil {
|
||||
// Fall back to raw archiving
|
||||
logger.Warnf(ctx, "[MemoryConsolidator] LLM summarization failed after retries, "+
|
||||
"falling back to raw archive: %v", err)
|
||||
summary = c.rawArchive(toConsolidate)
|
||||
}
|
||||
|
||||
// Build consolidated messages
|
||||
summaryMsg := chat.Message{
|
||||
Role: "system",
|
||||
Content: fmt.Sprintf(
|
||||
@@ -118,59 +138,56 @@ func (c *Consolidator) Consolidate(
|
||||
),
|
||||
}
|
||||
|
||||
result := make([]chat.Message, 0, 3+len(toKeep))
|
||||
result := make([]chat.Message, 0, 2+len(toKeep)+len(tail))
|
||||
result = append(result, systemMsg)
|
||||
result = append(result, summaryMsg)
|
||||
result = append(result, toKeep...)
|
||||
result = append(result, lastMsg)
|
||||
result = append(result, tail...)
|
||||
|
||||
logger.Infof(ctx, "[MemoryConsolidator] Consolidated %d messages → summary (%d chars), "+
|
||||
"keeping %d recent messages",
|
||||
len(toConsolidate), len(summary), len(toKeep))
|
||||
"keeping %d history + %d current-turn messages",
|
||||
len(toConsolidate), len(summary), len(toKeep), len(tail))
|
||||
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// findKeepBoundary determines how many messages from the end of middle to keep.
|
||||
// findKeepBoundary determines how many messages from the end of history to keep.
|
||||
// Returns the count of messages to keep (from the end), respecting tool_call/tool_result boundaries.
|
||||
// tailTokens is the token cost of the current-turn tail that is always preserved.
|
||||
func (c *Consolidator) findKeepBoundary(
|
||||
middle []chat.Message,
|
||||
history []chat.Message,
|
||||
targetTokens int,
|
||||
systemMsg, lastMsg *chat.Message,
|
||||
systemMsg *chat.Message,
|
||||
tailTokens int,
|
||||
) int {
|
||||
// Budget for kept messages = target - system - last
|
||||
budget := targetTokens -
|
||||
c.estimator.EstimateMessage(systemMsg) -
|
||||
c.estimator.EstimateMessage(lastMsg) -
|
||||
tailTokens -
|
||||
500 // reserve for summary message
|
||||
|
||||
if budget <= 0 {
|
||||
return 0
|
||||
}
|
||||
|
||||
// Walk from the end, accumulating tokens, respecting tool boundaries
|
||||
tokens := 0
|
||||
keepCount := 0
|
||||
i := len(middle) - 1
|
||||
i := len(history) - 1
|
||||
|
||||
for i >= 0 {
|
||||
msg := middle[i]
|
||||
msg := history[i]
|
||||
msgTokens := c.estimator.EstimateMessage(&msg)
|
||||
|
||||
// If this is a tool result, we must also keep its assistant message
|
||||
if msg.Role == "tool" {
|
||||
// Walk back to find the assistant message with tool_calls
|
||||
groupTokens := msgTokens
|
||||
groupSize := 1
|
||||
j := i - 1
|
||||
for j >= 0 && middle[j].Role == "tool" {
|
||||
groupTokens += c.estimator.EstimateMessage(&middle[j])
|
||||
for j >= 0 && history[j].Role == "tool" {
|
||||
groupTokens += c.estimator.EstimateMessage(&history[j])
|
||||
groupSize++
|
||||
j--
|
||||
}
|
||||
// Include the assistant message
|
||||
if j >= 0 && middle[j].Role == "assistant" {
|
||||
groupTokens += c.estimator.EstimateMessage(&middle[j])
|
||||
if j >= 0 && history[j].Role == "assistant" {
|
||||
groupTokens += c.estimator.EstimateMessage(&history[j])
|
||||
groupSize++
|
||||
}
|
||||
|
||||
|
||||
@@ -1,13 +1,37 @@
|
||||
package memory
|
||||
|
||||
import (
|
||||
"context"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/Tencent/WeKnora/internal/agent/token"
|
||||
"github.com/Tencent/WeKnora/internal/models/chat"
|
||||
"github.com/Tencent/WeKnora/internal/types"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
// stubChat is a minimal chat.Chat implementation for testing consolidation.
|
||||
type stubChat struct {
|
||||
response string
|
||||
err error
|
||||
}
|
||||
|
||||
func (s *stubChat) Chat(_ context.Context, _ []chat.Message, _ *chat.ChatOptions) (*types.ChatResponse, error) {
|
||||
if s.err != nil {
|
||||
return nil, s.err
|
||||
}
|
||||
return &types.ChatResponse{Content: s.response}, nil
|
||||
}
|
||||
|
||||
func (s *stubChat) ChatStream(context.Context, []chat.Message, *chat.ChatOptions) (<-chan types.StreamResponse, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (s *stubChat) GetModelName() string { return "stub" }
|
||||
func (s *stubChat) GetModelID() string { return "stub" }
|
||||
|
||||
func TestConsolidator_ShouldConsolidate(t *testing.T) {
|
||||
est, err := token.NewEstimator()
|
||||
assert.NoError(t, err)
|
||||
@@ -93,3 +117,258 @@ func TestBuildConsolidationPrompt(t *testing.T) {
|
||||
assert.Contains(t, prompt, "web_search")
|
||||
assert.Contains(t, prompt, "**Tool [web_search]**: results here")
|
||||
}
|
||||
|
||||
// ---------- Consolidate() 核心流程测试 ----------
|
||||
|
||||
func TestConsolidate_TooFewMessages(t *testing.T) {
|
||||
est, err := token.NewEstimator()
|
||||
require.NoError(t, err)
|
||||
c := NewConsolidator(&stubChat{response: "summary"}, est, 100, 0)
|
||||
|
||||
msgs := []chat.Message{
|
||||
{Role: "system", Content: "sys"},
|
||||
{Role: "user", Content: "hi"},
|
||||
{Role: "assistant", Content: "hello"},
|
||||
}
|
||||
result, err := c.Consolidate(context.Background(), msgs)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, msgs, result, "<=3 messages should be returned unchanged")
|
||||
}
|
||||
|
||||
func TestConsolidate_Round1_UserQueryAtEnd(t *testing.T) {
|
||||
est, err := token.NewEstimator()
|
||||
require.NoError(t, err)
|
||||
|
||||
// Use a stub that returns a known summary.
|
||||
c := NewConsolidator(&stubChat{response: "summary of old history"}, est, 200, 0)
|
||||
|
||||
long := strings.Repeat("old context data ", 200)
|
||||
msgs := []chat.Message{
|
||||
{Role: "system", Content: "You are a helpful assistant"},
|
||||
{Role: "user", Content: long},
|
||||
{Role: "assistant", Content: long},
|
||||
{Role: "user", Content: long},
|
||||
{Role: "assistant", Content: long},
|
||||
{Role: "user", Content: "current question"}, // last user = current query
|
||||
}
|
||||
|
||||
result, err := c.Consolidate(context.Background(), msgs)
|
||||
require.NoError(t, err)
|
||||
|
||||
// System prompt is preserved.
|
||||
assert.Equal(t, "system", result[0].Role)
|
||||
assert.Equal(t, "You are a helpful assistant", result[0].Content)
|
||||
|
||||
// The current user query (last user message) must be preserved at the tail.
|
||||
last := result[len(result)-1]
|
||||
assert.Equal(t, "user", last.Role)
|
||||
assert.Equal(t, "current question", last.Content)
|
||||
|
||||
// A summary message must have been inserted.
|
||||
assert.Equal(t, "system", result[1].Role)
|
||||
assert.Contains(t, result[1].Content, "Memory Summary")
|
||||
|
||||
// Total message count should be fewer than original.
|
||||
assert.Less(t, len(result), len(msgs))
|
||||
}
|
||||
|
||||
func TestConsolidate_Round2Plus_UserQueryNotAtEnd(t *testing.T) {
|
||||
est, err := token.NewEstimator()
|
||||
require.NoError(t, err)
|
||||
|
||||
c := NewConsolidator(&stubChat{response: "consolidated history"}, est, 300, 0)
|
||||
|
||||
longContent := strings.Repeat("verbose content ", 200)
|
||||
|
||||
// Simulates Agent Round 2+:
|
||||
// [system, ...old_history..., user_query, assistant+tools, tool_results]
|
||||
// The user query is NOT the last message.
|
||||
msgs := []chat.Message{
|
||||
{Role: "system", Content: "You are a helpful assistant"},
|
||||
// --- old history (from previous turns) ---
|
||||
{Role: "user", Content: longContent},
|
||||
{Role: "assistant", Content: longContent},
|
||||
{Role: "user", Content: longContent},
|
||||
{Role: "assistant", Content: longContent, ToolCalls: []chat.ToolCall{
|
||||
{ID: "old_call", Function: chat.FunctionCall{Name: "search", Arguments: `{}`}},
|
||||
}},
|
||||
{Role: "tool", Content: longContent, ToolCallID: "old_call", Name: "search"},
|
||||
// --- current turn ---
|
||||
{Role: "user", Content: "what is the weather today?"}, // current user query
|
||||
{Role: "assistant", Content: "let me check", ToolCalls: []chat.ToolCall{
|
||||
{ID: "call_1", Function: chat.FunctionCall{Name: "weather", Arguments: `{"city":"beijing"}`}},
|
||||
}},
|
||||
{Role: "tool", Content: "sunny, 25°C", ToolCallID: "call_1", Name: "weather"},
|
||||
}
|
||||
|
||||
result, err := c.Consolidate(context.Background(), msgs)
|
||||
require.NoError(t, err)
|
||||
|
||||
// System prompt preserved.
|
||||
assert.Equal(t, "system", result[0].Role)
|
||||
|
||||
// The entire current turn tail must be intact (user query + assistant + tool).
|
||||
// Find the user query in result.
|
||||
userQueryIdx := -1
|
||||
for i, m := range result {
|
||||
if m.Role == "user" && m.Content == "what is the weather today?" {
|
||||
userQueryIdx = i
|
||||
break
|
||||
}
|
||||
}
|
||||
require.NotEqual(t, -1, userQueryIdx, "current user query must be preserved")
|
||||
|
||||
// After user query: assistant with tool_calls, then tool result.
|
||||
require.Greater(t, len(result), userQueryIdx+2)
|
||||
assert.Equal(t, "assistant", result[userQueryIdx+1].Role)
|
||||
assert.Equal(t, "let me check", result[userQueryIdx+1].Content)
|
||||
assert.Len(t, result[userQueryIdx+1].ToolCalls, 1)
|
||||
|
||||
assert.Equal(t, "tool", result[userQueryIdx+2].Role)
|
||||
assert.Equal(t, "sunny, 25°C", result[userQueryIdx+2].Content)
|
||||
|
||||
// Summary message exists.
|
||||
hasSummary := false
|
||||
for _, m := range result {
|
||||
if m.Role == "system" && strings.Contains(m.Content, "Memory Summary") {
|
||||
hasSummary = true
|
||||
break
|
||||
}
|
||||
}
|
||||
assert.True(t, hasSummary, "should contain a memory summary message")
|
||||
|
||||
// Message count reduced.
|
||||
assert.Less(t, len(result), len(msgs))
|
||||
}
|
||||
|
||||
func TestConsolidate_Round2Plus_MultipleToolCallsPreserved(t *testing.T) {
|
||||
est, err := token.NewEstimator()
|
||||
require.NoError(t, err)
|
||||
|
||||
c := NewConsolidator(&stubChat{response: "summary"}, est, 300, 0)
|
||||
|
||||
longContent := strings.Repeat("filler ", 300)
|
||||
msgs := []chat.Message{
|
||||
{Role: "system", Content: "sys"},
|
||||
// old history
|
||||
{Role: "user", Content: longContent},
|
||||
{Role: "assistant", Content: longContent},
|
||||
// current turn with parallel tool calls
|
||||
{Role: "user", Content: "do two things"},
|
||||
{Role: "assistant", Content: "ok", ToolCalls: []chat.ToolCall{
|
||||
{ID: "c1", Function: chat.FunctionCall{Name: "toolA", Arguments: `{}`}},
|
||||
{ID: "c2", Function: chat.FunctionCall{Name: "toolB", Arguments: `{}`}},
|
||||
}},
|
||||
{Role: "tool", Content: "resultA", ToolCallID: "c1", Name: "toolA"},
|
||||
{Role: "tool", Content: "resultB", ToolCallID: "c2", Name: "toolB"},
|
||||
}
|
||||
|
||||
result, err := c.Consolidate(context.Background(), msgs)
|
||||
require.NoError(t, err)
|
||||
|
||||
// All three current-turn messages after the user query must be present.
|
||||
var toolNames []string
|
||||
for _, m := range result {
|
||||
if m.Role == "tool" {
|
||||
toolNames = append(toolNames, m.Name)
|
||||
}
|
||||
}
|
||||
assert.Contains(t, toolNames, "toolA")
|
||||
assert.Contains(t, toolNames, "toolB")
|
||||
|
||||
// User query preserved.
|
||||
found := false
|
||||
for _, m := range result {
|
||||
if m.Role == "user" && m.Content == "do two things" {
|
||||
found = true
|
||||
}
|
||||
}
|
||||
assert.True(t, found)
|
||||
}
|
||||
|
||||
func TestConsolidate_NoUserMessage_ReturnUnchanged(t *testing.T) {
|
||||
est, err := token.NewEstimator()
|
||||
require.NoError(t, err)
|
||||
|
||||
c := NewConsolidator(&stubChat{response: "summary"}, est, 100, 0)
|
||||
|
||||
// Edge case: no user message at all (shouldn't happen normally but be defensive).
|
||||
msgs := []chat.Message{
|
||||
{Role: "system", Content: "sys"},
|
||||
{Role: "assistant", Content: "hello"},
|
||||
{Role: "assistant", Content: "world"},
|
||||
{Role: "assistant", Content: "more"},
|
||||
}
|
||||
|
||||
result, err := c.Consolidate(context.Background(), msgs)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, msgs, result, "no user message → return unchanged")
|
||||
}
|
||||
|
||||
func TestConsolidate_LLMFailure_FallsBackToRawArchive(t *testing.T) {
|
||||
est, err := token.NewEstimator()
|
||||
require.NoError(t, err)
|
||||
|
||||
failChat := &stubChat{err: assert.AnError}
|
||||
c := NewConsolidator(failChat, est, 200, 0)
|
||||
|
||||
longContent := strings.Repeat("data ", 300)
|
||||
msgs := []chat.Message{
|
||||
{Role: "system", Content: "sys"},
|
||||
{Role: "user", Content: longContent},
|
||||
{Role: "assistant", Content: longContent},
|
||||
{Role: "user", Content: longContent},
|
||||
{Role: "assistant", Content: longContent},
|
||||
{Role: "user", Content: "current"},
|
||||
{Role: "assistant", Content: "thinking", ToolCalls: []chat.ToolCall{
|
||||
{ID: "c1", Function: chat.FunctionCall{Name: "t1"}},
|
||||
}},
|
||||
{Role: "tool", Content: "res", ToolCallID: "c1", Name: "t1"},
|
||||
}
|
||||
|
||||
result, err := c.Consolidate(context.Background(), msgs)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Should still have a summary (raw archive fallback).
|
||||
hasSummary := false
|
||||
for _, m := range result {
|
||||
if m.Role == "system" && strings.Contains(m.Content, "Memory Summary") {
|
||||
hasSummary = true
|
||||
assert.Contains(t, m.Content, "Raw conversation archive",
|
||||
"should be a raw archive when LLM fails")
|
||||
}
|
||||
}
|
||||
assert.True(t, hasSummary)
|
||||
|
||||
// Current turn tail must still be preserved.
|
||||
last := result[len(result)-1]
|
||||
assert.Equal(t, "tool", last.Role)
|
||||
|
||||
userFound := false
|
||||
for _, m := range result {
|
||||
if m.Role == "user" && m.Content == "current" {
|
||||
userFound = true
|
||||
}
|
||||
}
|
||||
assert.True(t, userFound, "current user query must survive even on LLM failure")
|
||||
}
|
||||
|
||||
func TestConsolidate_OnlyCurrentTurn_NothingToConsolidate(t *testing.T) {
|
||||
est, err := token.NewEstimator()
|
||||
require.NoError(t, err)
|
||||
c := NewConsolidator(&stubChat{response: "summary"}, est, 100, 0)
|
||||
|
||||
// Only system + user + assistant + tool → history between system and user is empty.
|
||||
msgs := []chat.Message{
|
||||
{Role: "system", Content: "sys"},
|
||||
{Role: "user", Content: "hello"},
|
||||
{Role: "assistant", Content: "let me help", ToolCalls: []chat.ToolCall{
|
||||
{ID: "c1", Function: chat.FunctionCall{Name: "t1"}},
|
||||
}},
|
||||
{Role: "tool", Content: "done", ToolCallID: "c1", Name: "t1"},
|
||||
}
|
||||
|
||||
result, err := c.Consolidate(context.Background(), msgs)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, msgs, result, "nothing to consolidate → unchanged")
|
||||
}
|
||||
|
||||
@@ -8,15 +8,14 @@ import (
|
||||
// When message tokens exceed MaxContextTokens * threshold, old messages are trimmed.
|
||||
const DefaultContextThresholdRatio = 0.8
|
||||
|
||||
// CompressContext trims older messages to bring total token count below the threshold.
|
||||
// currentTokens is the caller's best estimate of the current context size (from API
|
||||
// Usage when available, falling back to BPE estimation). The estimator is still used
|
||||
// internally for per-message-group cost calculation during trimming.
|
||||
//
|
||||
// It preserves:
|
||||
// - The first message (system prompt)
|
||||
// - The last message (current user query)
|
||||
// CompressContext trims older history messages to bring total token count below
|
||||
// the threshold. It preserves:
|
||||
// - The system prompt (first message)
|
||||
// - The current turn: user query (last user message) and all subsequent
|
||||
// assistant/tool messages
|
||||
// - tool_call / tool_result message pairs (never splits them)
|
||||
//
|
||||
// currentTokens is the caller's best estimate of the current context size.
|
||||
func CompressContext(
|
||||
messages []chat.Message,
|
||||
estimator *Estimator,
|
||||
@@ -33,10 +32,24 @@ func CompressContext(
|
||||
}
|
||||
|
||||
systemMsg := messages[0]
|
||||
lastMsg := messages[len(messages)-1]
|
||||
middle := messages[1 : len(messages)-1]
|
||||
|
||||
groups := groupToolMessages(middle)
|
||||
// Find the current user query — the last message with role "user".
|
||||
lastUserIdx := len(messages) - 1
|
||||
for i := len(messages) - 1; i >= 1; i-- {
|
||||
if messages[i].Role == "user" {
|
||||
lastUserIdx = i
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
history := messages[1:lastUserIdx]
|
||||
tail := messages[lastUserIdx:]
|
||||
|
||||
if len(history) == 0 {
|
||||
return messages
|
||||
}
|
||||
|
||||
groups := groupToolMessages(history)
|
||||
|
||||
tokensToFree := currentTokens - threshold
|
||||
freed := 0
|
||||
@@ -59,7 +72,7 @@ func CompressContext(
|
||||
for i := removeUpTo; i < len(groups); i++ {
|
||||
remaining = append(remaining, groups[i]...)
|
||||
}
|
||||
remaining = append(remaining, lastMsg)
|
||||
remaining = append(remaining, tail...)
|
||||
|
||||
return remaining
|
||||
}
|
||||
|
||||
@@ -154,6 +154,153 @@ func TestCompressContext(t *testing.T) {
|
||||
result := CompressContext(messages, e, 1, 999)
|
||||
assert.Equal(t, messages, result)
|
||||
})
|
||||
|
||||
t.Run("round2+: user query not at end, preserves current turn tail", func(t *testing.T) {
|
||||
longText := strings.Repeat("filler content ", 200)
|
||||
messages := []chat.Message{
|
||||
{Role: "system", Content: "system prompt"},
|
||||
// old history
|
||||
{Role: "user", Content: longText},
|
||||
{Role: "assistant", Content: longText},
|
||||
{Role: "user", Content: longText},
|
||||
{Role: "assistant", Content: longText},
|
||||
// current turn
|
||||
{Role: "user", Content: "current question"},
|
||||
{Role: "assistant", Content: "let me search", ToolCalls: []chat.ToolCall{
|
||||
{ID: "c1", Function: chat.FunctionCall{Name: "search", Arguments: `{"q":"test"}`}},
|
||||
}},
|
||||
{Role: "tool", Content: "search results", ToolCallID: "c1", Name: "search"},
|
||||
}
|
||||
|
||||
tokens := e.EstimateMessages(messages)
|
||||
result := CompressContext(messages, e, 300, tokens)
|
||||
|
||||
// System prompt preserved.
|
||||
assert.Equal(t, "system", result[0].Role)
|
||||
assert.Equal(t, "system prompt", result[0].Content)
|
||||
|
||||
// Current turn tail must be fully intact.
|
||||
userIdx := -1
|
||||
for i, m := range result {
|
||||
if m.Role == "user" && m.Content == "current question" {
|
||||
userIdx = i
|
||||
break
|
||||
}
|
||||
}
|
||||
require.NotEqual(t, -1, userIdx, "user query must be preserved")
|
||||
require.Greater(t, len(result), userIdx+2, "assistant + tool must follow")
|
||||
assert.Equal(t, "assistant", result[userIdx+1].Role)
|
||||
assert.Equal(t, "let me search", result[userIdx+1].Content)
|
||||
assert.Equal(t, "tool", result[userIdx+2].Role)
|
||||
assert.Equal(t, "search results", result[userIdx+2].Content)
|
||||
|
||||
// Should be shorter than original (old history trimmed).
|
||||
assert.Less(t, len(result), len(messages))
|
||||
})
|
||||
|
||||
t.Run("round2+: multiple tool results after user query all preserved", func(t *testing.T) {
|
||||
longText := strings.Repeat("data ", 300)
|
||||
messages := []chat.Message{
|
||||
{Role: "system", Content: "sys"},
|
||||
{Role: "user", Content: longText},
|
||||
{Role: "assistant", Content: longText},
|
||||
// current turn with parallel tool calls
|
||||
{Role: "user", Content: "do things"},
|
||||
{Role: "assistant", Content: "ok", ToolCalls: []chat.ToolCall{
|
||||
{ID: "c1", Function: chat.FunctionCall{Name: "t1"}},
|
||||
{ID: "c2", Function: chat.FunctionCall{Name: "t2"}},
|
||||
}},
|
||||
{Role: "tool", Content: "res1", ToolCallID: "c1", Name: "t1"},
|
||||
{Role: "tool", Content: "res2", ToolCallID: "c2", Name: "t2"},
|
||||
}
|
||||
|
||||
tokens := e.EstimateMessages(messages)
|
||||
result := CompressContext(messages, e, 200, tokens)
|
||||
|
||||
// Both tool results must be present.
|
||||
var toolNames []string
|
||||
for _, m := range result {
|
||||
if m.Role == "tool" {
|
||||
toolNames = append(toolNames, m.Name)
|
||||
}
|
||||
}
|
||||
assert.Contains(t, toolNames, "t1")
|
||||
assert.Contains(t, toolNames, "t2")
|
||||
|
||||
// User query preserved.
|
||||
found := false
|
||||
for _, m := range result {
|
||||
if m.Content == "do things" {
|
||||
found = true
|
||||
}
|
||||
}
|
||||
assert.True(t, found)
|
||||
})
|
||||
|
||||
t.Run("round2+: no history between system and user query returns unchanged", func(t *testing.T) {
|
||||
messages := []chat.Message{
|
||||
{Role: "system", Content: "sys"},
|
||||
{Role: "user", Content: "hello"},
|
||||
{Role: "assistant", Content: "thinking", ToolCalls: []chat.ToolCall{
|
||||
{ID: "c1", Function: chat.FunctionCall{Name: "t1"}},
|
||||
}},
|
||||
{Role: "tool", Content: "done", ToolCallID: "c1", Name: "t1"},
|
||||
}
|
||||
tokens := e.EstimateMessages(messages)
|
||||
result := CompressContext(messages, e, 10, tokens)
|
||||
assert.Equal(t, messages, result, "no history to trim → unchanged")
|
||||
})
|
||||
|
||||
t.Run("no user message at all returns unchanged", func(t *testing.T) {
|
||||
messages := []chat.Message{
|
||||
{Role: "system", Content: "sys"},
|
||||
{Role: "assistant", Content: strings.Repeat("x", 1000)},
|
||||
{Role: "assistant", Content: strings.Repeat("y", 1000)},
|
||||
}
|
||||
tokens := e.EstimateMessages(messages)
|
||||
result := CompressContext(messages, e, 10, tokens)
|
||||
// lastUserIdx defaults to len-1 (last msg), history = messages[1:last] = [assistant]
|
||||
// This still does its best—the important thing is it doesn't panic.
|
||||
assert.NotNil(t, result)
|
||||
})
|
||||
|
||||
t.Run("round2+: tool pair in history never split", func(t *testing.T) {
|
||||
longText := strings.Repeat("verbose ", 200)
|
||||
messages := []chat.Message{
|
||||
{Role: "system", Content: "sys"},
|
||||
// old turn 1
|
||||
{Role: "user", Content: "old query 1"},
|
||||
{Role: "assistant", Content: longText, ToolCalls: []chat.ToolCall{
|
||||
{ID: "old1", Function: chat.FunctionCall{Name: "old_tool"}},
|
||||
}},
|
||||
{Role: "tool", Content: longText, ToolCallID: "old1", Name: "old_tool"},
|
||||
// old turn 2
|
||||
{Role: "user", Content: "old query 2"},
|
||||
{Role: "assistant", Content: "short reply"},
|
||||
// current turn
|
||||
{Role: "user", Content: "current"},
|
||||
{Role: "assistant", Content: "working", ToolCalls: []chat.ToolCall{
|
||||
{ID: "new1", Function: chat.FunctionCall{Name: "new_tool"}},
|
||||
}},
|
||||
{Role: "tool", Content: "result", ToolCallID: "new1", Name: "new_tool"},
|
||||
}
|
||||
|
||||
tokens := e.EstimateMessages(messages)
|
||||
result := CompressContext(messages, e, 300, tokens)
|
||||
|
||||
// Verify no orphaned tool message: every "tool" msg must be preceded by
|
||||
// an assistant with tool_calls.
|
||||
for i, m := range result {
|
||||
if m.Role == "tool" {
|
||||
require.Greater(t, i, 0, "tool message at index 0 is impossible")
|
||||
prev := result[i-1]
|
||||
isPairedTool := prev.Role == "tool"
|
||||
isPairedAssistant := prev.Role == "assistant" && len(prev.ToolCalls) > 0
|
||||
assert.True(t, isPairedTool || isPairedAssistant,
|
||||
"tool message at %d must be preceded by assistant+tool_calls or another tool, got %s", i, prev.Role)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestGroupToolMessages(t *testing.T) {
|
||||
|
||||
@@ -1,258 +0,0 @@
|
||||
package llmcontext
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"strings"
|
||||
"unicode"
|
||||
|
||||
"github.com/Tencent/WeKnora/internal/logger"
|
||||
"github.com/Tencent/WeKnora/internal/models/chat"
|
||||
"github.com/Tencent/WeKnora/internal/types/interfaces"
|
||||
)
|
||||
|
||||
// estimateStringTokens provides a more accurate token estimation by distinguishing
|
||||
// CJK characters (≈1.5 tokens each) from Latin characters (≈4 chars per token).
|
||||
// This is significantly more accurate than the naive totalChars/4 approach,
|
||||
// especially for mixed Chinese-English content common in WeKnora.
|
||||
func estimateStringTokens(s string) int {
|
||||
cjkChars := 0
|
||||
otherChars := 0
|
||||
for _, r := range s {
|
||||
if unicode.Is(unicode.Han, r) || unicode.Is(unicode.Hangul, r) || unicode.Is(unicode.Katakana, r) || unicode.Is(unicode.Hiragana, r) {
|
||||
cjkChars++
|
||||
} else {
|
||||
otherChars++
|
||||
}
|
||||
}
|
||||
// CJK characters average ~1.5 tokens each; Latin ~0.25 tokens per char
|
||||
return (cjkChars*3 + otherChars) / 2
|
||||
}
|
||||
|
||||
// estimateMessageTokens estimates the token count for a list of messages.
|
||||
// Accounts for per-message overhead (role markers, special tokens) and tool call metadata.
|
||||
func estimateMessageTokens(messages []chat.Message) int {
|
||||
totalTokens := 0
|
||||
for _, msg := range messages {
|
||||
totalTokens += estimateStringTokens(msg.Role) + estimateStringTokens(msg.Content)
|
||||
// Per-message overhead: role markers, delimiters, special tokens
|
||||
totalTokens += 4
|
||||
if len(msg.ToolCalls) > 0 {
|
||||
for _, tc := range msg.ToolCalls {
|
||||
totalTokens += estimateStringTokens(tc.Function.Name) + estimateStringTokens(tc.Function.Arguments)
|
||||
// Tool call overhead: function call structure tokens
|
||||
totalTokens += 8
|
||||
}
|
||||
}
|
||||
}
|
||||
return totalTokens
|
||||
}
|
||||
|
||||
// slidingWindowStrategy implements CompressionStrategy using sliding window
|
||||
type slidingWindowStrategy struct {
|
||||
recentMessageCount int
|
||||
}
|
||||
|
||||
// NewSlidingWindowStrategy creates a new sliding window compression strategy
|
||||
func NewSlidingWindowStrategy(recentMessageCount int) interfaces.CompressionStrategy {
|
||||
return &slidingWindowStrategy{
|
||||
recentMessageCount: recentMessageCount,
|
||||
}
|
||||
}
|
||||
|
||||
// Compress implements the sliding window compression
|
||||
// Keeps system messages and the most recent N messages
|
||||
func (s *slidingWindowStrategy) Compress(
|
||||
ctx context.Context,
|
||||
messages []chat.Message,
|
||||
maxTokens int,
|
||||
) ([]chat.Message, error) {
|
||||
if len(messages) <= s.recentMessageCount {
|
||||
return messages, nil
|
||||
}
|
||||
|
||||
// Separate system messages from regular messages
|
||||
var systemMessages []chat.Message
|
||||
var regularMessages []chat.Message
|
||||
|
||||
for _, msg := range messages {
|
||||
if msg.Role == "system" {
|
||||
systemMessages = append(systemMessages, msg)
|
||||
} else {
|
||||
regularMessages = append(regularMessages, msg)
|
||||
}
|
||||
}
|
||||
|
||||
// Keep the most recent N regular messages
|
||||
var keptMessages []chat.Message
|
||||
if len(regularMessages) > s.recentMessageCount {
|
||||
keptMessages = regularMessages[len(regularMessages)-s.recentMessageCount:]
|
||||
} else {
|
||||
keptMessages = regularMessages
|
||||
}
|
||||
|
||||
// Combine: system messages first, then recent messages
|
||||
result := make([]chat.Message, 0, len(systemMessages)+len(keptMessages))
|
||||
result = append(result, systemMessages...)
|
||||
result = append(result, keptMessages...)
|
||||
|
||||
logger.Infof(ctx, "[SlidingWindow] Compressed %d messages to %d messages (kept %d recent + %d system)",
|
||||
len(messages), len(result), len(keptMessages), len(systemMessages))
|
||||
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// EstimateTokens estimates token count using CJK-aware heuristics.
|
||||
func (s *slidingWindowStrategy) EstimateTokens(messages []chat.Message) int {
|
||||
return estimateMessageTokens(messages)
|
||||
}
|
||||
|
||||
// smartCompressionStrategy implements CompressionStrategy using LLM summarization
|
||||
type smartCompressionStrategy struct {
|
||||
recentMessageCount int
|
||||
chatModel chat.Chat
|
||||
summarizeThreshold int // Minimum messages before summarization
|
||||
}
|
||||
|
||||
// NewSmartCompressionStrategy creates a new smart compression strategy
|
||||
func NewSmartCompressionStrategy(
|
||||
recentMessageCount int,
|
||||
chatModel chat.Chat,
|
||||
summarizeThreshold int,
|
||||
) interfaces.CompressionStrategy {
|
||||
return &smartCompressionStrategy{
|
||||
recentMessageCount: recentMessageCount,
|
||||
chatModel: chatModel,
|
||||
summarizeThreshold: summarizeThreshold,
|
||||
}
|
||||
}
|
||||
|
||||
// Compress implements smart compression with LLM summarization
|
||||
// Summarizes old messages and keeps recent messages intact
|
||||
func (s *smartCompressionStrategy) Compress(
|
||||
ctx context.Context,
|
||||
messages []chat.Message,
|
||||
maxTokens int,
|
||||
) ([]chat.Message, error) {
|
||||
if len(messages) <= s.recentMessageCount {
|
||||
return messages, nil
|
||||
}
|
||||
|
||||
// Separate system messages, old messages, and recent messages
|
||||
var systemMessages []chat.Message
|
||||
var oldMessages []chat.Message
|
||||
var recentMessages []chat.Message
|
||||
|
||||
systemCount := 0
|
||||
for _, msg := range messages {
|
||||
if msg.Role == "system" {
|
||||
systemMessages = append(systemMessages, msg)
|
||||
systemCount++
|
||||
}
|
||||
}
|
||||
|
||||
// Get regular messages (non-system)
|
||||
regularMessages := make([]chat.Message, 0, len(messages)-systemCount)
|
||||
for _, msg := range messages {
|
||||
if msg.Role != "system" {
|
||||
regularMessages = append(regularMessages, msg)
|
||||
}
|
||||
}
|
||||
|
||||
// Split regular messages into old and recent
|
||||
if len(regularMessages) > s.recentMessageCount {
|
||||
splitPoint := len(regularMessages) - s.recentMessageCount
|
||||
oldMessages = regularMessages[:splitPoint]
|
||||
recentMessages = regularMessages[splitPoint:]
|
||||
} else {
|
||||
recentMessages = regularMessages
|
||||
}
|
||||
|
||||
// If old messages are few, no need to summarize
|
||||
if len(oldMessages) < s.summarizeThreshold {
|
||||
result := make([]chat.Message, 0, len(systemMessages)+len(regularMessages))
|
||||
result = append(result, systemMessages...)
|
||||
result = append(result, regularMessages...)
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// Summarize old messages using LLM
|
||||
summary, err := s.summarizeMessages(ctx, oldMessages)
|
||||
if err != nil {
|
||||
logger.Warnf(ctx, "[SmartCompression] Failed to summarize messages: %v, falling back to sliding window", err)
|
||||
// Fallback: use sliding window strategy to at least reduce message count
|
||||
fallback := &slidingWindowStrategy{recentMessageCount: s.recentMessageCount}
|
||||
return fallback.Compress(ctx, messages, maxTokens)
|
||||
}
|
||||
|
||||
// Construct final message list: system + summary + recent
|
||||
result := make([]chat.Message, 0, len(systemMessages)+1+len(recentMessages))
|
||||
result = append(result, systemMessages...)
|
||||
result = append(result, chat.Message{
|
||||
Role: "system",
|
||||
Content: fmt.Sprintf("Previous conversation summary:\n%s", summary),
|
||||
})
|
||||
result = append(result, recentMessages...)
|
||||
|
||||
logger.Infof(
|
||||
ctx,
|
||||
"[SmartCompression] Compressed %d messages to %d messages (summarized %d old + kept %d recent + %d system)",
|
||||
len(messages),
|
||||
len(result),
|
||||
len(oldMessages),
|
||||
len(recentMessages),
|
||||
len(systemMessages),
|
||||
)
|
||||
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// summarizeMessages uses LLM to create a summary of old messages
|
||||
func (s *smartCompressionStrategy) summarizeMessages(ctx context.Context, messages []chat.Message) (string, error) {
|
||||
// Build conversation text
|
||||
var sb strings.Builder
|
||||
for i, msg := range messages {
|
||||
fmt.Fprintf(&sb, "[%s] %s\n", msg.Role, msg.Content)
|
||||
if i < len(messages)-1 {
|
||||
sb.WriteString("\n")
|
||||
}
|
||||
}
|
||||
|
||||
// Create summarization prompt
|
||||
summaryPrompt := []chat.Message{
|
||||
{
|
||||
Role: "system",
|
||||
Content: "You are a helpful assistant that summarizes conversations. " +
|
||||
"Provide a concise summary that captures the key points, decisions, and context. " +
|
||||
"Keep the summary brief but informative.",
|
||||
},
|
||||
{
|
||||
Role: "user",
|
||||
Content: fmt.Sprintf("Please summarize the following conversation:\n\n%s", sb.String()),
|
||||
},
|
||||
}
|
||||
|
||||
// Call LLM for summarization
|
||||
response, err := s.chatModel.Chat(ctx, summaryPrompt, &chat.ChatOptions{
|
||||
Temperature: 0.3, // Lower temperature for more consistent summaries
|
||||
MaxTokens: 500, // Limit summary length
|
||||
})
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to generate summary: %w", err)
|
||||
}
|
||||
|
||||
if response == nil || response.Content == "" {
|
||||
return "", fmt.Errorf("no summary generated")
|
||||
}
|
||||
|
||||
summary := response.Content
|
||||
logger.Debugf(ctx, "[SmartCompression] Generated summary (%d chars) from %d messages",
|
||||
len(summary), len(messages))
|
||||
|
||||
return summary, nil
|
||||
}
|
||||
|
||||
// EstimateTokens estimates token count using CJK-aware heuristics.
|
||||
func (s *smartCompressionStrategy) EstimateTokens(messages []chat.Message) int {
|
||||
return estimateMessageTokens(messages)
|
||||
}
|
||||
@@ -3,205 +3,215 @@ package llmcontext
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"regexp"
|
||||
"sort"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/Tencent/WeKnora/internal/logger"
|
||||
"github.com/Tencent/WeKnora/internal/models/chat"
|
||||
"github.com/Tencent/WeKnora/internal/types"
|
||||
"github.com/Tencent/WeKnora/internal/types/interfaces"
|
||||
)
|
||||
|
||||
// contextManager implements the ContextManager interface
|
||||
// It handles business logic (compression, token management) and delegates storage to ContextStorage
|
||||
// dbFallbackFetchCount is the number of raw DB messages to fetch when
|
||||
// rebuilding context from persistent storage. This should be generous
|
||||
// because user+assistant messages are paired by RequestID and some
|
||||
// incomplete pairs are discarded.
|
||||
const dbFallbackFetchCount = 200
|
||||
|
||||
var regThinkTags = regexp.MustCompile(`(?s)<think>.*?</think>`)
|
||||
|
||||
// contextManager implements the ContextManager interface.
|
||||
// It is a cache-backed storage layer: messages are persisted per session in
|
||||
// a fast store (Redis / memory). When the cache is empty (e.g. TTL expired),
|
||||
// it falls back to the persistent messages table via MessageService to
|
||||
// rebuild context.
|
||||
//
|
||||
// All LLM-aware compression (summarisation, tool-boundary-aware truncation)
|
||||
// is handled by the Agent Engine's Consolidator before messages are sent to
|
||||
// the model.
|
||||
type contextManager struct {
|
||||
storage ContextStorage // Storage backend (Redis, Memory, etc.)
|
||||
compressionStrategy interfaces.CompressionStrategy // Compression strategy
|
||||
maxTokens int // Maximum tokens allowed in context
|
||||
storage ContextStorage
|
||||
messageRepo interfaces.MessageRepository // optional; enables DB fallback
|
||||
}
|
||||
|
||||
// NewContextManager creates a new context manager with the specified storage and compression strategy
|
||||
func NewContextManager(
|
||||
storage ContextStorage,
|
||||
compressionStrategy interfaces.CompressionStrategy,
|
||||
maxTokens int,
|
||||
) interfaces.ContextManager {
|
||||
// NewContextManager creates a context manager.
|
||||
// messageRepo is optional — when provided, GetContext will reconstruct
|
||||
// history from the DB if the cache is empty.
|
||||
func NewContextManager(storage ContextStorage, messageRepo interfaces.MessageRepository) interfaces.ContextManager {
|
||||
return &contextManager{
|
||||
storage: storage,
|
||||
compressionStrategy: compressionStrategy,
|
||||
maxTokens: maxTokens,
|
||||
storage: storage,
|
||||
messageRepo: messageRepo,
|
||||
}
|
||||
}
|
||||
|
||||
// NewContextManagerWithMemory creates a context manager with in-memory storage (for backward compatibility)
|
||||
func NewContextManagerWithMemory(
|
||||
compressionStrategy interfaces.CompressionStrategy,
|
||||
maxTokens int,
|
||||
) interfaces.ContextManager {
|
||||
return &contextManager{
|
||||
storage: NewMemoryStorage(),
|
||||
compressionStrategy: compressionStrategy,
|
||||
maxTokens: maxTokens,
|
||||
}
|
||||
}
|
||||
|
||||
// AddMessage adds a message to the session context
|
||||
// This method handles the business logic: loading, appending, compression, and saving
|
||||
// AddMessage appends a message to the session context and persists it.
|
||||
func (cm *contextManager) AddMessage(ctx context.Context, sessionID string, message chat.Message) error {
|
||||
logger.Infof(ctx, "[ContextManager][Session-%s] Adding message: role=%s, content_length=%d",
|
||||
sessionID, message.Role, len(message.Content))
|
||||
|
||||
// Log message content preview
|
||||
contentPreview := message.Content
|
||||
if len(contentPreview) > 200 {
|
||||
contentPreview = contentPreview[:200] + "..."
|
||||
}
|
||||
logger.Debugf(ctx, "[ContextManager][Session-%s] Message content preview: %s", sessionID, contentPreview)
|
||||
|
||||
// Load existing messages from storage
|
||||
messages, err := cm.storage.Load(ctx, sessionID)
|
||||
if err != nil {
|
||||
logger.Errorf(ctx, "[ContextManager][Session-%s] Failed to load context: %v", sessionID, err)
|
||||
return fmt.Errorf("failed to load context: %w", err)
|
||||
}
|
||||
|
||||
// Add new message
|
||||
beforeCount := len(messages)
|
||||
messages = append(messages, message)
|
||||
logger.Debugf(ctx, "[ContextManager][Session-%s] Messages count: %d -> %d", sessionID, beforeCount, len(messages))
|
||||
|
||||
// Check if compression is needed
|
||||
tokenCount := cm.compressionStrategy.EstimateTokens(messages)
|
||||
logger.Debugf(ctx, "[ContextManager][Session-%s] Current token count: %d (max: %d)",
|
||||
sessionID, tokenCount, cm.maxTokens)
|
||||
|
||||
if tokenCount > cm.maxTokens {
|
||||
logger.Infof(ctx, "[ContextManager][Session-%s] Context exceeds max tokens (%d > %d), applying compression",
|
||||
sessionID, tokenCount, cm.maxTokens)
|
||||
beforeCompressionCount := len(messages)
|
||||
compressed, err := cm.compressionStrategy.Compress(ctx, messages, cm.maxTokens)
|
||||
if err != nil {
|
||||
logger.Errorf(ctx, "[ContextManager][Session-%s] Failed to compress context: %v", sessionID, err)
|
||||
return fmt.Errorf("failed to compress context: %w", err)
|
||||
}
|
||||
messages = compressed
|
||||
afterTokenCount := cm.compressionStrategy.EstimateTokens(messages)
|
||||
logger.Infof(ctx, "[ContextManager][Session-%s] Context compressed: %d -> %d messages, %d -> %d tokens",
|
||||
sessionID, beforeCompressionCount, len(compressed), tokenCount, afterTokenCount)
|
||||
}
|
||||
|
||||
// Save updated messages to storage
|
||||
if err := cm.storage.Save(ctx, sessionID, messages); err != nil {
|
||||
logger.Errorf(ctx, "[ContextManager][Session-%s] Failed to save context: %v", sessionID, err)
|
||||
return fmt.Errorf("failed to save context: %w", err)
|
||||
}
|
||||
|
||||
logger.Infof(
|
||||
ctx,
|
||||
"[ContextManager][Session-%s] Successfully added message (total: %d messages)",
|
||||
sessionID,
|
||||
len(messages),
|
||||
)
|
||||
logger.Debugf(ctx, "[ContextManager][Session-%s] Message saved (total: %d)", sessionID, len(messages))
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetContext retrieves the current context for a session from storage
|
||||
// GetContext retrieves the stored context for a session.
|
||||
// If the cache is empty and a MessageService is available, it rebuilds
|
||||
// the context from the persistent messages table and warms the cache.
|
||||
func (cm *contextManager) GetContext(ctx context.Context, sessionID string) ([]chat.Message, error) {
|
||||
logger.Infof(ctx, "[ContextManager][Session-%s] Getting context", sessionID)
|
||||
|
||||
// Load messages from storage
|
||||
messages, err := cm.storage.Load(ctx, sessionID)
|
||||
if err != nil {
|
||||
logger.Errorf(ctx, "[ContextManager][Session-%s] Failed to load context: %v", sessionID, err)
|
||||
return nil, fmt.Errorf("failed to load context: %w", err)
|
||||
}
|
||||
|
||||
// Calculate token estimate
|
||||
tokenCount := cm.compressionStrategy.EstimateTokens(messages)
|
||||
|
||||
logger.Infof(ctx, "[ContextManager][Session-%s] Retrieved %d messages (~%d tokens)",
|
||||
sessionID, len(messages), tokenCount)
|
||||
|
||||
// Log message role distribution
|
||||
roleCount := make(map[string]int)
|
||||
for _, msg := range messages {
|
||||
roleCount[msg.Role]++
|
||||
if len(messages) > 0 {
|
||||
logger.Debugf(ctx, "[ContextManager][Session-%s] Cache hit: %d messages", sessionID, len(messages))
|
||||
return messages, nil
|
||||
}
|
||||
logger.Debugf(ctx, "[ContextManager][Session-%s] Message distribution: %v", sessionID, roleCount)
|
||||
|
||||
return messages, nil
|
||||
if cm.messageRepo == nil {
|
||||
return messages, nil
|
||||
}
|
||||
|
||||
// Cache miss — rebuild from DB
|
||||
rebuilt, err := cm.rebuildFromDB(ctx, sessionID)
|
||||
if err != nil {
|
||||
logger.Warnf(ctx, "[ContextManager][Session-%s] Failed to rebuild context from DB: %v", sessionID, err)
|
||||
return []chat.Message{}, nil
|
||||
}
|
||||
|
||||
if len(rebuilt) > 0 {
|
||||
if saveErr := cm.storage.Save(ctx, sessionID, rebuilt); saveErr != nil {
|
||||
logger.Warnf(ctx, "[ContextManager][Session-%s] Failed to warm cache: %v", sessionID, saveErr)
|
||||
}
|
||||
logger.Infof(ctx, "[ContextManager][Session-%s] Rebuilt %d messages from DB", sessionID, len(rebuilt))
|
||||
}
|
||||
|
||||
return rebuilt, nil
|
||||
}
|
||||
|
||||
// ClearContext clears all context for a session from storage
|
||||
func (cm *contextManager) ClearContext(ctx context.Context, sessionID string) error {
|
||||
logger.Infof(ctx, "[ContextManager][Session-%s] Clearing context", sessionID)
|
||||
// rebuildFromDB loads recent messages from the persistent messages table
|
||||
// and converts them into chat.Message pairs (user + assistant).
|
||||
func (cm *contextManager) rebuildFromDB(ctx context.Context, sessionID string) ([]chat.Message, error) {
|
||||
dbMessages, err := cm.messageRepo.GetRecentMessagesBySession(ctx, sessionID, dbFallbackFetchCount)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to load messages from DB: %w", err)
|
||||
}
|
||||
if len(dbMessages) == 0 {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
// Delete from storage
|
||||
// Group by RequestID into Q&A pairs, same logic as chat_pipeline/common.go
|
||||
type pair struct {
|
||||
query string
|
||||
answer string
|
||||
createdAt time.Time
|
||||
}
|
||||
pairMap := make(map[string]*pair)
|
||||
for _, msg := range dbMessages {
|
||||
p, ok := pairMap[msg.RequestID]
|
||||
if !ok {
|
||||
p = &pair{}
|
||||
pairMap[msg.RequestID] = p
|
||||
}
|
||||
switch msg.Role {
|
||||
case "user":
|
||||
p.query = msg.Content
|
||||
p.createdAt = msg.CreatedAt
|
||||
if desc := extractImageCaptions(msg.Images); desc != "" {
|
||||
p.query += "\n\n[用户上传图片内容]\n" + desc
|
||||
}
|
||||
case "assistant":
|
||||
p.answer = regThinkTags.ReplaceAllString(msg.Content, "")
|
||||
}
|
||||
}
|
||||
|
||||
pairs := make([]*pair, 0, len(pairMap))
|
||||
for _, p := range pairMap {
|
||||
if p.query != "" && p.answer != "" {
|
||||
pairs = append(pairs, p)
|
||||
}
|
||||
}
|
||||
|
||||
sort.Slice(pairs, func(i, j int) bool {
|
||||
return pairs[i].createdAt.Before(pairs[j].createdAt)
|
||||
})
|
||||
|
||||
result := make([]chat.Message, 0, len(pairs)*2)
|
||||
for _, p := range pairs {
|
||||
result = append(result,
|
||||
chat.Message{Role: "user", Content: p.query},
|
||||
chat.Message{Role: "assistant", Content: p.answer},
|
||||
)
|
||||
}
|
||||
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// extractImageCaptions concatenates non-empty Caption fields from message
|
||||
// images so that previous turns' image descriptions are included in context.
|
||||
func extractImageCaptions(images types.MessageImages) string {
|
||||
var parts []string
|
||||
for _, img := range images {
|
||||
if img.Caption != "" {
|
||||
parts = append(parts, img.Caption)
|
||||
}
|
||||
}
|
||||
return strings.Join(parts, "\n")
|
||||
}
|
||||
|
||||
// ClearContext removes all context for a session.
|
||||
func (cm *contextManager) ClearContext(ctx context.Context, sessionID string) error {
|
||||
if err := cm.storage.Delete(ctx, sessionID); err != nil {
|
||||
logger.Errorf(ctx, "[ContextManager][Session-%s] Failed to clear context: %v", sessionID, err)
|
||||
return fmt.Errorf("failed to clear context: %w", err)
|
||||
}
|
||||
|
||||
logger.Infof(ctx, "[ContextManager][Session-%s] Context cleared successfully", sessionID)
|
||||
logger.Infof(ctx, "[ContextManager][Session-%s] Context cleared", sessionID)
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetContextStats returns statistics about the context
|
||||
// GetContextStats returns statistics about the stored context.
|
||||
func (cm *contextManager) GetContextStats(ctx context.Context, sessionID string) (*interfaces.ContextStats, error) {
|
||||
// Load messages from storage
|
||||
messages, err := cm.storage.Load(ctx, sessionID)
|
||||
if err != nil {
|
||||
logger.Errorf(ctx, "[ContextManager][Session-%s] Failed to load context for stats: %v", sessionID, err)
|
||||
return nil, fmt.Errorf("failed to load context: %w", err)
|
||||
}
|
||||
|
||||
tokenCount := cm.compressionStrategy.EstimateTokens(messages)
|
||||
|
||||
stats := &interfaces.ContextStats{
|
||||
return &interfaces.ContextStats{
|
||||
MessageCount: len(messages),
|
||||
TokenCount: tokenCount,
|
||||
IsCompressed: false, // We'd need to track this explicitly for accurate reporting
|
||||
OriginalMessageCount: len(messages),
|
||||
}
|
||||
|
||||
logger.Debugf(ctx, "[ContextManager][Session-%s] Context stats: %d messages, ~%d tokens",
|
||||
sessionID, stats.MessageCount, stats.TokenCount)
|
||||
|
||||
return stats, nil
|
||||
}, nil
|
||||
}
|
||||
|
||||
// SetSystemPrompt sets or updates the system prompt for a session
|
||||
// If a system message exists, it will be replaced; otherwise, a new one will be added at the beginning
|
||||
// SetSystemPrompt sets or updates the system prompt for a session.
|
||||
func (cm *contextManager) SetSystemPrompt(ctx context.Context, sessionID string, systemPrompt string) error {
|
||||
logger.Infof(ctx, "[ContextManager][Session-%s] Setting system prompt, length=%d", sessionID, len(systemPrompt))
|
||||
|
||||
// Load existing messages from storage
|
||||
messages, err := cm.storage.Load(ctx, sessionID)
|
||||
if err != nil {
|
||||
logger.Errorf(ctx, "[ContextManager][Session-%s] Failed to load context: %v", sessionID, err)
|
||||
return fmt.Errorf("failed to load context: %w", err)
|
||||
}
|
||||
|
||||
// Create new system message
|
||||
systemMessage := chat.Message{
|
||||
Role: "system",
|
||||
Content: systemPrompt,
|
||||
}
|
||||
|
||||
// Check if first message is a system message
|
||||
if len(messages) > 0 && messages[0].Role == "system" {
|
||||
// Replace existing system message
|
||||
logger.Debugf(ctx, "[ContextManager][Session-%s] Replacing existing system prompt", sessionID)
|
||||
messages[0] = systemMessage
|
||||
} else {
|
||||
// Insert system message at the beginning
|
||||
logger.Debugf(ctx, "[ContextManager][Session-%s] Inserting new system prompt at beginning", sessionID)
|
||||
messages = append([]chat.Message{systemMessage}, messages...)
|
||||
}
|
||||
|
||||
// Save updated messages to storage
|
||||
if err := cm.storage.Save(ctx, sessionID, messages); err != nil {
|
||||
logger.Errorf(ctx, "[ContextManager][Session-%s] Failed to save context: %v", sessionID, err)
|
||||
return fmt.Errorf("failed to save context: %w", err)
|
||||
}
|
||||
|
||||
logger.Infof(ctx, "[ContextManager][Session-%s] System prompt set successfully", sessionID)
|
||||
logger.Debugf(ctx, "[ContextManager][Session-%s] System prompt set (length=%d)", sessionID, len(systemPrompt))
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -1,78 +1,18 @@
|
||||
package llmcontext
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/Tencent/WeKnora/internal/logger"
|
||||
"github.com/Tencent/WeKnora/internal/models/chat"
|
||||
"github.com/Tencent/WeKnora/internal/types"
|
||||
"github.com/Tencent/WeKnora/internal/types/interfaces"
|
||||
)
|
||||
|
||||
const (
|
||||
// Context manager types
|
||||
ContextManagerTypeMemory = "memory"
|
||||
ContextManagerTypeRedis = "redis"
|
||||
|
||||
// Default values
|
||||
DefaultMaxTokens = 128 * 1024 // 128K tokens
|
||||
DefaultRecentMessageCount = 20
|
||||
DefaultSummarizeThreshold = 5
|
||||
DefaultCompressionStrategy = "sliding_window"
|
||||
)
|
||||
|
||||
// NewContextManagerFromConfig creates a ContextManager based on configuration
|
||||
// NewContextManagerFromConfig creates a ContextManager.
|
||||
// messageRepo is optional — when provided, context will be rebuilt from
|
||||
// the persistent messages table if the cache (Redis/memory) is empty.
|
||||
func NewContextManagerFromConfig(
|
||||
contextCfg *types.ContextConfig,
|
||||
storage ContextStorage,
|
||||
chatModel chat.Chat,
|
||||
messageRepo interfaces.MessageRepository,
|
||||
) interfaces.ContextManager {
|
||||
// Use default values if config is nil
|
||||
if contextCfg == nil {
|
||||
logger.Info(context.TODO(), "ContextManager config not found, using default memory-based context manager")
|
||||
strategy := NewSlidingWindowStrategy(DefaultRecentMessageCount)
|
||||
storage := NewMemoryStorage()
|
||||
return NewContextManager(storage, strategy, DefaultMaxTokens)
|
||||
if storage == nil {
|
||||
storage = NewMemoryStorage()
|
||||
}
|
||||
|
||||
// Set default values if not specified
|
||||
maxTokens := contextCfg.MaxTokens
|
||||
if maxTokens == 0 {
|
||||
maxTokens = DefaultMaxTokens
|
||||
}
|
||||
|
||||
recentMessageCount := contextCfg.RecentMessageCount
|
||||
if recentMessageCount == 0 {
|
||||
recentMessageCount = DefaultRecentMessageCount
|
||||
}
|
||||
|
||||
summarizeThreshold := contextCfg.SummarizeThreshold
|
||||
if summarizeThreshold == 0 {
|
||||
summarizeThreshold = DefaultSummarizeThreshold
|
||||
}
|
||||
|
||||
compressionStrategy := contextCfg.CompressionStrategy
|
||||
if compressionStrategy == "" {
|
||||
compressionStrategy = DefaultCompressionStrategy
|
||||
}
|
||||
|
||||
// Create compression strategy
|
||||
var strategy interfaces.CompressionStrategy
|
||||
switch compressionStrategy {
|
||||
case "sliding_window":
|
||||
strategy = NewSlidingWindowStrategy(recentMessageCount)
|
||||
case "smart":
|
||||
if chatModel != nil {
|
||||
strategy = NewSmartCompressionStrategy(recentMessageCount, chatModel, summarizeThreshold)
|
||||
} else {
|
||||
logger.Warn(context.TODO(), "Smart compression requested but no chat model provided, falling back to sliding window")
|
||||
strategy = NewSlidingWindowStrategy(recentMessageCount)
|
||||
}
|
||||
default:
|
||||
logger.Warnf(context.TODO(), "Unknown compression strategy '%s', using sliding window", compressionStrategy)
|
||||
strategy = NewSlidingWindowStrategy(recentMessageCount)
|
||||
}
|
||||
|
||||
// Create context manager with storage and strategy
|
||||
return NewContextManager(storage, strategy, maxTokens)
|
||||
return NewContextManager(storage, messageRepo)
|
||||
}
|
||||
|
||||
@@ -111,7 +111,7 @@ func (s *sessionService) AgentQA(
|
||||
}
|
||||
|
||||
// Get or create contextManager for this session
|
||||
contextManager := s.getContextManagerForSession(ctx, req.Session, summaryModel)
|
||||
contextManager := s.getContextManagerForSession()
|
||||
|
||||
// Set system prompt for the current agent in context manager
|
||||
// This ensures the context uses the correct system prompt when switching agents
|
||||
@@ -265,6 +265,10 @@ func (s *sessionService) buildAgentConfig(
|
||||
agentConfig.SearchTargets = searchTargets
|
||||
logger.Infof(ctx, "Agent search targets built: %d targets", len(searchTargets))
|
||||
|
||||
if agentConfig.MaxContextTokens <= 0 {
|
||||
agentConfig.MaxContextTokens = types.DefaultMaxContextTokens
|
||||
}
|
||||
|
||||
return agentConfig, nil
|
||||
}
|
||||
|
||||
@@ -322,32 +326,9 @@ func (s *sessionService) configureSkillsFromAgent(
|
||||
|
||||
}
|
||||
|
||||
// getContextManagerForSession creates a context manager for the session based on configuration
|
||||
// Returns the configured context manager (tenant-level or session-level) or default
|
||||
func (s *sessionService) getContextManagerForSession(
|
||||
ctx context.Context,
|
||||
session *types.Session,
|
||||
chatModel chat.Chat,
|
||||
) interfaces.ContextManager {
|
||||
// Get tenant to access global context configuration
|
||||
tenant, _ := types.TenantInfoFromContext(ctx)
|
||||
// Determine which context config to use: tenant-level or default
|
||||
var contextConfig *types.ContextConfig
|
||||
if tenant != nil && tenant.ContextConfig != nil {
|
||||
// Use tenant-level configuration
|
||||
contextConfig = tenant.ContextConfig
|
||||
logger.Infof(ctx, "Using tenant-level context config for session %s", session.ID)
|
||||
} else {
|
||||
// Use service's default context manager
|
||||
logger.Debugf(ctx, "Using default context manager for session %s", session.ID)
|
||||
contextConfig = &types.ContextConfig{
|
||||
MaxTokens: llmcontext.DefaultMaxTokens,
|
||||
CompressionStrategy: llmcontext.DefaultCompressionStrategy,
|
||||
RecentMessageCount: llmcontext.DefaultRecentMessageCount,
|
||||
SummarizeThreshold: llmcontext.DefaultSummarizeThreshold,
|
||||
}
|
||||
}
|
||||
return llmcontext.NewContextManagerFromConfig(contextConfig, s.sessionStorage, chatModel)
|
||||
// getContextManagerForSession creates a context manager for the session.
|
||||
func (s *sessionService) getContextManagerForSession() interfaces.ContextManager {
|
||||
return llmcontext.NewContextManagerFromConfig(s.sessionStorage, s.messageRepo)
|
||||
}
|
||||
|
||||
// getContextForSession retrieves LLM context for a session
|
||||
|
||||
@@ -7,6 +7,9 @@ import (
|
||||
"time"
|
||||
)
|
||||
|
||||
// DefaultMaxContextTokens is the default context window budget for agent conversations (200k).
|
||||
const DefaultMaxContextTokens = 200000
|
||||
|
||||
// AgentConfig represents the full agent configuration (used at tenant level and runtime)
|
||||
// This includes all configuration parameters for agent execution
|
||||
type AgentConfig struct {
|
||||
@@ -47,8 +50,8 @@ type AgentConfig struct {
|
||||
// Outputs exceeding this limit are truncated with head + tail preservation.
|
||||
MaxToolOutputChars int `json:"max_tool_output_chars,omitempty"`
|
||||
|
||||
// Maximum context window tokens for the agent (default: 0 = disabled).
|
||||
// When set, the agent compresses older messages to stay within this limit,
|
||||
// Maximum context window tokens for the agent (default: 200000).
|
||||
// The agent compresses older messages to stay within this limit,
|
||||
// preserving tool_call/tool_result pairs.
|
||||
MaxContextTokens int `json:"max_context_tokens,omitempty"`
|
||||
}
|
||||
|
||||
@@ -42,12 +42,3 @@ type ContextStats struct {
|
||||
OriginalMessageCount int `json:"original_message_count"`
|
||||
}
|
||||
|
||||
// CompressionStrategy defines how context should be compressed
|
||||
type CompressionStrategy interface {
|
||||
// Compress compresses messages when context exceeds limits
|
||||
// Returns compressed messages that fit within the limit
|
||||
Compress(ctx context.Context, messages []chat.Message, maxTokens int) ([]chat.Message, error)
|
||||
|
||||
// EstimateTokens estimates token count for messages
|
||||
EstimateTokens(messages []chat.Message) int
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user