mirror of
https://github.com/Tencent/WeKnora.git
synced 2026-06-04 13:30:32 +08:00
fix(mcp-server): address Copilot PR review comments
- SSL verification now defaults to enabled; set WEKNORA_VERIFY_SSL=false to opt out (with a logged warning). Fixes MITM risk from default-off TLS. - WEKNORA_CHAT_TIMEOUT parse is now guarded with try/except ValueError so a bad env value falls back to 300s instead of crashing at import. - SSE streaming response is now closed via context manager (with response:) to guarantee connection pool return even on early break. - Replace asyncio.get_event_loop() (deprecated) with asyncio.get_running_loop() in both chat and agent_chat handlers. - create_session now calls resolve_kb_id() so KB names are accepted in addition to UUIDs (consistent with chat / hybrid_search). - knowledge_base_ids description changed from REQUIRED to Strongly recommended to match actual schema optionality. - run_sse() handle_sse rewritten as raw ASGI callable (scope, receive, send) to avoid accessing Starlette private _send attribute. - Fix main.py comment: http transport is Streamable HTTP (MCP spec), not long-polling.
This commit is contained in:
@@ -138,7 +138,7 @@ async def main():
|
|||||||
# Select transport mode based on CLI argument or MCP_TRANSPORT env var
|
# Select transport mode based on CLI argument or MCP_TRANSPORT env var
|
||||||
# - stdio: Default, used by VS Code Copilot for local integration
|
# - stdio: Default, used by VS Code Copilot for local integration
|
||||||
# - sse: Server-Sent Events over HTTP, suitable for cloud/remote deployments
|
# - sse: Server-Sent Events over HTTP, suitable for cloud/remote deployments
|
||||||
# - http: HTTP long-polling, compatible with various client architectures
|
# - http: Streamable HTTP sessions (MCP 2025-03-26 spec), compatible with REST clients
|
||||||
if args.transport == "stdio":
|
if args.transport == "stdio":
|
||||||
# Stdio mode: communication via stdin/stdout pipes (typical for CLI integrations)
|
# Stdio mode: communication via stdin/stdout pipes (typical for CLI integrations)
|
||||||
await run_stdio()
|
await run_stdio()
|
||||||
|
|||||||
@@ -30,7 +30,11 @@ logger = logging.getLogger(__name__)
|
|||||||
WEKNORA_BASE_URL = os.getenv("WEKNORA_BASE_URL", "http://localhost:8080/api/v1")
|
WEKNORA_BASE_URL = os.getenv("WEKNORA_BASE_URL", "http://localhost:8080/api/v1")
|
||||||
WEKNORA_API_KEY = os.getenv("WEKNORA_API_KEY", "")
|
WEKNORA_API_KEY = os.getenv("WEKNORA_API_KEY", "")
|
||||||
# Chat SSE read timeout in seconds. LLM responses can be slow; default 300s.
|
# Chat SSE read timeout in seconds. LLM responses can be slow; default 300s.
|
||||||
WEKNORA_CHAT_TIMEOUT = int(os.getenv("WEKNORA_CHAT_TIMEOUT", "300"))
|
try:
|
||||||
|
WEKNORA_CHAT_TIMEOUT = int(os.getenv("WEKNORA_CHAT_TIMEOUT", "300"))
|
||||||
|
except ValueError:
|
||||||
|
logger.warning("WEKNORA_CHAT_TIMEOUT is not a valid integer; falling back to 300s.")
|
||||||
|
WEKNORA_CHAT_TIMEOUT = 300
|
||||||
|
|
||||||
|
|
||||||
class WeKnoraClient:
|
class WeKnoraClient:
|
||||||
@@ -40,9 +44,14 @@ class WeKnoraClient:
|
|||||||
"""Initialize the WeKnora API client with base URL and authentication"""
|
"""Initialize the WeKnora API client with base URL and authentication"""
|
||||||
self.base_url = base_url
|
self.base_url = base_url
|
||||||
self.api_key = api_key
|
self.api_key = api_key
|
||||||
# SSL verification: set WEKNORA_VERIFY_SSL=true to enable (default off for self-signed certs)
|
# SSL verification: enabled by default. Set WEKNORA_VERIFY_SSL=false to disable
|
||||||
self.verify_ssl = os.getenv("WEKNORA_VERIFY_SSL", "false").lower() == "true"
|
# (e.g. for self-signed certs in dev environments — NOT recommended for production).
|
||||||
|
self.verify_ssl = os.getenv("WEKNORA_VERIFY_SSL", "true").lower() != "false"
|
||||||
if not self.verify_ssl:
|
if not self.verify_ssl:
|
||||||
|
logger.warning(
|
||||||
|
"SSL certificate verification is DISABLED (WEKNORA_VERIFY_SSL=false). "
|
||||||
|
"This is insecure and should not be used in production."
|
||||||
|
)
|
||||||
urllib3.disable_warnings(urllib3.exceptions.InsecureRequestWarning)
|
urllib3.disable_warnings(urllib3.exceptions.InsecureRequestWarning)
|
||||||
# Create a persistent session for connection pooling and performance
|
# Create a persistent session for connection pooling and performance
|
||||||
self.session = requests.Session()
|
self.session = requests.Session()
|
||||||
@@ -337,36 +346,39 @@ class WeKnoraClient:
|
|||||||
references: list = []
|
references: list = []
|
||||||
debug_events: list = []
|
debug_events: list = []
|
||||||
|
|
||||||
for raw_line in response.iter_lines():
|
# Use context manager to ensure the connection is returned to the pool
|
||||||
if not raw_line:
|
# even when breaking early on a 'complete' event.
|
||||||
continue
|
with response:
|
||||||
if isinstance(raw_line, bytes):
|
for raw_line in response.iter_lines():
|
||||||
raw_line = raw_line.decode("utf-8")
|
if not raw_line:
|
||||||
# Each SSE event is prefixed with "data: " followed by JSON payload
|
continue
|
||||||
if not raw_line.startswith("data:"):
|
if isinstance(raw_line, bytes):
|
||||||
continue
|
raw_line = raw_line.decode("utf-8")
|
||||||
payload = raw_line[5:].lstrip(" ")
|
# Each SSE event is prefixed with "data: " followed by JSON payload
|
||||||
try:
|
if not raw_line.startswith("data:"):
|
||||||
event_data = json.loads(payload)
|
continue
|
||||||
except json.JSONDecodeError:
|
payload = raw_line[5:].lstrip(" ")
|
||||||
continue
|
try:
|
||||||
|
event_data = json.loads(payload)
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
continue
|
||||||
|
|
||||||
response_type = event_data.get("response_type", "")
|
response_type = event_data.get("response_type", "")
|
||||||
debug_events.append({"type": response_type, "content": event_data.get("content", "")[:80]})
|
debug_events.append({"type": response_type, "content": event_data.get("content", "")[:80]})
|
||||||
|
|
||||||
# Parse different SSE event types: answer chunks, references, errors, completion
|
# Parse different SSE event types: answer chunks, references, errors, completion
|
||||||
if response_type == "answer":
|
if response_type == "answer":
|
||||||
chunk = event_data.get("content", "")
|
chunk = event_data.get("content", "")
|
||||||
if chunk:
|
if chunk:
|
||||||
answer_chunks.append(chunk)
|
answer_chunks.append(chunk)
|
||||||
elif response_type == "references":
|
elif response_type == "references":
|
||||||
references = event_data.get("knowledge_references") or []
|
references = event_data.get("knowledge_references") or []
|
||||||
elif response_type == "error":
|
elif response_type == "error":
|
||||||
raise RequestException(
|
raise RequestException(
|
||||||
f"Server error: {event_data.get('content', 'unknown error')}"
|
f"Server error: {event_data.get('content', 'unknown error')}"
|
||||||
)
|
)
|
||||||
elif response_type == "complete":
|
elif response_type == "complete":
|
||||||
break
|
break
|
||||||
|
|
||||||
return {
|
return {
|
||||||
"answer": "".join(answer_chunks),
|
"answer": "".join(answer_chunks),
|
||||||
@@ -832,7 +844,7 @@ async def handle_list_tools() -> list[types.Tool]:
|
|||||||
"knowledge_base_ids": {
|
"knowledge_base_ids": {
|
||||||
"type": "array",
|
"type": "array",
|
||||||
"items": {"type": "string"},
|
"items": {"type": "string"},
|
||||||
"description": "Knowledge base names OR UUIDs to search. REQUIRED for RAG. E.g. ['my-knowledge-base'] or ['a1b2c3d4-...']. Use list_knowledge_bases to find them.",
|
"description": "Knowledge base names OR UUIDs to search. Strongly recommended for RAG — without them the answer falls back to LLM knowledge only. E.g. ['my-knowledge-base'] or ['a1b2c3d4-...']. Use list_knowledge_bases to find them.",
|
||||||
},
|
},
|
||||||
"web_search_enabled": {"type": "boolean", "description": "Enable web search alongside KB retrieval.", "default": False},
|
"web_search_enabled": {"type": "boolean", "description": "Enable web search alongside KB retrieval.", "default": False},
|
||||||
"enable_memory": {"type": "boolean", "description": "Enable cross-session memory.", "default": False},
|
"enable_memory": {"type": "boolean", "description": "Enable cross-session memory.", "default": False},
|
||||||
@@ -1115,7 +1127,7 @@ async def handle_call_tool(
|
|||||||
# Strategy includes: max conversation rounds, query rewriting, summarization model,
|
# Strategy includes: max conversation rounds, query rewriting, summarization model,
|
||||||
# fallback response handling, and retrieval thresholds (keyword/vector similarity).
|
# fallback response handling, and retrieval thresholds (keyword/vector similarity).
|
||||||
result = client.create_session(
|
result = client.create_session(
|
||||||
kb_id=args["kb_id"],
|
kb_id=client.resolve_kb_id(args["kb_id"]),
|
||||||
max_rounds=args.get("max_rounds", 5),
|
max_rounds=args.get("max_rounds", 5),
|
||||||
enable_rewrite=args.get("enable_rewrite", True),
|
enable_rewrite=args.get("enable_rewrite", True),
|
||||||
fallback_response=args.get(
|
fallback_response=args.get(
|
||||||
@@ -1149,7 +1161,8 @@ async def handle_call_tool(
|
|||||||
web_search_enabled=args.get("web_search_enabled", False),
|
web_search_enabled=args.get("web_search_enabled", False),
|
||||||
enable_memory=args.get("enable_memory", False),
|
enable_memory=args.get("enable_memory", False),
|
||||||
)
|
)
|
||||||
result = await asyncio.get_event_loop().run_in_executor(None, fn)
|
# get_running_loop() is the correct API inside async functions (get_event_loop() is deprecated)
|
||||||
|
result = await asyncio.get_running_loop().run_in_executor(None, fn)
|
||||||
|
|
||||||
elif name == "agent_chat":
|
elif name == "agent_chat":
|
||||||
# Autonomous agent tool-calling: agent decides which tools to invoke (knowledge_search, web_search, etc.)
|
# Autonomous agent tool-calling: agent decides which tools to invoke (knowledge_search, web_search, etc.)
|
||||||
@@ -1199,7 +1212,7 @@ async def handle_call_tool(
|
|||||||
web_search_enabled=args.get("web_search_enabled", False),
|
web_search_enabled=args.get("web_search_enabled", False),
|
||||||
enable_memory=args.get("enable_memory", False),
|
enable_memory=args.get("enable_memory", False),
|
||||||
)
|
)
|
||||||
result = await asyncio.get_event_loop().run_in_executor(None, fn)
|
result = await asyncio.get_running_loop().run_in_executor(None, fn)
|
||||||
|
|
||||||
elif name == "list_agents":
|
elif name == "list_agents":
|
||||||
result = client.list_agents(
|
result = client.list_agents(
|
||||||
@@ -1273,8 +1286,7 @@ async def run_sse(host: str, port: int):
|
|||||||
try:
|
try:
|
||||||
from mcp.server.sse import SseServerTransport
|
from mcp.server.sse import SseServerTransport
|
||||||
from starlette.applications import Starlette
|
from starlette.applications import Starlette
|
||||||
from starlette.requests import Request
|
from starlette.routing import Mount
|
||||||
from starlette.routing import Mount, Route
|
|
||||||
import uvicorn
|
import uvicorn
|
||||||
except ImportError as e:
|
except ImportError as e:
|
||||||
raise ImportError(
|
raise ImportError(
|
||||||
@@ -1283,15 +1295,15 @@ async def run_sse(host: str, port: int):
|
|||||||
|
|
||||||
sse = SseServerTransport("/messages/")
|
sse = SseServerTransport("/messages/")
|
||||||
|
|
||||||
async def handle_sse(request: Request):
|
# Use a raw ASGI callable instead of a Starlette Request endpoint to avoid
|
||||||
async with sse.connect_sse(
|
# accessing Starlette's private _send attribute (which can break across versions).
|
||||||
request.scope, request.receive, request._send
|
async def handle_sse(scope, receive, send):
|
||||||
) as streams:
|
async with sse.connect_sse(scope, receive, send) as streams:
|
||||||
await app.run(streams[0], streams[1], _init_options())
|
await app.run(streams[0], streams[1], _init_options())
|
||||||
|
|
||||||
starlette_app = Starlette(
|
starlette_app = Starlette(
|
||||||
routes=[
|
routes=[
|
||||||
Route("/sse", endpoint=handle_sse),
|
Mount("/sse", app=handle_sse),
|
||||||
Mount("/messages/", app=sse.handle_post_message),
|
Mount("/messages/", app=sse.handle_post_message),
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
|
|||||||
Reference in New Issue
Block a user