fix(mcp-server): address Copilot PR review comments

- SSL verification now defaults to enabled; set WEKNORA_VERIFY_SSL=false to
  opt out (with a logged warning). Fixes MITM risk from default-off TLS.
- WEKNORA_CHAT_TIMEOUT parse is now guarded with try/except ValueError so a
  bad env value falls back to 300s instead of crashing at import.
- SSE streaming response is now closed via context manager (with response:)
  to guarantee connection pool return even on early break.
- Replace asyncio.get_event_loop() (deprecated) with asyncio.get_running_loop()
  in both chat and agent_chat handlers.
- create_session now calls resolve_kb_id() so KB names are accepted in addition
  to UUIDs (consistent with chat / hybrid_search).
- knowledge_base_ids description changed from REQUIRED to Strongly recommended
  to match actual schema optionality.
- run_sse() handle_sse rewritten as raw ASGI callable (scope, receive, send) to
  avoid accessing Starlette private _send attribute.
- Fix main.py comment: http transport is Streamable HTTP (MCP spec), not long-polling.
This commit is contained in:
mileslai
2026-05-29 11:48:19 +08:00
committed by lyingbug
parent 835148626b
commit b603b1dcfa
2 changed files with 55 additions and 43 deletions

View File

@@ -138,7 +138,7 @@ async def main():
# Select transport mode based on CLI argument or MCP_TRANSPORT env var
# - stdio: Default, used by VS Code Copilot for local integration
# - sse: Server-Sent Events over HTTP, suitable for cloud/remote deployments
# - http: HTTP long-polling, compatible with various client architectures
# - http: Streamable HTTP sessions (MCP 2025-03-26 spec), compatible with REST clients
if args.transport == "stdio":
# Stdio mode: communication via stdin/stdout pipes (typical for CLI integrations)
await run_stdio()

View File

