mirror of
https://github.com/Tencent/WeKnora.git
synced 2026-06-04 13:30:32 +08:00
fix(mcp-server): address Copilot PR review comments
- SSL verification now defaults to enabled; set WEKNORA_VERIFY_SSL=false to opt out (with a logged warning). Fixes MITM risk from default-off TLS. - WEKNORA_CHAT_TIMEOUT parse is now guarded with try/except ValueError so a bad env value falls back to 300s instead of crashing at import. - SSE streaming response is now closed via context manager (with response:) to guarantee connection pool return even on early break. - Replace asyncio.get_event_loop() (deprecated) with asyncio.get_running_loop() in both chat and agent_chat handlers. - create_session now calls resolve_kb_id() so KB names are accepted in addition to UUIDs (consistent with chat / hybrid_search). - knowledge_base_ids description changed from REQUIRED to Strongly recommended to match actual schema optionality. - run_sse() handle_sse rewritten as raw ASGI callable (scope, receive, send) to avoid accessing Starlette private _send attribute. - Fix main.py comment: http transport is Streamable HTTP (MCP spec), not long-polling.
This commit is contained in:
@@ -138,7 +138,7 @@ async def main():
|
||||
# Select transport mode based on CLI argument or MCP_TRANSPORT env var
|
||||
# - stdio: Default, used by VS Code Copilot for local integration
|
||||
# - sse: Server-Sent Events over HTTP, suitable for cloud/remote deployments
|
||||
# - http: HTTP long-polling, compatible with various client architectures
|
||||
# - http: Streamable HTTP sessions (MCP 2025-03-26 spec), compatible with REST clients
|
||||
if args.transport == "stdio":
|
||||
# Stdio mode: communication via stdin/stdout pipes (typical for CLI integrations)
|
||||
await run_stdio()
|
||||
|
||||
@@ -30,7 +30,11 @@ logger = logging.getLogger(__name__)
|
||||
WEKNORA_BASE_URL = os.getenv("WEKNORA_BASE_URL", "http://localhost:8080/api/v1")
|
||||
WEKNORA_API_KEY = os.getenv("WEKNORA_API_KEY", "")
|
||||
# Chat SSE read timeout in seconds. LLM responses can be slow; default 300s.
|
||||
try:
|
||||
WEKNORA_CHAT_TIMEOUT = int(os.getenv("WEKNORA_CHAT_TIMEOUT", "300"))
|
||||
except ValueError:
|
||||
logger.warning("WEKNORA_CHAT_TIMEOUT is not a valid integer; falling back to 300s.")
|
||||
WEKNORA_CHAT_TIMEOUT = 300
|
||||
|
||||
|
||||
class WeKnoraClient:
|
||||
@@ -40,9 +44,14 @@ class WeKnoraClient:
|
||||
"""Initialize the WeKnora API client with base URL and authentication"""
|
||||
self.base_url = base_url
|
||||
self.api_key = api_key
|
||||
# SSL verification: set WEKNORA_VERIFY_SSL=true to enable (default off for self-signed certs)
|
||||
self.verify_ssl = os.getenv("WEKNORA_VERIFY_SSL", "false").lower() == "true"
|
||||
# SSL verification: enabled by default. Set WEKNORA_VERIFY_SSL=false to disable
|
||||
# (e.g. for self-signed certs in dev environments — NOT recommended for production).
|
||||
self.verify_ssl = os.getenv("WEKNORA_VERIFY_SSL", "true").lower() != "false"
|
||||
if not self.verify_ssl:
|
||||
logger.warning(
|
||||
"SSL certificate verification is DISABLED (WEKNORA_VERIFY_SSL=false). "
|
||||
"This is insecure and should not be used in production."
|
||||
)
|
||||
urllib3.disable_warnings(urllib3.exceptions.InsecureRequestWarning)
|
||||
# Create a persistent session for connection pooling and performance
|
||||
self.session = requests.Session()
|
||||
@@ -337,6 +346,9 @@ class WeKnoraClient:
|
||||
references: list = []
|
||||
debug_events: list = []
|
||||
|
||||
# Use context manager to ensure the connection is returned to the pool
|
||||
# even when breaking early on a 'complete' event.
|
||||
with response:
|
||||
for raw_line in response.iter_lines():
|
||||
if not raw_line:
|
||||
continue
|
||||
@@ -832,7 +844,7 @@ async def handle_list_tools() -> list[types.Tool]:
|
||||
"knowledge_base_ids": {
|
||||
"type": "array",
|
||||
"items": {"type": "string"},
|
||||
"description": "Knowledge base names OR UUIDs to search. REQUIRED for RAG. E.g. ['my-knowledge-base'] or ['a1b2c3d4-...']. Use list_knowledge_bases to find them.",
|
||||
"description": "Knowledge base names OR UUIDs to search. Strongly recommended for RAG — without them the answer falls back to LLM knowledge only. E.g. ['my-knowledge-base'] or ['a1b2c3d4-...']. Use list_knowledge_bases to find them.",
|
||||
},
|
||||
"web_search_enabled": {"type": "boolean", "description": "Enable web search alongside KB retrieval.", "default": False},
|
||||
"enable_memory": {"type": "boolean", "description": "Enable cross-session memory.", "default": False},
|
||||
@@ -1115,7 +1127,7 @@ async def handle_call_tool(
|
||||
# Strategy includes: max conversation rounds, query rewriting, summarization model,
|
||||
# fallback response handling, and retrieval thresholds (keyword/vector similarity).
|
||||
result = client.create_session(
|
||||
kb_id=args["kb_id"],
|
||||
kb_id=client.resolve_kb_id(args["kb_id"]),
|
||||
max_rounds=args.get("max_rounds", 5),
|
||||
enable_rewrite=args.get("enable_rewrite", True),
|
||||
fallback_response=args.get(
|
||||
@@ -1149,7 +1161,8 @@ async def handle_call_tool(
|
||||
web_search_enabled=args.get("web_search_enabled", False),
|
||||
enable_memory=args.get("enable_memory", False),
|
||||
)
|
||||
result = await asyncio.get_event_loop().run_in_executor(None, fn)
|
||||
# get_running_loop() is the correct API inside async functions (get_event_loop() is deprecated)
|
||||
result = await asyncio.get_running_loop().run_in_executor(None, fn)
|
||||
|
||||
elif name == "agent_chat":
|
||||
# Autonomous agent tool-calling: agent decides which tools to invoke (knowledge_search, web_search, etc.)
|
||||
@@ -1199,7 +1212,7 @@ async def handle_call_tool(
|
||||
web_search_enabled=args.get("web_search_enabled", False),
|
||||
enable_memory=args.get("enable_memory", False),
|
||||
)
|
||||
result = await asyncio.get_event_loop().run_in_executor(None, fn)
|
||||
result = await asyncio.get_running_loop().run_in_executor(None, fn)
|
||||
|
||||
elif name == "list_agents":
|
||||
result = client.list_agents(
|
||||
@@ -1273,8 +1286,7 @@ async def run_sse(host: str, port: int):
|
||||
try:
|
||||
from mcp.server.sse import SseServerTransport
|
||||
from starlette.applications import Starlette
|
||||
from starlette.requests import Request
|
||||
from starlette.routing import Mount, Route
|
||||
from starlette.routing import Mount
|
||||
import uvicorn
|
||||
except ImportError as e:
|
||||
raise ImportError(
|
||||
@@ -1283,15 +1295,15 @@ async def run_sse(host: str, port: int):
|
||||
|
||||
sse = SseServerTransport("/messages/")
|
||||
|
||||
async def handle_sse(request: Request):
|
||||
async with sse.connect_sse(
|
||||
request.scope, request.receive, request._send
|
||||
) as streams:
|
||||
# Use a raw ASGI callable instead of a Starlette Request endpoint to avoid
|
||||
# accessing Starlette's private _send attribute (which can break across versions).
|
||||
async def handle_sse(scope, receive, send):
|
||||
async with sse.connect_sse(scope, receive, send) as streams:
|
||||
await app.run(streams[0], streams[1], _init_options())
|
||||
|
||||
starlette_app = Starlette(
|
||||
routes=[
|
||||
Route("/sse", endpoint=handle_sse),
|
||||
Mount("/sse", app=handle_sse),
|
||||
Mount("/messages/", app=sse.handle_post_message),
|
||||
]
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user