diff --git a/cli/cmd/agent/invoke.go b/cli/cmd/agent/invoke.go index 692b4730..b1fceda0 100644 --- a/cli/cmd/agent/invoke.go +++ b/cli/cmd/agent/invoke.go @@ -13,14 +13,15 @@ import ( "github.com/Tencent/WeKnora/cli/internal/cmdutil" "github.com/Tencent/WeKnora/cli/internal/format" "github.com/Tencent/WeKnora/cli/internal/iostreams" + "github.com/Tencent/WeKnora/cli/internal/sse" sdk "github.com/Tencent/WeKnora/client" ) // agentInvokeFields enumerates fields surfaced for `--json` discovery on // `agent invoke`. Matches invokeData below — single-shot result envelope -// with the agent's final answer plus the trace (references, tool calls). +// with the agent's final answer plus the trace (references, tool events). var agentInvokeFields = []string{ - "answer", "references", "tool_calls", "thinking", + "answer", "references", "tool_events", "thinking", "session_id", "agent_id", "query", } @@ -45,25 +46,15 @@ type InvokeService interface { AgentQAStreamWithRequest(ctx context.Context, sessionID string, req *sdk.AgentQARequest, cb sdk.AgentEventCallback) error } -// toolCallTrace mirrors a single SSE `tool_call` event so agents can see -// which tools the WeKnora-side agent invoked (and their results). Only the -// fields the server actually emits are captured. -type toolCallTrace struct { - ID string `json:"id"` - Name string `json:"name,omitempty"` - Result string `json:"result,omitempty"` - Data map[string]any `json:"data,omitempty"` -} - // invokeData is the JSON envelope payload. type invokeData struct { - Answer string `json:"answer"` - References []*sdk.SearchResult `json:"references"` - ToolCalls []toolCallTrace `json:"tool_calls,omitempty"` - Thinking string `json:"thinking,omitempty"` - SessionID string `json:"session_id"` - AgentID string `json:"agent_id"` - Query string `json:"query"` + Answer string `json:"answer"` + References []*sdk.SearchResult `json:"references"` + ToolEvents []sse.AgentToolEvent `json:"tool_events,omitempty"` + Thinking string `json:"thinking,omitempty"` + SessionID string `json:"session_id"` + AgentID string `json:"agent_id"` + Query string `json:"query"` } // NewCmdInvoke builds `weknora agent invoke ""`. @@ -152,12 +143,12 @@ func runInvoke(ctx context.Context, opts *InvokeOptions, jopts *cmdutil.JSONOpti Channel: "api", } - acc := newAgentAccumulator() + acc := &sse.AgentAccumulator{} cb := func(r *sdk.AgentStreamResponse) error { if streamMode && r != nil && r.ResponseType == sdk.AgentResponseTypeAnswer && r.Content != "" { _, _ = iostreams.IO.Out.Write([]byte(r.Content)) } - acc.append(r) + acc.Append(r) return nil } @@ -169,7 +160,7 @@ func runInvoke(ctx context.Context, opts *InvokeOptions, jopts *cmdutil.JSONOpti if errors.Is(streamErr, context.Canceled) || errors.Is(ctx.Err(), context.Canceled) { return cmdutil.Wrapf(cmdutil.CodeUserAborted, streamErr, "agent invoke cancelled") } - if acc.answer.Len() > 0 && !acc.done { + if acc.Answer() != "" && !acc.Done() { return cmdutil.Wrapf(cmdutil.CodeSSEStreamAborted, streamErr, "stream aborted before completion") } return cmdutil.WrapHTTP(streamErr, "agent-chat stream") @@ -177,17 +168,17 @@ func runInvoke(ctx context.Context, opts *InvokeOptions, jopts *cmdutil.JSONOpti // Server closed cleanly but never sent a Done event — treat as aborted // so agents don't silently emit a truncated answer as ok=true. - if !acc.done { + if !acc.Done() { return cmdutil.NewError(cmdutil.CodeSSEStreamAborted, "stream ended without a terminal event") } - answer := acc.answer.String() + answer := acc.Answer() if jsonOut { data := invokeData{ Answer: answer, - References: acc.references, - ToolCalls: acc.toolCalls, - Thinking: acc.thinking.String(), + References: acc.References, + ToolEvents: acc.ToolEvents, + Thinking: acc.Thinking(), SessionID: sessionID, AgentID: opts.AgentID, Query: opts.Query, @@ -210,98 +201,24 @@ func runInvoke(ctx context.Context, opts *InvokeOptions, jopts *cmdutil.JSONOpti fmt.Fprintln(out) } } - renderToolTrace(out, acc.toolCalls) - renderReferences(out, acc.references) + renderToolTrace(out, acc.ToolEvents) + format.WriteReferences(out, acc.References) return nil } -// agentAccumulator buffers an AgentQAStream callback sequence. Distinct -// from internal/sse.Accumulator because the agent event model is wider -// (thinking / tool_call / tool_result / answer / references / reflection / -// error) and uses a flat r.Done bool instead of the -// ResponseType=complete sentinel that KnowledgeQAStream emits. -type agentAccumulator struct { - answer strings.Builder - thinking strings.Builder - references []*sdk.SearchResult - toolCalls []toolCallTrace - done bool -} - -func newAgentAccumulator() *agentAccumulator { return &agentAccumulator{} } - -func (a *agentAccumulator) append(r *sdk.AgentStreamResponse) { - if r == nil || a.done { - return - } - switch r.ResponseType { - case sdk.AgentResponseTypeAnswer: - if r.Content != "" { - a.answer.WriteString(r.Content) - } - case sdk.AgentResponseTypeThinking, sdk.AgentResponseTypeReflection: - if r.Content != "" { - a.thinking.WriteString(r.Content) - } - case sdk.AgentResponseTypeReferences: - if r.KnowledgeReferences != nil { - a.references = r.KnowledgeReferences - } - case sdk.AgentResponseTypeToolCall, sdk.AgentResponseTypeToolResult: - a.toolCalls = append(a.toolCalls, toolCallTrace{ - ID: r.ID, - Name: string(r.ResponseType), - Result: r.Content, - Data: r.Data, - }) - } - // References can also arrive on the terminal frame. - if r.KnowledgeReferences != nil && a.references == nil { - a.references = r.KnowledgeReferences - } - if r.Done { - a.done = true - } -} - -// renderToolTrace prints a compact tool-call footer in human mode. Skipped -// when the agent invoked no tools — silent beats an empty banner. -func renderToolTrace(w io.Writer, calls []toolCallTrace) { - if len(calls) == 0 { +// renderToolTrace prints a compact tool-event footer in human mode. +// Skipped when the agent emitted no tool events — silent beats an empty +// banner. +func renderToolTrace(w io.Writer, events []sse.AgentToolEvent) { + if len(events) == 0 { return } fmt.Fprintln(w) fmt.Fprintln(w, "──── Tool trace ────") - for i, c := range calls { - fmt.Fprintf(w, "[%d] %s", i+1, c.Name) - if c.Result != "" { - fmt.Fprintf(w, " %s", truncateInline(c.Result, 80)) - } - fmt.Fprintln(w) - } -} - -// renderReferences mirrors chat.go's references footer for parity. -func renderReferences(w io.Writer, refs []*sdk.SearchResult) { - if len(refs) == 0 { - return - } - fmt.Fprintln(w) - fmt.Fprintln(w, "──── References ────") - for i, r := range refs { - if r == nil { - continue - } - title := r.KnowledgeTitle - if title == "" { - title = r.KnowledgeFilename - } - if title == "" { - title = r.KnowledgeID - } - fmt.Fprintf(w, "[%d] %s", i+1, title) - if r.Score > 0 { - fmt.Fprintf(w, " score=%.3f", r.Score) + for i, e := range events { + fmt.Fprintf(w, "[%d] %s", i+1, e.Kind) + if e.Result != "" { + fmt.Fprintf(w, " %s", truncateInline(e.Result, 80)) } fmt.Fprintln(w) } diff --git a/cli/cmd/agent/invoke_test.go b/cli/cmd/agent/invoke_test.go index c6250b63..8391900f 100644 --- a/cli/cmd/agent/invoke_test.go +++ b/cli/cmd/agent/invoke_test.go @@ -150,7 +150,7 @@ func (c *createSessionTracker) CreateSession(ctx context.Context, req *sdk.Creat return c.InvokeService.CreateSession(ctx, req) } -func TestInvoke_ToolCallsCaptured(t *testing.T) { +func TestInvoke_ToolEventsCaptured(t *testing.T) { out, _ := iostreams.SetForTest(t) svc := &scriptedInvokeSvc{events: []*sdk.AgentStreamResponse{ toolCallEvent("call_1", "knowledge_search"), @@ -167,11 +167,11 @@ func TestInvoke_ToolCallsCaptured(t *testing.T) { if err := json.Unmarshal(out.Bytes(), &env); err != nil { t.Fatalf("parse: %v", err) } - if len(env.Data.ToolCalls) != 1 { - t.Fatalf("expected 1 tool call, got %d", len(env.Data.ToolCalls)) + if len(env.Data.ToolEvents) != 1 { + t.Fatalf("expected 1 tool call, got %d", len(env.Data.ToolEvents)) } - if env.Data.ToolCalls[0].ID != "call_1" { - t.Errorf("tool_calls[0].id = %q, want call_1", env.Data.ToolCalls[0].ID) + if env.Data.ToolEvents[0].ID != "call_1" { + t.Errorf("tool_calls[0].id = %q, want call_1", env.Data.ToolEvents[0].ID) } } diff --git a/cli/cmd/chat/chat.go b/cli/cmd/chat/chat.go index 4e336595..61690760 100644 --- a/cli/cmd/chat/chat.go +++ b/cli/cmd/chat/chat.go @@ -23,7 +23,6 @@ import ( "context" "errors" "fmt" - "io" "strings" "github.com/spf13/cobra" @@ -258,37 +257,9 @@ func runChat(ctx context.Context, opts *Options, jopts *cmdutil.JSONOptions, svc fmt.Fprintln(out) } } - renderReferences(out, references) + format.WriteReferences(out, references) return nil } -// renderReferences prints a compact human-readable references block. -// Skipped entirely when the server returned no references — agent-friendly -// silence beats an empty banner. -func renderReferences(w io.Writer, refs []*sdk.SearchResult) { - if len(refs) == 0 { - return - } - fmt.Fprintln(w) - fmt.Fprintln(w, "──── References ────") - for i, r := range refs { - if r == nil { - continue - } - title := r.KnowledgeTitle - if title == "" { - title = r.KnowledgeFilename - } - if title == "" { - title = r.KnowledgeID - } - fmt.Fprintf(w, "[%d] %s", i+1, title) - if r.Score > 0 { - fmt.Fprintf(w, " score=%.3f", r.Score) - } - fmt.Fprintln(w) - } -} - // compile-time check: the production SDK client implements chatService. var _ chatService = (*sdk.Client)(nil) diff --git a/cli/cmd/mcp/mcp.go b/cli/cmd/mcp/mcp.go new file mode 100644 index 00000000..4f67e2b7 --- /dev/null +++ b/cli/cmd/mcp/mcp.go @@ -0,0 +1,40 @@ +// Package mcpcmd holds the `weknora mcp` command tree. +// +// MCP (Model Context Protocol; https://spec.modelcontextprotocol.io/) is the +// JSON-RPC 2.0 wire protocol agentic IDEs (Claude Code, Cursor, Continue, +// Zed) and runtimes (Anthropic Reference, Stripe MCP) use to call external +// tools. `weknora mcp serve` exposes a curated subset of the CLI's read +// surface as MCP tools so an IDE-side agent can list / view / search / chat +// against the user's active WeKnora context without shelling out to the CLI +// per call. +// +// Package name is `mcpcmd` to avoid shadowing `cli/internal/mcp` (the +// transport-and-handlers implementation). Same naming hygiene as +// `agentcmd` / `sessioncmd`. +package mcpcmd + +import ( + "github.com/spf13/cobra" + + "github.com/Tencent/WeKnora/cli/internal/cmdutil" +) + +// NewCmd builds the `weknora mcp` parent. Called from cli/cmd/root.go. +func NewCmd(f *cmdutil.Factory) *cobra.Command { + cmd := &cobra.Command{ + Use: "mcp", + Short: "Run weknora as a Model Context Protocol server", + Long: `Exposes weknora's read surface as MCP tools so agentic IDE clients +(Claude Code, Cursor, Continue, Zed) can call them over JSON-RPC. + +Initial tool surface is read-only and curated: kb_list / kb_view / +doc_list / doc_view / doc_download / search_chunks / chat / agent_list / +agent_invoke. Destructive verbs (create / delete / upload) are deliberately +excluded — the agent should ask the user before mutating; the CLI's +exit-10 protocol covers that path.`, + Args: cobra.NoArgs, + Run: func(c *cobra.Command, _ []string) { _ = c.Help() }, + } + cmd.AddCommand(NewCmdServe(f)) + return cmd +} diff --git a/cli/cmd/mcp/serve.go b/cli/cmd/mcp/serve.go new file mode 100644 index 00000000..54118e2e --- /dev/null +++ b/cli/cmd/mcp/serve.go @@ -0,0 +1,47 @@ +package mcpcmd + +import ( + "github.com/spf13/cobra" + + "github.com/Tencent/WeKnora/cli/internal/agent" + "github.com/Tencent/WeKnora/cli/internal/cmdutil" + mcpserver "github.com/Tencent/WeKnora/cli/internal/mcp" +) + +// NewCmdServe builds `weknora mcp serve`. Currently stdio-only; HTTP +// (streamable / SSE) is roadmap 5-7. +func NewCmdServe(f *cmdutil.Factory) *cobra.Command { + cmd := &cobra.Command{ + Use: "serve", + Short: "Run an MCP server over stdio", + Long: `Speaks JSON-RPC 2.0 on stdin/stdout to an MCP client. Logs go to +stderr; the data channel is reserved for protocol traffic. + +Authentication is inherited from the active context (or --context). The +server eagerly resolves the SDK client at startup — if no context is +configured, the process exits with auth.unauthenticated before any MCP +handshake. This way an IDE-side agent sees a clear failure mode rather +than a server that handshakes successfully then errors on every tool. + +To use with Claude Code, add to ~/.claude/mcp_servers.json: + + { + "weknora": { + "command": "weknora", + "args": ["mcp", "serve"] + } + }`, + Args: cobra.NoArgs, + RunE: func(c *cobra.Command, _ []string) error { + // Eagerly construct the SDK client. Surfaces auth / + // configuration problems before any MCP handshake. + cli, err := f.Client() + if err != nil { + return err + } + return mcpserver.RunStdio(c.Context(), cli) + }, + } + agent.SetAgentHelp(cmd, "Long-lived stdio MCP server. Reads JSON-RPC requests from stdin, writes responses to stdout, logs to stderr. Surfaces 9 read-only tools (kb_list/kb_view/doc_list/doc_view/doc_download/search_chunks/chat/agent_list/agent_invoke); destructive verbs are intentionally excluded. Auth is inherited from the active context — to switch, exit and re-launch with --context. Long tools (chat / agent_invoke) accumulate the LLM stream server-side and return a single CallToolResult — no MCP streaming-content extension, per spec 2025-06-18 (tools.mdx).") + return cmd +} diff --git a/cli/cmd/root.go b/cli/cmd/root.go index 65d0b847..7431060f 100644 --- a/cli/cmd/root.go +++ b/cli/cmd/root.go @@ -17,6 +17,7 @@ import ( "github.com/Tencent/WeKnora/cli/cmd/doctor" "github.com/Tencent/WeKnora/cli/cmd/kb" linkcmd "github.com/Tencent/WeKnora/cli/cmd/link" + mcpcmd "github.com/Tencent/WeKnora/cli/cmd/mcp" "github.com/Tencent/WeKnora/cli/cmd/search" sessioncmd "github.com/Tencent/WeKnora/cli/cmd/session" "github.com/Tencent/WeKnora/cli/internal/agent" @@ -177,6 +178,7 @@ hybrid searches against a WeKnora server from your shell or an AI agent.`, cmd.AddCommand(chatcmd.NewCmd(f)) cmd.AddCommand(sessioncmd.NewCmd(f)) cmd.AddCommand(agentcmd.NewCmd(f)) + cmd.AddCommand(mcpcmd.NewCmd(f)) return cmd } diff --git a/cli/cmd/session/list.go b/cli/cmd/session/list.go index b0106c0a..4b943783 100644 --- a/cli/cmd/session/list.go +++ b/cli/cmd/session/list.go @@ -119,7 +119,15 @@ func runList(ctx context.Context, opts *ListOptions, jopts *cmdutil.JSONOptions, } if jopts.Enabled() { - meta := &format.Meta{HasMore: opts.Page*opts.PageSize < total} + // When --since is active, has_more is meaningless: the server's + // total counts all sessions but the page is now client-side + // filtered. An agent walking pages would see has_more=true even + // when no later page can contain matching items. Drop the flag + // rather than mislead. The agent_help string documents this. + meta := &format.Meta{} + if since == 0 { + meta.HasMore = opts.Page*opts.PageSize < total + } return format.WriteEnvelopeFiltered( iostreams.IO.Out, format.Success(listResult{Items: items}, meta), diff --git a/cli/go.mod b/cli/go.mod index 1ad044d4..c1ef376c 100644 --- a/cli/go.mod +++ b/cli/go.mod @@ -1,8 +1,6 @@ module github.com/Tencent/WeKnora/cli -go 1.24.2 - -toolchain go1.24.4 +go 1.25.0 require ( github.com/Tencent/WeKnora/client v0.0.0-00010101000000-000000000000 @@ -10,6 +8,7 @@ require ( github.com/itchyny/gojq v0.12.19 github.com/mattn/go-isatty v0.0.22 github.com/mattn/go-runewidth v0.0.23 + github.com/modelcontextprotocol/go-sdk v1.6.0 github.com/spf13/cobra v1.10.2 github.com/spf13/pflag v1.0.10 github.com/stretchr/testify v1.11.1 @@ -36,6 +35,7 @@ require ( github.com/dustin/go-humanize v1.0.1 // indirect github.com/erikgeiser/coninput v0.0.0-20211004153227-1c3628e74d0f // indirect github.com/godbus/dbus/v5 v5.2.2 // indirect + github.com/google/jsonschema-go v0.4.3 // indirect github.com/inconshreveable/mousetrap v1.1.0 // indirect github.com/itchyny/timefmt-go v0.1.8 // indirect github.com/lucasb-eyer/go-colorful v1.2.0 // indirect @@ -46,9 +46,13 @@ require ( github.com/muesli/termenv v0.16.0 // indirect github.com/pmezard/go-difflib v1.0.0 // indirect github.com/rivo/uniseg v0.4.7 // indirect + github.com/segmentio/asm v1.1.3 // indirect + github.com/segmentio/encoding v0.5.4 // indirect github.com/xo/terminfo v0.0.0-20220910002029-abceb7e1c41e // indirect + github.com/yosida95/uritemplate/v3 v3.0.2 // indirect + golang.org/x/oauth2 v0.35.0 // indirect golang.org/x/sync v0.15.0 // indirect - golang.org/x/sys v0.38.0 // indirect + golang.org/x/sys v0.41.0 // indirect golang.org/x/text v0.23.0 // indirect ) diff --git a/cli/go.sum b/cli/go.sum index 44f05d34..d88b317d 100644 --- a/cli/go.sum +++ b/cli/go.sum @@ -53,6 +53,12 @@ github.com/erikgeiser/coninput v0.0.0-20211004153227-1c3628e74d0f h1:Y/CXytFA4m6 github.com/erikgeiser/coninput v0.0.0-20211004153227-1c3628e74d0f/go.mod h1:vw97MGsxSvLiUE2X8qFplwetxpGLQrlU1Q9AUEIzCaM= github.com/godbus/dbus/v5 v5.2.2 h1:TUR3TgtSVDmjiXOgAAyaZbYmIeP3DPkld3jgKGV8mXQ= github.com/godbus/dbus/v5 v5.2.2/go.mod h1:3AAv2+hPq5rdnr5txxxRwiGjPXamgoIHgz9FPBfOp3c= +github.com/golang-jwt/jwt/v5 v5.3.1 h1:kYf81DTWFe7t+1VvL7eS+jKFVWaUnK9cB1qbwn63YCY= +github.com/golang-jwt/jwt/v5 v5.3.1/go.mod h1:fxCRLWMO43lRc8nhHWY6LGqRcf+1gQWArsqaEUEa5bE= +github.com/google/go-cmp v0.7.0 h1:wk8382ETsv4JYUZwIsn6YpYiWiBsYLSJiTsyBybVuN8= +github.com/google/go-cmp v0.7.0/go.mod h1:pXiqmnSA92OHEEa9HXL2W4E7lf9JzCmGVUdgjX3N/iU= +github.com/google/jsonschema-go v0.4.3 h1:/DBOLZTfDow7pe2GmaJNhltueGTtDKICi8V8p+DQPd0= +github.com/google/jsonschema-go v0.4.3/go.mod h1:r5quNTdLOYEz95Ru18zA0ydNbBuYoo9tgaYcxEYhJVE= github.com/inconshreveable/mousetrap v1.1.0 h1:wN+x4NVGpMsO7ErUn/mUI3vEoE6Jt13X2s0bqwp9tc8= github.com/inconshreveable/mousetrap v1.1.0/go.mod h1:vpF70FUmC8bwa3OWnCshd2FqLfsEA9PFc4w1p2J65bw= github.com/itchyny/gojq v0.12.19 h1:ttXA0XCLEMoaLOz5lSeFOZ6u6Q3QxmG46vfgI4O0DEs= @@ -69,6 +75,8 @@ github.com/mattn/go-runewidth v0.0.23 h1:7ykA0T0jkPpzSvMS5i9uoNn2Xy3R383f9HDx3Ry github.com/mattn/go-runewidth v0.0.23/go.mod h1:XBkDxAl56ILZc9knddidhrOlY5R/pDhgLpndooCuJAs= github.com/mitchellh/hashstructure/v2 v2.0.2 h1:vGKWl0YJqUNxE8d+h8f6NJLcCJrgbhC4NcD46KavDd4= github.com/mitchellh/hashstructure/v2 v2.0.2/go.mod h1:MG3aRVU/N29oo/V/IhBX8GR/zz4kQkprJgF2EVszyDE= +github.com/modelcontextprotocol/go-sdk v1.6.0 h1:PPLS3kn7WtOEnR+Af4X5H96SG0qSab8R/ZQT/HkhPkY= +github.com/modelcontextprotocol/go-sdk v1.6.0/go.mod h1:kzm3kzFL1/+AziGOE0nUs3gvPoNxMCvkxokMkuFapXQ= github.com/muesli/ansi v0.0.0-20230316100256-276c6243b2f6 h1:ZK8zHtRHOkbHy6Mmr5D264iyp3TiX5OmNcI5cIARiQI= github.com/muesli/ansi v0.0.0-20230316100256-276c6243b2f6/go.mod h1:CJlz5H+gyd6CUWT45Oy4q24RdLyn7Md9Vj2/ldJBSIo= github.com/muesli/cancelreader v0.2.2 h1:3I4Kt4BQjOR54NavqnDogx/MIoWBFa0StPA8ELUXHmA= @@ -80,6 +88,10 @@ github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZN github.com/rivo/uniseg v0.4.7 h1:WUdvkW8uEhrYfLC4ZzdpI2ztxP1I582+49Oc5Mq64VQ= github.com/rivo/uniseg v0.4.7/go.mod h1:FN3SvrM+Zdj16jyLfmOkMNblXMcoc8DfTHruCPUcx88= github.com/russross/blackfriday/v2 v2.1.0/go.mod h1:+Rmxgy9KzJVeS9/2gXHxylqXiyQDYRxCVz55jmeOWTM= +github.com/segmentio/asm v1.1.3 h1:WM03sfUOENvvKexOLp+pCqgb/WDjsi7EK8gIsICtzhc= +github.com/segmentio/asm v1.1.3/go.mod h1:Ld3L4ZXGNcSLRg4JBsZ3//1+f/TjYl0Mzen/DQy1EJg= +github.com/segmentio/encoding v0.5.4 h1:OW1VRern8Nw6ITAtwSZ7Idrl3MXCFwXHPgqESYfvNt0= +github.com/segmentio/encoding v0.5.4/go.mod h1:HS1ZKa3kSN32ZHVZ7ZLPLXWvOVIiZtyJnO1gPH1sKt0= github.com/spf13/cobra v1.10.2 h1:DMTTonx5m65Ic0GOoRY2c16WCbHxOOw6xxezuLaBpcU= github.com/spf13/cobra v1.10.2/go.mod h1:7C1pvHqHw5A4vrJfjNwvOdzYu0Gml16OCs2GRiTUUS4= github.com/spf13/pflag v1.0.9/go.mod h1:McXfInJRrz4CZXVZOBLb0bTZqETkiAhM9Iw0y3An2Bg= @@ -91,18 +103,24 @@ github.com/stretchr/testify v1.11.1 h1:7s2iGBzp5EwR7/aIZr8ao5+dra3wiQyKjjFuvgVKu github.com/stretchr/testify v1.11.1/go.mod h1:wZwfW3scLgRK+23gO65QZefKpKQRnfz6sD981Nm4B6U= github.com/xo/terminfo v0.0.0-20220910002029-abceb7e1c41e h1:JVG44RsyaB9T2KIHavMF/ppJZNG9ZpyihvCd0w101no= github.com/xo/terminfo v0.0.0-20220910002029-abceb7e1c41e/go.mod h1:RbqR21r5mrJuqunuUZ/Dhy/avygyECGrLceyNeo4LiM= +github.com/yosida95/uritemplate/v3 v3.0.2 h1:Ed3Oyj9yrmi9087+NczuL5BwkIc4wvTb5zIM+UJPGz4= +github.com/yosida95/uritemplate/v3 v3.0.2/go.mod h1:ILOh0sOhIJR3+L/8afwt/kE++YT040gmv5BQTMR2HP4= github.com/zalando/go-keyring v0.2.8 h1:6sD/Ucpl7jNq10rM2pgqTs0sZ9V3qMrqfIIy5YPccHs= github.com/zalando/go-keyring v0.2.8/go.mod h1:tsMo+VpRq5NGyKfxoBVjCuMrG47yj8cmakZDO5QGii0= go.yaml.in/yaml/v3 v3.0.4/go.mod h1:DhzuOOF2ATzADvBadXxruRBLzYTpT36CKvDb3+aBEFg= golang.org/x/exp v0.0.0-20231006140011-7918f672742d h1:jtJma62tbqLibJ5sFQz8bKtEM8rJBtfilJ2qTU199MI= golang.org/x/exp v0.0.0-20231006140011-7918f672742d/go.mod h1:ldy0pHrwJyGW56pPQzzkH36rKxoZW1tw7ZJpeKx+hdo= +golang.org/x/oauth2 v0.35.0 h1:Mv2mzuHuZuY2+bkyWXIHMfhNdJAdwW3FuWeCPYN5GVQ= +golang.org/x/oauth2 v0.35.0/go.mod h1:lzm5WQJQwKZ3nwavOZ3IS5Aulzxi68dUSgRHujetwEA= golang.org/x/sync v0.15.0 h1:KWH3jNZsfyT6xfAfKiz6MRNmd46ByHDYaZ7KSkCtdW8= golang.org/x/sync v0.15.0/go.mod h1:1dzgHSNfp02xaA81J2MS99Qcpr2w7fw1gpm99rleRqA= golang.org/x/sys v0.0.0-20210809222454-d867a43fc93e/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.38.0 h1:3yZWxaJjBmCWXqhN1qh02AkOnCQ1poK6oF+a7xWL6Gc= -golang.org/x/sys v0.38.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks= +golang.org/x/sys v0.41.0 h1:Ivj+2Cp/ylzLiEU89QhWblYnOE9zerudt9Ftecq2C6k= +golang.org/x/sys v0.41.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks= golang.org/x/text v0.23.0 h1:D71I7dUrlY+VX0gQShAThNGHFxZ13dGLBHQLVl1mJlY= golang.org/x/text v0.23.0/go.mod h1:/BLNzu4aZCJ1+kcD0DNRotWKage4q2rGVAg4o22unh4= +golang.org/x/tools v0.42.0 h1:uNgphsn75Tdz5Ji2q36v/nsFSfR/9BRFvqhGBaJGd5k= +golang.org/x/tools v0.42.0/go.mod h1:Ma6lCIwGZvHK6XtgbswSoWroEkhugApmsXyrUmBhfr0= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= diff --git a/cli/internal/format/references.go b/cli/internal/format/references.go new file mode 100644 index 00000000..1bfdf040 --- /dev/null +++ b/cli/internal/format/references.go @@ -0,0 +1,37 @@ +package format + +import ( + "fmt" + "io" + + sdk "github.com/Tencent/WeKnora/client" +) + +// WriteReferences renders the compact references footer used by chat and +// agent invoke: a horizontal rule, one numbered line per reference, +// best-available title + optional score. Skipped entirely when refs is +// empty — agent-friendly silence beats an empty banner. +func WriteReferences(w io.Writer, refs []*sdk.SearchResult) { + if len(refs) == 0 { + return + } + fmt.Fprintln(w) + fmt.Fprintln(w, "──── References ────") + for i, r := range refs { + if r == nil { + continue + } + title := r.KnowledgeTitle + if title == "" { + title = r.KnowledgeFilename + } + if title == "" { + title = r.KnowledgeID + } + fmt.Fprintf(w, "[%d] %s", i+1, title) + if r.Score > 0 { + fmt.Fprintf(w, " score=%.3f", r.Score) + } + fmt.Fprintln(w) + } +} diff --git a/cli/internal/mcp/server.go b/cli/internal/mcp/server.go new file mode 100644 index 00000000..ad99aee8 --- /dev/null +++ b/cli/internal/mcp/server.go @@ -0,0 +1,59 @@ +// Package mcp wires the curated weknora tool set to an +// modelcontextprotocol/go-sdk server. RunStdio is the entry point invoked +// by `weknora mcp serve`. +// +// Design notes: +// +// - Tool surface is hand-curated rather than auto-derived from the cobra +// tree (which would expose auth/link/completion/destructive verbs that +// don't belong on an agent-callable surface). +// - Long-running tools (chat / agent_invoke) accumulate the LLM SSE +// stream server-side and return a single CallToolResult — MCP spec +// 2025-06-18 does not define streamed tool-result content, so this is +// the canonical pattern (see Anthropic reference `everything` server +// and Stripe @stripe/mcp). +// - Handlers receive ctx for cancellation; mid-LLM-stream cancellation +// propagates to the SDK via context, which closes the SSE connection. +package mcp + +import ( + "context" + "fmt" + + mcpsdk "github.com/modelcontextprotocol/go-sdk/mcp" + + "github.com/Tencent/WeKnora/cli/internal/build" +) + +// ServiceClient bundles the SDK methods the tool registry needs. *sdk.Client +// satisfies it; tests substitute a fake to exercise the tool handlers +// in-process without standing up a real WeKnora server. +// +// Embedding the full SDK Client would couple every tool test to every SDK +// method; declaring the narrow surface here keeps the seam tight. +type ServiceClient interface { + knowledgeBaseService + knowledgeService + chatService + agentService +} + +// RunStdio constructs the MCP server, registers the curated 9 tools, and +// blocks reading JSON-RPC from stdin until the client disconnects or ctx +// is cancelled. Returns the underlying transport error (if any); the cobra +// RunE caller maps it through the usual cmdutil exit-code path. +func RunStdio(ctx context.Context, svc ServiceClient) error { + v, _, _ := build.Info() + server := mcpsdk.NewServer( + &mcpsdk.Implementation{ + Name: "weknora", + Version: v, + }, + nil, + ) + registerTools(server, svc) + if err := server.Run(ctx, &mcpsdk.StdioTransport{}); err != nil { + return fmt.Errorf("mcp serve: %w", err) + } + return nil +} diff --git a/cli/internal/mcp/tools.go b/cli/internal/mcp/tools.go new file mode 100644 index 00000000..d58387b2 --- /dev/null +++ b/cli/internal/mcp/tools.go @@ -0,0 +1,463 @@ +package mcp + +import ( + "context" + "encoding/base64" + "fmt" + "io" + "strings" + + mcpsdk "github.com/modelcontextprotocol/go-sdk/mcp" + + "github.com/Tencent/WeKnora/cli/internal/sse" + sdk "github.com/Tencent/WeKnora/client" +) + +// Narrow per-domain service interfaces. ServiceClient (server.go) embeds +// them all; *sdk.Client satisfies the union implicitly. + +type knowledgeBaseService interface { + ListKnowledgeBases(ctx context.Context) ([]sdk.KnowledgeBase, error) + GetKnowledgeBase(ctx context.Context, id string) (*sdk.KnowledgeBase, error) +} + +type knowledgeService interface { + ListKnowledgeWithFilter(ctx context.Context, kbID string, page, pageSize int, filter sdk.KnowledgeListFilter) ([]sdk.Knowledge, int64, error) + GetKnowledge(ctx context.Context, knowledgeID string) (*sdk.Knowledge, error) + OpenKnowledgeFile(ctx context.Context, knowledgeID string) (string, io.ReadCloser, error) + HybridSearch(ctx context.Context, kbID string, params *sdk.SearchParams) ([]*sdk.SearchResult, error) +} + +type chatService interface { + CreateSession(ctx context.Context, req *sdk.CreateSessionRequest) (*sdk.Session, error) + KnowledgeQAStream(ctx context.Context, sessionID string, req *sdk.KnowledgeQARequest, cb func(*sdk.StreamResponse) error) error +} + +type agentService interface { + ListAgents(ctx context.Context) ([]sdk.Agent, error) + GetAgent(ctx context.Context, agentID string) (*sdk.Agent, error) + AgentQAStreamWithRequest(ctx context.Context, sessionID string, req *sdk.AgentQARequest, cb sdk.AgentEventCallback) error +} + +// agentInvokeService composes the two SDK methods agent_invoke needs +// (CreateSession for the auto-session path + AgentQAStreamWithRequest +// for the run itself). Declared here alongside the per-domain +// interfaces above so ServiceClient (server.go) — which embeds the +// four domain interfaces — also satisfies it. +type agentInvokeService interface { + CreateSession(ctx context.Context, req *sdk.CreateSessionRequest) (*sdk.Session, error) + AgentQAStreamWithRequest(ctx context.Context, sessionID string, req *sdk.AgentQARequest, cb sdk.AgentEventCallback) error +} + +// registerTools wires the curated 9 tools onto server. Adding a tool here +// is a deliberate API expansion — the agent-callable surface is the +// reason this CLI ships an MCP server, not its CLI command list, so this +// list must be maintained by hand (see also AGENTS.md mcp serve section). +func registerTools(server *mcpsdk.Server, svc ServiceClient) { + addKBList(server, svc) + addKBView(server, svc) + addDocList(server, svc) + addDocView(server, svc) + addDocDownload(server, svc) + addSearchChunks(server, svc) + addChat(server, svc) + addAgentList(server, svc) + addAgentInvoke(server, svc) +} + +// ---- kb_list ------------------------------------------------------------- + +type kbListInput struct{} + +type kbListOutput struct { + Items []sdk.KnowledgeBase `json:"items"` +} + +func addKBList(server *mcpsdk.Server, svc knowledgeBaseService) { + mcpsdk.AddTool(server, &mcpsdk.Tool{ + Name: "kb_list", + Description: "List all knowledge bases visible to the active WeKnora tenant. No arguments. Returns items[]: each item carries id, name, description, knowledge_count, is_pinned, updated_at — useful for selecting a kb_id to pass to other tools.", + }, func(ctx context.Context, _ *mcpsdk.CallToolRequest, _ kbListInput) (*mcpsdk.CallToolResult, kbListOutput, error) { + items, err := svc.ListKnowledgeBases(ctx) + if err != nil { + return nil, kbListOutput{}, fmt.Errorf("list knowledge bases: %w", err) + } + if items == nil { + items = []sdk.KnowledgeBase{} + } + return nil, kbListOutput{Items: items}, nil + }) +} + +// ---- kb_view ------------------------------------------------------------- + +type kbViewInput struct { + KBID string `json:"kb_id" jsonschema:"knowledge base ID"` +} + +func addKBView(server *mcpsdk.Server, svc knowledgeBaseService) { + mcpsdk.AddTool(server, &mcpsdk.Tool{ + Name: "kb_view", + Description: "Fetch a knowledge base by ID. Returns the full record including chunking config, embedding/summary model IDs, knowledge_count, and chunk_count.", + }, func(ctx context.Context, _ *mcpsdk.CallToolRequest, in kbViewInput) (*mcpsdk.CallToolResult, *sdk.KnowledgeBase, error) { + if in.KBID == "" { + return nil, nil, fmt.Errorf("kb_id is required") + } + kb, err := svc.GetKnowledgeBase(ctx, in.KBID) + if err != nil { + return nil, nil, fmt.Errorf("get knowledge base: %w", err) + } + return nil, kb, nil + }) +} + +// ---- doc_list ------------------------------------------------------------ + +type docListInput struct { + KBID string `json:"kb_id" jsonschema:"knowledge base ID"` + Page int `json:"page,omitempty" jsonschema:"1-indexed page number; defaults to 1"` + PageSize int `json:"page_size,omitempty" jsonschema:"items per page (1..1000); defaults to 20"` + Status string `json:"status,omitempty" jsonschema:"filter by parse status: pending | processing | completed | failed"` +} + +type docListOutput struct { + Items []sdk.Knowledge `json:"items"` + Page int `json:"page"` + PageSize int `json:"page_size"` + Total int64 `json:"total"` +} + +func addDocList(server *mcpsdk.Server, svc knowledgeService) { + mcpsdk.AddTool(server, &mcpsdk.Tool{ + Name: "doc_list", + Description: "List documents in a knowledge base, with pagination and optional parse-status filter. Returns items[] with id, file_name, title, parse_status, size, updated_at — plus the page/total metadata.", + }, func(ctx context.Context, _ *mcpsdk.CallToolRequest, in docListInput) (*mcpsdk.CallToolResult, docListOutput, error) { + if in.KBID == "" { + return nil, docListOutput{}, fmt.Errorf("kb_id is required") + } + page := in.Page + if page < 1 { + page = 1 + } + size := in.PageSize + if size < 1 { + size = 20 + } + if size > 1000 { + return nil, docListOutput{}, fmt.Errorf("page_size must be in 1..1000") + } + items, total, err := svc.ListKnowledgeWithFilter(ctx, in.KBID, page, size, + sdk.KnowledgeListFilter{ParseStatus: in.Status}) + if err != nil { + return nil, docListOutput{}, fmt.Errorf("list documents: %w", err) + } + if items == nil { + items = []sdk.Knowledge{} + } + return nil, docListOutput{Items: items, Page: page, PageSize: size, Total: total}, nil + }) +} + +// ---- doc_view ------------------------------------------------------------ + +type docViewInput struct { + KnowledgeID string `json:"knowledge_id" jsonschema:"document (knowledge entry) ID"` +} + +func addDocView(server *mcpsdk.Server, svc knowledgeService) { + mcpsdk.AddTool(server, &mcpsdk.Tool{ + Name: "doc_view", + Description: "Fetch a single document by ID. Returns the Knowledge record (file_name, title, type, parse_status, size, embedding_model_id, source URL if any, etc.).", + }, func(ctx context.Context, _ *mcpsdk.CallToolRequest, in docViewInput) (*mcpsdk.CallToolResult, *sdk.Knowledge, error) { + if in.KnowledgeID == "" { + return nil, nil, fmt.Errorf("knowledge_id is required") + } + k, err := svc.GetKnowledge(ctx, in.KnowledgeID) + if err != nil { + return nil, nil, fmt.Errorf("get knowledge: %w", err) + } + return nil, k, nil + }) +} + +// ---- doc_download -------------------------------------------------------- + +type docDownloadInput struct { + KnowledgeID string `json:"knowledge_id" jsonschema:"document (knowledge entry) ID"` +} + +type docDownloadOutput struct { + KnowledgeID string `json:"knowledge_id"` + FileName string `json:"file_name"` + Bytes int `json:"bytes"` + // Content is the file contents (UTF-8 if text, base64 if the SDK + // reports a binary-looking blob). For binary, agents should decode + // before consuming. + Content string `json:"content"` + IsBase64 bool `json:"is_base64"` +} + +// maxDocDownloadBytes caps the per-call payload to keep an agent's context +// window safe; agents needing larger documents should chunk via doc_view + +// search_chunks. 1 MiB matches a typical LLM context-window budget for +// inline content (~250k tokens) while remaining cheap to serialize. +const maxDocDownloadBytes = 1 << 20 + +func addDocDownload(server *mcpsdk.Server, svc knowledgeService) { + mcpsdk.AddTool(server, &mcpsdk.Tool{ + Name: "doc_download", + Description: "Download a document's raw bytes by ID. Capped at 1 MiB per call — for larger documents, use search_chunks to find the relevant excerpts. is_base64 reports whether content was base64-encoded (heuristic: presence of NUL byte in the first 512 bytes).", + }, func(ctx context.Context, _ *mcpsdk.CallToolRequest, in docDownloadInput) (*mcpsdk.CallToolResult, docDownloadOutput, error) { + if in.KnowledgeID == "" { + return nil, docDownloadOutput{}, fmt.Errorf("knowledge_id is required") + } + name, body, err := svc.OpenKnowledgeFile(ctx, in.KnowledgeID) + if err != nil { + return nil, docDownloadOutput{}, fmt.Errorf("open knowledge file: %w", err) + } + defer body.Close() + buf, err := io.ReadAll(io.LimitReader(body, maxDocDownloadBytes+1)) + if err != nil { + return nil, docDownloadOutput{}, fmt.Errorf("read knowledge file: %w", err) + } + if len(buf) > maxDocDownloadBytes { + return nil, docDownloadOutput{}, fmt.Errorf("document exceeds the %d-byte per-call cap; use search_chunks for excerpts", maxDocDownloadBytes) + } + content, isBase64 := encodeDownload(buf) + return nil, docDownloadOutput{ + KnowledgeID: in.KnowledgeID, + FileName: name, + Bytes: len(buf), + Content: content, + IsBase64: isBase64, + }, nil + }) +} + +// ---- search_chunks ------------------------------------------------------- + +type searchChunksInput struct { + KBID string `json:"kb_id" jsonschema:"knowledge base ID to search"` + Query string `json:"query" jsonschema:"natural-language search query"` + Limit int `json:"limit,omitempty" jsonschema:"client-side cap on results (1..1000); defaults to 10"` + VectorThreshold float64 `json:"vector_threshold,omitempty" jsonschema:"minimum vector similarity (0..1)"` + KeywordThreshold float64 `json:"keyword_threshold,omitempty" jsonschema:"minimum keyword score (0..1)"` +} + +type searchChunksOutput struct { + Results []*sdk.SearchResult `json:"results"` +} + +func addSearchChunks(server *mcpsdk.Server, svc knowledgeService) { + // Out = any: SDK output schema would derive from searchChunksOutput, + // which embeds *sdk.SearchResult — and SearchResult.Metadata is a + // nilable map[string]any that violates the auto-generated + // type=object constraint when empty. Skipping derivation by using + // `any` keeps the structured JSON shape identical while bypassing + // the over-eager validator. Same pattern applied to chat / agent_invoke + // below. + mcpsdk.AddTool(server, &mcpsdk.Tool{ + Name: "search_chunks", + Description: "Hybrid (vector + keyword) retrieval against a knowledge base. Returns the top chunks ranked by RRF; use this before chat to ground an answer in cited context. Results include knowledge_id, content, score — feed back into chat as context or display directly.", + }, func(ctx context.Context, _ *mcpsdk.CallToolRequest, in searchChunksInput) (*mcpsdk.CallToolResult, any, error) { + if in.KBID == "" { + return nil, nil, fmt.Errorf("kb_id is required") + } + if strings.TrimSpace(in.Query) == "" { + return nil, nil, fmt.Errorf("query cannot be empty") + } + limit := in.Limit + if limit < 1 { + limit = 10 + } + if limit > 1000 { + return nil, nil, fmt.Errorf("limit must be in 1..1000") + } + results, err := svc.HybridSearch(ctx, in.KBID, &sdk.SearchParams{ + QueryText: in.Query, + VectorThreshold: in.VectorThreshold, + KeywordThreshold: in.KeywordThreshold, + }) + if err != nil { + return nil, nil, fmt.Errorf("hybrid search: %w", err) + } + if len(results) > limit { + results = results[:limit] + } + if results == nil { + results = []*sdk.SearchResult{} + } + return nil, searchChunksOutput{Results: results}, nil + }) +} + +// ---- chat ---------------------------------------------------------------- + +type chatInput struct { + KBID string `json:"kb_id" jsonschema:"knowledge base ID to chat against"` + Query string `json:"query" jsonschema:"user query"` + SessionID string `json:"session_id,omitempty" jsonschema:"existing session to continue; auto-created when empty"` +} + +type chatOutput struct { + Answer string `json:"answer"` + References []*sdk.SearchResult `json:"references"` + SessionID string `json:"session_id"` + AssistantMessageID string `json:"assistant_message_id,omitempty"` +} + +func addChat(server *mcpsdk.Server, svc chatService) { + mcpsdk.AddTool(server, &mcpsdk.Tool{ + Name: "chat", + Description: "Stream a RAG answer from the LLM, grounded in the given knowledge base. The SSE stream is accumulated server-side; this tool returns the full answer + references + session_id once the stream completes. Pass session_id to continue a multi-turn conversation; otherwise a fresh session is auto-created.", + }, func(ctx context.Context, _ *mcpsdk.CallToolRequest, in chatInput) (*mcpsdk.CallToolResult, any, error) { + if in.KBID == "" { + return nil, nil, fmt.Errorf("kb_id is required") + } + if strings.TrimSpace(in.Query) == "" { + return nil, nil, fmt.Errorf("query cannot be empty") + } + sessionID := in.SessionID + if sessionID == "" { + sess, err := svc.CreateSession(ctx, &sdk.CreateSessionRequest{Title: "weknora mcp chat"}) + if err != nil { + return nil, nil, fmt.Errorf("create chat session: %w", err) + } + sessionID = sess.ID + } + req := &sdk.KnowledgeQARequest{ + Query: in.Query, + KnowledgeBaseIDs: []string{in.KBID}, + AgentEnabled: false, + Channel: "api", + } + acc := &sse.Accumulator{} + streamErr := svc.KnowledgeQAStream(ctx, sessionID, req, func(r *sdk.StreamResponse) error { + acc.Append(r) + return nil + }) + if streamErr != nil { + return nil, nil, fmt.Errorf("knowledge qa stream: %w", streamErr) + } + if !acc.Done() { + return nil, nil, fmt.Errorf("stream ended without a terminal event") + } + sid := acc.SessionID + if sid == "" { + sid = sessionID + } + return nil, chatOutput{ + Answer: acc.Result(), + References: acc.References, + SessionID: sid, + AssistantMessageID: acc.AssistantMessageID, + }, nil + }) +} + +// ---- agent_list ---------------------------------------------------------- + +type agentListInput struct{} + +type agentListOutput struct { + Items []sdk.Agent `json:"items"` +} + +func addAgentList(server *mcpsdk.Server, svc agentService) { + mcpsdk.AddTool(server, &mcpsdk.Tool{ + Name: "agent_list", + Description: "List the tenant's custom agents. Returns items[] with id, name, description, is_builtin — use to discover an agent_id before agent_invoke.", + }, func(ctx context.Context, _ *mcpsdk.CallToolRequest, _ agentListInput) (*mcpsdk.CallToolResult, agentListOutput, error) { + items, err := svc.ListAgents(ctx) + if err != nil { + return nil, agentListOutput{}, fmt.Errorf("list agents: %w", err) + } + if items == nil { + items = []sdk.Agent{} + } + return nil, agentListOutput{Items: items}, nil + }) +} + +// ---- agent_invoke -------------------------------------------------------- + +type agentInvokeInput struct { + AgentID string `json:"agent_id" jsonschema:"custom agent ID"` + Query string `json:"query" jsonschema:"user query"` + SessionID string `json:"session_id,omitempty" jsonschema:"existing session to continue; auto-created when empty"` +} + +type agentInvokeOutput struct { + Answer string `json:"answer"` + References []*sdk.SearchResult `json:"references"` + ToolEvents []sse.AgentToolEvent `json:"tool_events,omitempty"` + Thinking string `json:"thinking,omitempty"` + SessionID string `json:"session_id"` + AgentID string `json:"agent_id"` +} + +func addAgentInvoke(server *mcpsdk.Server, svc agentInvokeService) { + mcpsdk.AddTool(server, &mcpsdk.Tool{ + Name: "agent_invoke", + Description: "Run a query through a custom agent (system prompt + tool allow-list + KB scope). The agent's SSE stream is accumulated server-side; this tool returns the final answer plus the trace (references, tool_events, thinking).", + }, func(ctx context.Context, _ *mcpsdk.CallToolRequest, in agentInvokeInput) (*mcpsdk.CallToolResult, any, error) { + if in.AgentID == "" { + return nil, nil, fmt.Errorf("agent_id is required") + } + if strings.TrimSpace(in.Query) == "" { + return nil, nil, fmt.Errorf("query cannot be empty") + } + acc := &sse.AgentAccumulator{} + req := &sdk.AgentQARequest{ + Query: in.Query, + AgentEnabled: true, + AgentID: in.AgentID, + Channel: "api", + } + // Auto-create session if not supplied. Sessions are agent- + // agnostic at creation (Q3 — verified against server source). + sessionID := in.SessionID + if sessionID == "" { + sess, err := svc.CreateSession(ctx, &sdk.CreateSessionRequest{Title: "weknora mcp agent_invoke"}) + if err != nil { + return nil, nil, fmt.Errorf("create chat session: %w", err) + } + sessionID = sess.ID + } + streamErr := svc.AgentQAStreamWithRequest(ctx, sessionID, req, func(r *sdk.AgentStreamResponse) error { + acc.Append(r) + return nil + }) + if streamErr != nil { + return nil, nil, fmt.Errorf("agent-chat stream: %w", streamErr) + } + if !acc.Done() { + return nil, nil, fmt.Errorf("stream ended without a terminal event") + } + return nil, agentInvokeOutput{ + Answer: acc.Answer(), + References: acc.References, + ToolEvents: acc.ToolEvents, + Thinking: acc.Thinking(), + SessionID: sessionID, + AgentID: in.AgentID, + }, nil + }) +} + +// encodeDownload returns (content, isBase64). Heuristic: if the first 512 +// bytes contain a NUL, treat as binary. Otherwise it's UTF-8-ish text. +// Matches what /usr/bin/file's "binary" heuristic does at a coarse level — +// good enough to spare an agent from base64-decoding obvious text. +func encodeDownload(buf []byte) (string, bool) { + probe := buf + if len(probe) > 512 { + probe = probe[:512] + } + for _, b := range probe { + if b == 0 { + return base64.StdEncoding.EncodeToString(buf), true + } + } + return string(buf), false +} diff --git a/cli/internal/mcp/tools_test.go b/cli/internal/mcp/tools_test.go new file mode 100644 index 00000000..ce791b79 --- /dev/null +++ b/cli/internal/mcp/tools_test.go @@ -0,0 +1,407 @@ +package mcp + +import ( + "context" + "encoding/json" + "errors" + "io" + "strings" + "testing" + "time" + + mcpsdk "github.com/modelcontextprotocol/go-sdk/mcp" + + sdk "github.com/Tencent/WeKnora/client" +) + +// fakeSvc implements every narrow service interface ServiceClient embeds. +// Each method records the last call args; per-test setup populates the +// return values it wants to assert against. +type fakeSvc struct { + listKBs []sdk.KnowledgeBase + listKBsErr error + getKB *sdk.KnowledgeBase + getKBErr error + listDocs []sdk.Knowledge + listDocsTotal int64 + listDocsErr error + getDoc *sdk.Knowledge + getDocErr error + openDocName string + openDocBody io.ReadCloser + openDocErr error + hybridResults []*sdk.SearchResult + hybridErr error + createSess *sdk.Session + createSessErr error + kbStreamEvents []*sdk.StreamResponse + kbStreamErr error + agents []sdk.Agent + agentsErr error + agent *sdk.Agent + agentErr error + agentEvents []*sdk.AgentStreamResponse + agentStreamErr error + // Captured args: + calls struct { + listKBs int + kbViewID string + docListKBID string + docListFilter sdk.KnowledgeListFilter + docViewID string + openDocID string + hybridKBID string + hybridParams *sdk.SearchParams + createSessReq *sdk.CreateSessionRequest + kbQAReq *sdk.KnowledgeQARequest + kbQASess string + agentListN int + agentViewID string + agentReq *sdk.AgentQARequest + agentSess string + } +} + +func (f *fakeSvc) ListKnowledgeBases(_ context.Context) ([]sdk.KnowledgeBase, error) { + f.calls.listKBs++ + return f.listKBs, f.listKBsErr +} +func (f *fakeSvc) GetKnowledgeBase(_ context.Context, id string) (*sdk.KnowledgeBase, error) { + f.calls.kbViewID = id + return f.getKB, f.getKBErr +} +func (f *fakeSvc) ListKnowledgeWithFilter(_ context.Context, kbID string, _, _ int, filter sdk.KnowledgeListFilter) ([]sdk.Knowledge, int64, error) { + f.calls.docListKBID = kbID + f.calls.docListFilter = filter + return f.listDocs, f.listDocsTotal, f.listDocsErr +} +func (f *fakeSvc) GetKnowledge(_ context.Context, id string) (*sdk.Knowledge, error) { + f.calls.docViewID = id + return f.getDoc, f.getDocErr +} +func (f *fakeSvc) OpenKnowledgeFile(_ context.Context, id string) (string, io.ReadCloser, error) { + f.calls.openDocID = id + return f.openDocName, f.openDocBody, f.openDocErr +} +func (f *fakeSvc) HybridSearch(_ context.Context, kbID string, p *sdk.SearchParams) ([]*sdk.SearchResult, error) { + f.calls.hybridKBID, f.calls.hybridParams = kbID, p + return f.hybridResults, f.hybridErr +} +func (f *fakeSvc) CreateSession(_ context.Context, req *sdk.CreateSessionRequest) (*sdk.Session, error) { + f.calls.createSessReq = req + if f.createSess == nil && f.createSessErr == nil { + return &sdk.Session{ID: "sess_auto"}, nil + } + return f.createSess, f.createSessErr +} +func (f *fakeSvc) KnowledgeQAStream(_ context.Context, sess string, req *sdk.KnowledgeQARequest, cb func(*sdk.StreamResponse) error) error { + f.calls.kbQASess, f.calls.kbQAReq = sess, req + for _, e := range f.kbStreamEvents { + if err := cb(e); err != nil { + return err + } + } + return f.kbStreamErr +} +func (f *fakeSvc) ListAgents(_ context.Context) ([]sdk.Agent, error) { + f.calls.agentListN++ + return f.agents, f.agentsErr +} +func (f *fakeSvc) GetAgent(_ context.Context, id string) (*sdk.Agent, error) { + f.calls.agentViewID = id + return f.agent, f.agentErr +} +func (f *fakeSvc) AgentQAStreamWithRequest(_ context.Context, sess string, req *sdk.AgentQARequest, cb sdk.AgentEventCallback) error { + f.calls.agentSess, f.calls.agentReq = sess, req + for _, e := range f.agentEvents { + if err := cb(e); err != nil { + return err + } + } + return f.agentStreamErr +} + +// newTestServer wires svc to an in-process MCP server and returns a +// connected client session ready to CallTool against it. +func newTestServer(t *testing.T, svc ServiceClient) (*mcpsdk.ClientSession, context.CancelFunc) { + t.Helper() + ctx, cancel := context.WithCancel(context.Background()) + server := mcpsdk.NewServer(&mcpsdk.Implementation{Name: "weknora-test", Version: "v0.0.0-test"}, nil) + registerTools(server, svc) + + st, ct := mcpsdk.NewInMemoryTransports() + serverSession, err := server.Connect(ctx, st, nil) + if err != nil { + cancel() + t.Fatalf("server.Connect: %v", err) + } + client := mcpsdk.NewClient(&mcpsdk.Implementation{Name: "test-client", Version: "v0.0.0"}, nil) + clientSession, err := client.Connect(ctx, ct, nil) + if err != nil { + _ = serverSession.Close() + cancel() + t.Fatalf("client.Connect: %v", err) + } + t.Cleanup(func() { + _ = clientSession.Close() + _ = serverSession.Close() + cancel() + }) + return clientSession, cancel +} + +// callTool invokes name with args and returns the parsed structured output. +func callTool(t *testing.T, c *mcpsdk.ClientSession, name string, args any, out any) *mcpsdk.CallToolResult { + t.Helper() + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + res, err := c.CallTool(ctx, &mcpsdk.CallToolParams{Name: name, Arguments: args}) + if err != nil { + t.Fatalf("CallTool(%s): %v", name, err) + } + if res.IsError { + if len(res.Content) > 0 { + t.Fatalf("tool %s returned error: %+v", name, res.Content) + } + t.Fatalf("tool %s returned error (no content)", name) + } + if out != nil && res.StructuredContent != nil { + b, _ := json.Marshal(res.StructuredContent) + if err := json.Unmarshal(b, out); err != nil { + t.Fatalf("decode %s output: %v\nraw=%s", name, err, b) + } + } + return res +} + +func TestTool_ListsRegistered(t *testing.T) { + c, _ := newTestServer(t, &fakeSvc{}) + ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) + defer cancel() + res, err := c.ListTools(ctx, nil) + if err != nil { + t.Fatalf("ListTools: %v", err) + } + want := []string{"kb_list", "kb_view", "doc_list", "doc_view", "doc_download", "search_chunks", "chat", "agent_list", "agent_invoke"} + got := map[string]bool{} + for _, tool := range res.Tools { + got[tool.Name] = true + } + for _, name := range want { + if !got[name] { + t.Errorf("missing tool %q in ListTools response", name) + } + } + if len(res.Tools) != len(want) { + t.Errorf("registered %d tools, want exactly %d (no scope creep)", len(res.Tools), len(want)) + } +} + +func TestTool_KBList(t *testing.T) { + svc := &fakeSvc{listKBs: []sdk.KnowledgeBase{{ID: "kb1", Name: "Marketing"}}} + c, _ := newTestServer(t, svc) + var out kbListOutput + callTool(t, c, "kb_list", map[string]any{}, &out) + if len(out.Items) != 1 || out.Items[0].ID != "kb1" { + t.Errorf("got %+v", out) + } +} + +func TestTool_KBView_RequiresKBID(t *testing.T) { + c, _ := newTestServer(t, &fakeSvc{}) + ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) + defer cancel() + res, err := c.CallTool(ctx, &mcpsdk.CallToolParams{Name: "kb_view", Arguments: map[string]any{}}) + if err != nil { + t.Fatalf("unexpected transport error: %v", err) + } + if !res.IsError { + t.Fatal("expected IsError=true on missing kb_id") + } +} + +func TestTool_KBView(t *testing.T) { + svc := &fakeSvc{getKB: &sdk.KnowledgeBase{ID: "kb_x", Name: "Eng"}} + c, _ := newTestServer(t, svc) + var out sdk.KnowledgeBase + callTool(t, c, "kb_view", map[string]any{"kb_id": "kb_x"}, &out) + if out.ID != "kb_x" || out.Name != "Eng" { + t.Errorf("got %+v", out) + } + if svc.calls.kbViewID != "kb_x" { + t.Errorf("kb_id not forwarded: %s", svc.calls.kbViewID) + } +} + +func TestTool_DocList_DefaultPagination(t *testing.T) { + svc := &fakeSvc{listDocs: []sdk.Knowledge{{ID: "k1"}}, listDocsTotal: 1} + c, _ := newTestServer(t, svc) + var out docListOutput + callTool(t, c, "doc_list", map[string]any{"kb_id": "kb_x"}, &out) + if out.Page != 1 || out.PageSize != 20 { + t.Errorf("default pagination not applied: %+v", out) + } + if svc.calls.docListKBID != "kb_x" { + t.Errorf("kb_id not forwarded: %s", svc.calls.docListKBID) + } +} + +func TestTool_DocList_StatusFilter_Forwarded(t *testing.T) { + svc := &fakeSvc{} + c, _ := newTestServer(t, svc) + callTool(t, c, "doc_list", map[string]any{"kb_id": "kb_x", "status": "failed"}, nil) + if svc.calls.docListFilter.ParseStatus != "failed" { + t.Errorf("status not forwarded as filter.ParseStatus: %+v", svc.calls.docListFilter) + } +} + +func TestTool_DocView(t *testing.T) { + svc := &fakeSvc{getDoc: &sdk.Knowledge{ID: "k1", FileName: "a.pdf"}} + c, _ := newTestServer(t, svc) + var out sdk.Knowledge + callTool(t, c, "doc_view", map[string]any{"knowledge_id": "k1"}, &out) + if out.ID != "k1" { + t.Errorf("got %+v", out) + } +} + +func TestTool_DocDownload_Text(t *testing.T) { + svc := &fakeSvc{ + openDocName: "notes.txt", + openDocBody: io.NopCloser(strings.NewReader("hello world")), + } + c, _ := newTestServer(t, svc) + var out docDownloadOutput + callTool(t, c, "doc_download", map[string]any{"knowledge_id": "k1"}, &out) + if out.Content != "hello world" { + t.Errorf("content = %q", out.Content) + } + if out.IsBase64 { + t.Error("text content should not be base64-encoded") + } +} + +func TestTool_DocDownload_BinaryBase64(t *testing.T) { + // First 512 bytes contain a NUL → encodeDownload returns base64. + bin := []byte{0x00, 0x01, 0x02, 0x03} + svc := &fakeSvc{ + openDocName: "blob.bin", + openDocBody: io.NopCloser(strings.NewReader(string(bin))), + } + c, _ := newTestServer(t, svc) + var out docDownloadOutput + callTool(t, c, "doc_download", map[string]any{"knowledge_id": "k1"}, &out) + if !out.IsBase64 { + t.Errorf("binary should be base64; got is_base64=%v content=%q", out.IsBase64, out.Content) + } +} + +func TestTool_SearchChunks(t *testing.T) { + svc := &fakeSvc{hybridResults: []*sdk.SearchResult{{KnowledgeID: "k1", Score: 0.9}}} + c, _ := newTestServer(t, svc) + var out searchChunksOutput + callTool(t, c, "search_chunks", map[string]any{"kb_id": "kb_x", "query": "what is RAG"}, &out) + if len(out.Results) != 1 || out.Results[0].KnowledgeID != "k1" { + t.Errorf("got %+v", out) + } +} + +func TestTool_SearchChunks_LimitCap(t *testing.T) { + // 5 results, limit 3 → 3 returned. + svc := &fakeSvc{} + for i := 0; i < 5; i++ { + svc.hybridResults = append(svc.hybridResults, &sdk.SearchResult{KnowledgeID: "k", Score: float64(i)}) + } + c, _ := newTestServer(t, svc) + var out searchChunksOutput + callTool(t, c, "search_chunks", map[string]any{"kb_id": "kb_x", "query": "x", "limit": 3}, &out) + if len(out.Results) != 3 { + t.Errorf("limit not honored: got %d, want 3", len(out.Results)) + } +} + +func TestTool_Chat_AccumulateAnswerAndReferences(t *testing.T) { + svc := &fakeSvc{ + kbStreamEvents: []*sdk.StreamResponse{ + {Content: "Hello "}, + {Content: "world."}, + {KnowledgeReferences: []*sdk.SearchResult{{KnowledgeID: "k1"}}}, + {ResponseType: sdk.ResponseTypeComplete}, + }, + } + c, _ := newTestServer(t, svc) + var out chatOutput + callTool(t, c, "chat", map[string]any{"kb_id": "kb_x", "query": "ping"}, &out) + if out.Answer != "Hello world." { + t.Errorf("answer = %q", out.Answer) + } + if len(out.References) != 1 || out.References[0].KnowledgeID != "k1" { + t.Errorf("references missing: %+v", out.References) + } + if out.SessionID != "sess_auto" { + t.Errorf("session_id = %q, want sess_auto", out.SessionID) + } +} + +func TestTool_Chat_ExistingSessionSkipsCreate(t *testing.T) { + svc := &fakeSvc{ + kbStreamEvents: []*sdk.StreamResponse{{ResponseType: sdk.ResponseTypeComplete}}, + } + c, _ := newTestServer(t, svc) + callTool(t, c, "chat", map[string]any{"kb_id": "kb_x", "query": "x", "session_id": "sess_existing"}, nil) + if svc.calls.createSessReq != nil { + t.Error("CreateSession should not fire when session_id is supplied") + } + if svc.calls.kbQASess != "sess_existing" { + t.Errorf("session id not forwarded to QA stream: %s", svc.calls.kbQASess) + } +} + +func TestTool_AgentList(t *testing.T) { + svc := &fakeSvc{agents: []sdk.Agent{{ID: "ag1", Name: "Research"}}} + c, _ := newTestServer(t, svc) + var out agentListOutput + callTool(t, c, "agent_list", map[string]any{}, &out) + if len(out.Items) != 1 || out.Items[0].ID != "ag1" { + t.Errorf("got %+v", out) + } +} + +func TestTool_AgentInvoke(t *testing.T) { + svc := &fakeSvc{ + agentEvents: []*sdk.AgentStreamResponse{ + {ResponseType: sdk.AgentResponseTypeAnswer, Content: "result"}, + {ResponseType: sdk.AgentResponseTypeToolCall, ID: "c1", Content: "knowledge_search"}, + {Done: true}, + }, + } + c, _ := newTestServer(t, svc) + var out agentInvokeOutput + callTool(t, c, "agent_invoke", map[string]any{"agent_id": "ag1", "query": "x"}, &out) + if out.Answer != "result" { + t.Errorf("answer = %q", out.Answer) + } + if len(out.ToolEvents) != 1 { + t.Errorf("tool_calls len = %d, want 1", len(out.ToolEvents)) + } + if out.AgentID != "ag1" { + t.Errorf("agent_id = %q", out.AgentID) + } +} + +func TestTool_AgentInvoke_StreamAbort(t *testing.T) { + svc := &fakeSvc{ + agentEvents: []*sdk.AgentStreamResponse{{ResponseType: sdk.AgentResponseTypeAnswer, Content: "partial"}}, + agentStreamErr: errors.New("connection reset"), + } + c, _ := newTestServer(t, svc) + ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) + defer cancel() + res, err := c.CallTool(ctx, &mcpsdk.CallToolParams{Name: "agent_invoke", Arguments: map[string]any{"agent_id": "ag1", "query": "x"}}) + if err != nil { + t.Fatalf("unexpected transport error: %v", err) + } + if !res.IsError { + t.Fatal("expected IsError=true on mid-stream abort") + } +} diff --git a/cli/internal/sse/agent_accumulator.go b/cli/internal/sse/agent_accumulator.go new file mode 100644 index 00000000..318332e1 --- /dev/null +++ b/cli/internal/sse/agent_accumulator.go @@ -0,0 +1,86 @@ +package sse + +import ( + "strings" + + sdk "github.com/Tencent/WeKnora/client" +) + +// AgentToolEvent captures one tool_call / tool_result event from an +// agent SSE stream. Kind is the SDK event type (typed, not bare string) +// so consumers can compare against sdk.AgentResponseTypeToolCall / +// sdk.AgentResponseTypeToolResult without retyping the constants. +// NOTE: Kind is the event kind, NOT the function name; the function +// name typically lives in Data for tool_call events. +type AgentToolEvent struct { + ID string `json:"id"` + Kind sdk.AgentResponseType `json:"kind,omitempty"` + Result string `json:"result,omitempty"` + Data map[string]any `json:"data,omitempty"` +} + +// AgentAccumulator buffers an AgentQAStream callback sequence. Distinct +// from Accumulator (KnowledgeQAStream) because the agent event model is +// wider — events include thinking / reflection / tool_call / tool_result +// / answer / references / error, with a flat `Done bool` on each frame +// rather than the ResponseType=complete sentinel KnowledgeQAStream emits. +// +// Zero value is ready to use. Not safe for concurrent Append calls — the +// SDK callback contract is sequential on a single goroutine. +// +// API mirrors sse.Accumulator: private builders + accessor methods so the +// "Append is idempotent post-Done" invariant cannot be broken by external +// mutation. References and ToolEvents stay exported as plain slices since +// they carry no such invariant. +type AgentAccumulator struct { + answer strings.Builder + thinking strings.Builder + References []*sdk.SearchResult + ToolEvents []AgentToolEvent + done bool +} + +// Answer returns the accumulated `answer`-event content. +func (a *AgentAccumulator) Answer() string { return a.answer.String() } + +// Thinking returns the accumulated `thinking` / `reflection` content +// surfaced by the agent during its reasoning pass. +func (a *AgentAccumulator) Thinking() string { return a.thinking.String() } + +// Done reports whether the stream emitted a terminal frame. +func (a *AgentAccumulator) Done() bool { return a.done } + +// Append consumes one AgentStreamResponse event. Idempotent post-Done so +// callers do not need to special-case late events. +func (a *AgentAccumulator) Append(r *sdk.AgentStreamResponse) { + if r == nil || a.done { + return + } + switch r.ResponseType { + case sdk.AgentResponseTypeAnswer: + if r.Content != "" { + a.answer.WriteString(r.Content) + } + case sdk.AgentResponseTypeThinking, sdk.AgentResponseTypeReflection: + if r.Content != "" { + a.thinking.WriteString(r.Content) + } + case sdk.AgentResponseTypeToolCall, sdk.AgentResponseTypeToolResult: + a.ToolEvents = append(a.ToolEvents, AgentToolEvent{ + ID: r.ID, + Kind: r.ResponseType, + Result: r.Content, + Data: r.Data, + }) + } + // References can arrive on a dedicated `references` event OR + // piggyback on another event's KnowledgeReferences field; always + // capture the latest, matching sse.Accumulator's "always replace" + // semantic for the parallel KnowledgeQAStream contract. + if r.KnowledgeReferences != nil { + a.References = r.KnowledgeReferences + } + if r.Done { + a.done = true + } +}