@@ -30,7 +30,11 @@ logger = logging.getLogger(__name__)
WEKNORA_BASE_URL = os.getenv("WEKNORA_BASE_URL", "http://localhost:8080/api/v1")
WEKNORA_API_KEY = os.getenv("WEKNORA_API_KEY", "")
# Chat SSE read timeout in seconds. LLM responses can be slow; default 300s.
WEKNORA_CHAT_TIMEOUT = int(os.getenv("WEKNORA_CHAT_TIMEOUT", "300"))
try:
WEKNORA_CHAT_TIMEOUT = int(os.getenv("WEKNORA_CHAT_TIMEOUT", "300"))
except ValueError:
logger.warning("WEKNORA_CHAT_TIMEOUT is not a valid integer; falling back to 300s.")
WEKNORA_CHAT_TIMEOUT = 300
class WeKnoraClient:
@@ -40,9 +44,14 @@ class WeKnoraClient:
"""Initialize the WeKnora API client with base URL and authentication"""
self.base_url = base_url
self.api_key = api_key
# SSL verification: set WEKNORA_VERIFY_SSL=true to enable (default off for self-signed certs)
self.verify_ssl = os.getenv("WEKNORA_VERIFY_SSL", "false").lower() == "true"
# SSL verification: enabled by default. Set WEKNORA_VERIFY_SSL=false to disable
# (e.g. for self-signed certs in dev environments — NOT recommended for production).
self.verify_ssl = os.getenv("WEKNORA_VERIFY_SSL", "true").lower() != "false"
if not self.verify_ssl:
logger.warning(
"SSL certificate verification is DISABLED (WEKNORA_VERIFY_SSL=false). "
"This is insecure and should not be used in production."
)
urllib3.disable_warnings(urllib3.exceptions.InsecureRequestWarning)
# Create a persistent session for connection pooling and performance
self.session = requests.Session()
@@ -337,36 +346,39 @@ class WeKnoraClient:
references: list = []
debug_events: list = []
for raw_line in response.iter_lines():
if not raw_line:
continue
if isinstance(raw_line, bytes):
raw_line = raw_line.decode("utf-8")
# Each SSE event is prefixed with "data: " followed by JSON payload
if not raw_line.startswith("data:"):
continue
payload = raw_line[5:].lstrip(" ")
try:
event_data = json.loads(payload)
except json.JSONDecodeError:
continue
# Use context manager to ensure the connection is returned to the pool
# even when breaking early on a 'complete' event.
with response:
for raw_line in response.iter_lines():
if not raw_line:
continue
if isinstance(raw_line, bytes):
raw_line = raw_line.decode("utf-8")
# Each SSE event is prefixed with "data: " followed by JSON payload
if not raw_line.startswith("data:"):
continue
payload = raw_line[5:].lstrip(" ")
try:
event_data = json.loads(payload)
except json.JSONDecodeError:
continue
response_type = event_data.get("response_type", "")
debug_events.append({"type": response_type, "content": event_data.get("content", "")[:80]})
response_type = event_data.get("response_type", "")
debug_events.append({"type": response_type, "content": event_data.get("content", "")[:80]})
# Parse different SSE event types: answer chunks, references, errors, completion
if response_type == "answer":
chunk = event_data.get("content", "")
if chunk:
answer_chunks.append(chunk)
elif response_type == "references":
references = event_data.get("knowledge_references") or []
elif response_type == "error":
raise RequestException(
f"Server error: {event_data.get('content', 'unknown error')}"
)
elif response_type == "complete":
break
# Parse different SSE event types: answer chunks, references, errors, completion
if response_type == "answer":
chunk = event_data.get("content", "")
if chunk:
answer_chunks.append(chunk)
elif response_type == "references":
references = event_data.get("knowledge_references") or []
elif response_type == "error":
raise RequestException(
f"Server error: {event_data.get('content', 'unknown error')}"
)
elif response_type == "complete":
break
return {
"answer": "".join(answer_chunks),
@@ -832,7 +844,7 @@ async def handle_list_tools() -> list[types.Tool]:
"knowledge_base_ids": {
"type": "array",
"items": {"type": "string"},
"description": "Knowledge base names OR UUIDs to search. REQUIRED for RAG. E.g. ['my-knowledge-base'] or ['a1b2c3d4-...']. Use list_knowledge_bases to find them.",
"description": "Knowledge base names OR UUIDs to search. Strongly recommended for RAG — without them the answer falls back to LLM knowledge only. E.g. ['my-knowledge-base'] or ['a1b2c3d4-...']. Use list_knowledge_bases to find them.",
},
"web_search_enabled": {"type": "boolean", "description": "Enable web search alongside KB retrieval.", "default": False},
"enable_memory": {"type": "boolean", "description": "Enable cross-session memory.", "default": False},
@@ -1115,7 +1127,7 @@ async def handle_call_tool(
# Strategy includes: max conversation rounds, query rewriting, summarization model,
# fallback response handling, and retrieval thresholds (keyword/vector similarity).
result = client.create_session(
kb_id=args["kb_id"],
kb_id=client.resolve_kb_id(args["kb_id"]),
max_rounds=args.get("max_rounds", 5),
enable_rewrite=args.get("enable_rewrite", True),
fallback_response=args.get(
@@ -1149,7 +1161,8 @@ async def handle_call_tool(
web_search_enabled=args.get("web_search_enabled", False),
enable_memory=args.get("enable_memory", False),
)
result = await asyncio.get_event_loop().run_in_executor(None, fn)
# get_running_loop() is the correct API inside async functions (get_event_loop() is deprecated)
result = await asyncio.get_running_loop().run_in_executor(None, fn)
elif name == "agent_chat":
# Autonomous agent tool-calling: agent decides which tools to invoke (knowledge_search, web_search, etc.)
@@ -1199,7 +1212,7 @@ async def handle_call_tool(
web_search_enabled=args.get("web_search_enabled", False),
enable_memory=args.get("enable_memory", False),
)
result = await asyncio.get_event_loop().run_in_executor(None, fn)
result = await asyncio.get_running_loop().run_in_executor(None, fn)
elif name == "list_agents":
result = client.list_agents(
@@ -1273,8 +1286,7 @@ async def run_sse(host: str, port: int):
try:
from mcp.server.sse import SseServerTransport
from starlette.applications import Starlette
from starlette.requests import Request
from starlette.routing import Mount, Route
from starlette.routing import Mount
import uvicorn
except ImportError as e:
raise ImportError(
@@ -1283,15 +1295,15 @@ async def run_sse(host: str, port: int):
sse = SseServerTransport("/messages/")
async def handle_sse(request: Request):
async with sse.connect_sse(
request.scope, request.receive, request._send
) as streams:
# Use a raw ASGI callable instead of a Starlette Request endpoint to avoid
# accessing Starlette's private _send attribute (which can break across versions).
async def handle_sse(scope, receive, send):
async with sse.connect_sse(scope, receive, send) as streams:
await app.run(streams[0], streams[1], _init_options())
starlette_app = Starlette(
routes=[
Route("/sse", endpoint=handle_sse),
Mount("/sse", app=handle_sse),
Mount("/messages/", app=sse.handle_post_message),
]
)