From b603b1dcfa7ed1d35b1ac20812173b84d45e08f4 Mon Sep 17 00:00:00 2001 From: mileslai Date: Fri, 29 May 2026 11:48:19 +0800 Subject: [PATCH] fix(mcp-server): address Copilot PR review comments - SSL verification now defaults to enabled; set WEKNORA_VERIFY_SSL=false to opt out (with a logged warning). Fixes MITM risk from default-off TLS. - WEKNORA_CHAT_TIMEOUT parse is now guarded with try/except ValueError so a bad env value falls back to 300s instead of crashing at import. - SSE streaming response is now closed via context manager (with response:) to guarantee connection pool return even on early break. - Replace asyncio.get_event_loop() (deprecated) with asyncio.get_running_loop() in both chat and agent_chat handlers. - create_session now calls resolve_kb_id() so KB names are accepted in addition to UUIDs (consistent with chat / hybrid_search). - knowledge_base_ids description changed from REQUIRED to Strongly recommended to match actual schema optionality. - run_sse() handle_sse rewritten as raw ASGI callable (scope, receive, send) to avoid accessing Starlette private _send attribute. - Fix main.py comment: http transport is Streamable HTTP (MCP spec), not long-polling. --- mcp-server/main.py | 2 +- mcp-server/weknora_mcp_server.py | 96 ++++++++++++++++++-------------- 2 files changed, 55 insertions(+), 43 deletions(-) diff --git a/mcp-server/main.py b/mcp-server/main.py index 47afa162..d1ff6cc6 100644 --- a/mcp-server/main.py +++ b/mcp-server/main.py @@ -138,7 +138,7 @@ async def main(): # Select transport mode based on CLI argument or MCP_TRANSPORT env var # - stdio: Default, used by VS Code Copilot for local integration # - sse: Server-Sent Events over HTTP, suitable for cloud/remote deployments - # - http: HTTP long-polling, compatible with various client architectures + # - http: Streamable HTTP sessions (MCP 2025-03-26 spec), compatible with REST clients if args.transport == "stdio": # Stdio mode: communication via stdin/stdout pipes (typical for CLI integrations) await run_stdio() diff --git a/mcp-server/weknora_mcp_server.py b/mcp-server/weknora_mcp_server.py index 67554449..49904b8e 100644 --- a/mcp-server/weknora_mcp_server.py +++ b/mcp-server/weknora_mcp_server.py @@ -30,7 +30,11 @@ logger = logging.getLogger(__name__) WEKNORA_BASE_URL = os.getenv("WEKNORA_BASE_URL", "http://localhost:8080/api/v1") WEKNORA_API_KEY = os.getenv("WEKNORA_API_KEY", "") # Chat SSE read timeout in seconds. LLM responses can be slow; default 300s. -WEKNORA_CHAT_TIMEOUT = int(os.getenv("WEKNORA_CHAT_TIMEOUT", "300")) +try: + WEKNORA_CHAT_TIMEOUT = int(os.getenv("WEKNORA_CHAT_TIMEOUT", "300")) +except ValueError: + logger.warning("WEKNORA_CHAT_TIMEOUT is not a valid integer; falling back to 300s.") + WEKNORA_CHAT_TIMEOUT = 300 class WeKnoraClient: @@ -40,9 +44,14 @@ class WeKnoraClient: """Initialize the WeKnora API client with base URL and authentication""" self.base_url = base_url self.api_key = api_key - # SSL verification: set WEKNORA_VERIFY_SSL=true to enable (default off for self-signed certs) - self.verify_ssl = os.getenv("WEKNORA_VERIFY_SSL", "false").lower() == "true" + # SSL verification: enabled by default. Set WEKNORA_VERIFY_SSL=false to disable + # (e.g. for self-signed certs in dev environments — NOT recommended for production). + self.verify_ssl = os.getenv("WEKNORA_VERIFY_SSL", "true").lower() != "false" if not self.verify_ssl: + logger.warning( + "SSL certificate verification is DISABLED (WEKNORA_VERIFY_SSL=false). " + "This is insecure and should not be used in production." + ) urllib3.disable_warnings(urllib3.exceptions.InsecureRequestWarning) # Create a persistent session for connection pooling and performance self.session = requests.Session() @@ -337,36 +346,39 @@ class WeKnoraClient: references: list = [] debug_events: list = [] - for raw_line in response.iter_lines(): - if not raw_line: - continue - if isinstance(raw_line, bytes): - raw_line = raw_line.decode("utf-8") - # Each SSE event is prefixed with "data: " followed by JSON payload - if not raw_line.startswith("data:"): - continue - payload = raw_line[5:].lstrip(" ") - try: - event_data = json.loads(payload) - except json.JSONDecodeError: - continue + # Use context manager to ensure the connection is returned to the pool + # even when breaking early on a 'complete' event. + with response: + for raw_line in response.iter_lines(): + if not raw_line: + continue + if isinstance(raw_line, bytes): + raw_line = raw_line.decode("utf-8") + # Each SSE event is prefixed with "data: " followed by JSON payload + if not raw_line.startswith("data:"): + continue + payload = raw_line[5:].lstrip(" ") + try: + event_data = json.loads(payload) + except json.JSONDecodeError: + continue - response_type = event_data.get("response_type", "") - debug_events.append({"type": response_type, "content": event_data.get("content", "")[:80]}) + response_type = event_data.get("response_type", "") + debug_events.append({"type": response_type, "content": event_data.get("content", "")[:80]}) - # Parse different SSE event types: answer chunks, references, errors, completion - if response_type == "answer": - chunk = event_data.get("content", "") - if chunk: - answer_chunks.append(chunk) - elif response_type == "references": - references = event_data.get("knowledge_references") or [] - elif response_type == "error": - raise RequestException( - f"Server error: {event_data.get('content', 'unknown error')}" - ) - elif response_type == "complete": - break + # Parse different SSE event types: answer chunks, references, errors, completion + if response_type == "answer": + chunk = event_data.get("content", "") + if chunk: + answer_chunks.append(chunk) + elif response_type == "references": + references = event_data.get("knowledge_references") or [] + elif response_type == "error": + raise RequestException( + f"Server error: {event_data.get('content', 'unknown error')}" + ) + elif response_type == "complete": + break return { "answer": "".join(answer_chunks), @@ -832,7 +844,7 @@ async def handle_list_tools() -> list[types.Tool]: "knowledge_base_ids": { "type": "array", "items": {"type": "string"}, - "description": "Knowledge base names OR UUIDs to search. REQUIRED for RAG. E.g. ['my-knowledge-base'] or ['a1b2c3d4-...']. Use list_knowledge_bases to find them.", + "description": "Knowledge base names OR UUIDs to search. Strongly recommended for RAG — without them the answer falls back to LLM knowledge only. E.g. ['my-knowledge-base'] or ['a1b2c3d4-...']. Use list_knowledge_bases to find them.", }, "web_search_enabled": {"type": "boolean", "description": "Enable web search alongside KB retrieval.", "default": False}, "enable_memory": {"type": "boolean", "description": "Enable cross-session memory.", "default": False}, @@ -1115,7 +1127,7 @@ async def handle_call_tool( # Strategy includes: max conversation rounds, query rewriting, summarization model, # fallback response handling, and retrieval thresholds (keyword/vector similarity). result = client.create_session( - kb_id=args["kb_id"], + kb_id=client.resolve_kb_id(args["kb_id"]), max_rounds=args.get("max_rounds", 5), enable_rewrite=args.get("enable_rewrite", True), fallback_response=args.get( @@ -1149,7 +1161,8 @@ async def handle_call_tool( web_search_enabled=args.get("web_search_enabled", False), enable_memory=args.get("enable_memory", False), ) - result = await asyncio.get_event_loop().run_in_executor(None, fn) + # get_running_loop() is the correct API inside async functions (get_event_loop() is deprecated) + result = await asyncio.get_running_loop().run_in_executor(None, fn) elif name == "agent_chat": # Autonomous agent tool-calling: agent decides which tools to invoke (knowledge_search, web_search, etc.) @@ -1199,7 +1212,7 @@ async def handle_call_tool( web_search_enabled=args.get("web_search_enabled", False), enable_memory=args.get("enable_memory", False), ) - result = await asyncio.get_event_loop().run_in_executor(None, fn) + result = await asyncio.get_running_loop().run_in_executor(None, fn) elif name == "list_agents": result = client.list_agents( @@ -1273,8 +1286,7 @@ async def run_sse(host: str, port: int): try: from mcp.server.sse import SseServerTransport from starlette.applications import Starlette - from starlette.requests import Request - from starlette.routing import Mount, Route + from starlette.routing import Mount import uvicorn except ImportError as e: raise ImportError( @@ -1283,15 +1295,15 @@ async def run_sse(host: str, port: int): sse = SseServerTransport("/messages/") - async def handle_sse(request: Request): - async with sse.connect_sse( - request.scope, request.receive, request._send - ) as streams: + # Use a raw ASGI callable instead of a Starlette Request endpoint to avoid + # accessing Starlette's private _send attribute (which can break across versions). + async def handle_sse(scope, receive, send): + async with sse.connect_sse(scope, receive, send) as streams: await app.run(streams[0], streams[1], _init_options()) starlette_app = Starlette( routes=[ - Route("/sse", endpoint=handle_sse), + Mount("/sse", app=handle_sse), Mount("/messages/", app=sse.handle_post_message), ] )