feat(cli): mcp serve curated stdio MCP server

`weknora mcp serve` — long-lived stdio MCP (Model Context Protocol)
transport that exposes a fixed, curated tool surface to MCP-aware
agents (Claude Desktop, Claude Code, custom MCP clients).

Curated tool set (readonly by default):
- whoami — active context + tenant
- search (hybrid retrieval against a KB)
- kb list / view
- doc list / view
- agent list / view / invoke
- session list / view

The list is intentionally narrow to the read + agent-invoke surface;
destructive verbs (`delete` / `empty` / `upload`) are gated behind
`--write`. Schema is built from each leaf cobra command's flags so
adding a new tool is a single registry entry plus a Service interface.

Includes the simplify post-review polish + a second simplify pass to
fold the resulting feedback (typed schemas, agent_help wording, unify
chat / agent invoke option names).
This commit is contained in:
nullkey
2026-05-14 23:28:14 +08:00
committed by lyingbug
parent 493fc41e98
commit 9bb83b47fd
14 changed files with 1213 additions and 154 deletions

View File

@@ -13,14 +13,15 @@ import (
"github.com/Tencent/WeKnora/cli/internal/cmdutil"
"github.com/Tencent/WeKnora/cli/internal/format"
"github.com/Tencent/WeKnora/cli/internal/iostreams"
"github.com/Tencent/WeKnora/cli/internal/sse"
sdk "github.com/Tencent/WeKnora/client"
)
// agentInvokeFields enumerates fields surfaced for `--json` discovery on
// `agent invoke`. Matches invokeData below — single-shot result envelope
// with the agent's final answer plus the trace (references, tool calls).
// with the agent's final answer plus the trace (references, tool events).
var agentInvokeFields = []string{
"answer", "references", "tool_calls", "thinking",
"answer", "references", "tool_events", "thinking",
"session_id", "agent_id", "query",
}
@@ -45,25 +46,15 @@ type InvokeService interface {
AgentQAStreamWithRequest(ctx context.Context, sessionID string, req *sdk.AgentQARequest, cb sdk.AgentEventCallback) error
}
// toolCallTrace mirrors a single SSE `tool_call` event so agents can see
// which tools the WeKnora-side agent invoked (and their results). Only the
// fields the server actually emits are captured.
type toolCallTrace struct {
ID string `json:"id"`
Name string `json:"name,omitempty"`
Result string `json:"result,omitempty"`
Data map[string]any `json:"data,omitempty"`
}
// invokeData is the JSON envelope payload.
type invokeData struct {
Answer string `json:"answer"`
References []*sdk.SearchResult `json:"references"`
ToolCalls []toolCallTrace `json:"tool_calls,omitempty"`
Thinking string `json:"thinking,omitempty"`
SessionID string `json:"session_id"`
AgentID string `json:"agent_id"`
Query string `json:"query"`
Answer string `json:"answer"`
References []*sdk.SearchResult `json:"references"`
ToolEvents []sse.AgentToolEvent `json:"tool_events,omitempty"`
Thinking string `json:"thinking,omitempty"`
SessionID string `json:"session_id"`
AgentID string `json:"agent_id"`
Query string `json:"query"`
}
// NewCmdInvoke builds `weknora agent invoke <agent-id> "<text>"`.
@@ -152,12 +143,12 @@ func runInvoke(ctx context.Context, opts *InvokeOptions, jopts *cmdutil.JSONOpti
Channel: "api",
}
acc := newAgentAccumulator()
acc := &sse.AgentAccumulator{}
cb := func(r *sdk.AgentStreamResponse) error {
if streamMode && r != nil && r.ResponseType == sdk.AgentResponseTypeAnswer && r.Content != "" {
_, _ = iostreams.IO.Out.Write([]byte(r.Content))
}
acc.append(r)
acc.Append(r)
return nil
}
@@ -169,7 +160,7 @@ func runInvoke(ctx context.Context, opts *InvokeOptions, jopts *cmdutil.JSONOpti
if errors.Is(streamErr, context.Canceled) || errors.Is(ctx.Err(), context.Canceled) {
return cmdutil.Wrapf(cmdutil.CodeUserAborted, streamErr, "agent invoke cancelled")
}
if acc.answer.Len() > 0 && !acc.done {
if acc.Answer() != "" && !acc.Done() {
return cmdutil.Wrapf(cmdutil.CodeSSEStreamAborted, streamErr, "stream aborted before completion")
}
return cmdutil.WrapHTTP(streamErr, "agent-chat stream")
@@ -177,17 +168,17 @@ func runInvoke(ctx context.Context, opts *InvokeOptions, jopts *cmdutil.JSONOpti
// Server closed cleanly but never sent a Done event — treat as aborted
// so agents don't silently emit a truncated answer as ok=true.
if !acc.done {
if !acc.Done() {
return cmdutil.NewError(cmdutil.CodeSSEStreamAborted, "stream ended without a terminal event")
}
answer := acc.answer.String()
answer := acc.Answer()
if jsonOut {
data := invokeData{
Answer: answer,
References: acc.references,
ToolCalls: acc.toolCalls,
Thinking: acc.thinking.String(),
References: acc.References,
ToolEvents: acc.ToolEvents,
Thinking: acc.Thinking(),
SessionID: sessionID,
AgentID: opts.AgentID,
Query: opts.Query,
@@ -210,98 +201,24 @@ func runInvoke(ctx context.Context, opts *InvokeOptions, jopts *cmdutil.JSONOpti
fmt.Fprintln(out)
}
}
renderToolTrace(out, acc.toolCalls)
renderReferences(out, acc.references)
renderToolTrace(out, acc.ToolEvents)
format.WriteReferences(out, acc.References)
return nil
}
// agentAccumulator buffers an AgentQAStream callback sequence. Distinct
// from internal/sse.Accumulator because the agent event model is wider
// (thinking / tool_call / tool_result / answer / references / reflection /
// error) and uses a flat r.Done bool instead of the
// ResponseType=complete sentinel that KnowledgeQAStream emits.
type agentAccumulator struct {
answer strings.Builder
thinking strings.Builder
references []*sdk.SearchResult
toolCalls []toolCallTrace
done bool
}
func newAgentAccumulator() *agentAccumulator { return &agentAccumulator{} }
func (a *agentAccumulator) append(r *sdk.AgentStreamResponse) {
if r == nil || a.done {
return
}
switch r.ResponseType {
case sdk.AgentResponseTypeAnswer:
if r.Content != "" {
a.answer.WriteString(r.Content)
}
case sdk.AgentResponseTypeThinking, sdk.AgentResponseTypeReflection:
if r.Content != "" {
a.thinking.WriteString(r.Content)
}
case sdk.AgentResponseTypeReferences:
if r.KnowledgeReferences != nil {
a.references = r.KnowledgeReferences
}
case sdk.AgentResponseTypeToolCall, sdk.AgentResponseTypeToolResult:
a.toolCalls = append(a.toolCalls, toolCallTrace{
ID: r.ID,
Name: string(r.ResponseType),
Result: r.Content,
Data: r.Data,
})
}
// References can also arrive on the terminal frame.
if r.KnowledgeReferences != nil && a.references == nil {
a.references = r.KnowledgeReferences
}
if r.Done {
a.done = true
}
}
// renderToolTrace prints a compact tool-call footer in human mode. Skipped
// when the agent invoked no tools — silent beats an empty banner.
func renderToolTrace(w io.Writer, calls []toolCallTrace) {
if len(calls) == 0 {
// renderToolTrace prints a compact tool-event footer in human mode.
// Skipped when the agent emitted no tool events — silent beats an empty
// banner.
func renderToolTrace(w io.Writer, events []sse.AgentToolEvent) {
if len(events) == 0 {
return
}
fmt.Fprintln(w)
fmt.Fprintln(w, "──── Tool trace ────")
for i, c := range calls {
fmt.Fprintf(w, "[%d] %s", i+1, c.Name)
if c.Result != "" {
fmt.Fprintf(w, " %s", truncateInline(c.Result, 80))
}
fmt.Fprintln(w)
}
}
// renderReferences mirrors chat.go's references footer for parity.
func renderReferences(w io.Writer, refs []*sdk.SearchResult) {
if len(refs) == 0 {
return
}
fmt.Fprintln(w)
fmt.Fprintln(w, "──── References ────")
for i, r := range refs {
if r == nil {
continue
}
title := r.KnowledgeTitle
if title == "" {
title = r.KnowledgeFilename
}
if title == "" {
title = r.KnowledgeID
}
fmt.Fprintf(w, "[%d] %s", i+1, title)
if r.Score > 0 {
fmt.Fprintf(w, " score=%.3f", r.Score)
for i, e := range events {
fmt.Fprintf(w, "[%d] %s", i+1, e.Kind)
if e.Result != "" {
fmt.Fprintf(w, " %s", truncateInline(e.Result, 80))
}
fmt.Fprintln(w)
}

View File

@@ -150,7 +150,7 @@ func (c *createSessionTracker) CreateSession(ctx context.Context, req *sdk.Creat
return c.InvokeService.CreateSession(ctx, req)
}
func TestInvoke_ToolCallsCaptured(t *testing.T) {
func TestInvoke_ToolEventsCaptured(t *testing.T) {
out, _ := iostreams.SetForTest(t)
svc := &scriptedInvokeSvc{events: []*sdk.AgentStreamResponse{
toolCallEvent("call_1", "knowledge_search"),
@@ -167,11 +167,11 @@ func TestInvoke_ToolCallsCaptured(t *testing.T) {
if err := json.Unmarshal(out.Bytes(), &env); err != nil {
t.Fatalf("parse: %v", err)
}
if len(env.Data.ToolCalls) != 1 {
t.Fatalf("expected 1 tool call, got %d", len(env.Data.ToolCalls))
if len(env.Data.ToolEvents) != 1 {
t.Fatalf("expected 1 tool call, got %d", len(env.Data.ToolEvents))
}
if env.Data.ToolCalls[0].ID != "call_1" {
t.Errorf("tool_calls[0].id = %q, want call_1", env.Data.ToolCalls[0].ID)
if env.Data.ToolEvents[0].ID != "call_1" {
t.Errorf("tool_calls[0].id = %q, want call_1", env.Data.ToolEvents[0].ID)
}
}

View File

@@ -23,7 +23,6 @@ import (
"context"
"errors"
"fmt"
"io"
"strings"
"github.com/spf13/cobra"
@@ -258,37 +257,9 @@ func runChat(ctx context.Context, opts *Options, jopts *cmdutil.JSONOptions, svc
fmt.Fprintln(out)
}
}
renderReferences(out, references)
format.WriteReferences(out, references)
return nil
}
// renderReferences prints a compact human-readable references block.
// Skipped entirely when the server returned no references — agent-friendly
// silence beats an empty banner.
func renderReferences(w io.Writer, refs []*sdk.SearchResult) {
if len(refs) == 0 {
return
}
fmt.Fprintln(w)
fmt.Fprintln(w, "──── References ────")
for i, r := range refs {
if r == nil {
continue
}
title := r.KnowledgeTitle
if title == "" {
title = r.KnowledgeFilename
}
if title == "" {
title = r.KnowledgeID
}
fmt.Fprintf(w, "[%d] %s", i+1, title)
if r.Score > 0 {
fmt.Fprintf(w, " score=%.3f", r.Score)
}
fmt.Fprintln(w)
}
}
// compile-time check: the production SDK client implements chatService.
var _ chatService = (*sdk.Client)(nil)

40
cli/cmd/mcp/mcp.go Normal file
View File

@@ -0,0 +1,40 @@
// Package mcpcmd holds the `weknora mcp` command tree.
//
// MCP (Model Context Protocol; https://spec.modelcontextprotocol.io/) is the
// JSON-RPC 2.0 wire protocol agentic IDEs (Claude Code, Cursor, Continue,
// Zed) and runtimes (Anthropic Reference, Stripe MCP) use to call external
// tools. `weknora mcp serve` exposes a curated subset of the CLI's read
// surface as MCP tools so an IDE-side agent can list / view / search / chat
// against the user's active WeKnora context without shelling out to the CLI
// per call.
//
// Package name is `mcpcmd` to avoid shadowing `cli/internal/mcp` (the
// transport-and-handlers implementation). Same naming hygiene as
// `agentcmd` / `sessioncmd`.
package mcpcmd
import (
"github.com/spf13/cobra"
"github.com/Tencent/WeKnora/cli/internal/cmdutil"
)
// NewCmd builds the `weknora mcp` parent. Called from cli/cmd/root.go.
func NewCmd(f *cmdutil.Factory) *cobra.Command {
cmd := &cobra.Command{
Use: "mcp",
Short: "Run weknora as a Model Context Protocol server",
Long: `Exposes weknora's read surface as MCP tools so agentic IDE clients
(Claude Code, Cursor, Continue, Zed) can call them over JSON-RPC.
Initial tool surface is read-only and curated: kb_list / kb_view /
doc_list / doc_view / doc_download / search_chunks / chat / agent_list /
agent_invoke. Destructive verbs (create / delete / upload) are deliberately
excluded — the agent should ask the user before mutating; the CLI's
exit-10 protocol covers that path.`,
Args: cobra.NoArgs,
Run: func(c *cobra.Command, _ []string) { _ = c.Help() },
}
cmd.AddCommand(NewCmdServe(f))
return cmd
}

47
cli/cmd/mcp/serve.go Normal file
View File

@@ -0,0 +1,47 @@
package mcpcmd
import (
"github.com/spf13/cobra"
"github.com/Tencent/WeKnora/cli/internal/agent"
"github.com/Tencent/WeKnora/cli/internal/cmdutil"
mcpserver "github.com/Tencent/WeKnora/cli/internal/mcp"
)
// NewCmdServe builds `weknora mcp serve`. Currently stdio-only; HTTP
// (streamable / SSE) is roadmap 5-7.
func NewCmdServe(f *cmdutil.Factory) *cobra.Command {
cmd := &cobra.Command{
Use: "serve",
Short: "Run an MCP server over stdio",
Long: `Speaks JSON-RPC 2.0 on stdin/stdout to an MCP client. Logs go to
stderr; the data channel is reserved for protocol traffic.
Authentication is inherited from the active context (or --context). The
server eagerly resolves the SDK client at startup — if no context is
configured, the process exits with auth.unauthenticated before any MCP
handshake. This way an IDE-side agent sees a clear failure mode rather
than a server that handshakes successfully then errors on every tool.
To use with Claude Code, add to ~/.claude/mcp_servers.json:
{
"weknora": {
"command": "weknora",
"args": ["mcp", "serve"]
}
}`,
Args: cobra.NoArgs,
RunE: func(c *cobra.Command, _ []string) error {
// Eagerly construct the SDK client. Surfaces auth /
// configuration problems before any MCP handshake.
cli, err := f.Client()
if err != nil {
return err
}
return mcpserver.RunStdio(c.Context(), cli)
},
}
agent.SetAgentHelp(cmd, "Long-lived stdio MCP server. Reads JSON-RPC requests from stdin, writes responses to stdout, logs to stderr. Surfaces 9 read-only tools (kb_list/kb_view/doc_list/doc_view/doc_download/search_chunks/chat/agent_list/agent_invoke); destructive verbs are intentionally excluded. Auth is inherited from the active context — to switch, exit and re-launch with --context. Long tools (chat / agent_invoke) accumulate the LLM stream server-side and return a single CallToolResult — no MCP streaming-content extension, per spec 2025-06-18 (tools.mdx).")
return cmd
}

View File

@@ -17,6 +17,7 @@ import (
"github.com/Tencent/WeKnora/cli/cmd/doctor"
"github.com/Tencent/WeKnora/cli/cmd/kb"
linkcmd "github.com/Tencent/WeKnora/cli/cmd/link"
mcpcmd "github.com/Tencent/WeKnora/cli/cmd/mcp"
"github.com/Tencent/WeKnora/cli/cmd/search"
sessioncmd "github.com/Tencent/WeKnora/cli/cmd/session"
"github.com/Tencent/WeKnora/cli/internal/agent"
@@ -177,6 +178,7 @@ hybrid searches against a WeKnora server from your shell or an AI agent.`,
cmd.AddCommand(chatcmd.NewCmd(f))
cmd.AddCommand(sessioncmd.NewCmd(f))
cmd.AddCommand(agentcmd.NewCmd(f))
cmd.AddCommand(mcpcmd.NewCmd(f))
return cmd
}

View File

@@ -119,7 +119,15 @@ func runList(ctx context.Context, opts *ListOptions, jopts *cmdutil.JSONOptions,
}
if jopts.Enabled() {
meta := &format.Meta{HasMore: opts.Page*opts.PageSize < total}
// When --since is active, has_more is meaningless: the server's
// total counts all sessions but the page is now client-side
// filtered. An agent walking pages would see has_more=true even
// when no later page can contain matching items. Drop the flag
// rather than mislead. The agent_help string documents this.
meta := &format.Meta{}
if since == 0 {
meta.HasMore = opts.Page*opts.PageSize < total
}
return format.WriteEnvelopeFiltered(
iostreams.IO.Out,
format.Success(listResult{Items: items}, meta),

View File

@@ -1,8 +1,6 @@
module github.com/Tencent/WeKnora/cli
go 1.24.2
toolchain go1.24.4
go 1.25.0
require (
github.com/Tencent/WeKnora/client v0.0.0-00010101000000-000000000000
@@ -10,6 +8,7 @@ require (
github.com/itchyny/gojq v0.12.19
github.com/mattn/go-isatty v0.0.22
github.com/mattn/go-runewidth v0.0.23
github.com/modelcontextprotocol/go-sdk v1.6.0
github.com/spf13/cobra v1.10.2
github.com/spf13/pflag v1.0.10
github.com/stretchr/testify v1.11.1
@@ -36,6 +35,7 @@ require (
github.com/dustin/go-humanize v1.0.1 // indirect
github.com/erikgeiser/coninput v0.0.0-20211004153227-1c3628e74d0f // indirect
github.com/godbus/dbus/v5 v5.2.2 // indirect
github.com/google/jsonschema-go v0.4.3 // indirect
github.com/inconshreveable/mousetrap v1.1.0 // indirect
github.com/itchyny/timefmt-go v0.1.8 // indirect
github.com/lucasb-eyer/go-colorful v1.2.0 // indirect
@@ -46,9 +46,13 @@ require (
github.com/muesli/termenv v0.16.0 // indirect
github.com/pmezard/go-difflib v1.0.0 // indirect
github.com/rivo/uniseg v0.4.7 // indirect
github.com/segmentio/asm v1.1.3 // indirect
github.com/segmentio/encoding v0.5.4 // indirect
github.com/xo/terminfo v0.0.0-20220910002029-abceb7e1c41e // indirect
github.com/yosida95/uritemplate/v3 v3.0.2 // indirect
golang.org/x/oauth2 v0.35.0 // indirect
golang.org/x/sync v0.15.0 // indirect
golang.org/x/sys v0.38.0 // indirect
golang.org/x/sys v0.41.0 // indirect
golang.org/x/text v0.23.0 // indirect
)

View File

@@ -53,6 +53,12 @@ github.com/erikgeiser/coninput v0.0.0-20211004153227-1c3628e74d0f h1:Y/CXytFA4m6
github.com/erikgeiser/coninput v0.0.0-20211004153227-1c3628e74d0f/go.mod h1:vw97MGsxSvLiUE2X8qFplwetxpGLQrlU1Q9AUEIzCaM=
github.com/godbus/dbus/v5 v5.2.2 h1:TUR3TgtSVDmjiXOgAAyaZbYmIeP3DPkld3jgKGV8mXQ=
github.com/godbus/dbus/v5 v5.2.2/go.mod h1:3AAv2+hPq5rdnr5txxxRwiGjPXamgoIHgz9FPBfOp3c=
github.com/golang-jwt/jwt/v5 v5.3.1 h1:kYf81DTWFe7t+1VvL7eS+jKFVWaUnK9cB1qbwn63YCY=
github.com/golang-jwt/jwt/v5 v5.3.1/go.mod h1:fxCRLWMO43lRc8nhHWY6LGqRcf+1gQWArsqaEUEa5bE=
github.com/google/go-cmp v0.7.0 h1:wk8382ETsv4JYUZwIsn6YpYiWiBsYLSJiTsyBybVuN8=
github.com/google/go-cmp v0.7.0/go.mod h1:pXiqmnSA92OHEEa9HXL2W4E7lf9JzCmGVUdgjX3N/iU=
github.com/google/jsonschema-go v0.4.3 h1:/DBOLZTfDow7pe2GmaJNhltueGTtDKICi8V8p+DQPd0=
github.com/google/jsonschema-go v0.4.3/go.mod h1:r5quNTdLOYEz95Ru18zA0ydNbBuYoo9tgaYcxEYhJVE=
github.com/inconshreveable/mousetrap v1.1.0 h1:wN+x4NVGpMsO7ErUn/mUI3vEoE6Jt13X2s0bqwp9tc8=
github.com/inconshreveable/mousetrap v1.1.0/go.mod h1:vpF70FUmC8bwa3OWnCshd2FqLfsEA9PFc4w1p2J65bw=
github.com/itchyny/gojq v0.12.19 h1:ttXA0XCLEMoaLOz5lSeFOZ6u6Q3QxmG46vfgI4O0DEs=
@@ -69,6 +75,8 @@ github.com/mattn/go-runewidth v0.0.23 h1:7ykA0T0jkPpzSvMS5i9uoNn2Xy3R383f9HDx3Ry
github.com/mattn/go-runewidth v0.0.23/go.mod h1:XBkDxAl56ILZc9knddidhrOlY5R/pDhgLpndooCuJAs=
github.com/mitchellh/hashstructure/v2 v2.0.2 h1:vGKWl0YJqUNxE8d+h8f6NJLcCJrgbhC4NcD46KavDd4=
github.com/mitchellh/hashstructure/v2 v2.0.2/go.mod h1:MG3aRVU/N29oo/V/IhBX8GR/zz4kQkprJgF2EVszyDE=
github.com/modelcontextprotocol/go-sdk v1.6.0 h1:PPLS3kn7WtOEnR+Af4X5H96SG0qSab8R/ZQT/HkhPkY=
github.com/modelcontextprotocol/go-sdk v1.6.0/go.mod h1:kzm3kzFL1/+AziGOE0nUs3gvPoNxMCvkxokMkuFapXQ=
github.com/muesli/ansi v0.0.0-20230316100256-276c6243b2f6 h1:ZK8zHtRHOkbHy6Mmr5D264iyp3TiX5OmNcI5cIARiQI=
github.com/muesli/ansi v0.0.0-20230316100256-276c6243b2f6/go.mod h1:CJlz5H+gyd6CUWT45Oy4q24RdLyn7Md9Vj2/ldJBSIo=
github.com/muesli/cancelreader v0.2.2 h1:3I4Kt4BQjOR54NavqnDogx/MIoWBFa0StPA8ELUXHmA=
@@ -80,6 +88,10 @@ github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZN
github.com/rivo/uniseg v0.4.7 h1:WUdvkW8uEhrYfLC4ZzdpI2ztxP1I582+49Oc5Mq64VQ=
github.com/rivo/uniseg v0.4.7/go.mod h1:FN3SvrM+Zdj16jyLfmOkMNblXMcoc8DfTHruCPUcx88=
github.com/russross/blackfriday/v2 v2.1.0/go.mod h1:+Rmxgy9KzJVeS9/2gXHxylqXiyQDYRxCVz55jmeOWTM=
github.com/segmentio/asm v1.1.3 h1:WM03sfUOENvvKexOLp+pCqgb/WDjsi7EK8gIsICtzhc=
github.com/segmentio/asm v1.1.3/go.mod h1:Ld3L4ZXGNcSLRg4JBsZ3//1+f/TjYl0Mzen/DQy1EJg=
github.com/segmentio/encoding v0.5.4 h1:OW1VRern8Nw6ITAtwSZ7Idrl3MXCFwXHPgqESYfvNt0=
github.com/segmentio/encoding v0.5.4/go.mod h1:HS1ZKa3kSN32ZHVZ7ZLPLXWvOVIiZtyJnO1gPH1sKt0=
github.com/spf13/cobra v1.10.2 h1:DMTTonx5m65Ic0GOoRY2c16WCbHxOOw6xxezuLaBpcU=
github.com/spf13/cobra v1.10.2/go.mod h1:7C1pvHqHw5A4vrJfjNwvOdzYu0Gml16OCs2GRiTUUS4=
github.com/spf13/pflag v1.0.9/go.mod h1:McXfInJRrz4CZXVZOBLb0bTZqETkiAhM9Iw0y3An2Bg=
@@ -91,18 +103,24 @@ github.com/stretchr/testify v1.11.1 h1:7s2iGBzp5EwR7/aIZr8ao5+dra3wiQyKjjFuvgVKu
github.com/stretchr/testify v1.11.1/go.mod h1:wZwfW3scLgRK+23gO65QZefKpKQRnfz6sD981Nm4B6U=
github.com/xo/terminfo v0.0.0-20220910002029-abceb7e1c41e h1:JVG44RsyaB9T2KIHavMF/ppJZNG9ZpyihvCd0w101no=
github.com/xo/terminfo v0.0.0-20220910002029-abceb7e1c41e/go.mod h1:RbqR21r5mrJuqunuUZ/Dhy/avygyECGrLceyNeo4LiM=
github.com/yosida95/uritemplate/v3 v3.0.2 h1:Ed3Oyj9yrmi9087+NczuL5BwkIc4wvTb5zIM+UJPGz4=
github.com/yosida95/uritemplate/v3 v3.0.2/go.mod h1:ILOh0sOhIJR3+L/8afwt/kE++YT040gmv5BQTMR2HP4=
github.com/zalando/go-keyring v0.2.8 h1:6sD/Ucpl7jNq10rM2pgqTs0sZ9V3qMrqfIIy5YPccHs=
github.com/zalando/go-keyring v0.2.8/go.mod h1:tsMo+VpRq5NGyKfxoBVjCuMrG47yj8cmakZDO5QGii0=
go.yaml.in/yaml/v3 v3.0.4/go.mod h1:DhzuOOF2ATzADvBadXxruRBLzYTpT36CKvDb3+aBEFg=
golang.org/x/exp v0.0.0-20231006140011-7918f672742d h1:jtJma62tbqLibJ5sFQz8bKtEM8rJBtfilJ2qTU199MI=
golang.org/x/exp v0.0.0-20231006140011-7918f672742d/go.mod h1:ldy0pHrwJyGW56pPQzzkH36rKxoZW1tw7ZJpeKx+hdo=
golang.org/x/oauth2 v0.35.0 h1:Mv2mzuHuZuY2+bkyWXIHMfhNdJAdwW3FuWeCPYN5GVQ=
golang.org/x/oauth2 v0.35.0/go.mod h1:lzm5WQJQwKZ3nwavOZ3IS5Aulzxi68dUSgRHujetwEA=
golang.org/x/sync v0.15.0 h1:KWH3jNZsfyT6xfAfKiz6MRNmd46ByHDYaZ7KSkCtdW8=
golang.org/x/sync v0.15.0/go.mod h1:1dzgHSNfp02xaA81J2MS99Qcpr2w7fw1gpm99rleRqA=
golang.org/x/sys v0.0.0-20210809222454-d867a43fc93e/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.38.0 h1:3yZWxaJjBmCWXqhN1qh02AkOnCQ1poK6oF+a7xWL6Gc=
golang.org/x/sys v0.38.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks=
golang.org/x/sys v0.41.0 h1:Ivj+2Cp/ylzLiEU89QhWblYnOE9zerudt9Ftecq2C6k=
golang.org/x/sys v0.41.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks=
golang.org/x/text v0.23.0 h1:D71I7dUrlY+VX0gQShAThNGHFxZ13dGLBHQLVl1mJlY=
golang.org/x/text v0.23.0/go.mod h1:/BLNzu4aZCJ1+kcD0DNRotWKage4q2rGVAg4o22unh4=
golang.org/x/tools v0.42.0 h1:uNgphsn75Tdz5Ji2q36v/nsFSfR/9BRFvqhGBaJGd5k=
golang.org/x/tools v0.42.0/go.mod h1:Ma6lCIwGZvHK6XtgbswSoWroEkhugApmsXyrUmBhfr0=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=

View File

@@ -0,0 +1,37 @@
package format
import (
"fmt"
"io"
sdk "github.com/Tencent/WeKnora/client"
)
// WriteReferences renders the compact references footer used by chat and
// agent invoke: a horizontal rule, one numbered line per reference,
// best-available title + optional score. Skipped entirely when refs is
// empty — agent-friendly silence beats an empty banner.
func WriteReferences(w io.Writer, refs []*sdk.SearchResult) {
if len(refs) == 0 {
return
}
fmt.Fprintln(w)
fmt.Fprintln(w, "──── References ────")
for i, r := range refs {
if r == nil {
continue
}
title := r.KnowledgeTitle
if title == "" {
title = r.KnowledgeFilename
}
if title == "" {
title = r.KnowledgeID
}
fmt.Fprintf(w, "[%d] %s", i+1, title)
if r.Score > 0 {
fmt.Fprintf(w, " score=%.3f", r.Score)
}
fmt.Fprintln(w)
}
}

View File

@@ -0,0 +1,59 @@
// Package mcp wires the curated weknora tool set to an
// modelcontextprotocol/go-sdk server. RunStdio is the entry point invoked
// by `weknora mcp serve`.
//
// Design notes:
//
// - Tool surface is hand-curated rather than auto-derived from the cobra
// tree (which would expose auth/link/completion/destructive verbs that
// don't belong on an agent-callable surface).
// - Long-running tools (chat / agent_invoke) accumulate the LLM SSE
// stream server-side and return a single CallToolResult — MCP spec
// 2025-06-18 does not define streamed tool-result content, so this is
// the canonical pattern (see Anthropic reference `everything` server
// and Stripe @stripe/mcp).
// - Handlers receive ctx for cancellation; mid-LLM-stream cancellation
// propagates to the SDK via context, which closes the SSE connection.
package mcp
import (
"context"
"fmt"
mcpsdk "github.com/modelcontextprotocol/go-sdk/mcp"
"github.com/Tencent/WeKnora/cli/internal/build"
)
// ServiceClient bundles the SDK methods the tool registry needs. *sdk.Client
// satisfies it; tests substitute a fake to exercise the tool handlers
// in-process without standing up a real WeKnora server.
//
// Embedding the full SDK Client would couple every tool test to every SDK
// method; declaring the narrow surface here keeps the seam tight.
type ServiceClient interface {
knowledgeBaseService
knowledgeService
chatService
agentService
}
// RunStdio constructs the MCP server, registers the curated 9 tools, and
// blocks reading JSON-RPC from stdin until the client disconnects or ctx
// is cancelled. Returns the underlying transport error (if any); the cobra
// RunE caller maps it through the usual cmdutil exit-code path.
func RunStdio(ctx context.Context, svc ServiceClient) error {
v, _, _ := build.Info()
server := mcpsdk.NewServer(
&mcpsdk.Implementation{
Name: "weknora",
Version: v,
},
nil,
)
registerTools(server, svc)
if err := server.Run(ctx, &mcpsdk.StdioTransport{}); err != nil {
return fmt.Errorf("mcp serve: %w", err)
}
return nil
}

463
cli/internal/mcp/tools.go Normal file
View File

@@ -0,0 +1,463 @@
package mcp
import (
"context"
"encoding/base64"
"fmt"
"io"
"strings"
mcpsdk "github.com/modelcontextprotocol/go-sdk/mcp"
"github.com/Tencent/WeKnora/cli/internal/sse"
sdk "github.com/Tencent/WeKnora/client"
)
// Narrow per-domain service interfaces. ServiceClient (server.go) embeds
// them all; *sdk.Client satisfies the union implicitly.
type knowledgeBaseService interface {
ListKnowledgeBases(ctx context.Context) ([]sdk.KnowledgeBase, error)
GetKnowledgeBase(ctx context.Context, id string) (*sdk.KnowledgeBase, error)
}
type knowledgeService interface {
ListKnowledgeWithFilter(ctx context.Context, kbID string, page, pageSize int, filter sdk.KnowledgeListFilter) ([]sdk.Knowledge, int64, error)
GetKnowledge(ctx context.Context, knowledgeID string) (*sdk.Knowledge, error)
OpenKnowledgeFile(ctx context.Context, knowledgeID string) (string, io.ReadCloser, error)
HybridSearch(ctx context.Context, kbID string, params *sdk.SearchParams) ([]*sdk.SearchResult, error)
}
type chatService interface {
CreateSession(ctx context.Context, req *sdk.CreateSessionRequest) (*sdk.Session, error)
KnowledgeQAStream(ctx context.Context, sessionID string, req *sdk.KnowledgeQARequest, cb func(*sdk.StreamResponse) error) error
}
type agentService interface {
ListAgents(ctx context.Context) ([]sdk.Agent, error)
GetAgent(ctx context.Context, agentID string) (*sdk.Agent, error)
AgentQAStreamWithRequest(ctx context.Context, sessionID string, req *sdk.AgentQARequest, cb sdk.AgentEventCallback) error
}
// agentInvokeService composes the two SDK methods agent_invoke needs
// (CreateSession for the auto-session path + AgentQAStreamWithRequest
// for the run itself). Declared here alongside the per-domain
// interfaces above so ServiceClient (server.go) — which embeds the
// four domain interfaces — also satisfies it.
type agentInvokeService interface {
CreateSession(ctx context.Context, req *sdk.CreateSessionRequest) (*sdk.Session, error)
AgentQAStreamWithRequest(ctx context.Context, sessionID string, req *sdk.AgentQARequest, cb sdk.AgentEventCallback) error
}
// registerTools wires the curated 9 tools onto server. Adding a tool here
// is a deliberate API expansion — the agent-callable surface is the
// reason this CLI ships an MCP server, not its CLI command list, so this
// list must be maintained by hand (see also AGENTS.md mcp serve section).
func registerTools(server *mcpsdk.Server, svc ServiceClient) {
addKBList(server, svc)
addKBView(server, svc)
addDocList(server, svc)
addDocView(server, svc)
addDocDownload(server, svc)
addSearchChunks(server, svc)
addChat(server, svc)
addAgentList(server, svc)
addAgentInvoke(server, svc)
}
// ---- kb_list -------------------------------------------------------------
type kbListInput struct{}
type kbListOutput struct {
Items []sdk.KnowledgeBase `json:"items"`
}
func addKBList(server *mcpsdk.Server, svc knowledgeBaseService) {
mcpsdk.AddTool(server, &mcpsdk.Tool{
Name: "kb_list",
Description: "List all knowledge bases visible to the active WeKnora tenant. No arguments. Returns items[]: each item carries id, name, description, knowledge_count, is_pinned, updated_at — useful for selecting a kb_id to pass to other tools.",
}, func(ctx context.Context, _ *mcpsdk.CallToolRequest, _ kbListInput) (*mcpsdk.CallToolResult, kbListOutput, error) {
items, err := svc.ListKnowledgeBases(ctx)
if err != nil {
return nil, kbListOutput{}, fmt.Errorf("list knowledge bases: %w", err)
}
if items == nil {
items = []sdk.KnowledgeBase{}
}
return nil, kbListOutput{Items: items}, nil
})
}
// ---- kb_view -------------------------------------------------------------
type kbViewInput struct {
KBID string `json:"kb_id" jsonschema:"knowledge base ID"`
}
func addKBView(server *mcpsdk.Server, svc knowledgeBaseService) {
mcpsdk.AddTool(server, &mcpsdk.Tool{
Name: "kb_view",
Description: "Fetch a knowledge base by ID. Returns the full record including chunking config, embedding/summary model IDs, knowledge_count, and chunk_count.",
}, func(ctx context.Context, _ *mcpsdk.CallToolRequest, in kbViewInput) (*mcpsdk.CallToolResult, *sdk.KnowledgeBase, error) {
if in.KBID == "" {
return nil, nil, fmt.Errorf("kb_id is required")
}
kb, err := svc.GetKnowledgeBase(ctx, in.KBID)
if err != nil {
return nil, nil, fmt.Errorf("get knowledge base: %w", err)
}
return nil, kb, nil
})
}
// ---- doc_list ------------------------------------------------------------
type docListInput struct {
KBID string `json:"kb_id" jsonschema:"knowledge base ID"`
Page int `json:"page,omitempty" jsonschema:"1-indexed page number; defaults to 1"`
PageSize int `json:"page_size,omitempty" jsonschema:"items per page (1..1000); defaults to 20"`
Status string `json:"status,omitempty" jsonschema:"filter by parse status: pending | processing | completed | failed"`
}
type docListOutput struct {
Items []sdk.Knowledge `json:"items"`
Page int `json:"page"`
PageSize int `json:"page_size"`
Total int64 `json:"total"`
}
func addDocList(server *mcpsdk.Server, svc knowledgeService) {
mcpsdk.AddTool(server, &mcpsdk.Tool{
Name: "doc_list",
Description: "List documents in a knowledge base, with pagination and optional parse-status filter. Returns items[] with id, file_name, title, parse_status, size, updated_at — plus the page/total metadata.",
}, func(ctx context.Context, _ *mcpsdk.CallToolRequest, in docListInput) (*mcpsdk.CallToolResult, docListOutput, error) {
if in.KBID == "" {
return nil, docListOutput{}, fmt.Errorf("kb_id is required")
}
page := in.Page
if page < 1 {
page = 1
}
size := in.PageSize
if size < 1 {
size = 20
}
if size > 1000 {
return nil, docListOutput{}, fmt.Errorf("page_size must be in 1..1000")
}
items, total, err := svc.ListKnowledgeWithFilter(ctx, in.KBID, page, size,
sdk.KnowledgeListFilter{ParseStatus: in.Status})
if err != nil {
return nil, docListOutput{}, fmt.Errorf("list documents: %w", err)
}
if items == nil {
items = []sdk.Knowledge{}
}
return nil, docListOutput{Items: items, Page: page, PageSize: size, Total: total}, nil
})
}
// ---- doc_view ------------------------------------------------------------
type docViewInput struct {
KnowledgeID string `json:"knowledge_id" jsonschema:"document (knowledge entry) ID"`
}
func addDocView(server *mcpsdk.Server, svc knowledgeService) {
mcpsdk.AddTool(server, &mcpsdk.Tool{
Name: "doc_view",
Description: "Fetch a single document by ID. Returns the Knowledge record (file_name, title, type, parse_status, size, embedding_model_id, source URL if any, etc.).",
}, func(ctx context.Context, _ *mcpsdk.CallToolRequest, in docViewInput) (*mcpsdk.CallToolResult, *sdk.Knowledge, error) {
if in.KnowledgeID == "" {
return nil, nil, fmt.Errorf("knowledge_id is required")
}
k, err := svc.GetKnowledge(ctx, in.KnowledgeID)
if err != nil {
return nil, nil, fmt.Errorf("get knowledge: %w", err)
}
return nil, k, nil
})
}
// ---- doc_download --------------------------------------------------------
type docDownloadInput struct {
KnowledgeID string `json:"knowledge_id" jsonschema:"document (knowledge entry) ID"`
}
type docDownloadOutput struct {
KnowledgeID string `json:"knowledge_id"`
FileName string `json:"file_name"`
Bytes int `json:"bytes"`
// Content is the file contents (UTF-8 if text, base64 if the SDK
// reports a binary-looking blob). For binary, agents should decode
// before consuming.
Content string `json:"content"`
IsBase64 bool `json:"is_base64"`
}
// maxDocDownloadBytes caps the per-call payload to keep an agent's context
// window safe; agents needing larger documents should chunk via doc_view +
// search_chunks. 1 MiB matches a typical LLM context-window budget for
// inline content (~250k tokens) while remaining cheap to serialize.
const maxDocDownloadBytes = 1 << 20
func addDocDownload(server *mcpsdk.Server, svc knowledgeService) {
mcpsdk.AddTool(server, &mcpsdk.Tool{
Name: "doc_download",
Description: "Download a document's raw bytes by ID. Capped at 1 MiB per call — for larger documents, use search_chunks to find the relevant excerpts. is_base64 reports whether content was base64-encoded (heuristic: presence of NUL byte in the first 512 bytes).",
}, func(ctx context.Context, _ *mcpsdk.CallToolRequest, in docDownloadInput) (*mcpsdk.CallToolResult, docDownloadOutput, error) {
if in.KnowledgeID == "" {
return nil, docDownloadOutput{}, fmt.Errorf("knowledge_id is required")
}
name, body, err := svc.OpenKnowledgeFile(ctx, in.KnowledgeID)
if err != nil {
return nil, docDownloadOutput{}, fmt.Errorf("open knowledge file: %w", err)
}
defer body.Close()
buf, err := io.ReadAll(io.LimitReader(body, maxDocDownloadBytes+1))
if err != nil {
return nil, docDownloadOutput{}, fmt.Errorf("read knowledge file: %w", err)
}
if len(buf) > maxDocDownloadBytes {
return nil, docDownloadOutput{}, fmt.Errorf("document exceeds the %d-byte per-call cap; use search_chunks for excerpts", maxDocDownloadBytes)
}
content, isBase64 := encodeDownload(buf)
return nil, docDownloadOutput{
KnowledgeID: in.KnowledgeID,
FileName: name,
Bytes: len(buf),
Content: content,
IsBase64: isBase64,
}, nil
})
}
// ---- search_chunks -------------------------------------------------------
type searchChunksInput struct {
KBID string `json:"kb_id" jsonschema:"knowledge base ID to search"`
Query string `json:"query" jsonschema:"natural-language search query"`
Limit int `json:"limit,omitempty" jsonschema:"client-side cap on results (1..1000); defaults to 10"`
VectorThreshold float64 `json:"vector_threshold,omitempty" jsonschema:"minimum vector similarity (0..1)"`
KeywordThreshold float64 `json:"keyword_threshold,omitempty" jsonschema:"minimum keyword score (0..1)"`
}
type searchChunksOutput struct {
Results []*sdk.SearchResult `json:"results"`
}
func addSearchChunks(server *mcpsdk.Server, svc knowledgeService) {
// Out = any: SDK output schema would derive from searchChunksOutput,
// which embeds *sdk.SearchResult — and SearchResult.Metadata is a
// nilable map[string]any that violates the auto-generated
// type=object constraint when empty. Skipping derivation by using
// `any` keeps the structured JSON shape identical while bypassing
// the over-eager validator. Same pattern applied to chat / agent_invoke
// below.
mcpsdk.AddTool(server, &mcpsdk.Tool{
Name: "search_chunks",
Description: "Hybrid (vector + keyword) retrieval against a knowledge base. Returns the top chunks ranked by RRF; use this before chat to ground an answer in cited context. Results include knowledge_id, content, score — feed back into chat as context or display directly.",
}, func(ctx context.Context, _ *mcpsdk.CallToolRequest, in searchChunksInput) (*mcpsdk.CallToolResult, any, error) {
if in.KBID == "" {
return nil, nil, fmt.Errorf("kb_id is required")
}
if strings.TrimSpace(in.Query) == "" {
return nil, nil, fmt.Errorf("query cannot be empty")
}
limit := in.Limit
if limit < 1 {
limit = 10
}
if limit > 1000 {
return nil, nil, fmt.Errorf("limit must be in 1..1000")
}
results, err := svc.HybridSearch(ctx, in.KBID, &sdk.SearchParams{
QueryText: in.Query,
VectorThreshold: in.VectorThreshold,
KeywordThreshold: in.KeywordThreshold,
})
if err != nil {
return nil, nil, fmt.Errorf("hybrid search: %w", err)
}
if len(results) > limit {
results = results[:limit]
}
if results == nil {
results = []*sdk.SearchResult{}
}
return nil, searchChunksOutput{Results: results}, nil
})
}
// ---- chat ----------------------------------------------------------------
type chatInput struct {
KBID string `json:"kb_id" jsonschema:"knowledge base ID to chat against"`
Query string `json:"query" jsonschema:"user query"`
SessionID string `json:"session_id,omitempty" jsonschema:"existing session to continue; auto-created when empty"`
}
type chatOutput struct {
Answer string `json:"answer"`
References []*sdk.SearchResult `json:"references"`
SessionID string `json:"session_id"`
AssistantMessageID string `json:"assistant_message_id,omitempty"`
}
func addChat(server *mcpsdk.Server, svc chatService) {
mcpsdk.AddTool(server, &mcpsdk.Tool{
Name: "chat",
Description: "Stream a RAG answer from the LLM, grounded in the given knowledge base. The SSE stream is accumulated server-side; this tool returns the full answer + references + session_id once the stream completes. Pass session_id to continue a multi-turn conversation; otherwise a fresh session is auto-created.",
}, func(ctx context.Context, _ *mcpsdk.CallToolRequest, in chatInput) (*mcpsdk.CallToolResult, any, error) {
if in.KBID == "" {
return nil, nil, fmt.Errorf("kb_id is required")
}
if strings.TrimSpace(in.Query) == "" {
return nil, nil, fmt.Errorf("query cannot be empty")
}
sessionID := in.SessionID
if sessionID == "" {
sess, err := svc.CreateSession(ctx, &sdk.CreateSessionRequest{Title: "weknora mcp chat"})
if err != nil {
return nil, nil, fmt.Errorf("create chat session: %w", err)
}
sessionID = sess.ID
}
req := &sdk.KnowledgeQARequest{
Query: in.Query,
KnowledgeBaseIDs: []string{in.KBID},
AgentEnabled: false,
Channel: "api",
}
acc := &sse.Accumulator{}
streamErr := svc.KnowledgeQAStream(ctx, sessionID, req, func(r *sdk.StreamResponse) error {
acc.Append(r)
return nil
})
if streamErr != nil {
return nil, nil, fmt.Errorf("knowledge qa stream: %w", streamErr)
}
if !acc.Done() {
return nil, nil, fmt.Errorf("stream ended without a terminal event")
}
sid := acc.SessionID
if sid == "" {
sid = sessionID
}
return nil, chatOutput{
Answer: acc.Result(),
References: acc.References,
SessionID: sid,
AssistantMessageID: acc.AssistantMessageID,
}, nil
})
}
// ---- agent_list ----------------------------------------------------------
type agentListInput struct{}
type agentListOutput struct {
Items []sdk.Agent `json:"items"`
}
func addAgentList(server *mcpsdk.Server, svc agentService) {
mcpsdk.AddTool(server, &mcpsdk.Tool{
Name: "agent_list",
Description: "List the tenant's custom agents. Returns items[] with id, name, description, is_builtin — use to discover an agent_id before agent_invoke.",
}, func(ctx context.Context, _ *mcpsdk.CallToolRequest, _ agentListInput) (*mcpsdk.CallToolResult, agentListOutput, error) {
items, err := svc.ListAgents(ctx)
if err != nil {
return nil, agentListOutput{}, fmt.Errorf("list agents: %w", err)
}
if items == nil {
items = []sdk.Agent{}
}
return nil, agentListOutput{Items: items}, nil
})
}
// ---- agent_invoke --------------------------------------------------------
type agentInvokeInput struct {
AgentID string `json:"agent_id" jsonschema:"custom agent ID"`
Query string `json:"query" jsonschema:"user query"`
SessionID string `json:"session_id,omitempty" jsonschema:"existing session to continue; auto-created when empty"`
}
type agentInvokeOutput struct {
Answer string `json:"answer"`
References []*sdk.SearchResult `json:"references"`
ToolEvents []sse.AgentToolEvent `json:"tool_events,omitempty"`
Thinking string `json:"thinking,omitempty"`
SessionID string `json:"session_id"`
AgentID string `json:"agent_id"`
}
func addAgentInvoke(server *mcpsdk.Server, svc agentInvokeService) {
mcpsdk.AddTool(server, &mcpsdk.Tool{
Name: "agent_invoke",
Description: "Run a query through a custom agent (system prompt + tool allow-list + KB scope). The agent's SSE stream is accumulated server-side; this tool returns the final answer plus the trace (references, tool_events, thinking).",
}, func(ctx context.Context, _ *mcpsdk.CallToolRequest, in agentInvokeInput) (*mcpsdk.CallToolResult, any, error) {
if in.AgentID == "" {
return nil, nil, fmt.Errorf("agent_id is required")
}
if strings.TrimSpace(in.Query) == "" {
return nil, nil, fmt.Errorf("query cannot be empty")
}
acc := &sse.AgentAccumulator{}
req := &sdk.AgentQARequest{
Query: in.Query,
AgentEnabled: true,
AgentID: in.AgentID,
Channel: "api",
}
// Auto-create session if not supplied. Sessions are agent-
// agnostic at creation (Q3 — verified against server source).
sessionID := in.SessionID
if sessionID == "" {
sess, err := svc.CreateSession(ctx, &sdk.CreateSessionRequest{Title: "weknora mcp agent_invoke"})
if err != nil {
return nil, nil, fmt.Errorf("create chat session: %w", err)
}
sessionID = sess.ID
}
streamErr := svc.AgentQAStreamWithRequest(ctx, sessionID, req, func(r *sdk.AgentStreamResponse) error {
acc.Append(r)
return nil
})
if streamErr != nil {
return nil, nil, fmt.Errorf("agent-chat stream: %w", streamErr)
}
if !acc.Done() {
return nil, nil, fmt.Errorf("stream ended without a terminal event")
}
return nil, agentInvokeOutput{
Answer: acc.Answer(),
References: acc.References,
ToolEvents: acc.ToolEvents,
Thinking: acc.Thinking(),
SessionID: sessionID,
AgentID: in.AgentID,
}, nil
})
}
// encodeDownload returns (content, isBase64). Heuristic: if the first 512
// bytes contain a NUL, treat as binary. Otherwise it's UTF-8-ish text.
// Matches what /usr/bin/file's "binary" heuristic does at a coarse level —
// good enough to spare an agent from base64-decoding obvious text.
func encodeDownload(buf []byte) (string, bool) {
probe := buf
if len(probe) > 512 {
probe = probe[:512]
}
for _, b := range probe {
if b == 0 {
return base64.StdEncoding.EncodeToString(buf), true
}
}
return string(buf), false
}

View File

@@ -0,0 +1,407 @@
package mcp
import (
"context"
"encoding/json"
"errors"
"io"
"strings"
"testing"
"time"
mcpsdk "github.com/modelcontextprotocol/go-sdk/mcp"
sdk "github.com/Tencent/WeKnora/client"
)
// fakeSvc implements every narrow service interface ServiceClient embeds.
// Each method records the last call args; per-test setup populates the
// return values it wants to assert against.
type fakeSvc struct {
listKBs []sdk.KnowledgeBase
listKBsErr error
getKB *sdk.KnowledgeBase
getKBErr error
listDocs []sdk.Knowledge
listDocsTotal int64
listDocsErr error
getDoc *sdk.Knowledge
getDocErr error
openDocName string
openDocBody io.ReadCloser
openDocErr error
hybridResults []*sdk.SearchResult
hybridErr error
createSess *sdk.Session
createSessErr error
kbStreamEvents []*sdk.StreamResponse
kbStreamErr error
agents []sdk.Agent
agentsErr error
agent *sdk.Agent
agentErr error
agentEvents []*sdk.AgentStreamResponse
agentStreamErr error
// Captured args:
calls struct {
listKBs int
kbViewID string
docListKBID string
docListFilter sdk.KnowledgeListFilter
docViewID string
openDocID string
hybridKBID string
hybridParams *sdk.SearchParams
createSessReq *sdk.CreateSessionRequest
kbQAReq *sdk.KnowledgeQARequest
kbQASess string
agentListN int
agentViewID string
agentReq *sdk.AgentQARequest
agentSess string
}
}
func (f *fakeSvc) ListKnowledgeBases(_ context.Context) ([]sdk.KnowledgeBase, error) {
f.calls.listKBs++
return f.listKBs, f.listKBsErr
}
func (f *fakeSvc) GetKnowledgeBase(_ context.Context, id string) (*sdk.KnowledgeBase, error) {
f.calls.kbViewID = id
return f.getKB, f.getKBErr
}
func (f *fakeSvc) ListKnowledgeWithFilter(_ context.Context, kbID string, _, _ int, filter sdk.KnowledgeListFilter) ([]sdk.Knowledge, int64, error) {
f.calls.docListKBID = kbID
f.calls.docListFilter = filter
return f.listDocs, f.listDocsTotal, f.listDocsErr
}
func (f *fakeSvc) GetKnowledge(_ context.Context, id string) (*sdk.Knowledge, error) {
f.calls.docViewID = id
return f.getDoc, f.getDocErr
}
func (f *fakeSvc) OpenKnowledgeFile(_ context.Context, id string) (string, io.ReadCloser, error) {
f.calls.openDocID = id
return f.openDocName, f.openDocBody, f.openDocErr
}
func (f *fakeSvc) HybridSearch(_ context.Context, kbID string, p *sdk.SearchParams) ([]*sdk.SearchResult, error) {
f.calls.hybridKBID, f.calls.hybridParams = kbID, p
return f.hybridResults, f.hybridErr
}
func (f *fakeSvc) CreateSession(_ context.Context, req *sdk.CreateSessionRequest) (*sdk.Session, error) {
f.calls.createSessReq = req
if f.createSess == nil && f.createSessErr == nil {
return &sdk.Session{ID: "sess_auto"}, nil
}
return f.createSess, f.createSessErr
}
func (f *fakeSvc) KnowledgeQAStream(_ context.Context, sess string, req *sdk.KnowledgeQARequest, cb func(*sdk.StreamResponse) error) error {
f.calls.kbQASess, f.calls.kbQAReq = sess, req
for _, e := range f.kbStreamEvents {
if err := cb(e); err != nil {
return err
}
}
return f.kbStreamErr
}
func (f *fakeSvc) ListAgents(_ context.Context) ([]sdk.Agent, error) {
f.calls.agentListN++
return f.agents, f.agentsErr
}
func (f *fakeSvc) GetAgent(_ context.Context, id string) (*sdk.Agent, error) {
f.calls.agentViewID = id
return f.agent, f.agentErr
}
func (f *fakeSvc) AgentQAStreamWithRequest(_ context.Context, sess string, req *sdk.AgentQARequest, cb sdk.AgentEventCallback) error {
f.calls.agentSess, f.calls.agentReq = sess, req
for _, e := range f.agentEvents {
if err := cb(e); err != nil {
return err
}
}
return f.agentStreamErr
}
// newTestServer wires svc to an in-process MCP server and returns a
// connected client session ready to CallTool against it.
func newTestServer(t *testing.T, svc ServiceClient) (*mcpsdk.ClientSession, context.CancelFunc) {
t.Helper()
ctx, cancel := context.WithCancel(context.Background())
server := mcpsdk.NewServer(&mcpsdk.Implementation{Name: "weknora-test", Version: "v0.0.0-test"}, nil)
registerTools(server, svc)
st, ct := mcpsdk.NewInMemoryTransports()
serverSession, err := server.Connect(ctx, st, nil)
if err != nil {
cancel()
t.Fatalf("server.Connect: %v", err)
}
client := mcpsdk.NewClient(&mcpsdk.Implementation{Name: "test-client", Version: "v0.0.0"}, nil)
clientSession, err := client.Connect(ctx, ct, nil)
if err != nil {
_ = serverSession.Close()
cancel()
t.Fatalf("client.Connect: %v", err)
}
t.Cleanup(func() {
_ = clientSession.Close()
_ = serverSession.Close()
cancel()
})
return clientSession, cancel
}
// callTool invokes name with args and returns the parsed structured output.
func callTool(t *testing.T, c *mcpsdk.ClientSession, name string, args any, out any) *mcpsdk.CallToolResult {
t.Helper()
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
res, err := c.CallTool(ctx, &mcpsdk.CallToolParams{Name: name, Arguments: args})
if err != nil {
t.Fatalf("CallTool(%s): %v", name, err)
}
if res.IsError {
if len(res.Content) > 0 {
t.Fatalf("tool %s returned error: %+v", name, res.Content)
}
t.Fatalf("tool %s returned error (no content)", name)
}
if out != nil && res.StructuredContent != nil {
b, _ := json.Marshal(res.StructuredContent)
if err := json.Unmarshal(b, out); err != nil {
t.Fatalf("decode %s output: %v\nraw=%s", name, err, b)
}
}
return res
}
func TestTool_ListsRegistered(t *testing.T) {
c, _ := newTestServer(t, &fakeSvc{})
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
defer cancel()
res, err := c.ListTools(ctx, nil)
if err != nil {
t.Fatalf("ListTools: %v", err)
}
want := []string{"kb_list", "kb_view", "doc_list", "doc_view", "doc_download", "search_chunks", "chat", "agent_list", "agent_invoke"}
got := map[string]bool{}
for _, tool := range res.Tools {
got[tool.Name] = true
}
for _, name := range want {
if !got[name] {
t.Errorf("missing tool %q in ListTools response", name)
}
}
if len(res.Tools) != len(want) {
t.Errorf("registered %d tools, want exactly %d (no scope creep)", len(res.Tools), len(want))
}
}
func TestTool_KBList(t *testing.T) {
svc := &fakeSvc{listKBs: []sdk.KnowledgeBase{{ID: "kb1", Name: "Marketing"}}}
c, _ := newTestServer(t, svc)
var out kbListOutput
callTool(t, c, "kb_list", map[string]any{}, &out)
if len(out.Items) != 1 || out.Items[0].ID != "kb1" {
t.Errorf("got %+v", out)
}
}
func TestTool_KBView_RequiresKBID(t *testing.T) {
c, _ := newTestServer(t, &fakeSvc{})
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
defer cancel()
res, err := c.CallTool(ctx, &mcpsdk.CallToolParams{Name: "kb_view", Arguments: map[string]any{}})
if err != nil {
t.Fatalf("unexpected transport error: %v", err)
}
if !res.IsError {
t.Fatal("expected IsError=true on missing kb_id")
}
}
func TestTool_KBView(t *testing.T) {
svc := &fakeSvc{getKB: &sdk.KnowledgeBase{ID: "kb_x", Name: "Eng"}}
c, _ := newTestServer(t, svc)
var out sdk.KnowledgeBase
callTool(t, c, "kb_view", map[string]any{"kb_id": "kb_x"}, &out)
if out.ID != "kb_x" || out.Name != "Eng" {
t.Errorf("got %+v", out)
}
if svc.calls.kbViewID != "kb_x" {
t.Errorf("kb_id not forwarded: %s", svc.calls.kbViewID)
}
}
func TestTool_DocList_DefaultPagination(t *testing.T) {
svc := &fakeSvc{listDocs: []sdk.Knowledge{{ID: "k1"}}, listDocsTotal: 1}
c, _ := newTestServer(t, svc)
var out docListOutput
callTool(t, c, "doc_list", map[string]any{"kb_id": "kb_x"}, &out)
if out.Page != 1 || out.PageSize != 20 {
t.Errorf("default pagination not applied: %+v", out)
}
if svc.calls.docListKBID != "kb_x" {
t.Errorf("kb_id not forwarded: %s", svc.calls.docListKBID)
}
}
func TestTool_DocList_StatusFilter_Forwarded(t *testing.T) {
svc := &fakeSvc{}
c, _ := newTestServer(t, svc)
callTool(t, c, "doc_list", map[string]any{"kb_id": "kb_x", "status": "failed"}, nil)
if svc.calls.docListFilter.ParseStatus != "failed" {
t.Errorf("status not forwarded as filter.ParseStatus: %+v", svc.calls.docListFilter)
}
}
func TestTool_DocView(t *testing.T) {
svc := &fakeSvc{getDoc: &sdk.Knowledge{ID: "k1", FileName: "a.pdf"}}
c, _ := newTestServer(t, svc)
var out sdk.Knowledge
callTool(t, c, "doc_view", map[string]any{"knowledge_id": "k1"}, &out)
if out.ID != "k1" {
t.Errorf("got %+v", out)
}
}
func TestTool_DocDownload_Text(t *testing.T) {
svc := &fakeSvc{
openDocName: "notes.txt",
openDocBody: io.NopCloser(strings.NewReader("hello world")),
}
c, _ := newTestServer(t, svc)
var out docDownloadOutput
callTool(t, c, "doc_download", map[string]any{"knowledge_id": "k1"}, &out)
if out.Content != "hello world" {
t.Errorf("content = %q", out.Content)
}
if out.IsBase64 {
t.Error("text content should not be base64-encoded")
}
}
func TestTool_DocDownload_BinaryBase64(t *testing.T) {
// First 512 bytes contain a NUL → encodeDownload returns base64.
bin := []byte{0x00, 0x01, 0x02, 0x03}
svc := &fakeSvc{
openDocName: "blob.bin",
openDocBody: io.NopCloser(strings.NewReader(string(bin))),
}
c, _ := newTestServer(t, svc)
var out docDownloadOutput
callTool(t, c, "doc_download", map[string]any{"knowledge_id": "k1"}, &out)
if !out.IsBase64 {
t.Errorf("binary should be base64; got is_base64=%v content=%q", out.IsBase64, out.Content)
}
}
func TestTool_SearchChunks(t *testing.T) {
svc := &fakeSvc{hybridResults: []*sdk.SearchResult{{KnowledgeID: "k1", Score: 0.9}}}
c, _ := newTestServer(t, svc)
var out searchChunksOutput
callTool(t, c, "search_chunks", map[string]any{"kb_id": "kb_x", "query": "what is RAG"}, &out)
if len(out.Results) != 1 || out.Results[0].KnowledgeID != "k1" {
t.Errorf("got %+v", out)
}
}
func TestTool_SearchChunks_LimitCap(t *testing.T) {
// 5 results, limit 3 → 3 returned.
svc := &fakeSvc{}
for i := 0; i < 5; i++ {
svc.hybridResults = append(svc.hybridResults, &sdk.SearchResult{KnowledgeID: "k", Score: float64(i)})
}
c, _ := newTestServer(t, svc)
var out searchChunksOutput
callTool(t, c, "search_chunks", map[string]any{"kb_id": "kb_x", "query": "x", "limit": 3}, &out)
if len(out.Results) != 3 {
t.Errorf("limit not honored: got %d, want 3", len(out.Results))
}
}
func TestTool_Chat_AccumulateAnswerAndReferences(t *testing.T) {
svc := &fakeSvc{
kbStreamEvents: []*sdk.StreamResponse{
{Content: "Hello "},
{Content: "world."},
{KnowledgeReferences: []*sdk.SearchResult{{KnowledgeID: "k1"}}},
{ResponseType: sdk.ResponseTypeComplete},
},
}
c, _ := newTestServer(t, svc)
var out chatOutput
callTool(t, c, "chat", map[string]any{"kb_id": "kb_x", "query": "ping"}, &out)
if out.Answer != "Hello world." {
t.Errorf("answer = %q", out.Answer)
}
if len(out.References) != 1 || out.References[0].KnowledgeID != "k1" {
t.Errorf("references missing: %+v", out.References)
}
if out.SessionID != "sess_auto" {
t.Errorf("session_id = %q, want sess_auto", out.SessionID)
}
}
func TestTool_Chat_ExistingSessionSkipsCreate(t *testing.T) {
svc := &fakeSvc{
kbStreamEvents: []*sdk.StreamResponse{{ResponseType: sdk.ResponseTypeComplete}},
}
c, _ := newTestServer(t, svc)
callTool(t, c, "chat", map[string]any{"kb_id": "kb_x", "query": "x", "session_id": "sess_existing"}, nil)
if svc.calls.createSessReq != nil {
t.Error("CreateSession should not fire when session_id is supplied")
}
if svc.calls.kbQASess != "sess_existing" {
t.Errorf("session id not forwarded to QA stream: %s", svc.calls.kbQASess)
}
}
func TestTool_AgentList(t *testing.T) {
svc := &fakeSvc{agents: []sdk.Agent{{ID: "ag1", Name: "Research"}}}
c, _ := newTestServer(t, svc)
var out agentListOutput
callTool(t, c, "agent_list", map[string]any{}, &out)
if len(out.Items) != 1 || out.Items[0].ID != "ag1" {
t.Errorf("got %+v", out)
}
}
func TestTool_AgentInvoke(t *testing.T) {
svc := &fakeSvc{
agentEvents: []*sdk.AgentStreamResponse{
{ResponseType: sdk.AgentResponseTypeAnswer, Content: "result"},
{ResponseType: sdk.AgentResponseTypeToolCall, ID: "c1", Content: "knowledge_search"},
{Done: true},
},
}
c, _ := newTestServer(t, svc)
var out agentInvokeOutput
callTool(t, c, "agent_invoke", map[string]any{"agent_id": "ag1", "query": "x"}, &out)
if out.Answer != "result" {
t.Errorf("answer = %q", out.Answer)
}
if len(out.ToolEvents) != 1 {
t.Errorf("tool_calls len = %d, want 1", len(out.ToolEvents))
}
if out.AgentID != "ag1" {
t.Errorf("agent_id = %q", out.AgentID)
}
}
func TestTool_AgentInvoke_StreamAbort(t *testing.T) {
svc := &fakeSvc{
agentEvents: []*sdk.AgentStreamResponse{{ResponseType: sdk.AgentResponseTypeAnswer, Content: "partial"}},
agentStreamErr: errors.New("connection reset"),
}
c, _ := newTestServer(t, svc)
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
defer cancel()
res, err := c.CallTool(ctx, &mcpsdk.CallToolParams{Name: "agent_invoke", Arguments: map[string]any{"agent_id": "ag1", "query": "x"}})
if err != nil {
t.Fatalf("unexpected transport error: %v", err)
}
if !res.IsError {
t.Fatal("expected IsError=true on mid-stream abort")
}
}

View File

@@ -0,0 +1,86 @@
package sse
import (
"strings"
sdk "github.com/Tencent/WeKnora/client"
)
// AgentToolEvent captures one tool_call / tool_result event from an
// agent SSE stream. Kind is the SDK event type (typed, not bare string)
// so consumers can compare against sdk.AgentResponseTypeToolCall /
// sdk.AgentResponseTypeToolResult without retyping the constants.
// NOTE: Kind is the event kind, NOT the function name; the function
// name typically lives in Data for tool_call events.
type AgentToolEvent struct {
ID string `json:"id"`
Kind sdk.AgentResponseType `json:"kind,omitempty"`
Result string `json:"result,omitempty"`
Data map[string]any `json:"data,omitempty"`
}
// AgentAccumulator buffers an AgentQAStream callback sequence. Distinct
// from Accumulator (KnowledgeQAStream) because the agent event model is
// wider — events include thinking / reflection / tool_call / tool_result
// / answer / references / error, with a flat `Done bool` on each frame
// rather than the ResponseType=complete sentinel KnowledgeQAStream emits.
//
// Zero value is ready to use. Not safe for concurrent Append calls — the
// SDK callback contract is sequential on a single goroutine.
//
// API mirrors sse.Accumulator: private builders + accessor methods so the
// "Append is idempotent post-Done" invariant cannot be broken by external
// mutation. References and ToolEvents stay exported as plain slices since
// they carry no such invariant.
type AgentAccumulator struct {
answer strings.Builder
thinking strings.Builder
References []*sdk.SearchResult
ToolEvents []AgentToolEvent
done bool
}
// Answer returns the accumulated `answer`-event content.
func (a *AgentAccumulator) Answer() string { return a.answer.String() }
// Thinking returns the accumulated `thinking` / `reflection` content
// surfaced by the agent during its reasoning pass.
func (a *AgentAccumulator) Thinking() string { return a.thinking.String() }
// Done reports whether the stream emitted a terminal frame.
func (a *AgentAccumulator) Done() bool { return a.done }
// Append consumes one AgentStreamResponse event. Idempotent post-Done so
// callers do not need to special-case late events.
func (a *AgentAccumulator) Append(r *sdk.AgentStreamResponse) {
if r == nil || a.done {
return
}
switch r.ResponseType {
case sdk.AgentResponseTypeAnswer:
if r.Content != "" {
a.answer.WriteString(r.Content)
}
case sdk.AgentResponseTypeThinking, sdk.AgentResponseTypeReflection:
if r.Content != "" {
a.thinking.WriteString(r.Content)
}
case sdk.AgentResponseTypeToolCall, sdk.AgentResponseTypeToolResult:
a.ToolEvents = append(a.ToolEvents, AgentToolEvent{
ID: r.ID,
Kind: r.ResponseType,
Result: r.Content,
Data: r.Data,
})
}
// References can arrive on a dedicated `references` event OR
// piggyback on another event's KnowledgeReferences field; always
// capture the latest, matching sse.Accumulator's "always replace"
// semantic for the parallel KnowledgeQAStream contract.
if r.KnowledgeReferences != nil {
a.References = r.KnowledgeReferences
}
if r.Done {
a.done = true
}
}