mirror of
https://github.com/Tencent/WeKnora.git
synced 2026-06-04 13:30:32 +08:00
feat(agent): human-in-the-loop approval for MCP tool calls (#1173)
Add an opt-in human approval gate so Agent runs pause before executing MCP tools that operators flag as dangerous, surface an approval card in the chat UI, and only resume after the user approves (optionally with edited args) or rejects. Backend - New mcp_tool_approvals table + repo/service to mark per-tool approval required (PG migration 000042 + sqlite init). - approval.Gate coordinates RequestAndWait / Resolve with sync.Once delivery, configurable timeout, and Redis Pub/Sub fan-out so multi- replica deployments work without sticky sessions. - MCPTool.Execute integrates the gate; uses a round-level ApprovalCtx (without the per-tool 60s timeout) for the wait, and re-derives a fresh 60s exec ctx after approval so CallTool keeps a full window. - New SSE response types (tool_approval_required / _resolved) and EventBus events plumb approval state to AgentStreamDisplay. - REST: list/set per-tool approval flag, resolve pending approval. - Configurable via agent.tool_approval_timeout_seconds (yaml) or WEKNORA_AGENT_TOOL_APPROVAL_TIMEOUT env (accepts seconds or Go duration). Frontend - MCP settings: per-tool "require approval" switch on the test panel. - Chat: ToolApprovalCard renders the pause point with editable JSON args, validation feedback, mm:ss countdown that turns warning/danger near deadline, and a resolved state that retains context. - i18n strings added for zh-CN / en-US / ko-KR / ru-RU. Docs - docs/zh/mcp-approval.md covering behavior, config, API, deployment considerations (Redis cross-instance, restart limitations).
This commit is contained in:
@@ -381,6 +381,12 @@ WEKNORA_SANDBOX_DOCKER_IMAGE=wechatopenai/weknora-sandbox:latest
|
||||
# 注:此值为全局默认值。若单个智能体在数据库中配置了独立的 llm_call_timeout,则以智能体配置为准(优先级更高)。
|
||||
# WEKNORA_AGENT_LLM_TIMEOUT=300
|
||||
|
||||
# MCP 工具人工审核等待超时(秒)(可选)
|
||||
# 当某个 MCP 工具被标记为「需人工审核」后,Agent 会暂停并等待用户确认;
|
||||
# 该值控制最长等待时间,超时视为拒绝。默认 600(10 分钟)。
|
||||
# 也支持 Go duration 写法(如 30s / 5m / 1h)。
|
||||
# WEKNORA_AGENT_TOOL_APPROVAL_TIMEOUT=600
|
||||
|
||||
# APK 镜像源设置(可选)
|
||||
APK_MIRROR_ARG=mirrors.tencent.com
|
||||
|
||||
|
||||
@@ -150,6 +150,7 @@ services:
|
||||
- WEKNORA_SANDBOX_DOCKER_IMAGE=${WEKNORA_SANDBOX_DOCKER_IMAGE:-wechatopenai/weknora-sandbox:${WEKNORA_VERSION:-latest}}
|
||||
# Agent LLM call timeout
|
||||
- WEKNORA_AGENT_LLM_TIMEOUT=${WEKNORA_AGENT_LLM_TIMEOUT:-}
|
||||
- WEKNORA_AGENT_TOOL_APPROVAL_TIMEOUT=${WEKNORA_AGENT_TOOL_APPROVAL_TIMEOUT:-}
|
||||
- APK_MIRROR_ARG=${APK_MIRROR_ARG:-}
|
||||
depends_on:
|
||||
redis:
|
||||
|
||||
42
docs/zh/mcp-approval.md
Normal file
42
docs/zh/mcp-approval.md
Normal file
@@ -0,0 +1,42 @@
|
||||
# MCP 工具人工审核(危险调用)
|
||||
|
||||
对应需求:智能体调用 MCP 工具前可中断,待人工确认后再执行(GitHub #1173)。
|
||||
|
||||
## 行为说明
|
||||
|
||||
1. 在 **设置 → MCP** 中连接测试成功后,在工具列表上打开 **「需人工审核」** 开关,即可为该工具打标。
|
||||
2. Agent 运行时若即将调用已打标的工具,会推送 `tool_approval_required` 事件,对话界面展示审批卡片(可编辑 JSON 参数)。
|
||||
3. 用户 **通过** 或 **拒绝** 后,后端恢复执行;拒绝时工具返回错误信息给模型,不会调用远端 MCP。
|
||||
4. 若超时未处理(默认 10 分钟,可通过配置 `agent.tool_approval_timeout_seconds` 调整),视为拒绝。
|
||||
|
||||
## 配置示例
|
||||
|
||||
任选一种:
|
||||
|
||||
**1. config.yaml**
|
||||
|
||||
```yaml
|
||||
agent:
|
||||
tool_approval_timeout_seconds: 600 # 可选,默认 600(秒)
|
||||
```
|
||||
|
||||
**2. 环境变量**(优先级高于 yaml)
|
||||
|
||||
```bash
|
||||
# 支持纯秒数或 Go duration(30s / 5m / 1h)
|
||||
WEKNORA_AGENT_TOOL_APPROVAL_TIMEOUT=600
|
||||
```
|
||||
|
||||
## API
|
||||
|
||||
- `GET /api/v1/mcp-services/:id/tool-approvals` — 列出已保存的审核配置
|
||||
- `PUT /api/v1/mcp-services/:id/tool-approvals/:tool_name` — 设置某工具是否需审核(`{"require_approval": true}`)
|
||||
- `POST /api/v1/agent/tool-approvals/:pending_id` — 在审批卡片中提交结果
|
||||
- body: `{"decision":"approve"|"reject","modified_args":{...}可选,"reason":"..."可选}`
|
||||
|
||||
## 部署与限制
|
||||
|
||||
- **审批等待状态保存在进程内存** 中:`pending_id` 仅对当前实例有效;进程重启后进行中的等待会失败(表现为拒绝/取消)。
|
||||
- **多副本部署**:当配置了 `REDIS_ADDR` 时,`Resolve` 会通过 Redis Pub/Sub(频道 `weknora:mcp_approval:resolve`)跨实例转发,因此 SSE 与提交审批的 HTTP 请求落到不同实例也能正确唤醒等待者;未配置 Redis 时退化为单机模式,需要使用会话粘滞(sticky session)。
|
||||
- **审批等待不会被工具默认 60s 超时取消**:审批阶段使用 round 级别的 ctx(不带 `defaultToolExecTimeout`),仅受 `agent.tool_approval_timeout_seconds` 与请求级取消控制。
|
||||
- 安全边界:审核通过后的参数仍由当前登录租户提交;请仅在可信环境下授予「通过」权限。
|
||||
@@ -33,6 +33,15 @@ export interface MCPTool {
|
||||
name: string
|
||||
description: string
|
||||
inputSchema: Record<string, any>
|
||||
require_approval?: boolean
|
||||
}
|
||||
|
||||
export interface MCPToolApprovalRow {
|
||||
id: string
|
||||
tenant_id?: number
|
||||
service_id: string
|
||||
tool_name: string
|
||||
require_approval: boolean
|
||||
}
|
||||
|
||||
export interface MCPResource {
|
||||
@@ -102,3 +111,22 @@ export async function getMCPServiceResources(id: string): Promise<MCPResource[]>
|
||||
return response.data || []
|
||||
}
|
||||
|
||||
/** Persisted per-tool human-approval flags (issue #1173) */
|
||||
export async function getMCPToolApprovals(serviceId: string): Promise<MCPToolApprovalRow[]> {
|
||||
const response: any = await get(`/api/v1/mcp-services/${serviceId}/tool-approvals`)
|
||||
return response.data || []
|
||||
}
|
||||
|
||||
export async function setMCPToolApproval(serviceId: string, toolName: string, requireApproval: boolean): Promise<void> {
|
||||
await put(`/api/v1/mcp-services/${serviceId}/tool-approvals/${encodeURIComponent(toolName)}`, {
|
||||
require_approval: requireApproval
|
||||
})
|
||||
}
|
||||
|
||||
export async function resolveToolApproval(
|
||||
pendingId: string,
|
||||
body: { decision: 'approve' | 'reject'; modified_args?: Record<string, unknown>; reason?: string }
|
||||
): Promise<void> {
|
||||
await post(`/api/v1/agent/tool-approvals/${encodeURIComponent(pendingId)}`, body)
|
||||
}
|
||||
|
||||
|
||||
@@ -2411,7 +2411,11 @@ export default {
|
||||
resourcesTitle: 'Available resources',
|
||||
descriptionLabel: 'Description',
|
||||
schemaLabel: 'Parameter schema',
|
||||
emptyDescription: 'This service did not provide tools or resources'
|
||||
emptyDescription: 'This service did not provide tools or resources',
|
||||
requireApproval: 'Require human approval',
|
||||
requireApprovalTip:
|
||||
'When enabled, the agent pauses before calling this tool until you approve — use for DB writes, deletes, etc.',
|
||||
approvalSaveFailed: 'Failed to save approval setting'
|
||||
}
|
||||
},
|
||||
error: {
|
||||
@@ -3646,6 +3650,22 @@ export default {
|
||||
supportedFormats: 'Supported formats',
|
||||
},
|
||||
agentStream: {
|
||||
toolApproval: {
|
||||
banner: 'This MCP tool requires human approval. Review parameters before execution.',
|
||||
service: 'Service',
|
||||
tool: 'Tool',
|
||||
argsLabel: 'Arguments',
|
||||
argsModified: 'Modified',
|
||||
countdown: 'About {seconds}s remaining',
|
||||
approve: 'Approve & run',
|
||||
reject: 'Reject',
|
||||
approvedTag: 'Approved',
|
||||
rejectedTag: 'Rejected',
|
||||
invalidJson: 'Arguments must be valid JSON',
|
||||
submitted: 'Submitted',
|
||||
submitFailed: 'Submit failed',
|
||||
userRejected: 'User rejected',
|
||||
},
|
||||
tools: {
|
||||
searchKnowledge: 'Knowledge Search',
|
||||
grepChunks: 'Text Pattern Search',
|
||||
|
||||
@@ -1671,6 +1671,9 @@ export default {
|
||||
descriptionLabel: "설명",
|
||||
schemaLabel: "파라미터 구조",
|
||||
emptyDescription: "이 서비스에서 제공하는 도구 또는 리소스가 없습니다",
|
||||
requireApproval: "수동 승인 필요",
|
||||
requireApprovalTip: "활성화 시 에이전트가 이 도구를 호출하기 전에 승인을 기다립니다.",
|
||||
approvalSaveFailed: "승인 설정 저장 실패",
|
||||
},
|
||||
},
|
||||
error: {
|
||||
@@ -3714,6 +3717,22 @@ export default {
|
||||
supportedFormats: '지원 형식',
|
||||
},
|
||||
agentStream: {
|
||||
toolApproval: {
|
||||
banner: '이 MCP 도구는 수동 승인이 필요합니다. 실행 전 매개변수를 확인하세요.',
|
||||
service: '서비스',
|
||||
tool: '도구',
|
||||
argsLabel: '인수',
|
||||
argsModified: '수정됨',
|
||||
countdown: '약 {seconds}초 남음',
|
||||
approve: '승인 후 실행',
|
||||
reject: '거부',
|
||||
approvedTag: '승인됨',
|
||||
rejectedTag: '거부됨',
|
||||
invalidJson: '유효한 JSON이 아닙니다',
|
||||
submitted: '제출됨',
|
||||
submitFailed: '제출 실패',
|
||||
userRejected: '사용자 거부',
|
||||
},
|
||||
tools: {
|
||||
searchKnowledge: '지식베이스 검색',
|
||||
grepChunks: '텍스트 패턴 검색',
|
||||
|
||||
@@ -1487,7 +1487,10 @@ export default {
|
||||
resourcesTitle: 'Доступные ресурсы',
|
||||
descriptionLabel: 'Описание',
|
||||
schemaLabel: 'Структура параметров',
|
||||
emptyDescription: 'Сервис не предоставил инструменты или ресурсы'
|
||||
emptyDescription: 'Сервис не предоставил инструменты или ресурсы',
|
||||
requireApproval: 'Требуется подтверждение',
|
||||
requireApprovalTip: 'При включении агент ждёт подтверждения перед вызовом инструмента.',
|
||||
approvalSaveFailed: 'Не удалось сохранить настройку'
|
||||
}
|
||||
},
|
||||
error: {
|
||||
@@ -3323,6 +3326,22 @@ export default {
|
||||
supportedFormats: 'Поддерживаемые форматы'
|
||||
},
|
||||
agentStream: {
|
||||
toolApproval: {
|
||||
banner: 'Этот инструмент MCP требует подтверждения. Проверьте параметры.',
|
||||
service: 'Сервис',
|
||||
tool: 'Инструмент',
|
||||
argsLabel: 'Аргументы',
|
||||
argsModified: 'Изменено',
|
||||
countdown: 'Осталось около {seconds} с',
|
||||
approve: 'Подтвердить и выполнить',
|
||||
reject: 'Отклонить',
|
||||
approvedTag: 'Подтверждено',
|
||||
rejectedTag: 'Отклонено',
|
||||
invalidJson: 'Некорректный JSON',
|
||||
submitted: 'Отправлено',
|
||||
submitFailed: 'Ошибка отправки',
|
||||
userRejected: 'Отклонено пользователем',
|
||||
},
|
||||
tools: {
|
||||
searchKnowledge: 'Поиск по базе знаний',
|
||||
grepChunks: 'Поиск по текстовому шаблону',
|
||||
|
||||
@@ -1651,6 +1651,9 @@ export default {
|
||||
descriptionLabel: "描述",
|
||||
schemaLabel: "参数结构",
|
||||
emptyDescription: "该服务未提供工具或资源",
|
||||
requireApproval: "需人工审核",
|
||||
requireApprovalTip: "开启后,Agent 调用该工具前会暂停并等待确认,适用于可能改库/删文件等高危操作",
|
||||
approvalSaveFailed: "保存审核设置失败",
|
||||
},
|
||||
},
|
||||
error: {
|
||||
@@ -3643,6 +3646,22 @@ export default {
|
||||
supportedFormats: "支持格式",
|
||||
},
|
||||
agentStream: {
|
||||
toolApproval: {
|
||||
banner: "该 MCP 工具已标记为「需人工审核」,确认参数后再执行",
|
||||
service: "服务",
|
||||
tool: "工具",
|
||||
argsLabel: "调用参数",
|
||||
argsModified: "已修改",
|
||||
countdown: "剩余约 {seconds} 秒",
|
||||
approve: "通过并执行",
|
||||
reject: "拒绝",
|
||||
approvedTag: "已通过",
|
||||
rejectedTag: "已拒绝",
|
||||
invalidJson: "参数不是合法 JSON",
|
||||
submitted: "已提交",
|
||||
submitFailed: "提交失败",
|
||||
userRejected: "用户拒绝",
|
||||
},
|
||||
tools: {
|
||||
searchKnowledge: "知识库检索",
|
||||
grepChunks: "文本模式搜索",
|
||||
|
||||
@@ -70,6 +70,22 @@
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<!-- MCP tool human approval (issue #1173) -->
|
||||
<div v-else-if="event.type === 'tool_approval_required'" class="tool-event">
|
||||
<ToolApprovalCard
|
||||
:pending-id="event.pending_id"
|
||||
:service-name="event.service_name || ''"
|
||||
:mcp-tool-name="event.mcp_tool_name || ''"
|
||||
:description="event.description"
|
||||
:args-json="event.args_json"
|
||||
:timeout-seconds="event.timeout_seconds"
|
||||
:requested-at="event.requested_at"
|
||||
:resolved="event.resolved"
|
||||
:approved="event.approved"
|
||||
:resolve-reason="event.resolve_reason"
|
||||
/>
|
||||
</div>
|
||||
|
||||
<!-- Tool Call Event (non-thinking) -->
|
||||
<div v-else-if="event.type === 'tool_call'" class="tool-event">
|
||||
<div
|
||||
@@ -179,6 +195,22 @@
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<!-- MCP tool human approval -->
|
||||
<div v-else-if="event.type === 'tool_approval_required'" class="tool-event">
|
||||
<ToolApprovalCard
|
||||
:pending-id="event.pending_id"
|
||||
:service-name="event.service_name || ''"
|
||||
:mcp-tool-name="event.mcp_tool_name || ''"
|
||||
:description="event.description"
|
||||
:args-json="event.args_json"
|
||||
:timeout-seconds="event.timeout_seconds"
|
||||
:requested-at="event.requested_at"
|
||||
:resolved="event.resolved"
|
||||
:approved="event.approved"
|
||||
:resolve-reason="event.resolve_reason"
|
||||
/>
|
||||
</div>
|
||||
|
||||
<!-- Thinking Tool Call -->
|
||||
<div v-else-if="event.type === 'tool_call' && event.tool_name === 'thinking'" class="tool-event">
|
||||
<div class="action-card" :class="{ 'action-pending': event.pending || isThinkingActive(event.tool_call_id) }">
|
||||
@@ -375,6 +407,7 @@ import markedKatex from 'marked-katex-extension';
|
||||
import 'katex/dist/katex.min.css';
|
||||
import DOMPurify from 'dompurify';
|
||||
import ToolResultRenderer from './ToolResultRenderer.vue';
|
||||
import ToolApprovalCard from './ToolApprovalCard.vue';
|
||||
import picturePreview from '@/components/picture-preview.vue';
|
||||
import { getChunkByIdOnly } from '@/api/knowledge-base';
|
||||
import { getWikiPage, type WikiPage } from '@/api/wiki';
|
||||
@@ -1196,6 +1229,9 @@ const getEventKey = (event: any, index: number): string => {
|
||||
if (!event) return `event-${index}`;
|
||||
if (event.event_id) return `event-${event.event_id}`;
|
||||
if (event.tool_call_id) return `tool-${event.tool_call_id}`;
|
||||
if (event.type === 'tool_approval_required' && event.pending_id) {
|
||||
return `approval-${event.pending_id}`;
|
||||
}
|
||||
return `event-${index}-${event.type || 'unknown'}`;
|
||||
};
|
||||
|
||||
|
||||
430
frontend/src/views/chat/components/ToolApprovalCard.vue
Normal file
430
frontend/src/views/chat/components/ToolApprovalCard.vue
Normal file
@@ -0,0 +1,430 @@
|
||||
<template>
|
||||
<div class="approval-card" :class="cardClass">
|
||||
<!-- Status strip -->
|
||||
<div class="approval-strip">
|
||||
<span class="approval-strip-icon">
|
||||
<t-icon v-if="!resolved" name="info-circle-filled" />
|
||||
<t-icon v-else-if="approved" name="check-circle-filled" />
|
||||
<t-icon v-else name="close-circle-filled" />
|
||||
</span>
|
||||
<span class="approval-strip-text">
|
||||
<template v-if="!resolved">{{ $t('agentStream.toolApproval.banner') }}</template>
|
||||
<template v-else-if="approved">{{ $t('agentStream.toolApproval.approvedTag') }}</template>
|
||||
<template v-else>{{ $t('agentStream.toolApproval.rejectedTag') }}</template>
|
||||
</span>
|
||||
<span v-if="!resolved && secondsLeft >= 0" class="approval-strip-timer" :class="timerClass">
|
||||
<t-icon name="time" />
|
||||
{{ formatCountdown(secondsLeft) }}
|
||||
</span>
|
||||
</div>
|
||||
|
||||
<!-- Identity row -->
|
||||
<div class="approval-identity">
|
||||
<span class="ident-service">{{ serviceName }}</span>
|
||||
<t-icon name="chevron-right" class="ident-sep" />
|
||||
<span class="ident-tool">{{ mcpToolName }}</span>
|
||||
</div>
|
||||
|
||||
<div v-if="description" class="approval-desc">{{ description }}</div>
|
||||
|
||||
<!-- Args (editable while pending, read-only after resolve) -->
|
||||
<div class="approval-args">
|
||||
<div class="approval-args-label">
|
||||
<span class="args-label-text">{{ $t('agentStream.toolApproval.argsLabel') }}</span>
|
||||
<span v-if="!resolved && !isJsonValid" class="args-status args-invalid">
|
||||
<t-icon name="error-circle" /> {{ $t('agentStream.toolApproval.invalidJson') }}
|
||||
</span>
|
||||
<span v-else-if="!resolved && isJsonValid && argsDirty" class="args-status args-dirty">
|
||||
{{ $t('agentStream.toolApproval.argsModified') }}
|
||||
</span>
|
||||
</div>
|
||||
<t-textarea
|
||||
v-if="!resolved"
|
||||
v-model="argsText"
|
||||
class="approval-args-input"
|
||||
:autosize="{ minRows: 3, maxRows: 14 }"
|
||||
placeholder="{}"
|
||||
/>
|
||||
<pre v-else class="approval-args-readonly"><code>{{ argsText }}</code></pre>
|
||||
</div>
|
||||
|
||||
<!-- Footer (pending) -->
|
||||
<div v-if="!resolved" class="approval-footer">
|
||||
<span class="approval-spacer" />
|
||||
<t-button
|
||||
theme="default"
|
||||
variant="outline"
|
||||
size="small"
|
||||
:loading="submitting && pendingDecision === 'reject'"
|
||||
:disabled="submitting"
|
||||
@click="submit('reject')"
|
||||
>
|
||||
{{ $t('agentStream.toolApproval.reject') }}
|
||||
</t-button>
|
||||
<t-button
|
||||
theme="primary"
|
||||
size="small"
|
||||
:loading="submitting && pendingDecision === 'approve'"
|
||||
:disabled="submitting || !isJsonValid"
|
||||
@click="submit('approve')"
|
||||
>
|
||||
{{ $t('agentStream.toolApproval.approve') }}
|
||||
</t-button>
|
||||
</div>
|
||||
|
||||
<!-- Footer (resolved) -->
|
||||
<div v-else class="approval-resolved-footer">
|
||||
<span v-if="resolveReason" class="approval-resolved-reason">{{ resolveReason }}</span>
|
||||
</div>
|
||||
</div>
|
||||
</template>
|
||||
|
||||
<script setup lang="ts">
|
||||
import { ref, computed, onMounted, onBeforeUnmount } from 'vue'
|
||||
import { MessagePlugin } from 'tdesign-vue-next'
|
||||
import { useI18n } from 'vue-i18n'
|
||||
import { resolveToolApproval } from '@/api/mcp-service'
|
||||
|
||||
const props = defineProps<{
|
||||
pendingId: string
|
||||
serviceName: string
|
||||
mcpToolName: string
|
||||
description?: string
|
||||
argsJson?: string
|
||||
timeoutSeconds?: number
|
||||
requestedAt?: number
|
||||
resolved?: boolean
|
||||
approved?: boolean
|
||||
resolveReason?: string
|
||||
}>()
|
||||
|
||||
const { t } = useI18n()
|
||||
|
||||
function formatJson(raw: string): string {
|
||||
try {
|
||||
return JSON.stringify(JSON.parse(raw), null, 2)
|
||||
} catch {
|
||||
return raw
|
||||
}
|
||||
}
|
||||
|
||||
const initialArgs = formatJson(props.argsJson || '{}')
|
||||
const argsText = ref(initialArgs)
|
||||
const submitting = ref(false)
|
||||
const pendingDecision = ref<'approve' | 'reject' | null>(null)
|
||||
const now = ref(Date.now())
|
||||
let timer: ReturnType<typeof setInterval> | null = null
|
||||
|
||||
const isJsonValid = computed(() => {
|
||||
if (!argsText.value.trim()) return true
|
||||
try {
|
||||
JSON.parse(argsText.value)
|
||||
return true
|
||||
} catch {
|
||||
return false
|
||||
}
|
||||
})
|
||||
|
||||
const argsDirty = computed(() => argsText.value.trim() !== initialArgs.trim())
|
||||
|
||||
const deadline = computed(() => {
|
||||
const base = (props.requestedAt || 0) * 1000
|
||||
const add = (props.timeoutSeconds || 600) * 1000
|
||||
return base + add
|
||||
})
|
||||
|
||||
const secondsLeft = computed(() => {
|
||||
if (props.resolved) return -1
|
||||
return Math.max(0, Math.floor((deadline.value - now.value) / 1000))
|
||||
})
|
||||
|
||||
const timerClass = computed(() => {
|
||||
if (secondsLeft.value <= 30) return 'timer-critical'
|
||||
if (secondsLeft.value <= 120) return 'timer-warning'
|
||||
return ''
|
||||
})
|
||||
|
||||
const cardClass = computed(() => ({
|
||||
'is-resolved': !!props.resolved,
|
||||
'is-approved': !!props.resolved && !!props.approved,
|
||||
'is-rejected': !!props.resolved && !props.approved,
|
||||
'is-pending': !props.resolved,
|
||||
}))
|
||||
|
||||
function formatCountdown(s: number): string {
|
||||
if (s < 60) return t('agentStream.toolApproval.countdown', { seconds: s })
|
||||
const m = Math.floor(s / 60)
|
||||
const r = s % 60
|
||||
return `${m}:${r.toString().padStart(2, '0')}`
|
||||
}
|
||||
|
||||
onMounted(() => {
|
||||
timer = setInterval(() => {
|
||||
now.value = Date.now()
|
||||
}, 1000)
|
||||
})
|
||||
|
||||
onBeforeUnmount(() => {
|
||||
if (timer) clearInterval(timer)
|
||||
})
|
||||
|
||||
const submit = async (decision: 'approve' | 'reject') => {
|
||||
if (props.resolved) return
|
||||
submitting.value = true
|
||||
pendingDecision.value = decision
|
||||
try {
|
||||
let modified: Record<string, unknown> | undefined
|
||||
if (decision === 'approve') {
|
||||
try {
|
||||
modified = JSON.parse(argsText.value || '{}') as Record<string, unknown>
|
||||
} catch {
|
||||
MessagePlugin.error(t('agentStream.toolApproval.invalidJson'))
|
||||
return
|
||||
}
|
||||
}
|
||||
await resolveToolApproval(props.pendingId, {
|
||||
decision,
|
||||
modified_args: decision === 'approve' ? modified : undefined,
|
||||
reason: decision === 'reject' ? t('agentStream.toolApproval.userRejected') : undefined,
|
||||
})
|
||||
MessagePlugin.success(t('agentStream.toolApproval.submitted'))
|
||||
} catch (e: any) {
|
||||
const msg = e?.response?.data?.error?.message || e?.message || t('agentStream.toolApproval.submitFailed')
|
||||
MessagePlugin.error(msg)
|
||||
} finally {
|
||||
submitting.value = false
|
||||
pendingDecision.value = null
|
||||
}
|
||||
}
|
||||
</script>
|
||||
|
||||
<style scoped lang="less">
|
||||
@warning-rgb: 237, 122, 11;
|
||||
@success-rgb: 7, 192, 95;
|
||||
@danger-rgb: 232, 80, 91;
|
||||
|
||||
.approval-card {
|
||||
--strip-color: var(--td-warning-color);
|
||||
--strip-rgb: @warning-rgb;
|
||||
background: var(--td-bg-color-container);
|
||||
border: 1px solid var(--td-component-stroke);
|
||||
border-radius: 6px;
|
||||
overflow: hidden;
|
||||
box-shadow: 0 1px 2px rgba(0, 0, 0, 0.02);
|
||||
transition: border-color 0.2s ease, box-shadow 0.2s ease, opacity 0.2s ease;
|
||||
display: flex;
|
||||
flex-direction: column;
|
||||
position: relative;
|
||||
|
||||
&::before {
|
||||
content: '';
|
||||
position: absolute;
|
||||
inset: 0 auto 0 0;
|
||||
width: 3px;
|
||||
background: var(--strip-color);
|
||||
transition: background-color 0.2s ease;
|
||||
}
|
||||
|
||||
&.is-pending {
|
||||
box-shadow: 0 1px 6px rgba(@warning-rgb, 0.08);
|
||||
}
|
||||
|
||||
&.is-approved {
|
||||
--strip-color: var(--td-success-color);
|
||||
--strip-rgb: @success-rgb;
|
||||
opacity: 0.94;
|
||||
}
|
||||
|
||||
&.is-rejected {
|
||||
--strip-color: var(--td-error-color);
|
||||
--strip-rgb: @danger-rgb;
|
||||
opacity: 0.94;
|
||||
}
|
||||
}
|
||||
|
||||
.approval-strip {
|
||||
display: flex;
|
||||
align-items: center;
|
||||
gap: 6px;
|
||||
padding: 7px 12px 7px 14px;
|
||||
font-size: 12px;
|
||||
font-weight: 500;
|
||||
color: var(--strip-color);
|
||||
background: rgba(var(--strip-rgb), 0.06);
|
||||
border-bottom: 1px solid var(--td-component-stroke);
|
||||
|
||||
.approval-strip-icon {
|
||||
display: inline-flex;
|
||||
align-items: center;
|
||||
.t-icon {
|
||||
font-size: 14px;
|
||||
}
|
||||
}
|
||||
.approval-strip-text {
|
||||
flex: 1;
|
||||
color: var(--strip-color);
|
||||
}
|
||||
.approval-strip-timer {
|
||||
display: inline-flex;
|
||||
align-items: center;
|
||||
gap: 4px;
|
||||
padding: 2px 8px;
|
||||
border-radius: 10px;
|
||||
background: rgba(0, 0, 0, 0.04);
|
||||
color: var(--td-text-color-secondary);
|
||||
font-variant-numeric: tabular-nums;
|
||||
font-weight: 500;
|
||||
|
||||
.t-icon {
|
||||
font-size: 12px;
|
||||
}
|
||||
&.timer-warning {
|
||||
color: var(--td-warning-color);
|
||||
background: rgba(@warning-rgb, 0.1);
|
||||
}
|
||||
&.timer-critical {
|
||||
color: var(--td-error-color);
|
||||
background: rgba(@danger-rgb, 0.12);
|
||||
animation: timerPulse 1.2s ease-in-out infinite;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
.approval-identity {
|
||||
display: flex;
|
||||
align-items: center;
|
||||
gap: 6px;
|
||||
padding: 10px 12px 6px 14px;
|
||||
font-size: 13px;
|
||||
flex-wrap: wrap;
|
||||
|
||||
.ident-service {
|
||||
color: var(--td-text-color-secondary);
|
||||
font-weight: 500;
|
||||
}
|
||||
.ident-sep {
|
||||
color: var(--td-text-color-placeholder);
|
||||
font-size: 12px;
|
||||
}
|
||||
.ident-tool {
|
||||
color: var(--td-brand-color);
|
||||
font-weight: 600;
|
||||
font-family: var(--td-font-family-mono, ui-monospace, SFMono-Regular, Menlo, monospace);
|
||||
font-size: 13px;
|
||||
}
|
||||
}
|
||||
|
||||
.approval-desc {
|
||||
padding: 0 12px 4px 14px;
|
||||
font-size: 12px;
|
||||
line-height: 1.6;
|
||||
color: var(--td-text-color-secondary);
|
||||
}
|
||||
|
||||
.approval-args {
|
||||
padding: 8px 12px 0 14px;
|
||||
display: flex;
|
||||
flex-direction: column;
|
||||
gap: 4px;
|
||||
}
|
||||
|
||||
.approval-args-label {
|
||||
display: flex;
|
||||
align-items: center;
|
||||
gap: 8px;
|
||||
font-size: 12px;
|
||||
|
||||
.args-label-text {
|
||||
color: var(--td-text-color-placeholder);
|
||||
text-transform: uppercase;
|
||||
letter-spacing: 0.04em;
|
||||
font-size: 11px;
|
||||
font-weight: 500;
|
||||
}
|
||||
.args-status {
|
||||
display: inline-flex;
|
||||
align-items: center;
|
||||
gap: 3px;
|
||||
margin-left: auto;
|
||||
font-size: 11px;
|
||||
.t-icon {
|
||||
font-size: 12px;
|
||||
}
|
||||
}
|
||||
.args-invalid {
|
||||
color: var(--td-error-color);
|
||||
}
|
||||
.args-dirty {
|
||||
color: var(--td-warning-color);
|
||||
}
|
||||
}
|
||||
|
||||
.approval-args-input {
|
||||
:deep(.t-textarea__inner) {
|
||||
font-family: var(--td-font-family-mono, ui-monospace, SFMono-Regular, Menlo, monospace);
|
||||
font-size: 12px;
|
||||
line-height: 1.7;
|
||||
background: var(--td-bg-color-secondarycontainer);
|
||||
border-color: var(--td-component-stroke);
|
||||
color: var(--td-text-color-primary);
|
||||
padding: 8px 10px;
|
||||
transition: border-color 0.15s ease, box-shadow 0.15s ease;
|
||||
|
||||
&:hover {
|
||||
border-color: var(--td-brand-color-hover);
|
||||
}
|
||||
&:focus,
|
||||
&:focus-visible {
|
||||
border-color: var(--td-brand-color);
|
||||
box-shadow: 0 0 0 2px rgba(@success-rgb, 0.12);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
.approval-args-readonly {
|
||||
margin: 0;
|
||||
padding: 8px 10px;
|
||||
background: var(--td-bg-color-secondarycontainer);
|
||||
border: 1px solid var(--td-component-stroke);
|
||||
border-radius: 4px;
|
||||
font-family: var(--td-font-family-mono, ui-monospace, SFMono-Regular, Menlo, monospace);
|
||||
font-size: 12px;
|
||||
line-height: 1.7;
|
||||
color: var(--td-text-color-primary);
|
||||
white-space: pre-wrap;
|
||||
word-break: break-word;
|
||||
max-height: 180px;
|
||||
overflow: auto;
|
||||
}
|
||||
|
||||
.approval-footer {
|
||||
display: flex;
|
||||
align-items: center;
|
||||
gap: 8px;
|
||||
padding: 10px 12px 12px 14px;
|
||||
}
|
||||
|
||||
.approval-spacer {
|
||||
flex: 1;
|
||||
}
|
||||
|
||||
.approval-resolved-footer {
|
||||
padding: 6px 12px 10px 14px;
|
||||
font-size: 12px;
|
||||
color: var(--td-text-color-secondary);
|
||||
min-height: 0;
|
||||
|
||||
.approval-resolved-reason {
|
||||
color: var(--td-text-color-secondary);
|
||||
}
|
||||
|
||||
&:empty {
|
||||
display: none;
|
||||
}
|
||||
}
|
||||
|
||||
@keyframes timerPulse {
|
||||
0%, 100% { opacity: 1; }
|
||||
50% { opacity: 0.55; }
|
||||
}
|
||||
</style>
|
||||
@@ -825,7 +825,7 @@ const handleAgentChunk = (data) => {
|
||||
|
||||
// 确保在继续流式传输时(刷新页面场景),一旦接收到实际内容就关闭 loading
|
||||
// 这是一个保护措施,防止任何边缘情况导致 loading 残留
|
||||
if (loading.value && (data.response_type === 'thinking' || data.response_type === 'answer' || data.response_type === 'tool_call')) {
|
||||
if (loading.value && (data.response_type === 'thinking' || data.response_type === 'answer' || data.response_type === 'tool_call' || data.response_type === 'tool_approval_required')) {
|
||||
console.log('[Agent Chunk] Closing loading for continued stream');
|
||||
loading.value = false;
|
||||
}
|
||||
@@ -891,6 +891,38 @@ const handleAgentChunk = (data) => {
|
||||
}
|
||||
break;
|
||||
|
||||
case 'tool_approval_required': {
|
||||
if (!message.agentEventStream) message.agentEventStream = [];
|
||||
const d = data.data || {};
|
||||
message.agentEventStream.push({
|
||||
type: 'tool_approval_required',
|
||||
pending_id: d.pending_id,
|
||||
service_name: d.service_name,
|
||||
mcp_tool_name: d.mcp_tool_name,
|
||||
description: d.description,
|
||||
args_json: d.args_json,
|
||||
timeout_seconds: d.timeout_seconds,
|
||||
requested_at: d.requested_at,
|
||||
tool_call_id: d.tool_call_id,
|
||||
resolved: false,
|
||||
});
|
||||
break;
|
||||
}
|
||||
case 'tool_approval_resolved': {
|
||||
const d = data.data || {};
|
||||
const pid = d.pending_id;
|
||||
const ev = message.agentEventStream?.find(
|
||||
(e) => e.type === 'tool_approval_required' && e.pending_id === pid
|
||||
);
|
||||
if (ev) {
|
||||
ev.resolved = true;
|
||||
ev.approved = d.approved;
|
||||
ev.resolve_reason = d.reason;
|
||||
ev.timed_out = d.timed_out;
|
||||
ev.canceled = d.canceled;
|
||||
}
|
||||
break;
|
||||
}
|
||||
case 'tool_call':
|
||||
// Skip final_answer tool call from event stream - its content appears as answer events
|
||||
if (data.data && data.data.tool_name === 'final_answer') {
|
||||
|
||||
@@ -97,6 +97,7 @@
|
||||
v-model:visible="testDialogVisible"
|
||||
:result="testResult"
|
||||
:service-name="testingServiceName"
|
||||
:service-id="testingServiceId"
|
||||
/>
|
||||
</div>
|
||||
</template>
|
||||
@@ -129,6 +130,7 @@ const currentService = ref<MCPService | null>(null)
|
||||
const testDialogVisible = ref(false)
|
||||
const testResult = ref<MCPTestResult | null>(null)
|
||||
const testingServiceName = ref('')
|
||||
const testingServiceId = ref('')
|
||||
const testing = ref(false)
|
||||
|
||||
// Load MCP services
|
||||
@@ -184,6 +186,7 @@ const handleTest = async (service: MCPService) => {
|
||||
if (!service || !service.id) return
|
||||
|
||||
testingServiceName.value = service.name
|
||||
testingServiceId.value = service.id
|
||||
testing.value = true
|
||||
|
||||
MessagePlugin.info({
|
||||
|
||||
@@ -29,7 +29,7 @@
|
||||
</div>
|
||||
<div class="tools-grid">
|
||||
<div
|
||||
v-for="(tool, index) in result.tools"
|
||||
v-for="(tool, index) in displayTools"
|
||||
:key="index"
|
||||
class="tool-card"
|
||||
:class="{ 'tool-card-expanded': expandedToolIndex === index }"
|
||||
@@ -44,10 +44,24 @@
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
<t-icon
|
||||
:name="expandedToolIndex === index ? 'chevron-up' : 'chevron-down'"
|
||||
class="expand-icon"
|
||||
/>
|
||||
<div class="tool-header-right" @click.stop>
|
||||
<t-tooltip v-if="serviceId" :content="$t('mcp.testResult.requireApprovalTip')" placement="top">
|
||||
<div class="approval-switch">
|
||||
<t-icon name="error-circle-filled" class="danger-icon" />
|
||||
<span class="approval-label">{{ $t('mcp.testResult.requireApproval') }}</span>
|
||||
<t-switch
|
||||
:value="tool.require_approval"
|
||||
:loading="approvalLoading[tool.name]"
|
||||
size="small"
|
||||
@change="(v: boolean) => onRequireApprovalChange(tool.name, v)"
|
||||
/>
|
||||
</div>
|
||||
</t-tooltip>
|
||||
<t-icon
|
||||
:name="expandedToolIndex === index ? 'chevron-up' : 'chevron-down'"
|
||||
class="expand-icon"
|
||||
/>
|
||||
</div>
|
||||
</div>
|
||||
<div v-if="expandedToolIndex === index" class="tool-card-content">
|
||||
<div v-if="tool.description" class="tool-description">
|
||||
@@ -119,14 +133,18 @@
|
||||
</template>
|
||||
|
||||
<script setup lang="ts">
|
||||
import { computed, ref } from 'vue'
|
||||
import type { MCPTestResult } from '@/api/mcp-service'
|
||||
import { computed, ref, watch } from 'vue'
|
||||
import type { MCPTestResult, MCPTool } from '@/api/mcp-service'
|
||||
import { getMCPToolApprovals, setMCPToolApproval } from '@/api/mcp-service'
|
||||
import { MessagePlugin } from 'tdesign-vue-next'
|
||||
import { useI18n } from 'vue-i18n'
|
||||
|
||||
interface Props {
|
||||
visible: boolean
|
||||
result: MCPTestResult | null
|
||||
serviceName: string
|
||||
/** When set, loads/saves per-tool approval flags */
|
||||
serviceId?: string
|
||||
}
|
||||
|
||||
interface Emits {
|
||||
@@ -138,6 +156,56 @@ const emit = defineEmits<Emits>()
|
||||
|
||||
const expandedToolIndex = ref<number | null>(null)
|
||||
const { t } = useI18n()
|
||||
const displayTools = ref<MCPTool[]>([])
|
||||
const approvalLoading = ref<Record<string, boolean>>({})
|
||||
|
||||
const mergeApprovals = async () => {
|
||||
const tools = props.result?.tools
|
||||
if (!tools?.length) {
|
||||
displayTools.value = []
|
||||
return
|
||||
}
|
||||
if (!props.serviceId) {
|
||||
displayTools.value = tools.map((x) => ({ ...x }))
|
||||
return
|
||||
}
|
||||
try {
|
||||
const rows = await getMCPToolApprovals(props.serviceId)
|
||||
const map = new Map(rows.map((r) => [r.tool_name, r.require_approval]))
|
||||
displayTools.value = tools.map((tool) => ({
|
||||
...tool,
|
||||
require_approval: map.get(tool.name) || false,
|
||||
}))
|
||||
} catch {
|
||||
displayTools.value = tools.map((x) => ({ ...x }))
|
||||
}
|
||||
}
|
||||
|
||||
watch(
|
||||
() => [props.visible, props.serviceId, props.result?.tools],
|
||||
() => {
|
||||
if (props.visible) {
|
||||
void mergeApprovals()
|
||||
}
|
||||
},
|
||||
{ deep: true }
|
||||
)
|
||||
|
||||
const onRequireApprovalChange = async (toolName: string, value: boolean) => {
|
||||
if (!props.serviceId) return
|
||||
approvalLoading.value = { ...approvalLoading.value, [toolName]: true }
|
||||
try {
|
||||
await setMCPToolApproval(props.serviceId, toolName, value)
|
||||
displayTools.value = displayTools.value.map((x) =>
|
||||
x.name === toolName ? { ...x, require_approval: value } : x
|
||||
)
|
||||
} catch (e) {
|
||||
console.error(e)
|
||||
MessagePlugin.error(t('mcp.testResult.approvalSaveFailed'))
|
||||
} finally {
|
||||
approvalLoading.value = { ...approvalLoading.value, [toolName]: false }
|
||||
}
|
||||
}
|
||||
|
||||
const dialogVisible = computed({
|
||||
get: () => props.visible,
|
||||
@@ -310,6 +378,29 @@ const handleClose = () => {
|
||||
}
|
||||
}
|
||||
|
||||
.tool-header-right {
|
||||
display: flex;
|
||||
align-items: center;
|
||||
gap: 10px;
|
||||
flex-shrink: 0;
|
||||
}
|
||||
|
||||
.approval-switch {
|
||||
display: flex;
|
||||
align-items: center;
|
||||
gap: 6px;
|
||||
font-size: 12px;
|
||||
color: var(--td-text-color-secondary);
|
||||
.danger-icon {
|
||||
color: var(--td-warning-color);
|
||||
font-size: 16px;
|
||||
}
|
||||
.approval-label {
|
||||
max-width: 88px;
|
||||
line-height: 1.2;
|
||||
}
|
||||
}
|
||||
|
||||
.expand-icon {
|
||||
color: var(--td-text-color-placeholder);
|
||||
font-size: 16px;
|
||||
|
||||
@@ -165,7 +165,7 @@ func formatToolHint(name string, args map[string]any) string {
|
||||
// When ParallelToolCalls is enabled and there are 2+ tool calls, they execute concurrently.
|
||||
func (e *AgentEngine) executeToolCalls(
|
||||
ctx context.Context, response *types.ChatResponse,
|
||||
step *types.AgentStep, iteration int, sessionID string,
|
||||
step *types.AgentStep, iteration int, sessionID, assistantMessageID string,
|
||||
) {
|
||||
if len(response.ToolCalls) == 0 {
|
||||
return
|
||||
@@ -177,12 +177,12 @@ func (e *AgentEngine) executeToolCalls(
|
||||
|
||||
// Use parallel execution when enabled and there are multiple tool calls
|
||||
if e.config.ParallelToolCalls && n >= 2 {
|
||||
e.executeToolCallsParallel(ctx, response, step, iteration, sessionID)
|
||||
e.executeToolCallsParallel(ctx, response, step, iteration, sessionID, assistantMessageID)
|
||||
return
|
||||
}
|
||||
|
||||
for i, tc := range response.ToolCalls {
|
||||
e.executeSingleToolCall(ctx, tc, i, step, iteration, round, sessionID)
|
||||
e.executeSingleToolCall(ctx, tc, i, step, iteration, round, sessionID, assistantMessageID)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -190,7 +190,7 @@ func (e *AgentEngine) executeToolCalls(
|
||||
// collecting results in original order.
|
||||
func (e *AgentEngine) executeToolCallsParallel(
|
||||
ctx context.Context, response *types.ChatResponse,
|
||||
step *types.AgentStep, iteration int, sessionID string,
|
||||
step *types.AgentStep, iteration int, sessionID, assistantMessageID string,
|
||||
) {
|
||||
round := iteration + 1
|
||||
n := len(response.ToolCalls)
|
||||
@@ -203,7 +203,7 @@ func (e *AgentEngine) executeToolCallsParallel(
|
||||
for i, tc := range response.ToolCalls {
|
||||
i, tc := i, tc // capture loop vars
|
||||
g.Go(func() error {
|
||||
toolCall := e.runToolCall(gCtx, tc, i, iteration, round, sessionID)
|
||||
toolCall := e.runToolCall(gCtx, tc, i, iteration, round, sessionID, assistantMessageID)
|
||||
mu.Lock()
|
||||
results[i] = toolCall
|
||||
mu.Unlock()
|
||||
@@ -258,9 +258,9 @@ func (e *AgentEngine) executeToolCallsParallel(
|
||||
// executeSingleToolCall runs one tool call sequentially (original behavior).
|
||||
func (e *AgentEngine) executeSingleToolCall(
|
||||
ctx context.Context, tc types.LLMToolCall, i int,
|
||||
step *types.AgentStep, iteration, round int, sessionID string,
|
||||
step *types.AgentStep, iteration, round int, sessionID, assistantMessageID string,
|
||||
) {
|
||||
toolCall := e.runToolCall(ctx, tc, i, iteration, round, sessionID)
|
||||
toolCall := e.runToolCall(ctx, tc, i, iteration, round, sessionID, assistantMessageID)
|
||||
step.ToolCalls = append(step.ToolCalls, toolCall)
|
||||
|
||||
result := toolCall.Result
|
||||
@@ -304,7 +304,7 @@ func (e *AgentEngine) executeSingleToolCall(
|
||||
// It returns the completed ToolCall struct. Safe to call from multiple goroutines.
|
||||
func (e *AgentEngine) runToolCall(
|
||||
ctx context.Context, tc types.LLMToolCall, i int,
|
||||
iteration, round int, sessionID string,
|
||||
iteration, round int, sessionID, assistantMessageID string,
|
||||
) types.ToolCall {
|
||||
tc.ID = agenttools.NormalizeToolCallID(tc.ID, tc.Function.Name, i)
|
||||
total := "?" // unknown in isolation; callers log the batch size
|
||||
@@ -392,7 +392,17 @@ func (e *AgentEngine) runToolCall(
|
||||
},
|
||||
})
|
||||
|
||||
execCtx, toolCancel := context.WithTimeout(toolCtx, defaultToolExecTimeout)
|
||||
toolExecCtx := agenttools.WithToolExecContext(toolCtx, &agenttools.ToolExecContext{
|
||||
SessionID: sessionID,
|
||||
AssistantMessageID: assistantMessageID,
|
||||
EventBus: e.eventBus,
|
||||
ToolCallID: tc.ID,
|
||||
// ApprovalCtx keeps the round-level ctx without the per-tool 60s timeout,
|
||||
// so MCP tool human-approval (issue #1173) can legitimately block longer.
|
||||
ApprovalCtx: toolCtx,
|
||||
})
|
||||
|
||||
execCtx, toolCancel := context.WithTimeout(toolExecCtx, defaultToolExecTimeout)
|
||||
result, err := e.toolRegistry.ExecuteTool(
|
||||
execCtx, tc.Function.Name,
|
||||
json.RawMessage(tc.Function.Arguments),
|
||||
|
||||
361
internal/agent/approval/gate.go
Normal file
361
internal/agent/approval/gate.go
Normal file
@@ -0,0 +1,361 @@
|
||||
// Package approval implements human-in-the-loop gating for dangerous MCP tool calls (issue #1173).
|
||||
package approval
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/Tencent/WeKnora/internal/config"
|
||||
"github.com/Tencent/WeKnora/internal/event"
|
||||
"github.com/Tencent/WeKnora/internal/logger"
|
||||
"github.com/google/uuid"
|
||||
"github.com/redis/go-redis/v9"
|
||||
)
|
||||
|
||||
// pubsubChannel is the Redis channel used to fan-out Resolve calls across
|
||||
// backend replicas (issue #1173 cross-instance support).
|
||||
const pubsubChannel = "weknora:mcp_approval:resolve"
|
||||
|
||||
// resolveMessage is the JSON payload published when one instance receives a
|
||||
// Resolve API call but the pending wait may live on another instance.
|
||||
type resolveMessage struct {
|
||||
TenantID uint64 `json:"tenant_id"`
|
||||
PendingID string `json:"pending_id"`
|
||||
Approved bool `json:"approved"`
|
||||
ModifiedArgs json.RawMessage `json:"modified_args,omitempty"`
|
||||
Reason string `json:"reason,omitempty"`
|
||||
TimedOut bool `json:"timed_out,omitempty"`
|
||||
Canceled bool `json:"canceled,omitempty"`
|
||||
}
|
||||
|
||||
// Checker answers whether a concrete MCP tool requires human approval before execution.
|
||||
type Checker interface {
|
||||
IsRequired(ctx context.Context, tenantID uint64, serviceID, toolName string) (bool, error)
|
||||
}
|
||||
|
||||
// Decision is the outcome of a pending tool approval.
|
||||
type Decision struct {
|
||||
Approved bool
|
||||
ModifiedArgs json.RawMessage // optional JSON object; when set and Approved, replaces original args
|
||||
Reason string
|
||||
TimedOut bool
|
||||
ContextCanceled bool
|
||||
}
|
||||
|
||||
// PendingRequest carries everything needed to block and notify the UI.
|
||||
type PendingRequest struct {
|
||||
TenantID uint64
|
||||
SessionID string
|
||||
AssistantMessageID string
|
||||
RequestID string
|
||||
EventBus *event.EventBus
|
||||
ServiceID string
|
||||
ServiceName string
|
||||
MCPToolName string // name on MCP server
|
||||
RegisteredToolName string // registry name e.g. mcp_svc_tool
|
||||
Description string
|
||||
Args json.RawMessage
|
||||
ToolCallID string
|
||||
}
|
||||
|
||||
// MCPApproval is the surface used by MCPTool (mockable in tests).
|
||||
type MCPApproval interface {
|
||||
NeedsApproval(ctx context.Context, tenantID uint64, serviceID, toolName string) bool
|
||||
RequestAndWait(ctx context.Context, req PendingRequest) (Decision, error)
|
||||
}
|
||||
|
||||
var _ MCPApproval = (*Gate)(nil)
|
||||
|
||||
// Gate coordinates wait/resolve for MCP tool approvals.
|
||||
//
|
||||
// Pending waiters live in-memory on the instance that started RequestAndWait.
|
||||
// When a redis client is supplied, Resolve calls hitting any replica are
|
||||
// published over Redis Pub/Sub so the owning instance can deliver the decision
|
||||
// (issue #1173 cross-instance support). Without redis, the gate degrades to
|
||||
// single-process behavior (deployments must use sticky sessions).
|
||||
type Gate struct {
|
||||
mu sync.Mutex
|
||||
pending map[string]*waiter
|
||||
checker Checker
|
||||
timeout time.Duration
|
||||
rdb *redis.Client // optional; nil disables cross-instance fan-out
|
||||
}
|
||||
|
||||
type waiter struct {
|
||||
ch chan Decision
|
||||
tenantID uint64
|
||||
once sync.Once
|
||||
}
|
||||
|
||||
func (w *waiter) deliver(d Decision) {
|
||||
if w == nil {
|
||||
return
|
||||
}
|
||||
w.once.Do(func() {
|
||||
select {
|
||||
case w.ch <- d:
|
||||
default:
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
var (
|
||||
// ErrPendingNotFound is returned when Resolve is called with an unknown id.
|
||||
ErrPendingNotFound = errors.New("tool approval pending not found")
|
||||
// ErrTenantMismatch is returned when Resolve tenant does not match the pending request.
|
||||
ErrTenantMismatch = errors.New("tenant mismatch for tool approval")
|
||||
)
|
||||
|
||||
// NewGate builds a gate. checker may be nil (disables gating). cfg may be nil
|
||||
// (defaults apply). rdb may be nil (single-instance mode).
|
||||
func NewGate(cfg *config.Config, checker Checker, rdb *redis.Client) *Gate {
|
||||
timeout := 10 * time.Minute
|
||||
if cfg != nil && cfg.Agent != nil && cfg.Agent.ToolApprovalTimeoutSeconds > 0 {
|
||||
timeout = time.Duration(cfg.Agent.ToolApprovalTimeoutSeconds) * time.Second
|
||||
}
|
||||
g := &Gate{
|
||||
pending: make(map[string]*waiter),
|
||||
checker: checker,
|
||||
timeout: timeout,
|
||||
rdb: rdb,
|
||||
}
|
||||
if rdb != nil {
|
||||
go g.runSubscriber()
|
||||
}
|
||||
return g
|
||||
}
|
||||
|
||||
// runSubscriber listens for cross-instance Resolve fan-outs and delivers
|
||||
// decisions to local waiters. Runs for the lifetime of the process.
|
||||
func (g *Gate) runSubscriber() {
|
||||
ctx := context.Background()
|
||||
for {
|
||||
sub := g.rdb.Subscribe(ctx, pubsubChannel)
|
||||
ch := sub.Channel()
|
||||
for msg := range ch {
|
||||
var m resolveMessage
|
||||
if err := json.Unmarshal([]byte(msg.Payload), &m); err != nil {
|
||||
logger.GetLogger(ctx).Warnf("mcp approval pubsub: bad payload: %v", err)
|
||||
continue
|
||||
}
|
||||
if err := g.deliverLocal(m.TenantID, m.PendingID, Decision{
|
||||
Approved: m.Approved,
|
||||
ModifiedArgs: m.ModifiedArgs,
|
||||
Reason: m.Reason,
|
||||
TimedOut: m.TimedOut,
|
||||
ContextCanceled: m.Canceled,
|
||||
}); err != nil && !errors.Is(err, ErrPendingNotFound) {
|
||||
logger.GetLogger(ctx).Warnf("mcp approval pubsub deliver: %v", err)
|
||||
}
|
||||
}
|
||||
_ = sub.Close()
|
||||
// Reconnect after brief backoff if Redis hiccups.
|
||||
time.Sleep(2 * time.Second)
|
||||
}
|
||||
}
|
||||
|
||||
// NeedsApproval returns whether execution should pause for human confirmation.
|
||||
func (g *Gate) NeedsApproval(ctx context.Context, tenantID uint64, serviceID, toolName string) bool {
|
||||
if g == nil || g.checker == nil || tenantID == 0 || serviceID == "" || toolName == "" {
|
||||
return false
|
||||
}
|
||||
ok, err := g.checker.IsRequired(ctx, tenantID, serviceID, toolName)
|
||||
if err != nil {
|
||||
logger.GetLogger(ctx).Warnf("mcp tool approval check failed (skip gate): %v", err)
|
||||
return false
|
||||
}
|
||||
return ok
|
||||
}
|
||||
|
||||
// RequestAndWait emits a UI event, then blocks until Resolve, timeout, or ctx cancellation.
|
||||
func (g *Gate) RequestAndWait(ctx context.Context, req PendingRequest) (Decision, error) {
|
||||
if g == nil {
|
||||
return Decision{Approved: true}, nil
|
||||
}
|
||||
if g.checker == nil {
|
||||
return Decision{Approved: true}, nil
|
||||
}
|
||||
if req.EventBus == nil {
|
||||
return Decision{}, fmt.Errorf("tool approval: EventBus is nil")
|
||||
}
|
||||
|
||||
pendingID := uuid.New().String()
|
||||
w := &waiter{
|
||||
ch: make(chan Decision, 1),
|
||||
tenantID: req.TenantID,
|
||||
}
|
||||
|
||||
g.mu.Lock()
|
||||
g.pending[pendingID] = w
|
||||
g.mu.Unlock()
|
||||
|
||||
defer func() {
|
||||
g.mu.Lock()
|
||||
delete(g.pending, pendingID)
|
||||
g.mu.Unlock()
|
||||
}()
|
||||
|
||||
var argsObj interface{}
|
||||
if len(req.Args) > 0 {
|
||||
_ = json.Unmarshal(req.Args, &argsObj)
|
||||
}
|
||||
|
||||
timeoutSec := int(g.timeout / time.Second)
|
||||
if timeoutSec < 1 {
|
||||
timeoutSec = 1
|
||||
}
|
||||
|
||||
evtData := event.ToolApprovalRequiredData{
|
||||
PendingID: pendingID,
|
||||
TenantID: req.TenantID,
|
||||
SessionID: req.SessionID,
|
||||
AssistantMessageID: req.AssistantMessageID,
|
||||
ServiceID: req.ServiceID,
|
||||
ServiceName: req.ServiceName,
|
||||
MCPToolName: req.MCPToolName,
|
||||
RegisteredToolName: req.RegisteredToolName,
|
||||
Description: req.Description,
|
||||
Args: argsObj,
|
||||
ArgsJSON: string(req.Args),
|
||||
TimeoutSeconds: timeoutSec,
|
||||
RequestedAtUnix: time.Now().Unix(),
|
||||
ToolCallID: req.ToolCallID,
|
||||
RequestID: req.RequestID,
|
||||
}
|
||||
|
||||
if err := req.EventBus.Emit(ctx, event.Event{
|
||||
ID: pendingID + "-approval-required",
|
||||
Type: event.EventToolApprovalRequired,
|
||||
SessionID: req.SessionID,
|
||||
Data: evtData,
|
||||
Metadata: map[string]interface{}{
|
||||
"assistant_message_id": req.AssistantMessageID,
|
||||
"pending_id": pendingID,
|
||||
},
|
||||
RequestID: req.RequestID,
|
||||
}); err != nil {
|
||||
return Decision{}, fmt.Errorf("emit tool approval required: %w", err)
|
||||
}
|
||||
|
||||
timer := time.NewTimer(g.timeout)
|
||||
defer timer.Stop()
|
||||
|
||||
emitResolved := func(d Decision) {
|
||||
if req.EventBus == nil {
|
||||
return
|
||||
}
|
||||
_ = req.EventBus.Emit(context.WithoutCancel(ctx), event.Event{
|
||||
ID: pendingID + "-approval-resolved",
|
||||
Type: event.EventToolApprovalResolved,
|
||||
SessionID: req.SessionID,
|
||||
Data: event.ToolApprovalResolvedData{
|
||||
PendingID: pendingID,
|
||||
Approved: d.Approved,
|
||||
Reason: d.Reason,
|
||||
TimedOut: d.TimedOut,
|
||||
Canceled: d.ContextCanceled,
|
||||
},
|
||||
Metadata: map[string]interface{}{
|
||||
"assistant_message_id": req.AssistantMessageID,
|
||||
},
|
||||
RequestID: req.RequestID,
|
||||
})
|
||||
}
|
||||
|
||||
var d Decision
|
||||
select {
|
||||
case d = <-w.ch:
|
||||
emitResolved(d)
|
||||
return d, nil
|
||||
case <-timer.C:
|
||||
d = Decision{Approved: false, Reason: "approval timeout", TimedOut: true}
|
||||
w.deliver(d)
|
||||
d = <-w.ch
|
||||
emitResolved(d)
|
||||
return d, nil
|
||||
case <-ctx.Done():
|
||||
d = Decision{Approved: false, Reason: "request canceled", ContextCanceled: true}
|
||||
w.deliver(d)
|
||||
d = <-w.ch
|
||||
emitResolved(d)
|
||||
return d, nil
|
||||
}
|
||||
}
|
||||
|
||||
// Resolve completes a pending approval. tenantID must match the tenant that
|
||||
// started the wait. If the waiter is not on this instance and Redis Pub/Sub
|
||||
// is configured, the decision is fanned out to all replicas (best-effort).
|
||||
func (g *Gate) Resolve(tenantID uint64, pendingID string, d Decision) error {
|
||||
if g == nil {
|
||||
return fmt.Errorf("gate is nil")
|
||||
}
|
||||
switch err := g.deliverLocal(tenantID, pendingID, d); {
|
||||
case err == nil:
|
||||
return nil
|
||||
case errors.Is(err, ErrTenantMismatch):
|
||||
return err
|
||||
case errors.Is(err, ErrPendingNotFound):
|
||||
if g.rdb == nil {
|
||||
return err
|
||||
}
|
||||
// Fan out to other replicas.
|
||||
payload, mErr := json.Marshal(resolveMessage{
|
||||
TenantID: tenantID,
|
||||
PendingID: pendingID,
|
||||
Approved: d.Approved,
|
||||
ModifiedArgs: d.ModifiedArgs,
|
||||
Reason: d.Reason,
|
||||
TimedOut: d.TimedOut,
|
||||
Canceled: d.ContextCanceled,
|
||||
})
|
||||
if mErr != nil {
|
||||
return fmt.Errorf("encode pubsub payload: %w", mErr)
|
||||
}
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
|
||||
defer cancel()
|
||||
if pErr := g.rdb.Publish(ctx, pubsubChannel, payload).Err(); pErr != nil {
|
||||
return fmt.Errorf("publish approval resolve: %w", pErr)
|
||||
}
|
||||
// Best-effort: the owning instance will deliver and emit Resolved.
|
||||
return nil
|
||||
default:
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
// deliverLocal attempts to satisfy a waiter on this instance only.
|
||||
func (g *Gate) deliverLocal(tenantID uint64, pendingID string, d Decision) error {
|
||||
g.mu.Lock()
|
||||
w, ok := g.pending[pendingID]
|
||||
if !ok {
|
||||
g.mu.Unlock()
|
||||
return ErrPendingNotFound
|
||||
}
|
||||
if w.tenantID != tenantID {
|
||||
g.mu.Unlock()
|
||||
return ErrTenantMismatch
|
||||
}
|
||||
g.mu.Unlock()
|
||||
|
||||
w.deliver(d)
|
||||
return nil
|
||||
}
|
||||
|
||||
// Adapter makes MCPToolApprovalService satisfy Checker without importing the service package here.
|
||||
type Adapter struct {
|
||||
Svc interface {
|
||||
IsRequired(ctx context.Context, tenantID uint64, serviceID, toolName string) (bool, error)
|
||||
}
|
||||
}
|
||||
|
||||
// IsRequired implements Checker.
|
||||
func (a *Adapter) IsRequired(ctx context.Context, tenantID uint64, serviceID, toolName string) (bool, error) {
|
||||
if a == nil || a.Svc == nil {
|
||||
return false, nil
|
||||
}
|
||||
return a.Svc.IsRequired(ctx, tenantID, serviceID, toolName)
|
||||
}
|
||||
80
internal/agent/approval/gate_test.go
Normal file
80
internal/agent/approval/gate_test.go
Normal file
@@ -0,0 +1,80 @@
|
||||
package approval
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"testing"
|
||||
|
||||
"github.com/Tencent/WeKnora/internal/config"
|
||||
"github.com/Tencent/WeKnora/internal/event"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
type stubChecker struct {
|
||||
required bool
|
||||
err error
|
||||
}
|
||||
|
||||
func (s *stubChecker) IsRequired(ctx context.Context, tenantID uint64, serviceID, toolName string) (bool, error) {
|
||||
return s.required, s.err
|
||||
}
|
||||
|
||||
func TestGate_RequestAndWait_Approve(t *testing.T) {
|
||||
bus := event.NewEventBus()
|
||||
g := NewGate(&config.Config{Agent: &config.AgentConfig{ToolApprovalTimeoutSeconds: 2}}, &stubChecker{required: true}, nil)
|
||||
|
||||
ctx := context.Background()
|
||||
req := PendingRequest{
|
||||
TenantID: 1,
|
||||
SessionID: "s1",
|
||||
AssistantMessageID: "m1",
|
||||
EventBus: bus,
|
||||
ServiceID: "svc",
|
||||
ServiceName: "svcname",
|
||||
MCPToolName: "danger_tool",
|
||||
RegisteredToolName: "mcp_svcname_danger_tool",
|
||||
Description: "desc",
|
||||
Args: json.RawMessage(`{"a":1}`),
|
||||
ToolCallID: "tc1",
|
||||
}
|
||||
|
||||
bus.On(event.EventToolApprovalRequired, func(_ context.Context, evt event.Event) error {
|
||||
data, ok := evt.Data.(event.ToolApprovalRequiredData)
|
||||
require.True(t, ok)
|
||||
require.NotEmpty(t, data.PendingID)
|
||||
go func() {
|
||||
_ = g.Resolve(1, data.PendingID, Decision{Approved: true, ModifiedArgs: json.RawMessage(`{"a":2}`)})
|
||||
}()
|
||||
return nil
|
||||
})
|
||||
|
||||
d, err := g.RequestAndWait(ctx, req)
|
||||
require.NoError(t, err)
|
||||
require.True(t, d.Approved)
|
||||
require.JSONEq(t, `{"a":2}`, string(d.ModifiedArgs))
|
||||
}
|
||||
|
||||
func TestGate_RequestAndWait_Timeout(t *testing.T) {
|
||||
g := NewGate(&config.Config{Agent: &config.AgentConfig{ToolApprovalTimeoutSeconds: 1}}, &stubChecker{required: true}, nil)
|
||||
ctx := context.Background()
|
||||
req := PendingRequest{
|
||||
TenantID: 1,
|
||||
SessionID: "s1",
|
||||
AssistantMessageID: "m1",
|
||||
EventBus: event.NewEventBus(),
|
||||
ServiceID: "svc",
|
||||
ServiceName: "svcname",
|
||||
MCPToolName: "t",
|
||||
RegisteredToolName: "mcp_svcname_t",
|
||||
Args: json.RawMessage(`{}`),
|
||||
}
|
||||
d, err := g.RequestAndWait(ctx, req)
|
||||
require.NoError(t, err)
|
||||
require.False(t, d.Approved)
|
||||
require.True(t, d.TimedOut)
|
||||
}
|
||||
|
||||
func TestGate_NeedsApproval_NoChecker(t *testing.T) {
|
||||
g := NewGate(nil, nil, nil)
|
||||
require.False(t, g.NeedsApproval(context.Background(), 1, "x", "y"))
|
||||
}
|
||||
@@ -397,7 +397,7 @@ loop:
|
||||
// every exit path (break/continue/next) without having to sprinkle
|
||||
// manual finish calls throughout the many branches below.
|
||||
outcome, iterErr := e.runReActIteration(ctx, state, &messages, tools,
|
||||
sessionID, query, &emptyRetries, &consecutiveSameContent, &lastResponseContent)
|
||||
sessionID, messageID, query, &emptyRetries, &consecutiveSameContent, &lastResponseContent)
|
||||
if iterErr != nil {
|
||||
return state, iterErr
|
||||
}
|
||||
@@ -451,7 +451,7 @@ func (e *AgentEngine) runReActIteration(
|
||||
state *types.AgentState,
|
||||
messagesPtr *[]chat.Message,
|
||||
tools []chat.Tool,
|
||||
sessionID, query string,
|
||||
sessionID, assistantMessageID, query string,
|
||||
emptyRetries, consecutiveSameContent *int,
|
||||
lastResponseContent *string,
|
||||
) (outcome iterOutcome, retErr error) {
|
||||
@@ -625,7 +625,7 @@ func (e *AgentEngine) runReActIteration(
|
||||
}
|
||||
|
||||
// 3. Act: Execute tool calls
|
||||
e.executeToolCalls(ctx, response, &step, state.CurrentRound, sessionID)
|
||||
e.executeToolCalls(ctx, response, &step, state.CurrentRound, sessionID, assistantMessageID)
|
||||
toolCallCount = len(step.ToolCalls)
|
||||
|
||||
// 4. Observe: Add tool results to messages and write to context
|
||||
|
||||
40
internal/agent/tools/exec_context.go
Normal file
40
internal/agent/tools/exec_context.go
Normal file
@@ -0,0 +1,40 @@
|
||||
package tools
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/Tencent/WeKnora/internal/event"
|
||||
)
|
||||
|
||||
type execCtxKey struct{}
|
||||
|
||||
// ToolExecContext is attached to context during agent tool execution (per tool call).
|
||||
type ToolExecContext struct {
|
||||
SessionID string
|
||||
AssistantMessageID string
|
||||
RequestID string
|
||||
ToolCallID string
|
||||
EventBus *event.EventBus
|
||||
// ApprovalCtx is the parent ctx WITHOUT defaultToolExecTimeout; used when the tool
|
||||
// must wait for human approval that may exceed normal tool exec timeout (issue #1173).
|
||||
// Falls back to the per-tool execCtx when nil.
|
||||
ApprovalCtx context.Context
|
||||
}
|
||||
|
||||
// WithToolExecContext returns ctx that carries ToolExecContext for MCP approval and similar features.
|
||||
func WithToolExecContext(ctx context.Context, meta *ToolExecContext) context.Context {
|
||||
if meta == nil {
|
||||
return ctx
|
||||
}
|
||||
return context.WithValue(ctx, execCtxKey{}, meta)
|
||||
}
|
||||
|
||||
// ToolExecFromContext returns metadata attached by the agent engine, if any.
|
||||
func ToolExecFromContext(ctx context.Context) (*ToolExecContext, bool) {
|
||||
v := ctx.Value(execCtxKey{})
|
||||
if v == nil {
|
||||
return nil, false
|
||||
}
|
||||
meta, ok := v.(*ToolExecContext)
|
||||
return meta, ok && meta != nil
|
||||
}
|
||||
@@ -7,6 +7,7 @@ import (
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/Tencent/WeKnora/internal/agent/approval"
|
||||
"github.com/Tencent/WeKnora/internal/logger"
|
||||
"github.com/Tencent/WeKnora/internal/mcp"
|
||||
"github.com/Tencent/WeKnora/internal/types"
|
||||
@@ -19,14 +20,16 @@ type MCPTool struct {
|
||||
service *types.MCPService
|
||||
mcpTool *types.MCPTool
|
||||
mcpManager *mcp.MCPManager
|
||||
gate approval.MCPApproval // optional human approval before CallTool (issue #1173)
|
||||
}
|
||||
|
||||
// NewMCPTool creates a new MCP tool wrapper
|
||||
func NewMCPTool(service *types.MCPService, mcpTool *types.MCPTool, mcpManager *mcp.MCPManager) *MCPTool {
|
||||
func NewMCPTool(service *types.MCPService, mcpTool *types.MCPTool, mcpManager *mcp.MCPManager, gate approval.MCPApproval) *MCPTool {
|
||||
return &MCPTool{
|
||||
service: service,
|
||||
mcpTool: mcpTool,
|
||||
mcpManager: mcpManager,
|
||||
gate: gate,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -93,13 +96,77 @@ func (t *MCPTool) Execute(ctx context.Context, args json.RawMessage) (*types.Too
|
||||
// Parse args from json.RawMessage
|
||||
var input MCPInput
|
||||
if err := json.Unmarshal(args, &input); err != nil {
|
||||
logger.Errorf(ctx, "[Tool][DatabaseQuery] Failed to parse args: %v", err)
|
||||
logger.Errorf(ctx, "[Tool][MCPTool] Failed to parse args: %v", err)
|
||||
return &types.ToolResult{
|
||||
Success: false,
|
||||
Error: fmt.Sprintf("Failed to parse args: %v", err),
|
||||
}, err
|
||||
}
|
||||
|
||||
// Human approval gate for dangerous tools (issue #1173)
|
||||
if t.gate != nil {
|
||||
if meta, ok := ToolExecFromContext(ctx); ok && meta != nil && meta.EventBus != nil {
|
||||
tenantID, _ := types.TenantIDFromContext(ctx)
|
||||
if t.gate.NeedsApproval(ctx, tenantID, t.service.ID, t.mcpTool.Name) {
|
||||
// Use ApprovalCtx (round-level ctx WITHOUT defaultToolExecTimeout) so
|
||||
// human approval can legitimately wait longer than the per-tool 60s.
|
||||
// User-stop / request cancel still propagates because ApprovalCtx is a
|
||||
// child of the request ctx.
|
||||
waitCtx := ctx
|
||||
if meta.ApprovalCtx != nil {
|
||||
waitCtx = meta.ApprovalCtx
|
||||
}
|
||||
decision, waitErr := t.gate.RequestAndWait(waitCtx, approval.PendingRequest{
|
||||
TenantID: tenantID,
|
||||
SessionID: meta.SessionID,
|
||||
AssistantMessageID: meta.AssistantMessageID,
|
||||
RequestID: meta.RequestID,
|
||||
EventBus: meta.EventBus,
|
||||
ServiceID: t.service.ID,
|
||||
ServiceName: t.service.Name,
|
||||
MCPToolName: t.mcpTool.Name,
|
||||
RegisteredToolName: t.Name(),
|
||||
Description: t.mcpTool.Description,
|
||||
Args: args,
|
||||
ToolCallID: meta.ToolCallID,
|
||||
})
|
||||
if waitErr != nil {
|
||||
return &types.ToolResult{
|
||||
Success: false,
|
||||
Error: fmt.Sprintf("Tool approval failed: %v", waitErr),
|
||||
}, nil
|
||||
}
|
||||
if !decision.Approved {
|
||||
msg := decision.Reason
|
||||
if msg == "" {
|
||||
msg = "tool execution rejected by user"
|
||||
}
|
||||
return &types.ToolResult{
|
||||
Success: false,
|
||||
Error: msg,
|
||||
}, nil
|
||||
}
|
||||
if len(decision.ModifiedArgs) > 0 {
|
||||
args = decision.ModifiedArgs
|
||||
if err := json.Unmarshal(args, &input); err != nil {
|
||||
return &types.ToolResult{
|
||||
Success: false,
|
||||
Error: fmt.Sprintf("Invalid modified_args after approval: %v", err),
|
||||
}, nil
|
||||
}
|
||||
}
|
||||
// Approval may have consumed most/all of the per-tool 60s budget set by the
|
||||
// agent engine (act.go). Re-derive a fresh tool-exec ctx from ApprovalCtx so
|
||||
// the actual MCP CallTool gets a full timeout window. (issue #1173 follow-up)
|
||||
if meta.ApprovalCtx != nil {
|
||||
freshCtx, freshCancel := context.WithTimeout(meta.ApprovalCtx, 60*time.Second)
|
||||
defer freshCancel()
|
||||
ctx = freshCtx
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Get or create MCP client
|
||||
client, err := t.mcpManager.GetOrCreateClient(t.service)
|
||||
if err != nil {
|
||||
@@ -326,6 +393,7 @@ func RegisterMCPTools(
|
||||
registry *ToolRegistry,
|
||||
services []*types.MCPService,
|
||||
mcpManager *mcp.MCPManager,
|
||||
gate approval.MCPApproval,
|
||||
) error {
|
||||
if len(services) == 0 {
|
||||
return nil
|
||||
@@ -393,7 +461,7 @@ func RegisterMCPTools(
|
||||
|
||||
// Register each tool
|
||||
for _, mcpTool := range mcpTools {
|
||||
tool := NewMCPTool(service, mcpTool, mcpManager)
|
||||
tool := NewMCPTool(service, mcpTool, mcpManager, gate)
|
||||
toolName := tool.Name()
|
||||
|
||||
// Check for name collision before registering (first-wins policy).
|
||||
|
||||
@@ -0,0 +1,79 @@
|
||||
package repository
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/Tencent/WeKnora/internal/types"
|
||||
"github.com/Tencent/WeKnora/internal/types/interfaces"
|
||||
"github.com/google/uuid"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
// MCPToolApprovalRepository implements interfaces.MCPToolApprovalRepository.
|
||||
type MCPToolApprovalRepository struct {
|
||||
db *gorm.DB
|
||||
}
|
||||
|
||||
// NewMCPToolApprovalRepository creates a repository backed by GORM.
|
||||
func NewMCPToolApprovalRepository(db *gorm.DB) interfaces.MCPToolApprovalRepository {
|
||||
return &MCPToolApprovalRepository{db: db}
|
||||
}
|
||||
|
||||
// ListByService returns all stored approval rows for an MCP service (may be empty).
|
||||
func (r *MCPToolApprovalRepository) ListByService(ctx context.Context, tenantID uint64, serviceID string) ([]*types.MCPToolApproval, error) {
|
||||
var rows []*types.MCPToolApproval
|
||||
err := r.db.WithContext(ctx).
|
||||
Where("tenant_id = ? AND service_id = ?", tenantID, serviceID).
|
||||
Order("tool_name ASC").
|
||||
Find(&rows).Error
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("list mcp tool approvals: %w", err)
|
||||
}
|
||||
return rows, nil
|
||||
}
|
||||
|
||||
// IsRequired returns true when a row exists with require_approval = true.
|
||||
func (r *MCPToolApprovalRepository) IsRequired(ctx context.Context, tenantID uint64, serviceID, toolName string) (bool, error) {
|
||||
var row types.MCPToolApproval
|
||||
err := r.db.WithContext(ctx).
|
||||
Select("require_approval").
|
||||
Where("tenant_id = ? AND service_id = ? AND tool_name = ?", tenantID, serviceID, toolName).
|
||||
First(&row).Error
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return false, nil
|
||||
}
|
||||
if err != nil {
|
||||
return false, fmt.Errorf("get mcp tool approval: %w", err)
|
||||
}
|
||||
return row.RequireApproval, nil
|
||||
}
|
||||
|
||||
// Upsert creates or updates the approval flag for a tool.
|
||||
func (r *MCPToolApprovalRepository) Upsert(ctx context.Context, row *types.MCPToolApproval) error {
|
||||
if row == nil {
|
||||
return errors.New("row is nil")
|
||||
}
|
||||
var existing types.MCPToolApproval
|
||||
err := r.db.WithContext(ctx).
|
||||
Where("tenant_id = ? AND service_id = ? AND tool_name = ?", row.TenantID, row.ServiceID, row.ToolName).
|
||||
First(&existing).Error
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
if row.ID == "" {
|
||||
row.ID = uuid.New().String()
|
||||
}
|
||||
if err := r.db.WithContext(ctx).Create(row).Error; err != nil {
|
||||
return fmt.Errorf("create mcp tool approval: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
if err != nil {
|
||||
return fmt.Errorf("get mcp tool approval for upsert: %w", err)
|
||||
}
|
||||
return r.db.WithContext(ctx).Model(&existing).Updates(map[string]interface{}{
|
||||
"require_approval": row.RequireApproval,
|
||||
"updated_at": time.Now(),
|
||||
}).Error
|
||||
}
|
||||
@@ -8,6 +8,7 @@ import (
|
||||
"strconv"
|
||||
|
||||
"github.com/Tencent/WeKnora/internal/agent"
|
||||
"github.com/Tencent/WeKnora/internal/agent/approval"
|
||||
"github.com/Tencent/WeKnora/internal/agent/skills"
|
||||
"github.com/Tencent/WeKnora/internal/agent/tools"
|
||||
"github.com/Tencent/WeKnora/internal/config"
|
||||
@@ -59,6 +60,7 @@ type agentService struct {
|
||||
webSearchStateService interfaces.WebSearchStateService
|
||||
wikiPageService interfaces.WikiPageService
|
||||
tenantService interfaces.TenantService
|
||||
toolApprovalGate approval.MCPApproval
|
||||
}
|
||||
|
||||
// NewAgentService creates a new agent service
|
||||
@@ -78,6 +80,7 @@ func NewAgentService(
|
||||
webSearchStateService interfaces.WebSearchStateService,
|
||||
wikiPageService interfaces.WikiPageService,
|
||||
tenantService interfaces.TenantService,
|
||||
toolApprovalGate approval.MCPApproval,
|
||||
) interfaces.AgentService {
|
||||
return &agentService{
|
||||
cfg: cfg,
|
||||
@@ -95,6 +98,7 @@ func NewAgentService(
|
||||
webSearchStateService: webSearchStateService,
|
||||
wikiPageService: wikiPageService,
|
||||
tenantService: tenantService,
|
||||
toolApprovalGate: toolApprovalGate,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -223,7 +227,7 @@ func (s *agentService) registerMCPTools(
|
||||
}
|
||||
}
|
||||
if len(enabledServices) > 0 {
|
||||
if err := tools.RegisterMCPTools(ctx, toolRegistry, enabledServices, s.mcpManager); err != nil {
|
||||
if err := tools.RegisterMCPTools(ctx, toolRegistry, enabledServices, s.mcpManager, s.toolApprovalGate); err != nil {
|
||||
logger.Warnf(ctx, "Failed to register MCP tools: %v", err)
|
||||
} else {
|
||||
logger.Infof(ctx, "Registered MCP tools from %d enabled services", len(enabledServices))
|
||||
|
||||
59
internal/application/service/mcp_tool_approval_service.go
Normal file
59
internal/application/service/mcp_tool_approval_service.go
Normal file
@@ -0,0 +1,59 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
|
||||
"github.com/Tencent/WeKnora/internal/types"
|
||||
"github.com/Tencent/WeKnora/internal/types/interfaces"
|
||||
)
|
||||
|
||||
type mcpToolApprovalService struct {
|
||||
repo interfaces.MCPToolApprovalRepository
|
||||
mcpRepo interfaces.MCPServiceRepository
|
||||
}
|
||||
|
||||
// NewMCPToolApprovalService constructs the MCP tool approval service.
|
||||
func NewMCPToolApprovalService(
|
||||
repo interfaces.MCPToolApprovalRepository,
|
||||
mcpRepo interfaces.MCPServiceRepository,
|
||||
) interfaces.MCPToolApprovalService {
|
||||
return &mcpToolApprovalService{repo: repo, mcpRepo: mcpRepo}
|
||||
}
|
||||
|
||||
func (s *mcpToolApprovalService) ListByService(ctx context.Context, tenantID uint64, serviceID string) ([]*types.MCPToolApproval, error) {
|
||||
svc, err := s.mcpRepo.GetByID(ctx, tenantID, serviceID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if svc == nil {
|
||||
return nil, fmt.Errorf("mcp service not found")
|
||||
}
|
||||
return s.repo.ListByService(ctx, tenantID, serviceID)
|
||||
}
|
||||
|
||||
func (s *mcpToolApprovalService) SetRequireApproval(
|
||||
ctx context.Context, tenantID uint64, serviceID, toolName string, require bool,
|
||||
) error {
|
||||
if toolName == "" {
|
||||
return fmt.Errorf("tool_name is required")
|
||||
}
|
||||
svc, err := s.mcpRepo.GetByID(ctx, tenantID, serviceID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if svc == nil {
|
||||
return fmt.Errorf("mcp service not found")
|
||||
}
|
||||
row := &types.MCPToolApproval{
|
||||
TenantID: tenantID,
|
||||
ServiceID: serviceID,
|
||||
ToolName: toolName,
|
||||
RequireApproval: require,
|
||||
}
|
||||
return s.repo.Upsert(ctx, row)
|
||||
}
|
||||
|
||||
func (s *mcpToolApprovalService) IsRequired(ctx context.Context, tenantID uint64, serviceID, toolName string) (bool, error) {
|
||||
return s.repo.IsRequired(ctx, tenantID, serviceID, toolName)
|
||||
}
|
||||
@@ -37,6 +37,9 @@ type AgentConfig struct {
|
||||
// LLMCallTimeout is the default timeout for a single LLM call in seconds.
|
||||
// Default: 120 (standard agents) or 300 (can be overridden by Env).
|
||||
LLMCallTimeout int `yaml:"llm_call_timeout" json:"llm_call_timeout"`
|
||||
// ToolApprovalTimeoutSeconds is how long the agent waits for human approval on a flagged MCP tool.
|
||||
// 0 means default 600 (10 minutes).
|
||||
ToolApprovalTimeoutSeconds int `yaml:"tool_approval_timeout_seconds" json:"tool_approval_timeout_seconds"`
|
||||
}
|
||||
|
||||
// IMConfig configures the IM integration service.
|
||||
@@ -269,7 +272,9 @@ func DefaultTemplateByMode(templates []PromptTemplate, mode string) *PromptTempl
|
||||
|
||||
// LocalizeTemplates returns a deep copy of the template list with Name and
|
||||
// Description replaced according to the given locale. Fallback chain:
|
||||
// locale → primary language (e.g. "zh" from "zh-CN") → original Name/Description.
|
||||
//
|
||||
// locale → primary language (e.g. "zh" from "zh-CN") → original Name/Description.
|
||||
//
|
||||
// The returned slice is safe to serialise directly; it never mutates the original.
|
||||
func LocalizeTemplates(templates []PromptTemplate, locale string) []PromptTemplate {
|
||||
if len(templates) == 0 {
|
||||
@@ -568,6 +573,15 @@ func applyAgentEnvOverrides(cfg *Config) {
|
||||
cfg.Agent.LLMCallTimeout = int(sec.Seconds())
|
||||
}
|
||||
}
|
||||
// MCP tool human-approval wait timeout (issue #1173). Accepts Go duration
|
||||
// (e.g. "10m", "30s") or a bare number interpreted as seconds.
|
||||
if value := strings.TrimSpace(os.Getenv("WEKNORA_AGENT_TOOL_APPROVAL_TIMEOUT")); value != "" {
|
||||
if d, err := time.ParseDuration(value); err == nil {
|
||||
cfg.Agent.ToolApprovalTimeoutSeconds = int(d.Seconds())
|
||||
} else if d, err := time.ParseDuration(value + "s"); err == nil {
|
||||
cfg.Agent.ToolApprovalTimeoutSeconds = int(d.Seconds())
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// backfillConversationDefaults resolves prompt template ID references
|
||||
|
||||
@@ -32,6 +32,7 @@ import (
|
||||
"gorm.io/driver/sqlite"
|
||||
"gorm.io/gorm"
|
||||
|
||||
"github.com/Tencent/WeKnora/internal/agent/approval"
|
||||
"github.com/Tencent/WeKnora/internal/application/repository"
|
||||
memoryRepo "github.com/Tencent/WeKnora/internal/application/repository/memory/neo4j"
|
||||
dorisRepo "github.com/Tencent/WeKnora/internal/application/repository/retriever/doris"
|
||||
@@ -148,6 +149,7 @@ func BuildContainer(container *dig.Container) *dig.Container {
|
||||
must(container.Provide(neo4jRepo.NewNeo4jRepository))
|
||||
must(container.Provide(memoryRepo.NewMemoryRepository))
|
||||
must(container.Provide(repository.NewMCPServiceRepository))
|
||||
must(container.Provide(repository.NewMCPToolApprovalRepository))
|
||||
must(container.Provide(repository.NewCustomAgentRepository))
|
||||
must(container.Provide(repository.NewOrganizationRepository))
|
||||
must(container.Provide(repository.NewKBShareRepository))
|
||||
@@ -190,6 +192,7 @@ func BuildContainer(container *dig.Container) *dig.Container {
|
||||
|
||||
must(container.Provide(service.NewMessageService))
|
||||
must(container.Provide(service.NewMCPServiceService))
|
||||
must(container.Provide(service.NewMCPToolApprovalService))
|
||||
must(container.Provide(service.NewCustomAgentService))
|
||||
must(container.Provide(memoryService.NewMemoryService))
|
||||
must(container.Provide(service.NewWikiPageService))
|
||||
@@ -221,6 +224,11 @@ func BuildContainer(container *dig.Container) *dig.Container {
|
||||
// SessionService is passed as parameter to CreateAgentEngine method when creating AgentService
|
||||
logger.Debugf(ctx, "[Container] Registering event bus and agent service...")
|
||||
must(container.Provide(event.NewEventBus))
|
||||
must(container.Provide(func(cfg *config.Config, s interfaces.MCPToolApprovalService, rdb *redis.Client) *approval.Gate {
|
||||
return approval.NewGate(cfg, &approval.Adapter{Svc: s}, rdb)
|
||||
}))
|
||||
// Expose Gate as MCPApproval interface so AgentService and others can depend on the abstraction.
|
||||
must(container.Provide(func(g *approval.Gate) approval.MCPApproval { return g }))
|
||||
must(container.Provide(service.NewAgentService))
|
||||
|
||||
// Session service (depends on agent service)
|
||||
|
||||
@@ -54,6 +54,10 @@ const (
|
||||
EventAgentReferences EventType = "references" // 知识引用
|
||||
EventAgentFinalAnswer EventType = "final_answer" // 最终答案
|
||||
|
||||
// MCP tool human approval (issue #1173)
|
||||
EventToolApprovalRequired EventType = "tool_approval_required"
|
||||
EventToolApprovalResolved EventType = "tool_approval_resolved"
|
||||
|
||||
// Error events
|
||||
EventError EventType = "error" // 错误事件
|
||||
|
||||
|
||||
@@ -209,3 +209,31 @@ type StopData struct {
|
||||
MessageID string `json:"message_id"`
|
||||
Reason string `json:"reason,omitempty"` // Optional reason for stopping
|
||||
}
|
||||
|
||||
// ToolApprovalRequiredData is emitted when an MCP tool marked dangerous is about to run.
|
||||
type ToolApprovalRequiredData struct {
|
||||
PendingID string `json:"pending_id"`
|
||||
TenantID uint64 `json:"tenant_id"`
|
||||
SessionID string `json:"session_id"`
|
||||
AssistantMessageID string `json:"assistant_message_id"`
|
||||
ServiceID string `json:"service_id"`
|
||||
ServiceName string `json:"service_name"`
|
||||
MCPToolName string `json:"mcp_tool_name"`
|
||||
RegisteredToolName string `json:"registered_tool_name"`
|
||||
Description string `json:"description"`
|
||||
Args interface{} `json:"args,omitempty"`
|
||||
ArgsJSON string `json:"args_json,omitempty"`
|
||||
TimeoutSeconds int `json:"timeout_seconds"`
|
||||
RequestedAtUnix int64 `json:"requested_at"`
|
||||
ToolCallID string `json:"tool_call_id"`
|
||||
RequestID string `json:"request_id,omitempty"`
|
||||
}
|
||||
|
||||
// ToolApprovalResolvedData confirms the user decision (or timeout/cancel).
|
||||
type ToolApprovalResolvedData struct {
|
||||
PendingID string `json:"pending_id"`
|
||||
Approved bool `json:"approved"`
|
||||
Reason string `json:"reason,omitempty"`
|
||||
TimedOut bool `json:"timed_out,omitempty"`
|
||||
Canceled bool `json:"canceled,omitempty"`
|
||||
}
|
||||
|
||||
@@ -1,9 +1,12 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/url"
|
||||
|
||||
"github.com/Tencent/WeKnora/internal/agent/approval"
|
||||
"github.com/Tencent/WeKnora/internal/errors"
|
||||
"github.com/Tencent/WeKnora/internal/logger"
|
||||
"github.com/Tencent/WeKnora/internal/types"
|
||||
@@ -14,13 +17,21 @@ import (
|
||||
|
||||
// MCPServiceHandler handles MCP service related HTTP requests
|
||||
type MCPServiceHandler struct {
|
||||
mcpServiceService interfaces.MCPServiceService
|
||||
mcpServiceService interfaces.MCPServiceService
|
||||
mcpToolApprovalService interfaces.MCPToolApprovalService
|
||||
toolApprovalGate *approval.Gate
|
||||
}
|
||||
|
||||
// NewMCPServiceHandler creates a new MCP service handler
|
||||
func NewMCPServiceHandler(mcpServiceService interfaces.MCPServiceService) *MCPServiceHandler {
|
||||
func NewMCPServiceHandler(
|
||||
mcpServiceService interfaces.MCPServiceService,
|
||||
mcpToolApprovalService interfaces.MCPToolApprovalService,
|
||||
toolApprovalGate *approval.Gate,
|
||||
) *MCPServiceHandler {
|
||||
return &MCPServiceHandler{
|
||||
mcpServiceService: mcpServiceService,
|
||||
mcpServiceService: mcpServiceService,
|
||||
mcpToolApprovalService: mcpToolApprovalService,
|
||||
toolApprovalGate: toolApprovalGate,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -445,3 +456,116 @@ func (h *MCPServiceHandler) GetMCPServiceResources(c *gin.Context) {
|
||||
"data": resources,
|
||||
})
|
||||
}
|
||||
|
||||
// ListMCPToolApprovals returns persisted require_approval flags for tools on an MCP service.
|
||||
func (h *MCPServiceHandler) ListMCPToolApprovals(c *gin.Context) {
|
||||
ctx := c.Request.Context()
|
||||
serviceID := secutils.SanitizeForLog(c.Param("id"))
|
||||
tenantID := c.GetUint64(types.TenantIDContextKey.String())
|
||||
if tenantID == 0 {
|
||||
c.Error(errors.NewBadRequestError("Tenant ID cannot be empty"))
|
||||
return
|
||||
}
|
||||
if h.mcpToolApprovalService == nil {
|
||||
c.Error(errors.NewInternalServerError("MCP tool approval is not configured"))
|
||||
return
|
||||
}
|
||||
rows, err := h.mcpToolApprovalService.ListByService(ctx, tenantID, serviceID)
|
||||
if err != nil {
|
||||
c.Error(errors.NewNotFoundError(err.Error()))
|
||||
return
|
||||
}
|
||||
c.JSON(http.StatusOK, gin.H{"success": true, "data": rows})
|
||||
}
|
||||
|
||||
type setMCPToolApprovalBody struct {
|
||||
RequireApproval bool `json:"require_approval"`
|
||||
}
|
||||
|
||||
// SetMCPToolApproval sets whether a tool requires human approval before the agent may call it.
|
||||
func (h *MCPServiceHandler) SetMCPToolApproval(c *gin.Context) {
|
||||
ctx := c.Request.Context()
|
||||
serviceID := secutils.SanitizeForLog(c.Param("id"))
|
||||
rawName := c.Param("tool_name")
|
||||
toolName, err := url.PathUnescape(rawName)
|
||||
if err != nil {
|
||||
toolName = rawName
|
||||
}
|
||||
tenantID := c.GetUint64(types.TenantIDContextKey.String())
|
||||
if tenantID == 0 {
|
||||
c.Error(errors.NewBadRequestError("Tenant ID cannot be empty"))
|
||||
return
|
||||
}
|
||||
if h.mcpToolApprovalService == nil {
|
||||
c.Error(errors.NewInternalServerError("MCP tool approval is not configured"))
|
||||
return
|
||||
}
|
||||
var body setMCPToolApprovalBody
|
||||
if err := c.ShouldBindJSON(&body); err != nil {
|
||||
c.Error(errors.NewBadRequestError(err.Error()))
|
||||
return
|
||||
}
|
||||
if err := h.mcpToolApprovalService.SetRequireApproval(ctx, tenantID, serviceID, toolName, body.RequireApproval); err != nil {
|
||||
c.Error(errors.NewInternalServerError(err.Error()))
|
||||
return
|
||||
}
|
||||
c.JSON(http.StatusOK, gin.H{"success": true})
|
||||
}
|
||||
|
||||
type resolveToolApprovalBody struct {
|
||||
Decision string `json:"decision" binding:"required"` // approve | reject
|
||||
ModifiedArgs json.RawMessage `json:"modified_args"`
|
||||
Reason string `json:"reason"`
|
||||
}
|
||||
|
||||
// ResolveToolApproval completes a pending MCP tool approval (agent execution resumes).
|
||||
func (h *MCPServiceHandler) ResolveToolApproval(c *gin.Context) {
|
||||
ctx := c.Request.Context()
|
||||
pendingID := c.Param("pending_id")
|
||||
tenantID := c.GetUint64(types.TenantIDContextKey.String())
|
||||
if tenantID == 0 {
|
||||
c.Error(errors.NewBadRequestError("Tenant ID cannot be empty"))
|
||||
return
|
||||
}
|
||||
if h.toolApprovalGate == nil {
|
||||
c.Error(errors.NewInternalServerError("Tool approval gate is not configured"))
|
||||
return
|
||||
}
|
||||
var body resolveToolApprovalBody
|
||||
if err := c.ShouldBindJSON(&body); err != nil {
|
||||
c.Error(errors.NewBadRequestError(err.Error()))
|
||||
return
|
||||
}
|
||||
dec := approval.Decision{Reason: body.Reason}
|
||||
switch body.Decision {
|
||||
case "approve":
|
||||
dec.Approved = true
|
||||
if len(body.ModifiedArgs) > 0 {
|
||||
var probe map[string]interface{}
|
||||
if err := json.Unmarshal(body.ModifiedArgs, &probe); err != nil {
|
||||
c.Error(errors.NewBadRequestError("modified_args must be a JSON object"))
|
||||
return
|
||||
}
|
||||
dec.ModifiedArgs = body.ModifiedArgs
|
||||
}
|
||||
case "reject":
|
||||
dec.Approved = false
|
||||
default:
|
||||
c.Error(errors.NewBadRequestError("decision must be approve or reject"))
|
||||
return
|
||||
}
|
||||
if err := h.toolApprovalGate.Resolve(tenantID, pendingID, dec); err != nil {
|
||||
if err == approval.ErrPendingNotFound {
|
||||
c.Error(errors.NewNotFoundError("pending approval not found or already completed"))
|
||||
return
|
||||
}
|
||||
if err == approval.ErrTenantMismatch {
|
||||
c.Error(errors.NewBadRequestError("tenant mismatch"))
|
||||
return
|
||||
}
|
||||
logger.ErrorWithFields(ctx, err, map[string]interface{}{"pending_id": pendingID})
|
||||
c.Error(errors.NewInternalServerError(err.Error()))
|
||||
return
|
||||
}
|
||||
c.JSON(http.StatusOK, gin.H{"success": true})
|
||||
}
|
||||
|
||||
@@ -2,6 +2,7 @@ package session
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"sync"
|
||||
"time"
|
||||
@@ -66,6 +67,8 @@ func (h *AgentStreamHandler) Subscribe() {
|
||||
h.eventBus.On(event.EventError, h.handleError)
|
||||
h.eventBus.On(event.EventSessionTitle, h.handleSessionTitle)
|
||||
h.eventBus.On(event.EventAgentComplete, h.handleComplete)
|
||||
h.eventBus.On(event.EventToolApprovalRequired, h.handleToolApprovalRequired)
|
||||
h.eventBus.On(event.EventToolApprovalResolved, h.handleToolApprovalResolved)
|
||||
}
|
||||
|
||||
// handleThought handles agent thought events
|
||||
@@ -210,6 +213,60 @@ func (h *AgentStreamHandler) handleToolResult(ctx context.Context, evt event.Eve
|
||||
return nil
|
||||
}
|
||||
|
||||
func toolApprovalDataToMap(v interface{}) map[string]interface{} {
|
||||
b, err := json.Marshal(v)
|
||||
if err != nil {
|
||||
return map[string]interface{}{}
|
||||
}
|
||||
var m map[string]interface{}
|
||||
if err := json.Unmarshal(b, &m); err != nil {
|
||||
return map[string]interface{}{}
|
||||
}
|
||||
return m
|
||||
}
|
||||
|
||||
// handleToolApprovalRequired persists MCP tool human-approval prompts for SSE / replay (issue #1173).
|
||||
func (h *AgentStreamHandler) handleToolApprovalRequired(ctx context.Context, evt event.Event) error {
|
||||
data, ok := evt.Data.(event.ToolApprovalRequiredData)
|
||||
if !ok {
|
||||
return nil
|
||||
}
|
||||
meta := toolApprovalDataToMap(data)
|
||||
meta["pending_id"] = data.PendingID
|
||||
if err := h.streamManager.AppendEvent(h.ctx, h.sessionID, h.assistantMessageID, interfaces.StreamEvent{
|
||||
ID: evt.ID,
|
||||
Type: types.ResponseTypeToolApprovalRequired,
|
||||
Content: "MCP tool requires human approval",
|
||||
Done: true,
|
||||
Timestamp: time.Now(),
|
||||
Data: meta,
|
||||
}); err != nil {
|
||||
logger.GetLogger(h.ctx).Error("Append tool approval required event failed", "error", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// handleToolApprovalResolved persists the outcome of a tool approval (issue #1173).
|
||||
func (h *AgentStreamHandler) handleToolApprovalResolved(ctx context.Context, evt event.Event) error {
|
||||
data, ok := evt.Data.(event.ToolApprovalResolvedData)
|
||||
if !ok {
|
||||
return nil
|
||||
}
|
||||
meta := toolApprovalDataToMap(data)
|
||||
meta["pending_id"] = data.PendingID
|
||||
if err := h.streamManager.AppendEvent(h.ctx, h.sessionID, h.assistantMessageID, interfaces.StreamEvent{
|
||||
ID: evt.ID,
|
||||
Type: types.ResponseTypeToolApprovalResolved,
|
||||
Content: "MCP tool approval resolved",
|
||||
Done: true,
|
||||
Timestamp: time.Now(),
|
||||
Data: meta,
|
||||
}); err != nil {
|
||||
logger.GetLogger(h.ctx).Error("Append tool approval resolved event failed", "error", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// handleReferences handles knowledge references events
|
||||
func (h *AgentStreamHandler) handleReferences(ctx context.Context, evt event.Event) error {
|
||||
data, ok := evt.Data.(event.AgentReferencesData)
|
||||
|
||||
@@ -500,6 +500,14 @@ func RegisterMCPServiceRoutes(r *gin.RouterGroup, handler *handler.MCPServiceHan
|
||||
mcpServices.GET("/:id/tools", handler.GetMCPServiceTools)
|
||||
// Get MCP service resources
|
||||
mcpServices.GET("/:id/resources", handler.GetMCPServiceResources)
|
||||
// MCP tool human approval (issue #1173)
|
||||
mcpServices.GET("/:id/tool-approvals", handler.ListMCPToolApprovals)
|
||||
mcpServices.PUT("/:id/tool-approvals/:tool_name", handler.SetMCPToolApproval)
|
||||
}
|
||||
|
||||
agentTool := r.Group("/agent")
|
||||
{
|
||||
agentTool.POST("/tool-approvals/:pending_id", handler.ResolveToolApproval)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -27,10 +27,10 @@ type FunctionCall struct {
|
||||
|
||||
// ChatResponse chat response
|
||||
type ChatResponse struct {
|
||||
Content string `json:"content"`
|
||||
Content string `json:"content"`
|
||||
ToolCalls []LLMToolCall `json:"tool_calls,omitempty"`
|
||||
FinishReason string `json:"finish_reason,omitempty"`
|
||||
Usage TokenUsage `json:"usage"`
|
||||
FinishReason string `json:"finish_reason,omitempty"`
|
||||
Usage TokenUsage `json:"usage"`
|
||||
}
|
||||
|
||||
// Response type
|
||||
@@ -57,6 +57,10 @@ const (
|
||||
ResponseTypeAgentQuery ResponseType = "agent_query"
|
||||
// Complete response type (agent complete)
|
||||
ResponseTypeComplete ResponseType = "complete"
|
||||
// ToolApprovalRequired: MCP tool marked dangerous — UI must collect user approval before execution continues
|
||||
ResponseTypeToolApprovalRequired ResponseType = "tool_approval_required"
|
||||
// ToolApprovalResolved: user approved/rejected (or timeout); informational for UI replay
|
||||
ResponseTypeToolApprovalResolved ResponseType = "tool_approval_resolved"
|
||||
)
|
||||
|
||||
// StreamResponse stream response
|
||||
|
||||
21
internal/types/interfaces/mcp_tool_approval.go
Normal file
21
internal/types/interfaces/mcp_tool_approval.go
Normal file
@@ -0,0 +1,21 @@
|
||||
package interfaces
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/Tencent/WeKnora/internal/types"
|
||||
)
|
||||
|
||||
// MCPToolApprovalRepository persists per-tool approval requirements.
|
||||
type MCPToolApprovalRepository interface {
|
||||
ListByService(ctx context.Context, tenantID uint64, serviceID string) ([]*types.MCPToolApproval, error)
|
||||
IsRequired(ctx context.Context, tenantID uint64, serviceID, toolName string) (bool, error)
|
||||
Upsert(ctx context.Context, row *types.MCPToolApproval) error
|
||||
}
|
||||
|
||||
// MCPToolApprovalService is the business layer for MCP tool approval flags.
|
||||
type MCPToolApprovalService interface {
|
||||
ListByService(ctx context.Context, tenantID uint64, serviceID string) ([]*types.MCPToolApproval, error)
|
||||
SetRequireApproval(ctx context.Context, tenantID uint64, serviceID, toolName string, require bool) error
|
||||
IsRequired(ctx context.Context, tenantID uint64, serviceID, toolName string) (bool, error)
|
||||
}
|
||||
@@ -30,9 +30,9 @@ type MCPService struct {
|
||||
Headers MCPHeaders `json:"headers" gorm:"type:json"`
|
||||
AuthConfig *MCPAuthConfig `json:"auth_config" gorm:"type:json"`
|
||||
AdvancedConfig *MCPAdvancedConfig `json:"advanced_config" gorm:"type:json"`
|
||||
StdioConfig *MCPStdioConfig `json:"stdio_config,omitempty" gorm:"type:json"` // Required for stdio transport
|
||||
EnvVars MCPEnvVars `json:"env_vars,omitempty" gorm:"type:json"` // Environment variables for stdio
|
||||
IsBuiltin bool `json:"is_builtin" gorm:"default:false"` // Whether this is a builtin MCP service (visible to all tenants)
|
||||
StdioConfig *MCPStdioConfig `json:"stdio_config,omitempty" gorm:"type:json"` // Required for stdio transport
|
||||
EnvVars MCPEnvVars `json:"env_vars,omitempty" gorm:"type:json"` // Environment variables for stdio
|
||||
IsBuiltin bool `json:"is_builtin" gorm:"default:false"` // Whether this is a builtin MCP service (visible to all tenants)
|
||||
CreatedAt time.Time `json:"created_at"`
|
||||
UpdatedAt time.Time `json:"updated_at"`
|
||||
DeletedAt gorm.DeletedAt `json:"deleted_at" gorm:"index"`
|
||||
@@ -69,6 +69,28 @@ type MCPTool struct {
|
||||
Name string `json:"name"`
|
||||
Description string `json:"description"`
|
||||
InputSchema json.RawMessage `json:"inputSchema"` // JSON Schema for tool parameters
|
||||
// RequireApproval when true: agent execution pauses until the user approves in UI (issue #1173).
|
||||
RequireApproval bool `json:"require_approval,omitempty"`
|
||||
}
|
||||
|
||||
// MCPToolApproval persists per-tool "danger / needs human approval" for an MCP service.
|
||||
// Tool list itself comes from MCP ListTools; this table only stores overrides.
|
||||
type MCPToolApproval struct {
|
||||
ID string `json:"id" gorm:"type:varchar(36);primaryKey"`
|
||||
TenantID uint64 `json:"tenant_id" gorm:"not null;uniqueIndex:idx_mcp_tool_approvals_tenant_svc_tool"`
|
||||
ServiceID string `json:"service_id" gorm:"type:varchar(36);not null;uniqueIndex:idx_mcp_tool_approvals_tenant_svc_tool;index"`
|
||||
ToolName string `json:"tool_name" gorm:"type:varchar(512);not null;uniqueIndex:idx_mcp_tool_approvals_tenant_svc_tool"`
|
||||
RequireApproval bool `json:"require_approval" gorm:"not null;default:false"`
|
||||
CreatedAt time.Time `json:"created_at"`
|
||||
UpdatedAt time.Time `json:"updated_at"`
|
||||
}
|
||||
|
||||
// BeforeCreate sets ID for MCPToolApproval.
|
||||
func (m *MCPToolApproval) BeforeCreate(tx *gorm.DB) error {
|
||||
if m.ID == "" {
|
||||
m.ID = uuid.New().String()
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// MCPResource represents a resource exposed by an MCP service
|
||||
|
||||
@@ -5,6 +5,7 @@ DROP TABLE IF EXISTS kb_shares;
|
||||
DROP TABLE IF EXISTS organization_members;
|
||||
DROP TABLE IF EXISTS organizations;
|
||||
DROP TABLE IF EXISTS custom_agents;
|
||||
DROP TABLE IF EXISTS mcp_tool_approvals;
|
||||
DROP TABLE IF EXISTS mcp_services;
|
||||
DROP TABLE IF EXISTS knowledge_tags;
|
||||
DROP TABLE IF EXISTS auth_tokens;
|
||||
|
||||
@@ -286,6 +286,20 @@ CREATE INDEX IF NOT EXISTS idx_mcp_services_enabled ON mcp_services(enabled);
|
||||
CREATE INDEX IF NOT EXISTS idx_mcp_services_is_builtin ON mcp_services(is_builtin);
|
||||
CREATE INDEX IF NOT EXISTS idx_mcp_services_deleted_at ON mcp_services(deleted_at);
|
||||
|
||||
CREATE TABLE IF NOT EXISTS mcp_tool_approvals (
|
||||
id VARCHAR(36) PRIMARY KEY,
|
||||
tenant_id INTEGER NOT NULL,
|
||||
service_id VARCHAR(36) NOT NULL,
|
||||
tool_name VARCHAR(512) NOT NULL,
|
||||
require_approval BOOLEAN NOT NULL DEFAULT 0,
|
||||
created_at DATETIME DEFAULT CURRENT_TIMESTAMP,
|
||||
updated_at DATETIME DEFAULT CURRENT_TIMESTAMP,
|
||||
FOREIGN KEY (service_id) REFERENCES mcp_services(id) ON DELETE CASCADE
|
||||
);
|
||||
|
||||
CREATE UNIQUE INDEX IF NOT EXISTS idx_mcp_tool_approvals_tenant_svc_tool ON mcp_tool_approvals(tenant_id, service_id, tool_name);
|
||||
CREATE INDEX IF NOT EXISTS idx_mcp_tool_approvals_service_id ON mcp_tool_approvals(service_id);
|
||||
|
||||
CREATE TABLE IF NOT EXISTS custom_agents (
|
||||
id VARCHAR(36) NOT NULL,
|
||||
name VARCHAR(255) NOT NULL,
|
||||
|
||||
7
migrations/versioned/000042_mcp_tool_approval.down.sql
Normal file
7
migrations/versioned/000042_mcp_tool_approval.down.sql
Normal file
@@ -0,0 +1,7 @@
|
||||
DO $$ BEGIN RAISE NOTICE '[Migration 000042 DOWN] Dropping mcp_tool_approvals...'; END $$;
|
||||
|
||||
DROP INDEX IF EXISTS idx_mcp_tool_approvals_service_id;
|
||||
DROP INDEX IF EXISTS idx_mcp_tool_approvals_tenant_svc_tool;
|
||||
DROP TABLE IF EXISTS mcp_tool_approvals;
|
||||
|
||||
DO $$ BEGIN RAISE NOTICE '[Migration 000042 DOWN] Done'; END $$;
|
||||
19
migrations/versioned/000042_mcp_tool_approval.up.sql
Normal file
19
migrations/versioned/000042_mcp_tool_approval.up.sql
Normal file
@@ -0,0 +1,19 @@
|
||||
-- MCP tool human-approval flags (per service + tool name, issue #1173)
|
||||
DO $$ BEGIN RAISE NOTICE '[Migration 000042] Creating mcp_tool_approvals...'; END $$;
|
||||
|
||||
CREATE TABLE IF NOT EXISTS mcp_tool_approvals (
|
||||
id VARCHAR(36) PRIMARY KEY,
|
||||
tenant_id INTEGER NOT NULL,
|
||||
service_id VARCHAR(36) NOT NULL REFERENCES mcp_services(id) ON DELETE CASCADE,
|
||||
tool_name VARCHAR(512) NOT NULL,
|
||||
require_approval BOOLEAN NOT NULL DEFAULT false,
|
||||
created_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP,
|
||||
updated_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP
|
||||
);
|
||||
|
||||
CREATE UNIQUE INDEX IF NOT EXISTS idx_mcp_tool_approvals_tenant_svc_tool
|
||||
ON mcp_tool_approvals(tenant_id, service_id, tool_name);
|
||||
|
||||
CREATE INDEX IF NOT EXISTS idx_mcp_tool_approvals_service_id ON mcp_tool_approvals(service_id);
|
||||
|
||||
DO $$ BEGIN RAISE NOTICE '[Migration 000042] mcp_tool_approvals ready'; END $$;
|
||||
Reference in New Issue
Block a user