fix(mcp-server): address Copilot PR review comments

- SSL verification now defaults to enabled; set WEKNORA_VERIFY_SSL=false to
  opt out (with a logged warning). Fixes MITM risk from default-off TLS.
- WEKNORA_CHAT_TIMEOUT parse is now guarded with try/except ValueError so a
  bad env value falls back to 300s instead of crashing at import.
- SSE streaming response is now closed via context manager (with response:)
  to guarantee connection pool return even on early break.
- Replace asyncio.get_event_loop() (deprecated) with asyncio.get_running_loop()
  in both chat and agent_chat handlers.
- create_session now calls resolve_kb_id() so KB names are accepted in addition
  to UUIDs (consistent with chat / hybrid_search).
- knowledge_base_ids description changed from REQUIRED to Strongly recommended
  to match actual schema optionality.
- run_sse() handle_sse rewritten as raw ASGI callable (scope, receive, send) to
  avoid accessing Starlette private _send attribute.
- Fix main.py comment: http transport is Streamable HTTP (MCP spec), not long-polling.
This commit is contained in:
mileslai
2026-05-29 11:48:19 +08:00
committed by lyingbug
parent 835148626b
commit b603b1dcfa
2 changed files with 55 additions and 43 deletions

View File

@@ -138,7 +138,7 @@ async def main():
# Select transport mode based on CLI argument or MCP_TRANSPORT env var # Select transport mode based on CLI argument or MCP_TRANSPORT env var
# - stdio: Default, used by VS Code Copilot for local integration # - stdio: Default, used by VS Code Copilot for local integration
# - sse: Server-Sent Events over HTTP, suitable for cloud/remote deployments # - sse: Server-Sent Events over HTTP, suitable for cloud/remote deployments
# - http: HTTP long-polling, compatible with various client architectures # - http: Streamable HTTP sessions (MCP 2025-03-26 spec), compatible with REST clients
if args.transport == "stdio": if args.transport == "stdio":
# Stdio mode: communication via stdin/stdout pipes (typical for CLI integrations) # Stdio mode: communication via stdin/stdout pipes (typical for CLI integrations)
await run_stdio() await run_stdio()

View File

