mirror of
https://github.com/Tencent/WeKnora.git
synced 2026-06-04 13:30:32 +08:00
feat(cli): mcp serve curated stdio MCP server
`weknora mcp serve` — long-lived stdio MCP (Model Context Protocol) transport that exposes a fixed, curated tool surface to MCP-aware agents (Claude Desktop, Claude Code, custom MCP clients). Curated tool set (readonly by default): - whoami — active context + tenant - search (hybrid retrieval against a KB) - kb list / view - doc list / view - agent list / view / invoke - session list / view The list is intentionally narrow to the read + agent-invoke surface; destructive verbs (`delete` / `empty` / `upload`) are gated behind `--write`. Schema is built from each leaf cobra command's flags so adding a new tool is a single registry entry plus a Service interface. Includes the simplify post-review polish + a second simplify pass to fold the resulting feedback (typed schemas, agent_help wording, unify chat / agent invoke option names).
This commit is contained in:
@@ -13,14 +13,15 @@ import (
|
||||
"github.com/Tencent/WeKnora/cli/internal/cmdutil"
|
||||
"github.com/Tencent/WeKnora/cli/internal/format"
|
||||
"github.com/Tencent/WeKnora/cli/internal/iostreams"
|
||||
"github.com/Tencent/WeKnora/cli/internal/sse"
|
||||
sdk "github.com/Tencent/WeKnora/client"
|
||||
)
|
||||
|
||||
// agentInvokeFields enumerates fields surfaced for `--json` discovery on
|
||||
// `agent invoke`. Matches invokeData below — single-shot result envelope
|
||||
// with the agent's final answer plus the trace (references, tool calls).
|
||||
// with the agent's final answer plus the trace (references, tool events).
|
||||
var agentInvokeFields = []string{
|
||||
"answer", "references", "tool_calls", "thinking",
|
||||
"answer", "references", "tool_events", "thinking",
|
||||
"session_id", "agent_id", "query",
|
||||
}
|
||||
|
||||
@@ -45,25 +46,15 @@ type InvokeService interface {
|
||||
AgentQAStreamWithRequest(ctx context.Context, sessionID string, req *sdk.AgentQARequest, cb sdk.AgentEventCallback) error
|
||||
}
|
||||
|
||||
// toolCallTrace mirrors a single SSE `tool_call` event so agents can see
|
||||
// which tools the WeKnora-side agent invoked (and their results). Only the
|
||||
// fields the server actually emits are captured.
|
||||
type toolCallTrace struct {
|
||||
ID string `json:"id"`
|
||||
Name string `json:"name,omitempty"`
|
||||
Result string `json:"result,omitempty"`
|
||||
Data map[string]any `json:"data,omitempty"`
|
||||
}
|
||||
|
||||
// invokeData is the JSON envelope payload.
|
||||
type invokeData struct {
|
||||
Answer string `json:"answer"`
|
||||
References []*sdk.SearchResult `json:"references"`
|
||||
ToolCalls []toolCallTrace `json:"tool_calls,omitempty"`
|
||||
Thinking string `json:"thinking,omitempty"`
|
||||
SessionID string `json:"session_id"`
|
||||
AgentID string `json:"agent_id"`
|
||||
Query string `json:"query"`
|
||||
Answer string `json:"answer"`
|
||||
References []*sdk.SearchResult `json:"references"`
|
||||
ToolEvents []sse.AgentToolEvent `json:"tool_events,omitempty"`
|
||||
Thinking string `json:"thinking,omitempty"`
|
||||
SessionID string `json:"session_id"`
|
||||
AgentID string `json:"agent_id"`
|
||||
Query string `json:"query"`
|
||||
}
|
||||
|
||||
// NewCmdInvoke builds `weknora agent invoke <agent-id> "<text>"`.
|
||||
@@ -152,12 +143,12 @@ func runInvoke(ctx context.Context, opts *InvokeOptions, jopts *cmdutil.JSONOpti
|
||||
Channel: "api",
|
||||
}
|
||||
|
||||
acc := newAgentAccumulator()
|
||||
acc := &sse.AgentAccumulator{}
|
||||
cb := func(r *sdk.AgentStreamResponse) error {
|
||||
if streamMode && r != nil && r.ResponseType == sdk.AgentResponseTypeAnswer && r.Content != "" {
|
||||
_, _ = iostreams.IO.Out.Write([]byte(r.Content))
|
||||
}
|
||||
acc.append(r)
|
||||
acc.Append(r)
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -169,7 +160,7 @@ func runInvoke(ctx context.Context, opts *InvokeOptions, jopts *cmdutil.JSONOpti
|
||||
if errors.Is(streamErr, context.Canceled) || errors.Is(ctx.Err(), context.Canceled) {
|
||||
return cmdutil.Wrapf(cmdutil.CodeUserAborted, streamErr, "agent invoke cancelled")
|
||||
}
|
||||
if acc.answer.Len() > 0 && !acc.done {
|
||||
if acc.Answer() != "" && !acc.Done() {
|
||||
return cmdutil.Wrapf(cmdutil.CodeSSEStreamAborted, streamErr, "stream aborted before completion")
|
||||
}
|
||||
return cmdutil.WrapHTTP(streamErr, "agent-chat stream")
|
||||
@@ -177,17 +168,17 @@ func runInvoke(ctx context.Context, opts *InvokeOptions, jopts *cmdutil.JSONOpti
|
||||
|
||||
// Server closed cleanly but never sent a Done event — treat as aborted
|
||||
// so agents don't silently emit a truncated answer as ok=true.
|
||||
if !acc.done {
|
||||
if !acc.Done() {
|
||||
return cmdutil.NewError(cmdutil.CodeSSEStreamAborted, "stream ended without a terminal event")
|
||||
}
|
||||
|
||||
answer := acc.answer.String()
|
||||
answer := acc.Answer()
|
||||
if jsonOut {
|
||||
data := invokeData{
|
||||
Answer: answer,
|
||||
References: acc.references,
|
||||
ToolCalls: acc.toolCalls,
|
||||
Thinking: acc.thinking.String(),
|
||||
References: acc.References,
|
||||
ToolEvents: acc.ToolEvents,
|
||||
Thinking: acc.Thinking(),
|
||||
SessionID: sessionID,
|
||||
AgentID: opts.AgentID,
|
||||
Query: opts.Query,
|
||||
@@ -210,98 +201,24 @@ func runInvoke(ctx context.Context, opts *InvokeOptions, jopts *cmdutil.JSONOpti
|
||||
fmt.Fprintln(out)
|
||||
}
|
||||
}
|
||||
renderToolTrace(out, acc.toolCalls)
|
||||
renderReferences(out, acc.references)
|
||||
renderToolTrace(out, acc.ToolEvents)
|
||||
format.WriteReferences(out, acc.References)
|
||||
return nil
|
||||
}
|
||||
|
||||
// agentAccumulator buffers an AgentQAStream callback sequence. Distinct
|
||||
// from internal/sse.Accumulator because the agent event model is wider
|
||||
// (thinking / tool_call / tool_result / answer / references / reflection /
|
||||
// error) and uses a flat r.Done bool instead of the
|
||||
// ResponseType=complete sentinel that KnowledgeQAStream emits.
|
||||
type agentAccumulator struct {
|
||||
answer strings.Builder
|
||||
thinking strings.Builder
|
||||
references []*sdk.SearchResult
|
||||
toolCalls []toolCallTrace
|
||||
done bool
|
||||
}
|
||||
|
||||
func newAgentAccumulator() *agentAccumulator { return &agentAccumulator{} }
|
||||
|
||||
func (a *agentAccumulator) append(r *sdk.AgentStreamResponse) {
|
||||
if r == nil || a.done {
|
||||
return
|
||||
}
|
||||
switch r.ResponseType {
|
||||
case sdk.AgentResponseTypeAnswer:
|
||||
if r.Content != "" {
|
||||
a.answer.WriteString(r.Content)
|
||||
}
|
||||
case sdk.AgentResponseTypeThinking, sdk.AgentResponseTypeReflection:
|
||||
if r.Content != "" {
|
||||
a.thinking.WriteString(r.Content)
|
||||
}
|
||||
case sdk.AgentResponseTypeReferences:
|
||||
if r.KnowledgeReferences != nil {
|
||||
a.references = r.KnowledgeReferences
|
||||
}
|
||||
case sdk.AgentResponseTypeToolCall, sdk.AgentResponseTypeToolResult:
|
||||
a.toolCalls = append(a.toolCalls, toolCallTrace{
|
||||
ID: r.ID,
|
||||
Name: string(r.ResponseType),
|
||||
Result: r.Content,
|
||||
Data: r.Data,
|
||||
})
|
||||
}
|
||||
// References can also arrive on the terminal frame.
|
||||
if r.KnowledgeReferences != nil && a.references == nil {
|
||||
a.references = r.KnowledgeReferences
|
||||
}
|
||||
if r.Done {
|
||||
a.done = true
|
||||
}
|
||||
}
|
||||
|
||||
// renderToolTrace prints a compact tool-call footer in human mode. Skipped
|
||||
// when the agent invoked no tools — silent beats an empty banner.
|
||||
func renderToolTrace(w io.Writer, calls []toolCallTrace) {
|
||||
if len(calls) == 0 {
|
||||
// renderToolTrace prints a compact tool-event footer in human mode.
|
||||
// Skipped when the agent emitted no tool events — silent beats an empty
|
||||
// banner.
|
||||
func renderToolTrace(w io.Writer, events []sse.AgentToolEvent) {
|
||||
if len(events) == 0 {
|
||||
return
|
||||
}
|
||||
fmt.Fprintln(w)
|
||||
fmt.Fprintln(w, "──── Tool trace ────")
|
||||
for i, c := range calls {
|
||||
fmt.Fprintf(w, "[%d] %s", i+1, c.Name)
|
||||
if c.Result != "" {
|
||||
fmt.Fprintf(w, " %s", truncateInline(c.Result, 80))
|
||||
}
|
||||
fmt.Fprintln(w)
|
||||
}
|
||||
}
|
||||
|
||||
// renderReferences mirrors chat.go's references footer for parity.
|
||||
func renderReferences(w io.Writer, refs []*sdk.SearchResult) {
|
||||
if len(refs) == 0 {
|
||||
return
|
||||
}
|
||||
fmt.Fprintln(w)
|
||||
fmt.Fprintln(w, "──── References ────")
|
||||
for i, r := range refs {
|
||||
if r == nil {
|
||||
continue
|
||||
}
|
||||
title := r.KnowledgeTitle
|
||||
if title == "" {
|
||||
title = r.KnowledgeFilename
|
||||
}
|
||||
if title == "" {
|
||||
title = r.KnowledgeID
|
||||
}
|
||||
fmt.Fprintf(w, "[%d] %s", i+1, title)
|
||||
if r.Score > 0 {
|
||||
fmt.Fprintf(w, " score=%.3f", r.Score)
|
||||
for i, e := range events {
|
||||
fmt.Fprintf(w, "[%d] %s", i+1, e.Kind)
|
||||
if e.Result != "" {
|
||||
fmt.Fprintf(w, " %s", truncateInline(e.Result, 80))
|
||||
}
|
||||
fmt.Fprintln(w)
|
||||
}
|
||||
|
||||
@@ -150,7 +150,7 @@ func (c *createSessionTracker) CreateSession(ctx context.Context, req *sdk.Creat
|
||||
return c.InvokeService.CreateSession(ctx, req)
|
||||
}
|
||||
|
||||
func TestInvoke_ToolCallsCaptured(t *testing.T) {
|
||||
func TestInvoke_ToolEventsCaptured(t *testing.T) {
|
||||
out, _ := iostreams.SetForTest(t)
|
||||
svc := &scriptedInvokeSvc{events: []*sdk.AgentStreamResponse{
|
||||
toolCallEvent("call_1", "knowledge_search"),
|
||||
@@ -167,11 +167,11 @@ func TestInvoke_ToolCallsCaptured(t *testing.T) {
|
||||
if err := json.Unmarshal(out.Bytes(), &env); err != nil {
|
||||
t.Fatalf("parse: %v", err)
|
||||
}
|
||||
if len(env.Data.ToolCalls) != 1 {
|
||||
t.Fatalf("expected 1 tool call, got %d", len(env.Data.ToolCalls))
|
||||
if len(env.Data.ToolEvents) != 1 {
|
||||
t.Fatalf("expected 1 tool call, got %d", len(env.Data.ToolEvents))
|
||||
}
|
||||
if env.Data.ToolCalls[0].ID != "call_1" {
|
||||
t.Errorf("tool_calls[0].id = %q, want call_1", env.Data.ToolCalls[0].ID)
|
||||
if env.Data.ToolEvents[0].ID != "call_1" {
|
||||
t.Errorf("tool_calls[0].id = %q, want call_1", env.Data.ToolEvents[0].ID)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -23,7 +23,6 @@ import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"strings"
|
||||
|
||||
"github.com/spf13/cobra"
|
||||
@@ -258,37 +257,9 @@ func runChat(ctx context.Context, opts *Options, jopts *cmdutil.JSONOptions, svc
|
||||
fmt.Fprintln(out)
|
||||
}
|
||||
}
|
||||
renderReferences(out, references)
|
||||
format.WriteReferences(out, references)
|
||||
return nil
|
||||
}
|
||||
|
||||
// renderReferences prints a compact human-readable references block.
|
||||
// Skipped entirely when the server returned no references — agent-friendly
|
||||
// silence beats an empty banner.
|
||||
func renderReferences(w io.Writer, refs []*sdk.SearchResult) {
|
||||
if len(refs) == 0 {
|
||||
return
|
||||
}
|
||||
fmt.Fprintln(w)
|
||||
fmt.Fprintln(w, "──── References ────")
|
||||
for i, r := range refs {
|
||||
if r == nil {
|
||||
continue
|
||||
}
|
||||
title := r.KnowledgeTitle
|
||||
if title == "" {
|
||||
title = r.KnowledgeFilename
|
||||
}
|
||||
if title == "" {
|
||||
title = r.KnowledgeID
|
||||
}
|
||||
fmt.Fprintf(w, "[%d] %s", i+1, title)
|
||||
if r.Score > 0 {
|
||||
fmt.Fprintf(w, " score=%.3f", r.Score)
|
||||
}
|
||||
fmt.Fprintln(w)
|
||||
}
|
||||
}
|
||||
|
||||
// compile-time check: the production SDK client implements chatService.
|
||||
var _ chatService = (*sdk.Client)(nil)
|
||||
|
||||
40
cli/cmd/mcp/mcp.go
Normal file
40
cli/cmd/mcp/mcp.go
Normal file
@@ -0,0 +1,40 @@
|
||||
// Package mcpcmd holds the `weknora mcp` command tree.
|
||||
//
|
||||
// MCP (Model Context Protocol; https://spec.modelcontextprotocol.io/) is the
|
||||
// JSON-RPC 2.0 wire protocol agentic IDEs (Claude Code, Cursor, Continue,
|
||||
// Zed) and runtimes (Anthropic Reference, Stripe MCP) use to call external
|
||||
// tools. `weknora mcp serve` exposes a curated subset of the CLI's read
|
||||
// surface as MCP tools so an IDE-side agent can list / view / search / chat
|
||||
// against the user's active WeKnora context without shelling out to the CLI
|
||||
// per call.
|
||||
//
|
||||
// Package name is `mcpcmd` to avoid shadowing `cli/internal/mcp` (the
|
||||
// transport-and-handlers implementation). Same naming hygiene as
|
||||
// `agentcmd` / `sessioncmd`.
|
||||
package mcpcmd
|
||||
|
||||
import (
|
||||
"github.com/spf13/cobra"
|
||||
|
||||
"github.com/Tencent/WeKnora/cli/internal/cmdutil"
|
||||
)
|
||||
|
||||
// NewCmd builds the `weknora mcp` parent. Called from cli/cmd/root.go.
|
||||
func NewCmd(f *cmdutil.Factory) *cobra.Command {
|
||||
cmd := &cobra.Command{
|
||||
Use: "mcp",
|
||||
Short: "Run weknora as a Model Context Protocol server",
|
||||
Long: `Exposes weknora's read surface as MCP tools so agentic IDE clients
|
||||
(Claude Code, Cursor, Continue, Zed) can call them over JSON-RPC.
|
||||
|
||||
Initial tool surface is read-only and curated: kb_list / kb_view /
|
||||
doc_list / doc_view / doc_download / search_chunks / chat / agent_list /
|
||||
agent_invoke. Destructive verbs (create / delete / upload) are deliberately
|
||||
excluded — the agent should ask the user before mutating; the CLI's
|
||||
exit-10 protocol covers that path.`,
|
||||
Args: cobra.NoArgs,
|
||||
Run: func(c *cobra.Command, _ []string) { _ = c.Help() },
|
||||
}
|
||||
cmd.AddCommand(NewCmdServe(f))
|
||||
return cmd
|
||||
}
|
||||
47
cli/cmd/mcp/serve.go
Normal file
47
cli/cmd/mcp/serve.go
Normal file
@@ -0,0 +1,47 @@
|
||||
package mcpcmd
|
||||
|
||||
import (
|
||||
"github.com/spf13/cobra"
|
||||
|
||||
"github.com/Tencent/WeKnora/cli/internal/agent"
|
||||
"github.com/Tencent/WeKnora/cli/internal/cmdutil"
|
||||
mcpserver "github.com/Tencent/WeKnora/cli/internal/mcp"
|
||||
)
|
||||
|
||||
// NewCmdServe builds `weknora mcp serve`. Currently stdio-only; HTTP
|
||||
// (streamable / SSE) is roadmap 5-7.
|
||||
func NewCmdServe(f *cmdutil.Factory) *cobra.Command {
|
||||
cmd := &cobra.Command{
|
||||
Use: "serve",
|
||||
Short: "Run an MCP server over stdio",
|
||||
Long: `Speaks JSON-RPC 2.0 on stdin/stdout to an MCP client. Logs go to
|
||||
stderr; the data channel is reserved for protocol traffic.
|
||||
|
||||
Authentication is inherited from the active context (or --context). The
|
||||
server eagerly resolves the SDK client at startup — if no context is
|
||||
configured, the process exits with auth.unauthenticated before any MCP
|
||||
handshake. This way an IDE-side agent sees a clear failure mode rather
|
||||
than a server that handshakes successfully then errors on every tool.
|
||||
|
||||
To use with Claude Code, add to ~/.claude/mcp_servers.json:
|
||||
|
||||
{
|
||||
"weknora": {
|
||||
"command": "weknora",
|
||||
"args": ["mcp", "serve"]
|
||||
}
|
||||
}`,
|
||||
Args: cobra.NoArgs,
|
||||
RunE: func(c *cobra.Command, _ []string) error {
|
||||
// Eagerly construct the SDK client. Surfaces auth /
|
||||
// configuration problems before any MCP handshake.
|
||||
cli, err := f.Client()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return mcpserver.RunStdio(c.Context(), cli)
|
||||
},
|
||||
}
|
||||
agent.SetAgentHelp(cmd, "Long-lived stdio MCP server. Reads JSON-RPC requests from stdin, writes responses to stdout, logs to stderr. Surfaces 9 read-only tools (kb_list/kb_view/doc_list/doc_view/doc_download/search_chunks/chat/agent_list/agent_invoke); destructive verbs are intentionally excluded. Auth is inherited from the active context — to switch, exit and re-launch with --context. Long tools (chat / agent_invoke) accumulate the LLM stream server-side and return a single CallToolResult — no MCP streaming-content extension, per spec 2025-06-18 (tools.mdx).")
|
||||
return cmd
|
||||
}
|
||||
@@ -17,6 +17,7 @@ import (
|
||||
"github.com/Tencent/WeKnora/cli/cmd/doctor"
|
||||
"github.com/Tencent/WeKnora/cli/cmd/kb"
|
||||
linkcmd "github.com/Tencent/WeKnora/cli/cmd/link"
|
||||
mcpcmd "github.com/Tencent/WeKnora/cli/cmd/mcp"
|
||||
"github.com/Tencent/WeKnora/cli/cmd/search"
|
||||
sessioncmd "github.com/Tencent/WeKnora/cli/cmd/session"
|
||||
"github.com/Tencent/WeKnora/cli/internal/agent"
|
||||
@@ -177,6 +178,7 @@ hybrid searches against a WeKnora server from your shell or an AI agent.`,
|
||||
cmd.AddCommand(chatcmd.NewCmd(f))
|
||||
cmd.AddCommand(sessioncmd.NewCmd(f))
|
||||
cmd.AddCommand(agentcmd.NewCmd(f))
|
||||
cmd.AddCommand(mcpcmd.NewCmd(f))
|
||||
return cmd
|
||||
}
|
||||
|
||||
|
||||
@@ -119,7 +119,15 @@ func runList(ctx context.Context, opts *ListOptions, jopts *cmdutil.JSONOptions,
|
||||
}
|
||||
|
||||
if jopts.Enabled() {
|
||||
meta := &format.Meta{HasMore: opts.Page*opts.PageSize < total}
|
||||
// When --since is active, has_more is meaningless: the server's
|
||||
// total counts all sessions but the page is now client-side
|
||||
// filtered. An agent walking pages would see has_more=true even
|
||||
// when no later page can contain matching items. Drop the flag
|
||||
// rather than mislead. The agent_help string documents this.
|
||||
meta := &format.Meta{}
|
||||
if since == 0 {
|
||||
meta.HasMore = opts.Page*opts.PageSize < total
|
||||
}
|
||||
return format.WriteEnvelopeFiltered(
|
||||
iostreams.IO.Out,
|
||||
format.Success(listResult{Items: items}, meta),
|
||||
|
||||
12
cli/go.mod
12
cli/go.mod
@@ -1,8 +1,6 @@
|
||||
module github.com/Tencent/WeKnora/cli
|
||||
|
||||
go 1.24.2
|
||||
|
||||
toolchain go1.24.4
|
||||
go 1.25.0
|
||||
|
||||
require (
|
||||
github.com/Tencent/WeKnora/client v0.0.0-00010101000000-000000000000
|
||||
@@ -10,6 +8,7 @@ require (
|
||||
github.com/itchyny/gojq v0.12.19
|
||||
github.com/mattn/go-isatty v0.0.22
|
||||
github.com/mattn/go-runewidth v0.0.23
|
||||
github.com/modelcontextprotocol/go-sdk v1.6.0
|
||||
github.com/spf13/cobra v1.10.2
|
||||
github.com/spf13/pflag v1.0.10
|
||||
github.com/stretchr/testify v1.11.1
|
||||
@@ -36,6 +35,7 @@ require (
|
||||
github.com/dustin/go-humanize v1.0.1 // indirect
|
||||
github.com/erikgeiser/coninput v0.0.0-20211004153227-1c3628e74d0f // indirect
|
||||
github.com/godbus/dbus/v5 v5.2.2 // indirect
|
||||
github.com/google/jsonschema-go v0.4.3 // indirect
|
||||
github.com/inconshreveable/mousetrap v1.1.0 // indirect
|
||||
github.com/itchyny/timefmt-go v0.1.8 // indirect
|
||||
github.com/lucasb-eyer/go-colorful v1.2.0 // indirect
|
||||
@@ -46,9 +46,13 @@ require (
|
||||
github.com/muesli/termenv v0.16.0 // indirect
|
||||
github.com/pmezard/go-difflib v1.0.0 // indirect
|
||||
github.com/rivo/uniseg v0.4.7 // indirect
|
||||
github.com/segmentio/asm v1.1.3 // indirect
|
||||
github.com/segmentio/encoding v0.5.4 // indirect
|
||||
github.com/xo/terminfo v0.0.0-20220910002029-abceb7e1c41e // indirect
|
||||
github.com/yosida95/uritemplate/v3 v3.0.2 // indirect
|
||||
golang.org/x/oauth2 v0.35.0 // indirect
|
||||
golang.org/x/sync v0.15.0 // indirect
|
||||
golang.org/x/sys v0.38.0 // indirect
|
||||
golang.org/x/sys v0.41.0 // indirect
|
||||
golang.org/x/text v0.23.0 // indirect
|
||||
)
|
||||
|
||||
|
||||
22
cli/go.sum
22
cli/go.sum
@@ -53,6 +53,12 @@ github.com/erikgeiser/coninput v0.0.0-20211004153227-1c3628e74d0f h1:Y/CXytFA4m6
|
||||
github.com/erikgeiser/coninput v0.0.0-20211004153227-1c3628e74d0f/go.mod h1:vw97MGsxSvLiUE2X8qFplwetxpGLQrlU1Q9AUEIzCaM=
|
||||
github.com/godbus/dbus/v5 v5.2.2 h1:TUR3TgtSVDmjiXOgAAyaZbYmIeP3DPkld3jgKGV8mXQ=
|
||||
github.com/godbus/dbus/v5 v5.2.2/go.mod h1:3AAv2+hPq5rdnr5txxxRwiGjPXamgoIHgz9FPBfOp3c=
|
||||
github.com/golang-jwt/jwt/v5 v5.3.1 h1:kYf81DTWFe7t+1VvL7eS+jKFVWaUnK9cB1qbwn63YCY=
|
||||
github.com/golang-jwt/jwt/v5 v5.3.1/go.mod h1:fxCRLWMO43lRc8nhHWY6LGqRcf+1gQWArsqaEUEa5bE=
|
||||
github.com/google/go-cmp v0.7.0 h1:wk8382ETsv4JYUZwIsn6YpYiWiBsYLSJiTsyBybVuN8=
|
||||
github.com/google/go-cmp v0.7.0/go.mod h1:pXiqmnSA92OHEEa9HXL2W4E7lf9JzCmGVUdgjX3N/iU=
|
||||
github.com/google/jsonschema-go v0.4.3 h1:/DBOLZTfDow7pe2GmaJNhltueGTtDKICi8V8p+DQPd0=
|
||||
github.com/google/jsonschema-go v0.4.3/go.mod h1:r5quNTdLOYEz95Ru18zA0ydNbBuYoo9tgaYcxEYhJVE=
|
||||
github.com/inconshreveable/mousetrap v1.1.0 h1:wN+x4NVGpMsO7ErUn/mUI3vEoE6Jt13X2s0bqwp9tc8=
|
||||
github.com/inconshreveable/mousetrap v1.1.0/go.mod h1:vpF70FUmC8bwa3OWnCshd2FqLfsEA9PFc4w1p2J65bw=
|
||||
github.com/itchyny/gojq v0.12.19 h1:ttXA0XCLEMoaLOz5lSeFOZ6u6Q3QxmG46vfgI4O0DEs=
|
||||
@@ -69,6 +75,8 @@ github.com/mattn/go-runewidth v0.0.23 h1:7ykA0T0jkPpzSvMS5i9uoNn2Xy3R383f9HDx3Ry
|
||||
github.com/mattn/go-runewidth v0.0.23/go.mod h1:XBkDxAl56ILZc9knddidhrOlY5R/pDhgLpndooCuJAs=
|
||||
github.com/mitchellh/hashstructure/v2 v2.0.2 h1:vGKWl0YJqUNxE8d+h8f6NJLcCJrgbhC4NcD46KavDd4=
|
||||
github.com/mitchellh/hashstructure/v2 v2.0.2/go.mod h1:MG3aRVU/N29oo/V/IhBX8GR/zz4kQkprJgF2EVszyDE=
|
||||
github.com/modelcontextprotocol/go-sdk v1.6.0 h1:PPLS3kn7WtOEnR+Af4X5H96SG0qSab8R/ZQT/HkhPkY=
|
||||
github.com/modelcontextprotocol/go-sdk v1.6.0/go.mod h1:kzm3kzFL1/+AziGOE0nUs3gvPoNxMCvkxokMkuFapXQ=
|
||||
github.com/muesli/ansi v0.0.0-20230316100256-276c6243b2f6 h1:ZK8zHtRHOkbHy6Mmr5D264iyp3TiX5OmNcI5cIARiQI=
|
||||
github.com/muesli/ansi v0.0.0-20230316100256-276c6243b2f6/go.mod h1:CJlz5H+gyd6CUWT45Oy4q24RdLyn7Md9Vj2/ldJBSIo=
|
||||
github.com/muesli/cancelreader v0.2.2 h1:3I4Kt4BQjOR54NavqnDogx/MIoWBFa0StPA8ELUXHmA=
|
||||
@@ -80,6 +88,10 @@ github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZN
|
||||
github.com/rivo/uniseg v0.4.7 h1:WUdvkW8uEhrYfLC4ZzdpI2ztxP1I582+49Oc5Mq64VQ=
|
||||
github.com/rivo/uniseg v0.4.7/go.mod h1:FN3SvrM+Zdj16jyLfmOkMNblXMcoc8DfTHruCPUcx88=
|
||||
github.com/russross/blackfriday/v2 v2.1.0/go.mod h1:+Rmxgy9KzJVeS9/2gXHxylqXiyQDYRxCVz55jmeOWTM=
|
||||
github.com/segmentio/asm v1.1.3 h1:WM03sfUOENvvKexOLp+pCqgb/WDjsi7EK8gIsICtzhc=
|
||||
github.com/segmentio/asm v1.1.3/go.mod h1:Ld3L4ZXGNcSLRg4JBsZ3//1+f/TjYl0Mzen/DQy1EJg=
|
||||
github.com/segmentio/encoding v0.5.4 h1:OW1VRern8Nw6ITAtwSZ7Idrl3MXCFwXHPgqESYfvNt0=
|
||||
github.com/segmentio/encoding v0.5.4/go.mod h1:HS1ZKa3kSN32ZHVZ7ZLPLXWvOVIiZtyJnO1gPH1sKt0=
|
||||
github.com/spf13/cobra v1.10.2 h1:DMTTonx5m65Ic0GOoRY2c16WCbHxOOw6xxezuLaBpcU=
|
||||
github.com/spf13/cobra v1.10.2/go.mod h1:7C1pvHqHw5A4vrJfjNwvOdzYu0Gml16OCs2GRiTUUS4=
|
||||
github.com/spf13/pflag v1.0.9/go.mod h1:McXfInJRrz4CZXVZOBLb0bTZqETkiAhM9Iw0y3An2Bg=
|
||||
@@ -91,18 +103,24 @@ github.com/stretchr/testify v1.11.1 h1:7s2iGBzp5EwR7/aIZr8ao5+dra3wiQyKjjFuvgVKu
|
||||
github.com/stretchr/testify v1.11.1/go.mod h1:wZwfW3scLgRK+23gO65QZefKpKQRnfz6sD981Nm4B6U=
|
||||
github.com/xo/terminfo v0.0.0-20220910002029-abceb7e1c41e h1:JVG44RsyaB9T2KIHavMF/ppJZNG9ZpyihvCd0w101no=
|
||||
github.com/xo/terminfo v0.0.0-20220910002029-abceb7e1c41e/go.mod h1:RbqR21r5mrJuqunuUZ/Dhy/avygyECGrLceyNeo4LiM=
|
||||
github.com/yosida95/uritemplate/v3 v3.0.2 h1:Ed3Oyj9yrmi9087+NczuL5BwkIc4wvTb5zIM+UJPGz4=
|
||||
github.com/yosida95/uritemplate/v3 v3.0.2/go.mod h1:ILOh0sOhIJR3+L/8afwt/kE++YT040gmv5BQTMR2HP4=
|
||||
github.com/zalando/go-keyring v0.2.8 h1:6sD/Ucpl7jNq10rM2pgqTs0sZ9V3qMrqfIIy5YPccHs=
|
||||
github.com/zalando/go-keyring v0.2.8/go.mod h1:tsMo+VpRq5NGyKfxoBVjCuMrG47yj8cmakZDO5QGii0=
|
||||
go.yaml.in/yaml/v3 v3.0.4/go.mod h1:DhzuOOF2ATzADvBadXxruRBLzYTpT36CKvDb3+aBEFg=
|
||||
golang.org/x/exp v0.0.0-20231006140011-7918f672742d h1:jtJma62tbqLibJ5sFQz8bKtEM8rJBtfilJ2qTU199MI=
|
||||
golang.org/x/exp v0.0.0-20231006140011-7918f672742d/go.mod h1:ldy0pHrwJyGW56pPQzzkH36rKxoZW1tw7ZJpeKx+hdo=
|
||||
golang.org/x/oauth2 v0.35.0 h1:Mv2mzuHuZuY2+bkyWXIHMfhNdJAdwW3FuWeCPYN5GVQ=
|
||||
golang.org/x/oauth2 v0.35.0/go.mod h1:lzm5WQJQwKZ3nwavOZ3IS5Aulzxi68dUSgRHujetwEA=
|
||||
golang.org/x/sync v0.15.0 h1:KWH3jNZsfyT6xfAfKiz6MRNmd46ByHDYaZ7KSkCtdW8=
|
||||
golang.org/x/sync v0.15.0/go.mod h1:1dzgHSNfp02xaA81J2MS99Qcpr2w7fw1gpm99rleRqA=
|
||||
golang.org/x/sys v0.0.0-20210809222454-d867a43fc93e/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.38.0 h1:3yZWxaJjBmCWXqhN1qh02AkOnCQ1poK6oF+a7xWL6Gc=
|
||||
golang.org/x/sys v0.38.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks=
|
||||
golang.org/x/sys v0.41.0 h1:Ivj+2Cp/ylzLiEU89QhWblYnOE9zerudt9Ftecq2C6k=
|
||||
golang.org/x/sys v0.41.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks=
|
||||
golang.org/x/text v0.23.0 h1:D71I7dUrlY+VX0gQShAThNGHFxZ13dGLBHQLVl1mJlY=
|
||||
golang.org/x/text v0.23.0/go.mod h1:/BLNzu4aZCJ1+kcD0DNRotWKage4q2rGVAg4o22unh4=
|
||||
golang.org/x/tools v0.42.0 h1:uNgphsn75Tdz5Ji2q36v/nsFSfR/9BRFvqhGBaJGd5k=
|
||||
golang.org/x/tools v0.42.0/go.mod h1:Ma6lCIwGZvHK6XtgbswSoWroEkhugApmsXyrUmBhfr0=
|
||||
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM=
|
||||
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
|
||||
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
|
||||
|
||||
37
cli/internal/format/references.go
Normal file
37
cli/internal/format/references.go
Normal file
@@ -0,0 +1,37 @@
|
||||
package format
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"io"
|
||||
|
||||
sdk "github.com/Tencent/WeKnora/client"
|
||||
)
|
||||
|
||||
// WriteReferences renders the compact references footer used by chat and
|
||||
// agent invoke: a horizontal rule, one numbered line per reference,
|
||||
// best-available title + optional score. Skipped entirely when refs is
|
||||
// empty — agent-friendly silence beats an empty banner.
|
||||
func WriteReferences(w io.Writer, refs []*sdk.SearchResult) {
|
||||
if len(refs) == 0 {
|
||||
return
|
||||
}
|
||||
fmt.Fprintln(w)
|
||||
fmt.Fprintln(w, "──── References ────")
|
||||
for i, r := range refs {
|
||||
if r == nil {
|
||||
continue
|
||||
}
|
||||
title := r.KnowledgeTitle
|
||||
if title == "" {
|
||||
title = r.KnowledgeFilename
|
||||
}
|
||||
if title == "" {
|
||||
title = r.KnowledgeID
|
||||
}
|
||||
fmt.Fprintf(w, "[%d] %s", i+1, title)
|
||||
if r.Score > 0 {
|
||||
fmt.Fprintf(w, " score=%.3f", r.Score)
|
||||
}
|
||||
fmt.Fprintln(w)
|
||||
}
|
||||
}
|
||||
59
cli/internal/mcp/server.go
Normal file
59
cli/internal/mcp/server.go
Normal file
@@ -0,0 +1,59 @@
|
||||
// Package mcp wires the curated weknora tool set to an
|
||||
// modelcontextprotocol/go-sdk server. RunStdio is the entry point invoked
|
||||
// by `weknora mcp serve`.
|
||||
//
|
||||
// Design notes:
|
||||
//
|
||||
// - Tool surface is hand-curated rather than auto-derived from the cobra
|
||||
// tree (which would expose auth/link/completion/destructive verbs that
|
||||
// don't belong on an agent-callable surface).
|
||||
// - Long-running tools (chat / agent_invoke) accumulate the LLM SSE
|
||||
// stream server-side and return a single CallToolResult — MCP spec
|
||||
// 2025-06-18 does not define streamed tool-result content, so this is
|
||||
// the canonical pattern (see Anthropic reference `everything` server
|
||||
// and Stripe @stripe/mcp).
|
||||
// - Handlers receive ctx for cancellation; mid-LLM-stream cancellation
|
||||
// propagates to the SDK via context, which closes the SSE connection.
|
||||
package mcp
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
|
||||
mcpsdk "github.com/modelcontextprotocol/go-sdk/mcp"
|
||||
|
||||
"github.com/Tencent/WeKnora/cli/internal/build"
|
||||
)
|
||||
|
||||
// ServiceClient bundles the SDK methods the tool registry needs. *sdk.Client
|
||||
// satisfies it; tests substitute a fake to exercise the tool handlers
|
||||
// in-process without standing up a real WeKnora server.
|
||||
//
|
||||
// Embedding the full SDK Client would couple every tool test to every SDK
|
||||
// method; declaring the narrow surface here keeps the seam tight.
|
||||
type ServiceClient interface {
|
||||
knowledgeBaseService
|
||||
knowledgeService
|
||||
chatService
|
||||
agentService
|
||||
}
|
||||
|
||||
// RunStdio constructs the MCP server, registers the curated 9 tools, and
|
||||
// blocks reading JSON-RPC from stdin until the client disconnects or ctx
|
||||
// is cancelled. Returns the underlying transport error (if any); the cobra
|
||||
// RunE caller maps it through the usual cmdutil exit-code path.
|
||||
func RunStdio(ctx context.Context, svc ServiceClient) error {
|
||||
v, _, _ := build.Info()
|
||||
server := mcpsdk.NewServer(
|
||||
&mcpsdk.Implementation{
|
||||
Name: "weknora",
|
||||
Version: v,
|
||||
},
|
||||
nil,
|
||||
)
|
||||
registerTools(server, svc)
|
||||
if err := server.Run(ctx, &mcpsdk.StdioTransport{}); err != nil {
|
||||
return fmt.Errorf("mcp serve: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
463
cli/internal/mcp/tools.go
Normal file
463
cli/internal/mcp/tools.go
Normal file
@@ -0,0 +1,463 @@
|
||||
package mcp
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/base64"
|
||||
"fmt"
|
||||
"io"
|
||||
"strings"
|
||||
|
||||
mcpsdk "github.com/modelcontextprotocol/go-sdk/mcp"
|
||||
|
||||
"github.com/Tencent/WeKnora/cli/internal/sse"
|
||||
sdk "github.com/Tencent/WeKnora/client"
|
||||
)
|
||||
|
||||
// Narrow per-domain service interfaces. ServiceClient (server.go) embeds
|
||||
// them all; *sdk.Client satisfies the union implicitly.
|
||||
|
||||
type knowledgeBaseService interface {
|
||||
ListKnowledgeBases(ctx context.Context) ([]sdk.KnowledgeBase, error)
|
||||
GetKnowledgeBase(ctx context.Context, id string) (*sdk.KnowledgeBase, error)
|
||||
}
|
||||
|
||||
type knowledgeService interface {
|
||||
ListKnowledgeWithFilter(ctx context.Context, kbID string, page, pageSize int, filter sdk.KnowledgeListFilter) ([]sdk.Knowledge, int64, error)
|
||||
GetKnowledge(ctx context.Context, knowledgeID string) (*sdk.Knowledge, error)
|
||||
OpenKnowledgeFile(ctx context.Context, knowledgeID string) (string, io.ReadCloser, error)
|
||||
HybridSearch(ctx context.Context, kbID string, params *sdk.SearchParams) ([]*sdk.SearchResult, error)
|
||||
}
|
||||
|
||||
type chatService interface {
|
||||
CreateSession(ctx context.Context, req *sdk.CreateSessionRequest) (*sdk.Session, error)
|
||||
KnowledgeQAStream(ctx context.Context, sessionID string, req *sdk.KnowledgeQARequest, cb func(*sdk.StreamResponse) error) error
|
||||
}
|
||||
|
||||
type agentService interface {
|
||||
ListAgents(ctx context.Context) ([]sdk.Agent, error)
|
||||
GetAgent(ctx context.Context, agentID string) (*sdk.Agent, error)
|
||||
AgentQAStreamWithRequest(ctx context.Context, sessionID string, req *sdk.AgentQARequest, cb sdk.AgentEventCallback) error
|
||||
}
|
||||
|
||||
// agentInvokeService composes the two SDK methods agent_invoke needs
|
||||
// (CreateSession for the auto-session path + AgentQAStreamWithRequest
|
||||
// for the run itself). Declared here alongside the per-domain
|
||||
// interfaces above so ServiceClient (server.go) — which embeds the
|
||||
// four domain interfaces — also satisfies it.
|
||||
type agentInvokeService interface {
|
||||
CreateSession(ctx context.Context, req *sdk.CreateSessionRequest) (*sdk.Session, error)
|
||||
AgentQAStreamWithRequest(ctx context.Context, sessionID string, req *sdk.AgentQARequest, cb sdk.AgentEventCallback) error
|
||||
}
|
||||
|
||||
// registerTools wires the curated 9 tools onto server. Adding a tool here
|
||||
// is a deliberate API expansion — the agent-callable surface is the
|
||||
// reason this CLI ships an MCP server, not its CLI command list, so this
|
||||
// list must be maintained by hand (see also AGENTS.md mcp serve section).
|
||||
func registerTools(server *mcpsdk.Server, svc ServiceClient) {
|
||||
addKBList(server, svc)
|
||||
addKBView(server, svc)
|
||||
addDocList(server, svc)
|
||||
addDocView(server, svc)
|
||||
addDocDownload(server, svc)
|
||||
addSearchChunks(server, svc)
|
||||
addChat(server, svc)
|
||||
addAgentList(server, svc)
|
||||
addAgentInvoke(server, svc)
|
||||
}
|
||||
|
||||
// ---- kb_list -------------------------------------------------------------
|
||||
|
||||
type kbListInput struct{}
|
||||
|
||||
type kbListOutput struct {
|
||||
Items []sdk.KnowledgeBase `json:"items"`
|
||||
}
|
||||
|
||||
func addKBList(server *mcpsdk.Server, svc knowledgeBaseService) {
|
||||
mcpsdk.AddTool(server, &mcpsdk.Tool{
|
||||
Name: "kb_list",
|
||||
Description: "List all knowledge bases visible to the active WeKnora tenant. No arguments. Returns items[]: each item carries id, name, description, knowledge_count, is_pinned, updated_at — useful for selecting a kb_id to pass to other tools.",
|
||||
}, func(ctx context.Context, _ *mcpsdk.CallToolRequest, _ kbListInput) (*mcpsdk.CallToolResult, kbListOutput, error) {
|
||||
items, err := svc.ListKnowledgeBases(ctx)
|
||||
if err != nil {
|
||||
return nil, kbListOutput{}, fmt.Errorf("list knowledge bases: %w", err)
|
||||
}
|
||||
if items == nil {
|
||||
items = []sdk.KnowledgeBase{}
|
||||
}
|
||||
return nil, kbListOutput{Items: items}, nil
|
||||
})
|
||||
}
|
||||
|
||||
// ---- kb_view -------------------------------------------------------------
|
||||
|
||||
type kbViewInput struct {
|
||||
KBID string `json:"kb_id" jsonschema:"knowledge base ID"`
|
||||
}
|
||||
|
||||
func addKBView(server *mcpsdk.Server, svc knowledgeBaseService) {
|
||||
mcpsdk.AddTool(server, &mcpsdk.Tool{
|
||||
Name: "kb_view",
|
||||
Description: "Fetch a knowledge base by ID. Returns the full record including chunking config, embedding/summary model IDs, knowledge_count, and chunk_count.",
|
||||
}, func(ctx context.Context, _ *mcpsdk.CallToolRequest, in kbViewInput) (*mcpsdk.CallToolResult, *sdk.KnowledgeBase, error) {
|
||||
if in.KBID == "" {
|
||||
return nil, nil, fmt.Errorf("kb_id is required")
|
||||
}
|
||||
kb, err := svc.GetKnowledgeBase(ctx, in.KBID)
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("get knowledge base: %w", err)
|
||||
}
|
||||
return nil, kb, nil
|
||||
})
|
||||
}
|
||||
|
||||
// ---- doc_list ------------------------------------------------------------
|
||||
|
||||
type docListInput struct {
|
||||
KBID string `json:"kb_id" jsonschema:"knowledge base ID"`
|
||||
Page int `json:"page,omitempty" jsonschema:"1-indexed page number; defaults to 1"`
|
||||
PageSize int `json:"page_size,omitempty" jsonschema:"items per page (1..1000); defaults to 20"`
|
||||
Status string `json:"status,omitempty" jsonschema:"filter by parse status: pending | processing | completed | failed"`
|
||||
}
|
||||
|
||||
type docListOutput struct {
|
||||
Items []sdk.Knowledge `json:"items"`
|
||||
Page int `json:"page"`
|
||||
PageSize int `json:"page_size"`
|
||||
Total int64 `json:"total"`
|
||||
}
|
||||
|
||||
func addDocList(server *mcpsdk.Server, svc knowledgeService) {
|
||||
mcpsdk.AddTool(server, &mcpsdk.Tool{
|
||||
Name: "doc_list",
|
||||
Description: "List documents in a knowledge base, with pagination and optional parse-status filter. Returns items[] with id, file_name, title, parse_status, size, updated_at — plus the page/total metadata.",
|
||||
}, func(ctx context.Context, _ *mcpsdk.CallToolRequest, in docListInput) (*mcpsdk.CallToolResult, docListOutput, error) {
|
||||
if in.KBID == "" {
|
||||
return nil, docListOutput{}, fmt.Errorf("kb_id is required")
|
||||
}
|
||||
page := in.Page
|
||||
if page < 1 {
|
||||
page = 1
|
||||
}
|
||||
size := in.PageSize
|
||||
if size < 1 {
|
||||
size = 20
|
||||
}
|
||||
if size > 1000 {
|
||||
return nil, docListOutput{}, fmt.Errorf("page_size must be in 1..1000")
|
||||
}
|
||||
items, total, err := svc.ListKnowledgeWithFilter(ctx, in.KBID, page, size,
|
||||
sdk.KnowledgeListFilter{ParseStatus: in.Status})
|
||||
if err != nil {
|
||||
return nil, docListOutput{}, fmt.Errorf("list documents: %w", err)
|
||||
}
|
||||
if items == nil {
|
||||
items = []sdk.Knowledge{}
|
||||
}
|
||||
return nil, docListOutput{Items: items, Page: page, PageSize: size, Total: total}, nil
|
||||
})
|
||||
}
|
||||
|
||||
// ---- doc_view ------------------------------------------------------------
|
||||
|
||||
type docViewInput struct {
|
||||
KnowledgeID string `json:"knowledge_id" jsonschema:"document (knowledge entry) ID"`
|
||||
}
|
||||
|
||||
func addDocView(server *mcpsdk.Server, svc knowledgeService) {
|
||||
mcpsdk.AddTool(server, &mcpsdk.Tool{
|
||||
Name: "doc_view",
|
||||
Description: "Fetch a single document by ID. Returns the Knowledge record (file_name, title, type, parse_status, size, embedding_model_id, source URL if any, etc.).",
|
||||
}, func(ctx context.Context, _ *mcpsdk.CallToolRequest, in docViewInput) (*mcpsdk.CallToolResult, *sdk.Knowledge, error) {
|
||||
if in.KnowledgeID == "" {
|
||||
return nil, nil, fmt.Errorf("knowledge_id is required")
|
||||
}
|
||||
k, err := svc.GetKnowledge(ctx, in.KnowledgeID)
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("get knowledge: %w", err)
|
||||
}
|
||||
return nil, k, nil
|
||||
})
|
||||
}
|
||||
|
||||
// ---- doc_download --------------------------------------------------------
|
||||
|
||||
type docDownloadInput struct {
|
||||
KnowledgeID string `json:"knowledge_id" jsonschema:"document (knowledge entry) ID"`
|
||||
}
|
||||
|
||||
type docDownloadOutput struct {
|
||||
KnowledgeID string `json:"knowledge_id"`
|
||||
FileName string `json:"file_name"`
|
||||
Bytes int `json:"bytes"`
|
||||
// Content is the file contents (UTF-8 if text, base64 if the SDK
|
||||
// reports a binary-looking blob). For binary, agents should decode
|
||||
// before consuming.
|
||||
Content string `json:"content"`
|
||||
IsBase64 bool `json:"is_base64"`
|
||||
}
|
||||
|
||||
// maxDocDownloadBytes caps the per-call payload to keep an agent's context
|
||||
// window safe; agents needing larger documents should chunk via doc_view +
|
||||
// search_chunks. 1 MiB matches a typical LLM context-window budget for
|
||||
// inline content (~250k tokens) while remaining cheap to serialize.
|
||||
const maxDocDownloadBytes = 1 << 20
|
||||
|
||||
func addDocDownload(server *mcpsdk.Server, svc knowledgeService) {
|
||||
mcpsdk.AddTool(server, &mcpsdk.Tool{
|
||||
Name: "doc_download",
|
||||
Description: "Download a document's raw bytes by ID. Capped at 1 MiB per call — for larger documents, use search_chunks to find the relevant excerpts. is_base64 reports whether content was base64-encoded (heuristic: presence of NUL byte in the first 512 bytes).",
|
||||
}, func(ctx context.Context, _ *mcpsdk.CallToolRequest, in docDownloadInput) (*mcpsdk.CallToolResult, docDownloadOutput, error) {
|
||||
if in.KnowledgeID == "" {
|
||||
return nil, docDownloadOutput{}, fmt.Errorf("knowledge_id is required")
|
||||
}
|
||||
name, body, err := svc.OpenKnowledgeFile(ctx, in.KnowledgeID)
|
||||
if err != nil {
|
||||
return nil, docDownloadOutput{}, fmt.Errorf("open knowledge file: %w", err)
|
||||
}
|
||||
defer body.Close()
|
||||
buf, err := io.ReadAll(io.LimitReader(body, maxDocDownloadBytes+1))
|
||||
if err != nil {
|
||||
return nil, docDownloadOutput{}, fmt.Errorf("read knowledge file: %w", err)
|
||||
}
|
||||
if len(buf) > maxDocDownloadBytes {
|
||||
return nil, docDownloadOutput{}, fmt.Errorf("document exceeds the %d-byte per-call cap; use search_chunks for excerpts", maxDocDownloadBytes)
|
||||
}
|
||||
content, isBase64 := encodeDownload(buf)
|
||||
return nil, docDownloadOutput{
|
||||
KnowledgeID: in.KnowledgeID,
|
||||
FileName: name,
|
||||
Bytes: len(buf),
|
||||
Content: content,
|
||||
IsBase64: isBase64,
|
||||
}, nil
|
||||
})
|
||||
}
|
||||
|
||||
// ---- search_chunks -------------------------------------------------------
|
||||
|
||||
type searchChunksInput struct {
|
||||
KBID string `json:"kb_id" jsonschema:"knowledge base ID to search"`
|
||||
Query string `json:"query" jsonschema:"natural-language search query"`
|
||||
Limit int `json:"limit,omitempty" jsonschema:"client-side cap on results (1..1000); defaults to 10"`
|
||||
VectorThreshold float64 `json:"vector_threshold,omitempty" jsonschema:"minimum vector similarity (0..1)"`
|
||||
KeywordThreshold float64 `json:"keyword_threshold,omitempty" jsonschema:"minimum keyword score (0..1)"`
|
||||
}
|
||||
|
||||
type searchChunksOutput struct {
|
||||
Results []*sdk.SearchResult `json:"results"`
|
||||
}
|
||||
|
||||
func addSearchChunks(server *mcpsdk.Server, svc knowledgeService) {
|
||||
// Out = any: SDK output schema would derive from searchChunksOutput,
|
||||
// which embeds *sdk.SearchResult — and SearchResult.Metadata is a
|
||||
// nilable map[string]any that violates the auto-generated
|
||||
// type=object constraint when empty. Skipping derivation by using
|
||||
// `any` keeps the structured JSON shape identical while bypassing
|
||||
// the over-eager validator. Same pattern applied to chat / agent_invoke
|
||||
// below.
|
||||
mcpsdk.AddTool(server, &mcpsdk.Tool{
|
||||
Name: "search_chunks",
|
||||
Description: "Hybrid (vector + keyword) retrieval against a knowledge base. Returns the top chunks ranked by RRF; use this before chat to ground an answer in cited context. Results include knowledge_id, content, score — feed back into chat as context or display directly.",
|
||||
}, func(ctx context.Context, _ *mcpsdk.CallToolRequest, in searchChunksInput) (*mcpsdk.CallToolResult, any, error) {
|
||||
if in.KBID == "" {
|
||||
return nil, nil, fmt.Errorf("kb_id is required")
|
||||
}
|
||||
if strings.TrimSpace(in.Query) == "" {
|
||||
return nil, nil, fmt.Errorf("query cannot be empty")
|
||||
}
|
||||
limit := in.Limit
|
||||
if limit < 1 {
|
||||
limit = 10
|
||||
}
|
||||
if limit > 1000 {
|
||||
return nil, nil, fmt.Errorf("limit must be in 1..1000")
|
||||
}
|
||||
results, err := svc.HybridSearch(ctx, in.KBID, &sdk.SearchParams{
|
||||
QueryText: in.Query,
|
||||
VectorThreshold: in.VectorThreshold,
|
||||
KeywordThreshold: in.KeywordThreshold,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("hybrid search: %w", err)
|
||||
}
|
||||
if len(results) > limit {
|
||||
results = results[:limit]
|
||||
}
|
||||
if results == nil {
|
||||
results = []*sdk.SearchResult{}
|
||||
}
|
||||
return nil, searchChunksOutput{Results: results}, nil
|
||||
})
|
||||
}
|
||||
|
||||
// ---- chat ----------------------------------------------------------------
|
||||
|
||||
type chatInput struct {
|
||||
KBID string `json:"kb_id" jsonschema:"knowledge base ID to chat against"`
|
||||
Query string `json:"query" jsonschema:"user query"`
|
||||
SessionID string `json:"session_id,omitempty" jsonschema:"existing session to continue; auto-created when empty"`
|
||||
}
|
||||
|
||||
type chatOutput struct {
|
||||
Answer string `json:"answer"`
|
||||
References []*sdk.SearchResult `json:"references"`
|
||||
SessionID string `json:"session_id"`
|
||||
AssistantMessageID string `json:"assistant_message_id,omitempty"`
|
||||
}
|
||||
|
||||
func addChat(server *mcpsdk.Server, svc chatService) {
|
||||
mcpsdk.AddTool(server, &mcpsdk.Tool{
|
||||
Name: "chat",
|
||||
Description: "Stream a RAG answer from the LLM, grounded in the given knowledge base. The SSE stream is accumulated server-side; this tool returns the full answer + references + session_id once the stream completes. Pass session_id to continue a multi-turn conversation; otherwise a fresh session is auto-created.",
|
||||
}, func(ctx context.Context, _ *mcpsdk.CallToolRequest, in chatInput) (*mcpsdk.CallToolResult, any, error) {
|
||||
if in.KBID == "" {
|
||||
return nil, nil, fmt.Errorf("kb_id is required")
|
||||
}
|
||||
if strings.TrimSpace(in.Query) == "" {
|
||||
return nil, nil, fmt.Errorf("query cannot be empty")
|
||||
}
|
||||
sessionID := in.SessionID
|
||||
if sessionID == "" {
|
||||
sess, err := svc.CreateSession(ctx, &sdk.CreateSessionRequest{Title: "weknora mcp chat"})
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("create chat session: %w", err)
|
||||
}
|
||||
sessionID = sess.ID
|
||||
}
|
||||
req := &sdk.KnowledgeQARequest{
|
||||
Query: in.Query,
|
||||
KnowledgeBaseIDs: []string{in.KBID},
|
||||
AgentEnabled: false,
|
||||
Channel: "api",
|
||||
}
|
||||
acc := &sse.Accumulator{}
|
||||
streamErr := svc.KnowledgeQAStream(ctx, sessionID, req, func(r *sdk.StreamResponse) error {
|
||||
acc.Append(r)
|
||||
return nil
|
||||
})
|
||||
if streamErr != nil {
|
||||
return nil, nil, fmt.Errorf("knowledge qa stream: %w", streamErr)
|
||||
}
|
||||
if !acc.Done() {
|
||||
return nil, nil, fmt.Errorf("stream ended without a terminal event")
|
||||
}
|
||||
sid := acc.SessionID
|
||||
if sid == "" {
|
||||
sid = sessionID
|
||||
}
|
||||
return nil, chatOutput{
|
||||
Answer: acc.Result(),
|
||||
References: acc.References,
|
||||
SessionID: sid,
|
||||
AssistantMessageID: acc.AssistantMessageID,
|
||||
}, nil
|
||||
})
|
||||
}
|
||||
|
||||
// ---- agent_list ----------------------------------------------------------
|
||||
|
||||
type agentListInput struct{}
|
||||
|
||||
type agentListOutput struct {
|
||||
Items []sdk.Agent `json:"items"`
|
||||
}
|
||||
|
||||
func addAgentList(server *mcpsdk.Server, svc agentService) {
|
||||
mcpsdk.AddTool(server, &mcpsdk.Tool{
|
||||
Name: "agent_list",
|
||||
Description: "List the tenant's custom agents. Returns items[] with id, name, description, is_builtin — use to discover an agent_id before agent_invoke.",
|
||||
}, func(ctx context.Context, _ *mcpsdk.CallToolRequest, _ agentListInput) (*mcpsdk.CallToolResult, agentListOutput, error) {
|
||||
items, err := svc.ListAgents(ctx)
|
||||
if err != nil {
|
||||
return nil, agentListOutput{}, fmt.Errorf("list agents: %w", err)
|
||||
}
|
||||
if items == nil {
|
||||
items = []sdk.Agent{}
|
||||
}
|
||||
return nil, agentListOutput{Items: items}, nil
|
||||
})
|
||||
}
|
||||
|
||||
// ---- agent_invoke --------------------------------------------------------
|
||||
|
||||
type agentInvokeInput struct {
|
||||
AgentID string `json:"agent_id" jsonschema:"custom agent ID"`
|
||||
Query string `json:"query" jsonschema:"user query"`
|
||||
SessionID string `json:"session_id,omitempty" jsonschema:"existing session to continue; auto-created when empty"`
|
||||
}
|
||||
|
||||
type agentInvokeOutput struct {
|
||||
Answer string `json:"answer"`
|
||||
References []*sdk.SearchResult `json:"references"`
|
||||
ToolEvents []sse.AgentToolEvent `json:"tool_events,omitempty"`
|
||||
Thinking string `json:"thinking,omitempty"`
|
||||
SessionID string `json:"session_id"`
|
||||
AgentID string `json:"agent_id"`
|
||||
}
|
||||
|
||||
func addAgentInvoke(server *mcpsdk.Server, svc agentInvokeService) {
|
||||
mcpsdk.AddTool(server, &mcpsdk.Tool{
|
||||
Name: "agent_invoke",
|
||||
Description: "Run a query through a custom agent (system prompt + tool allow-list + KB scope). The agent's SSE stream is accumulated server-side; this tool returns the final answer plus the trace (references, tool_events, thinking).",
|
||||
}, func(ctx context.Context, _ *mcpsdk.CallToolRequest, in agentInvokeInput) (*mcpsdk.CallToolResult, any, error) {
|
||||
if in.AgentID == "" {
|
||||
return nil, nil, fmt.Errorf("agent_id is required")
|
||||
}
|
||||
if strings.TrimSpace(in.Query) == "" {
|
||||
return nil, nil, fmt.Errorf("query cannot be empty")
|
||||
}
|
||||
acc := &sse.AgentAccumulator{}
|
||||
req := &sdk.AgentQARequest{
|
||||
Query: in.Query,
|
||||
AgentEnabled: true,
|
||||
AgentID: in.AgentID,
|
||||
Channel: "api",
|
||||
}
|
||||
// Auto-create session if not supplied. Sessions are agent-
|
||||
// agnostic at creation (Q3 — verified against server source).
|
||||
sessionID := in.SessionID
|
||||
if sessionID == "" {
|
||||
sess, err := svc.CreateSession(ctx, &sdk.CreateSessionRequest{Title: "weknora mcp agent_invoke"})
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("create chat session: %w", err)
|
||||
}
|
||||
sessionID = sess.ID
|
||||
}
|
||||
streamErr := svc.AgentQAStreamWithRequest(ctx, sessionID, req, func(r *sdk.AgentStreamResponse) error {
|
||||
acc.Append(r)
|
||||
return nil
|
||||
})
|
||||
if streamErr != nil {
|
||||
return nil, nil, fmt.Errorf("agent-chat stream: %w", streamErr)
|
||||
}
|
||||
if !acc.Done() {
|
||||
return nil, nil, fmt.Errorf("stream ended without a terminal event")
|
||||
}
|
||||
return nil, agentInvokeOutput{
|
||||
Answer: acc.Answer(),
|
||||
References: acc.References,
|
||||
ToolEvents: acc.ToolEvents,
|
||||
Thinking: acc.Thinking(),
|
||||
SessionID: sessionID,
|
||||
AgentID: in.AgentID,
|
||||
}, nil
|
||||
})
|
||||
}
|
||||
|
||||
// encodeDownload returns (content, isBase64). Heuristic: if the first 512
|
||||
// bytes contain a NUL, treat as binary. Otherwise it's UTF-8-ish text.
|
||||
// Matches what /usr/bin/file's "binary" heuristic does at a coarse level —
|
||||
// good enough to spare an agent from base64-decoding obvious text.
|
||||
func encodeDownload(buf []byte) (string, bool) {
|
||||
probe := buf
|
||||
if len(probe) > 512 {
|
||||
probe = probe[:512]
|
||||
}
|
||||
for _, b := range probe {
|
||||
if b == 0 {
|
||||
return base64.StdEncoding.EncodeToString(buf), true
|
||||
}
|
||||
}
|
||||
return string(buf), false
|
||||
}
|
||||
407
cli/internal/mcp/tools_test.go
Normal file
407
cli/internal/mcp/tools_test.go
Normal file
@@ -0,0 +1,407 @@
|
||||
package mcp
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"io"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
mcpsdk "github.com/modelcontextprotocol/go-sdk/mcp"
|
||||
|
||||
sdk "github.com/Tencent/WeKnora/client"
|
||||
)
|
||||
|
||||
// fakeSvc implements every narrow service interface ServiceClient embeds.
|
||||
// Each method records the last call args; per-test setup populates the
|
||||
// return values it wants to assert against.
|
||||
type fakeSvc struct {
|
||||
listKBs []sdk.KnowledgeBase
|
||||
listKBsErr error
|
||||
getKB *sdk.KnowledgeBase
|
||||
getKBErr error
|
||||
listDocs []sdk.Knowledge
|
||||
listDocsTotal int64
|
||||
listDocsErr error
|
||||
getDoc *sdk.Knowledge
|
||||
getDocErr error
|
||||
openDocName string
|
||||
openDocBody io.ReadCloser
|
||||
openDocErr error
|
||||
hybridResults []*sdk.SearchResult
|
||||
hybridErr error
|
||||
createSess *sdk.Session
|
||||
createSessErr error
|
||||
kbStreamEvents []*sdk.StreamResponse
|
||||
kbStreamErr error
|
||||
agents []sdk.Agent
|
||||
agentsErr error
|
||||
agent *sdk.Agent
|
||||
agentErr error
|
||||
agentEvents []*sdk.AgentStreamResponse
|
||||
agentStreamErr error
|
||||
// Captured args:
|
||||
calls struct {
|
||||
listKBs int
|
||||
kbViewID string
|
||||
docListKBID string
|
||||
docListFilter sdk.KnowledgeListFilter
|
||||
docViewID string
|
||||
openDocID string
|
||||
hybridKBID string
|
||||
hybridParams *sdk.SearchParams
|
||||
createSessReq *sdk.CreateSessionRequest
|
||||
kbQAReq *sdk.KnowledgeQARequest
|
||||
kbQASess string
|
||||
agentListN int
|
||||
agentViewID string
|
||||
agentReq *sdk.AgentQARequest
|
||||
agentSess string
|
||||
}
|
||||
}
|
||||
|
||||
func (f *fakeSvc) ListKnowledgeBases(_ context.Context) ([]sdk.KnowledgeBase, error) {
|
||||
f.calls.listKBs++
|
||||
return f.listKBs, f.listKBsErr
|
||||
}
|
||||
func (f *fakeSvc) GetKnowledgeBase(_ context.Context, id string) (*sdk.KnowledgeBase, error) {
|
||||
f.calls.kbViewID = id
|
||||
return f.getKB, f.getKBErr
|
||||
}
|
||||
func (f *fakeSvc) ListKnowledgeWithFilter(_ context.Context, kbID string, _, _ int, filter sdk.KnowledgeListFilter) ([]sdk.Knowledge, int64, error) {
|
||||
f.calls.docListKBID = kbID
|
||||
f.calls.docListFilter = filter
|
||||
return f.listDocs, f.listDocsTotal, f.listDocsErr
|
||||
}
|
||||
func (f *fakeSvc) GetKnowledge(_ context.Context, id string) (*sdk.Knowledge, error) {
|
||||
f.calls.docViewID = id
|
||||
return f.getDoc, f.getDocErr
|
||||
}
|
||||
func (f *fakeSvc) OpenKnowledgeFile(_ context.Context, id string) (string, io.ReadCloser, error) {
|
||||
f.calls.openDocID = id
|
||||
return f.openDocName, f.openDocBody, f.openDocErr
|
||||
}
|
||||
func (f *fakeSvc) HybridSearch(_ context.Context, kbID string, p *sdk.SearchParams) ([]*sdk.SearchResult, error) {
|
||||
f.calls.hybridKBID, f.calls.hybridParams = kbID, p
|
||||
return f.hybridResults, f.hybridErr
|
||||
}
|
||||
func (f *fakeSvc) CreateSession(_ context.Context, req *sdk.CreateSessionRequest) (*sdk.Session, error) {
|
||||
f.calls.createSessReq = req
|
||||
if f.createSess == nil && f.createSessErr == nil {
|
||||
return &sdk.Session{ID: "sess_auto"}, nil
|
||||
}
|
||||
return f.createSess, f.createSessErr
|
||||
}
|
||||
func (f *fakeSvc) KnowledgeQAStream(_ context.Context, sess string, req *sdk.KnowledgeQARequest, cb func(*sdk.StreamResponse) error) error {
|
||||
f.calls.kbQASess, f.calls.kbQAReq = sess, req
|
||||
for _, e := range f.kbStreamEvents {
|
||||
if err := cb(e); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return f.kbStreamErr
|
||||
}
|
||||
func (f *fakeSvc) ListAgents(_ context.Context) ([]sdk.Agent, error) {
|
||||
f.calls.agentListN++
|
||||
return f.agents, f.agentsErr
|
||||
}
|
||||
func (f *fakeSvc) GetAgent(_ context.Context, id string) (*sdk.Agent, error) {
|
||||
f.calls.agentViewID = id
|
||||
return f.agent, f.agentErr
|
||||
}
|
||||
func (f *fakeSvc) AgentQAStreamWithRequest(_ context.Context, sess string, req *sdk.AgentQARequest, cb sdk.AgentEventCallback) error {
|
||||
f.calls.agentSess, f.calls.agentReq = sess, req
|
||||
for _, e := range f.agentEvents {
|
||||
if err := cb(e); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return f.agentStreamErr
|
||||
}
|
||||
|
||||
// newTestServer wires svc to an in-process MCP server and returns a
|
||||
// connected client session ready to CallTool against it.
|
||||
func newTestServer(t *testing.T, svc ServiceClient) (*mcpsdk.ClientSession, context.CancelFunc) {
|
||||
t.Helper()
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
server := mcpsdk.NewServer(&mcpsdk.Implementation{Name: "weknora-test", Version: "v0.0.0-test"}, nil)
|
||||
registerTools(server, svc)
|
||||
|
||||
st, ct := mcpsdk.NewInMemoryTransports()
|
||||
serverSession, err := server.Connect(ctx, st, nil)
|
||||
if err != nil {
|
||||
cancel()
|
||||
t.Fatalf("server.Connect: %v", err)
|
||||
}
|
||||
client := mcpsdk.NewClient(&mcpsdk.Implementation{Name: "test-client", Version: "v0.0.0"}, nil)
|
||||
clientSession, err := client.Connect(ctx, ct, nil)
|
||||
if err != nil {
|
||||
_ = serverSession.Close()
|
||||
cancel()
|
||||
t.Fatalf("client.Connect: %v", err)
|
||||
}
|
||||
t.Cleanup(func() {
|
||||
_ = clientSession.Close()
|
||||
_ = serverSession.Close()
|
||||
cancel()
|
||||
})
|
||||
return clientSession, cancel
|
||||
}
|
||||
|
||||
// callTool invokes name with args and returns the parsed structured output.
|
||||
func callTool(t *testing.T, c *mcpsdk.ClientSession, name string, args any, out any) *mcpsdk.CallToolResult {
|
||||
t.Helper()
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
res, err := c.CallTool(ctx, &mcpsdk.CallToolParams{Name: name, Arguments: args})
|
||||
if err != nil {
|
||||
t.Fatalf("CallTool(%s): %v", name, err)
|
||||
}
|
||||
if res.IsError {
|
||||
if len(res.Content) > 0 {
|
||||
t.Fatalf("tool %s returned error: %+v", name, res.Content)
|
||||
}
|
||||
t.Fatalf("tool %s returned error (no content)", name)
|
||||
}
|
||||
if out != nil && res.StructuredContent != nil {
|
||||
b, _ := json.Marshal(res.StructuredContent)
|
||||
if err := json.Unmarshal(b, out); err != nil {
|
||||
t.Fatalf("decode %s output: %v\nraw=%s", name, err, b)
|
||||
}
|
||||
}
|
||||
return res
|
||||
}
|
||||
|
||||
func TestTool_ListsRegistered(t *testing.T) {
|
||||
c, _ := newTestServer(t, &fakeSvc{})
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
|
||||
defer cancel()
|
||||
res, err := c.ListTools(ctx, nil)
|
||||
if err != nil {
|
||||
t.Fatalf("ListTools: %v", err)
|
||||
}
|
||||
want := []string{"kb_list", "kb_view", "doc_list", "doc_view", "doc_download", "search_chunks", "chat", "agent_list", "agent_invoke"}
|
||||
got := map[string]bool{}
|
||||
for _, tool := range res.Tools {
|
||||
got[tool.Name] = true
|
||||
}
|
||||
for _, name := range want {
|
||||
if !got[name] {
|
||||
t.Errorf("missing tool %q in ListTools response", name)
|
||||
}
|
||||
}
|
||||
if len(res.Tools) != len(want) {
|
||||
t.Errorf("registered %d tools, want exactly %d (no scope creep)", len(res.Tools), len(want))
|
||||
}
|
||||
}
|
||||
|
||||
func TestTool_KBList(t *testing.T) {
|
||||
svc := &fakeSvc{listKBs: []sdk.KnowledgeBase{{ID: "kb1", Name: "Marketing"}}}
|
||||
c, _ := newTestServer(t, svc)
|
||||
var out kbListOutput
|
||||
callTool(t, c, "kb_list", map[string]any{}, &out)
|
||||
if len(out.Items) != 1 || out.Items[0].ID != "kb1" {
|
||||
t.Errorf("got %+v", out)
|
||||
}
|
||||
}
|
||||
|
||||
func TestTool_KBView_RequiresKBID(t *testing.T) {
|
||||
c, _ := newTestServer(t, &fakeSvc{})
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
|
||||
defer cancel()
|
||||
res, err := c.CallTool(ctx, &mcpsdk.CallToolParams{Name: "kb_view", Arguments: map[string]any{}})
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected transport error: %v", err)
|
||||
}
|
||||
if !res.IsError {
|
||||
t.Fatal("expected IsError=true on missing kb_id")
|
||||
}
|
||||
}
|
||||
|
||||
func TestTool_KBView(t *testing.T) {
|
||||
svc := &fakeSvc{getKB: &sdk.KnowledgeBase{ID: "kb_x", Name: "Eng"}}
|
||||
c, _ := newTestServer(t, svc)
|
||||
var out sdk.KnowledgeBase
|
||||
callTool(t, c, "kb_view", map[string]any{"kb_id": "kb_x"}, &out)
|
||||
if out.ID != "kb_x" || out.Name != "Eng" {
|
||||
t.Errorf("got %+v", out)
|
||||
}
|
||||
if svc.calls.kbViewID != "kb_x" {
|
||||
t.Errorf("kb_id not forwarded: %s", svc.calls.kbViewID)
|
||||
}
|
||||
}
|
||||
|
||||
func TestTool_DocList_DefaultPagination(t *testing.T) {
|
||||
svc := &fakeSvc{listDocs: []sdk.Knowledge{{ID: "k1"}}, listDocsTotal: 1}
|
||||
c, _ := newTestServer(t, svc)
|
||||
var out docListOutput
|
||||
callTool(t, c, "doc_list", map[string]any{"kb_id": "kb_x"}, &out)
|
||||
if out.Page != 1 || out.PageSize != 20 {
|
||||
t.Errorf("default pagination not applied: %+v", out)
|
||||
}
|
||||
if svc.calls.docListKBID != "kb_x" {
|
||||
t.Errorf("kb_id not forwarded: %s", svc.calls.docListKBID)
|
||||
}
|
||||
}
|
||||
|
||||
func TestTool_DocList_StatusFilter_Forwarded(t *testing.T) {
|
||||
svc := &fakeSvc{}
|
||||
c, _ := newTestServer(t, svc)
|
||||
callTool(t, c, "doc_list", map[string]any{"kb_id": "kb_x", "status": "failed"}, nil)
|
||||
if svc.calls.docListFilter.ParseStatus != "failed" {
|
||||
t.Errorf("status not forwarded as filter.ParseStatus: %+v", svc.calls.docListFilter)
|
||||
}
|
||||
}
|
||||
|
||||
func TestTool_DocView(t *testing.T) {
|
||||
svc := &fakeSvc{getDoc: &sdk.Knowledge{ID: "k1", FileName: "a.pdf"}}
|
||||
c, _ := newTestServer(t, svc)
|
||||
var out sdk.Knowledge
|
||||
callTool(t, c, "doc_view", map[string]any{"knowledge_id": "k1"}, &out)
|
||||
if out.ID != "k1" {
|
||||
t.Errorf("got %+v", out)
|
||||
}
|
||||
}
|
||||
|
||||
func TestTool_DocDownload_Text(t *testing.T) {
|
||||
svc := &fakeSvc{
|
||||
openDocName: "notes.txt",
|
||||
openDocBody: io.NopCloser(strings.NewReader("hello world")),
|
||||
}
|
||||
c, _ := newTestServer(t, svc)
|
||||
var out docDownloadOutput
|
||||
callTool(t, c, "doc_download", map[string]any{"knowledge_id": "k1"}, &out)
|
||||
if out.Content != "hello world" {
|
||||
t.Errorf("content = %q", out.Content)
|
||||
}
|
||||
if out.IsBase64 {
|
||||
t.Error("text content should not be base64-encoded")
|
||||
}
|
||||
}
|
||||
|
||||
func TestTool_DocDownload_BinaryBase64(t *testing.T) {
|
||||
// First 512 bytes contain a NUL → encodeDownload returns base64.
|
||||
bin := []byte{0x00, 0x01, 0x02, 0x03}
|
||||
svc := &fakeSvc{
|
||||
openDocName: "blob.bin",
|
||||
openDocBody: io.NopCloser(strings.NewReader(string(bin))),
|
||||
}
|
||||
c, _ := newTestServer(t, svc)
|
||||
var out docDownloadOutput
|
||||
callTool(t, c, "doc_download", map[string]any{"knowledge_id": "k1"}, &out)
|
||||
if !out.IsBase64 {
|
||||
t.Errorf("binary should be base64; got is_base64=%v content=%q", out.IsBase64, out.Content)
|
||||
}
|
||||
}
|
||||
|
||||
func TestTool_SearchChunks(t *testing.T) {
|
||||
svc := &fakeSvc{hybridResults: []*sdk.SearchResult{{KnowledgeID: "k1", Score: 0.9}}}
|
||||
c, _ := newTestServer(t, svc)
|
||||
var out searchChunksOutput
|
||||
callTool(t, c, "search_chunks", map[string]any{"kb_id": "kb_x", "query": "what is RAG"}, &out)
|
||||
if len(out.Results) != 1 || out.Results[0].KnowledgeID != "k1" {
|
||||
t.Errorf("got %+v", out)
|
||||
}
|
||||
}
|
||||
|
||||
func TestTool_SearchChunks_LimitCap(t *testing.T) {
|
||||
// 5 results, limit 3 → 3 returned.
|
||||
svc := &fakeSvc{}
|
||||
for i := 0; i < 5; i++ {
|
||||
svc.hybridResults = append(svc.hybridResults, &sdk.SearchResult{KnowledgeID: "k", Score: float64(i)})
|
||||
}
|
||||
c, _ := newTestServer(t, svc)
|
||||
var out searchChunksOutput
|
||||
callTool(t, c, "search_chunks", map[string]any{"kb_id": "kb_x", "query": "x", "limit": 3}, &out)
|
||||
if len(out.Results) != 3 {
|
||||
t.Errorf("limit not honored: got %d, want 3", len(out.Results))
|
||||
}
|
||||
}
|
||||
|
||||
func TestTool_Chat_AccumulateAnswerAndReferences(t *testing.T) {
|
||||
svc := &fakeSvc{
|
||||
kbStreamEvents: []*sdk.StreamResponse{
|
||||
{Content: "Hello "},
|
||||
{Content: "world."},
|
||||
{KnowledgeReferences: []*sdk.SearchResult{{KnowledgeID: "k1"}}},
|
||||
{ResponseType: sdk.ResponseTypeComplete},
|
||||
},
|
||||
}
|
||||
c, _ := newTestServer(t, svc)
|
||||
var out chatOutput
|
||||
callTool(t, c, "chat", map[string]any{"kb_id": "kb_x", "query": "ping"}, &out)
|
||||
if out.Answer != "Hello world." {
|
||||
t.Errorf("answer = %q", out.Answer)
|
||||
}
|
||||
if len(out.References) != 1 || out.References[0].KnowledgeID != "k1" {
|
||||
t.Errorf("references missing: %+v", out.References)
|
||||
}
|
||||
if out.SessionID != "sess_auto" {
|
||||
t.Errorf("session_id = %q, want sess_auto", out.SessionID)
|
||||
}
|
||||
}
|
||||
|
||||
func TestTool_Chat_ExistingSessionSkipsCreate(t *testing.T) {
|
||||
svc := &fakeSvc{
|
||||
kbStreamEvents: []*sdk.StreamResponse{{ResponseType: sdk.ResponseTypeComplete}},
|
||||
}
|
||||
c, _ := newTestServer(t, svc)
|
||||
callTool(t, c, "chat", map[string]any{"kb_id": "kb_x", "query": "x", "session_id": "sess_existing"}, nil)
|
||||
if svc.calls.createSessReq != nil {
|
||||
t.Error("CreateSession should not fire when session_id is supplied")
|
||||
}
|
||||
if svc.calls.kbQASess != "sess_existing" {
|
||||
t.Errorf("session id not forwarded to QA stream: %s", svc.calls.kbQASess)
|
||||
}
|
||||
}
|
||||
|
||||
func TestTool_AgentList(t *testing.T) {
|
||||
svc := &fakeSvc{agents: []sdk.Agent{{ID: "ag1", Name: "Research"}}}
|
||||
c, _ := newTestServer(t, svc)
|
||||
var out agentListOutput
|
||||
callTool(t, c, "agent_list", map[string]any{}, &out)
|
||||
if len(out.Items) != 1 || out.Items[0].ID != "ag1" {
|
||||
t.Errorf("got %+v", out)
|
||||
}
|
||||
}
|
||||
|
||||
func TestTool_AgentInvoke(t *testing.T) {
|
||||
svc := &fakeSvc{
|
||||
agentEvents: []*sdk.AgentStreamResponse{
|
||||
{ResponseType: sdk.AgentResponseTypeAnswer, Content: "result"},
|
||||
{ResponseType: sdk.AgentResponseTypeToolCall, ID: "c1", Content: "knowledge_search"},
|
||||
{Done: true},
|
||||
},
|
||||
}
|
||||
c, _ := newTestServer(t, svc)
|
||||
var out agentInvokeOutput
|
||||
callTool(t, c, "agent_invoke", map[string]any{"agent_id": "ag1", "query": "x"}, &out)
|
||||
if out.Answer != "result" {
|
||||
t.Errorf("answer = %q", out.Answer)
|
||||
}
|
||||
if len(out.ToolEvents) != 1 {
|
||||
t.Errorf("tool_calls len = %d, want 1", len(out.ToolEvents))
|
||||
}
|
||||
if out.AgentID != "ag1" {
|
||||
t.Errorf("agent_id = %q", out.AgentID)
|
||||
}
|
||||
}
|
||||
|
||||
func TestTool_AgentInvoke_StreamAbort(t *testing.T) {
|
||||
svc := &fakeSvc{
|
||||
agentEvents: []*sdk.AgentStreamResponse{{ResponseType: sdk.AgentResponseTypeAnswer, Content: "partial"}},
|
||||
agentStreamErr: errors.New("connection reset"),
|
||||
}
|
||||
c, _ := newTestServer(t, svc)
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
|
||||
defer cancel()
|
||||
res, err := c.CallTool(ctx, &mcpsdk.CallToolParams{Name: "agent_invoke", Arguments: map[string]any{"agent_id": "ag1", "query": "x"}})
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected transport error: %v", err)
|
||||
}
|
||||
if !res.IsError {
|
||||
t.Fatal("expected IsError=true on mid-stream abort")
|
||||
}
|
||||
}
|
||||
86
cli/internal/sse/agent_accumulator.go
Normal file
86
cli/internal/sse/agent_accumulator.go
Normal file
@@ -0,0 +1,86 @@
|
||||
package sse
|
||||
|
||||
import (
|
||||
"strings"
|
||||
|
||||
sdk "github.com/Tencent/WeKnora/client"
|
||||
)
|
||||
|
||||
// AgentToolEvent captures one tool_call / tool_result event from an
|
||||
// agent SSE stream. Kind is the SDK event type (typed, not bare string)
|
||||
// so consumers can compare against sdk.AgentResponseTypeToolCall /
|
||||
// sdk.AgentResponseTypeToolResult without retyping the constants.
|
||||
// NOTE: Kind is the event kind, NOT the function name; the function
|
||||
// name typically lives in Data for tool_call events.
|
||||
type AgentToolEvent struct {
|
||||
ID string `json:"id"`
|
||||
Kind sdk.AgentResponseType `json:"kind,omitempty"`
|
||||
Result string `json:"result,omitempty"`
|
||||
Data map[string]any `json:"data,omitempty"`
|
||||
}
|
||||
|
||||
// AgentAccumulator buffers an AgentQAStream callback sequence. Distinct
|
||||
// from Accumulator (KnowledgeQAStream) because the agent event model is
|
||||
// wider — events include thinking / reflection / tool_call / tool_result
|
||||
// / answer / references / error, with a flat `Done bool` on each frame
|
||||
// rather than the ResponseType=complete sentinel KnowledgeQAStream emits.
|
||||
//
|
||||
// Zero value is ready to use. Not safe for concurrent Append calls — the
|
||||
// SDK callback contract is sequential on a single goroutine.
|
||||
//
|
||||
// API mirrors sse.Accumulator: private builders + accessor methods so the
|
||||
// "Append is idempotent post-Done" invariant cannot be broken by external
|
||||
// mutation. References and ToolEvents stay exported as plain slices since
|
||||
// they carry no such invariant.
|
||||
type AgentAccumulator struct {
|
||||
answer strings.Builder
|
||||
thinking strings.Builder
|
||||
References []*sdk.SearchResult
|
||||
ToolEvents []AgentToolEvent
|
||||
done bool
|
||||
}
|
||||
|
||||
// Answer returns the accumulated `answer`-event content.
|
||||
func (a *AgentAccumulator) Answer() string { return a.answer.String() }
|
||||
|
||||
// Thinking returns the accumulated `thinking` / `reflection` content
|
||||
// surfaced by the agent during its reasoning pass.
|
||||
func (a *AgentAccumulator) Thinking() string { return a.thinking.String() }
|
||||
|
||||
// Done reports whether the stream emitted a terminal frame.
|
||||
func (a *AgentAccumulator) Done() bool { return a.done }
|
||||
|
||||
// Append consumes one AgentStreamResponse event. Idempotent post-Done so
|
||||
// callers do not need to special-case late events.
|
||||
func (a *AgentAccumulator) Append(r *sdk.AgentStreamResponse) {
|
||||
if r == nil || a.done {
|
||||
return
|
||||
}
|
||||
switch r.ResponseType {
|
||||
case sdk.AgentResponseTypeAnswer:
|
||||
if r.Content != "" {
|
||||
a.answer.WriteString(r.Content)
|
||||
}
|
||||
case sdk.AgentResponseTypeThinking, sdk.AgentResponseTypeReflection:
|
||||
if r.Content != "" {
|
||||
a.thinking.WriteString(r.Content)
|
||||
}
|
||||
case sdk.AgentResponseTypeToolCall, sdk.AgentResponseTypeToolResult:
|
||||
a.ToolEvents = append(a.ToolEvents, AgentToolEvent{
|
||||
ID: r.ID,
|
||||
Kind: r.ResponseType,
|
||||
Result: r.Content,
|
||||
Data: r.Data,
|
||||
})
|
||||
}
|
||||
// References can arrive on a dedicated `references` event OR
|
||||
// piggyback on another event's KnowledgeReferences field; always
|
||||
// capture the latest, matching sse.Accumulator's "always replace"
|
||||
// semantic for the parallel KnowledgeQAStream contract.
|
||||
if r.KnowledgeReferences != nil {
|
||||
a.References = r.KnowledgeReferences
|
||||
}
|
||||
if r.Done {
|
||||
a.done = true
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user