@@ -30,7 +30,11 @@ logger = logging.getLogger(__name__)
WEKNORA_BASE_URL = os.getenv("WEKNORA_BASE_URL", "http://localhost:8080/api/v1") WEKNORA_BASE_URL = os.getenv("WEKNORA_BASE_URL", "http://localhost:8080/api/v1")
WEKNORA_API_KEY = os.getenv("WEKNORA_API_KEY", "") WEKNORA_API_KEY = os.getenv("WEKNORA_API_KEY", "")
# Chat SSE read timeout in seconds. LLM responses can be slow; default 300s. # Chat SSE read timeout in seconds. LLM responses can be slow; default 300s.
WEKNORA_CHAT_TIMEOUT = int(os.getenv("WEKNORA_CHAT_TIMEOUT", "300")) try:
WEKNORA_CHAT_TIMEOUT = int(os.getenv("WEKNORA_CHAT_TIMEOUT", "300"))
except ValueError:
logger.warning("WEKNORA_CHAT_TIMEOUT is not a valid integer; falling back to 300s.")
WEKNORA_CHAT_TIMEOUT = 300
class WeKnoraClient: class WeKnoraClient:
@@ -40,9 +44,14 @@ class WeKnoraClient:
"""Initialize the WeKnora API client with base URL and authentication""" """Initialize the WeKnora API client with base URL and authentication"""
self.base_url = base_url self.base_url = base_url
self.api_key = api_key self.api_key = api_key
# SSL verification: set WEKNORA_VERIFY_SSL=true to enable (default off for self-signed certs) # SSL verification: enabled by default. Set WEKNORA_VERIFY_SSL=false to disable
self.verify_ssl = os.getenv("WEKNORA_VERIFY_SSL", "false").lower() == "true" # (e.g. for self-signed certs in dev environments — NOT recommended for production).
self.verify_ssl = os.getenv("WEKNORA_VERIFY_SSL", "true").lower() != "false"
if not self.verify_ssl: if not self.verify_ssl:
logger.warning(
"SSL certificate verification is DISABLED (WEKNORA_VERIFY_SSL=false). "
"This is insecure and should not be used in production."
)
urllib3.disable_warnings(urllib3.exceptions.InsecureRequestWarning) urllib3.disable_warnings(urllib3.exceptions.InsecureRequestWarning)
# Create a persistent session for connection pooling and performance # Create a persistent session for connection pooling and performance
self.session = requests.Session() self.session = requests.Session()
@@ -337,36 +346,39 @@ class WeKnoraClient:
references: list = [] references: list = []
debug_events: list = [] debug_events: list = []
for raw_line in response.iter_lines(): # Use context manager to ensure the connection is returned to the pool
if not raw_line: # even when breaking early on a 'complete' event.
continue with response:
if isinstance(raw_line, bytes): for raw_line in response.iter_lines():
raw_line = raw_line.decode("utf-8") if not raw_line:
# Each SSE event is prefixed with "data: " followed by JSON payload continue
if not raw_line.startswith("data:"): if isinstance(raw_line, bytes):
continue raw_line = raw_line.decode("utf-8")
payload = raw_line[5:].lstrip(" ") # Each SSE event is prefixed with "data: " followed by JSON payload
try: if not raw_line.startswith("data:"):
event_data = json.loads(payload) continue
except json.JSONDecodeError: payload = raw_line[5:].lstrip(" ")
continue try:
event_data = json.loads(payload)
except json.JSONDecodeError:
continue
response_type = event_data.get("response_type", "") response_type = event_data.get("response_type", "")
debug_events.append({"type": response_type, "content": event_data.get("content", "")[:80]}) debug_events.append({"type": response_type, "content": event_data.get("content", "")[:80]})
# Parse different SSE event types: answer chunks, references, errors, completion # Parse different SSE event types: answer chunks, references, errors, completion
if response_type == "answer": if response_type == "answer":
chunk = event_data.get("content", "") chunk = event_data.get("content", "")
if chunk: if chunk:
answer_chunks.append(chunk) answer_chunks.append(chunk)
elif response_type == "references": elif response_type == "references":
references = event_data.get("knowledge_references") or [] references = event_data.get("knowledge_references") or []
elif response_type == "error": elif response_type == "error":
raise RequestException( raise RequestException(
f"Server error: {event_data.get('content', 'unknown error')}" f"Server error: {event_data.get('content', 'unknown error')}"
) )
elif response_type == "complete": elif response_type == "complete":
break break
return { return {
"answer": "".join(answer_chunks), "answer": "".join(answer_chunks),
@@ -832,7 +844,7 @@ async def handle_list_tools() -> list[types.Tool]:
"knowledge_base_ids": { "knowledge_base_ids": {
"type": "array", "type": "array",
"items": {"type": "string"}, "items": {"type": "string"},
"description": "Knowledge base names OR UUIDs to search. REQUIRED for RAG. E.g. ['my-knowledge-base'] or ['a1b2c3d4-...']. Use list_knowledge_bases to find them.", "description": "Knowledge base names OR UUIDs to search. Strongly recommended for RAG — without them the answer falls back to LLM knowledge only. E.g. ['my-knowledge-base'] or ['a1b2c3d4-...']. Use list_knowledge_bases to find them.",
}, },
"web_search_enabled": {"type": "boolean", "description": "Enable web search alongside KB retrieval.", "default": False}, "web_search_enabled": {"type": "boolean", "description": "Enable web search alongside KB retrieval.", "default": False},
"enable_memory": {"type": "boolean", "description": "Enable cross-session memory.", "default": False}, "enable_memory": {"type": "boolean", "description": "Enable cross-session memory.", "default": False},
@@ -1115,7 +1127,7 @@ async def handle_call_tool(
# Strategy includes: max conversation rounds, query rewriting, summarization model, # Strategy includes: max conversation rounds, query rewriting, summarization model,
# fallback response handling, and retrieval thresholds (keyword/vector similarity). # fallback response handling, and retrieval thresholds (keyword/vector similarity).
result = client.create_session( result = client.create_session(
kb_id=args["kb_id"], kb_id=client.resolve_kb_id(args["kb_id"]),
max_rounds=args.get("max_rounds", 5), max_rounds=args.get("max_rounds", 5),
enable_rewrite=args.get("enable_rewrite", True), enable_rewrite=args.get("enable_rewrite", True),
fallback_response=args.get( fallback_response=args.get(
@@ -1149,7 +1161,8 @@ async def handle_call_tool(
web_search_enabled=args.get("web_search_enabled", False), web_search_enabled=args.get("web_search_enabled", False),
enable_memory=args.get("enable_memory", False), enable_memory=args.get("enable_memory", False),
) )
result = await asyncio.get_event_loop().run_in_executor(None, fn) # get_running_loop() is the correct API inside async functions (get_event_loop() is deprecated)
result = await asyncio.get_running_loop().run_in_executor(None, fn)
elif name == "agent_chat": elif name == "agent_chat":
# Autonomous agent tool-calling: agent decides which tools to invoke (knowledge_search, web_search, etc.) # Autonomous agent tool-calling: agent decides which tools to invoke (knowledge_search, web_search, etc.)
@@ -1199,7 +1212,7 @@ async def handle_call_tool(
web_search_enabled=args.get("web_search_enabled", False), web_search_enabled=args.get("web_search_enabled", False),
enable_memory=args.get("enable_memory", False), enable_memory=args.get("enable_memory", False),
) )
result = await asyncio.get_event_loop().run_in_executor(None, fn) result = await asyncio.get_running_loop().run_in_executor(None, fn)
elif name == "list_agents": elif name == "list_agents":
result = client.list_agents( result = client.list_agents(
@@ -1273,8 +1286,7 @@ async def run_sse(host: str, port: int):
try: try:
from mcp.server.sse import SseServerTransport from mcp.server.sse import SseServerTransport
from starlette.applications import Starlette from starlette.applications import Starlette
from starlette.requests import Request from starlette.routing import Mount
from starlette.routing import Mount, Route
import uvicorn import uvicorn
except ImportError as e: except ImportError as e:
raise ImportError( raise ImportError(
@@ -1283,15 +1295,15 @@ async def run_sse(host: str, port: int):
sse = SseServerTransport("/messages/") sse = SseServerTransport("/messages/")
async def handle_sse(request: Request): # Use a raw ASGI callable instead of a Starlette Request endpoint to avoid
async with sse.connect_sse( # accessing Starlette's private _send attribute (which can break across versions).
request.scope, request.receive, request._send async def handle_sse(scope, receive, send):
) as streams: async with sse.connect_sse(scope, receive, send) as streams:
await app.run(streams[0], streams[1], _init_options()) await app.run(streams[0], streams[1], _init_options())
starlette_app = Starlette( starlette_app = Starlette(
routes=[ routes=[
Route("/sse", endpoint=handle_sse), Mount("/sse", app=handle_sse),
Mount("/messages/", app=sse.handle_post_message), Mount("/messages/", app=sse.handle_post_message),
] ]
) )