mirror of
https://github.com/Tencent/WeKnora.git
synced 2026-06-04 13:30:32 +08:00
feat: add WeCom and Feishu IM bot integration
- support webhook and websocket modes for both platforms - add im_channel_sessions migration for channel-session mapping - register IM adapters and callback routes - update config and docker-compose for IM env vars
This commit is contained in:
@@ -206,7 +206,12 @@ conversation:
|
||||
- If the user asks in English, respond in English
|
||||
- If the user asks in Chinese, respond in Chinese
|
||||
|
||||
context_template: "{{query}}"
|
||||
context_template: |
|
||||
The following is retrieved information:
|
||||
{{contexts}}
|
||||
|
||||
Based on the above retrieved information, answer the following question:
|
||||
{{query}}
|
||||
extract_entities_prompt: |
|
||||
## Task
|
||||
Extract all entities from the user-provided text that match the following entity types:
|
||||
@@ -536,3 +541,39 @@ extract:
|
||||
tenant:
|
||||
# Enable cross-tenant access (can be enabled for intranet environments)
|
||||
enable_cross_tenant_access: false
|
||||
|
||||
# IM integration configuration (optional)
|
||||
# Uncomment and configure to enable WeCom/Feishu bot integration
|
||||
#
|
||||
# mode:
|
||||
# "webhook" — (default) platform pushes events to your callback URL, requires public domain
|
||||
# "websocket" — long connection, no public domain needed, SDK maintains persistent connection
|
||||
im:
|
||||
wecom:
|
||||
enabled: true
|
||||
tenant_id: 10000 # WeKnora tenant ID to use
|
||||
agent_id: "" # Default agent ID (optional)
|
||||
knowledge_base_ids:
|
||||
- "" # Default knowledge bases (optional)
|
||||
mode: "websocket" # "webhook" or "websocket"
|
||||
# --- websocket mode (智能机器人长连接) ---
|
||||
bot_id: "${WECOM_BOT_ID}" # 智能机器人 Bot ID
|
||||
bot_secret: "${WECOM_BOT_SECRET}" # 智能机器人 Secret
|
||||
# --- webhook mode (自建应用回调) ---
|
||||
# corp_id: "${WECOM_CORP_ID}" # 企业微信 Corp ID
|
||||
# agent_secret: "${WECOM_AGENT_SECRET}" # 应用 Secret
|
||||
# token: "${WECOM_TOKEN}" # 回调 Token
|
||||
# encoding_aes_key: "${WECOM_AES_KEY}" # 回调 EncodingAESKey
|
||||
# corp_agent_id: 1000001 # 应用 AgentID
|
||||
feishu:
|
||||
enabled: true
|
||||
tenant_id: 10000 # WeKnora tenant ID to use
|
||||
agent_id: "" # Default agent ID (optional)
|
||||
knowledge_base_ids: # Default knowledge bases
|
||||
- ""
|
||||
mode: "websocket" # "webhook" or "websocket"
|
||||
app_id: "${FEISHU_APP_ID}" # 飞书 App ID
|
||||
app_secret: "${FEISHU_APP_SECRET}" # 飞书 App Secret
|
||||
# --- 以下仅 webhook 模式需要,websocket 模式可留空 ---
|
||||
# verification_token: "${FEISHU_TOKEN}" # 事件订阅 Verification Token (webhook only)
|
||||
# encrypt_key: "${FEISHU_ENCRYPT_KEY}" # 事件订阅 Encrypt Key (webhook only)
|
||||
|
||||
@@ -38,8 +38,8 @@ services:
|
||||
volumes:
|
||||
- data-files:/data/files
|
||||
- docreader-tmp:/tmp/docreader:ro
|
||||
# Optional: mount custom config file
|
||||
# - ./config/config.yaml:/app/config/config.yaml
|
||||
# Mount custom config file (required for IM integration)
|
||||
- ./config/config.yaml:/app/config/config.yaml
|
||||
# Optional: mount custom skills directory (allows adding skills without rebuilding image)
|
||||
- ./skills/preloaded:/app/skills/preloaded
|
||||
healthcheck:
|
||||
@@ -114,6 +114,11 @@ services:
|
||||
- SYSTEM_AES_KEY=${SYSTEM_AES_KEY:-}
|
||||
- CONCURRENCY_POOL_SIZE=${CONCURRENCY_POOL_SIZE:-5}
|
||||
- JWT_SECRET=${JWT_SECRET:-}
|
||||
# IM integration
|
||||
- FEISHU_APP_ID=${FEISHU_APP_ID:-}
|
||||
- FEISHU_APP_SECRET=${FEISHU_APP_SECRET:-}
|
||||
- WECOM_BOT_ID=${WECOM_BOT_ID:-}
|
||||
- WECOM_BOT_SECRET=${WECOM_BOT_SECRET:-}
|
||||
# File size limit (in MB)
|
||||
- MAX_FILE_SIZE_MB=${MAX_FILE_SIZE_MB:-50}
|
||||
# Agent Skills Sandbox
|
||||
|
||||
3
go.mod
3
go.mod
@@ -22,7 +22,9 @@ require (
|
||||
github.com/golang-migrate/migrate/v4 v4.19.1
|
||||
github.com/google/jsonschema-go v0.4.2
|
||||
github.com/google/uuid v1.6.0
|
||||
github.com/gorilla/websocket v1.5.0
|
||||
github.com/hibiken/asynq v0.25.1
|
||||
github.com/larksuite/oapi-sdk-go/v3 v3.5.3
|
||||
github.com/mark3labs/mcp-go v0.43.0
|
||||
github.com/milvus-io/milvus/client/v2 v2.6.2
|
||||
github.com/minio/minio-go/v7 v7.0.91
|
||||
@@ -167,7 +169,6 @@ require (
|
||||
github.com/google/s2a-go v0.1.9 // indirect
|
||||
github.com/googleapis/enterprise-certificate-proxy v0.3.7 // indirect
|
||||
github.com/googleapis/gax-go/v2 v2.16.0 // indirect
|
||||
github.com/gorilla/websocket v1.5.0 // indirect
|
||||
github.com/grpc-ecosystem/go-grpc-middleware v1.4.0 // indirect
|
||||
github.com/grpc-ecosystem/go-grpc-prometheus v1.2.0 // indirect
|
||||
github.com/grpc-ecosystem/grpc-gateway v1.16.0 // indirect
|
||||
|
||||
2
go.sum
2
go.sum
@@ -2085,6 +2085,8 @@ github.com/kylelemons/godebug v1.1.0 h1:RPNrshWIDI6G2gRW9EHilWtl7Z6Sb1BR0xunSBf0
|
||||
github.com/kylelemons/godebug v1.1.0/go.mod h1:9/0rRGxNHcop5bhtWyNeEfOS8JIWk580+fNqagV/RAw=
|
||||
github.com/labstack/echo/v4 v4.5.0/go.mod h1:czIriw4a0C1dFun+ObrXp7ok03xON0N1awStJ6ArI7Y=
|
||||
github.com/labstack/gommon v0.3.0/go.mod h1:MULnywXg0yavhxWKc+lOruYdAhDwPK9wf0OL7NoOu+k=
|
||||
github.com/larksuite/oapi-sdk-go/v3 v3.5.3 h1:xvf8Dv29kBXC5/DNDCLhHkAFW8l/0LlQJimO5Zn+JUk=
|
||||
github.com/larksuite/oapi-sdk-go/v3 v3.5.3/go.mod h1:ZEplY+kwuIrj/nqw5uSCINNATcH3KdxSN7y+UxYY5fI=
|
||||
github.com/ledongthuc/pdf v0.0.0-20220302134840-0c2507a12d80 h1:6Yzfa6GP0rIo/kULo2bwGEkFvCePZ3qHDDTC3/J9Swo=
|
||||
github.com/ledongthuc/pdf v0.0.0-20220302134840-0c2507a12d80/go.mod h1:imJHygn/1yfhB7XSJJKlFZKl/J+dCPAknuiaGOshXAs=
|
||||
github.com/leodido/go-urn v1.2.0/go.mod h1:+8+nEpDfqqsY+g338gtMEUOtuK+4dEMhiQEgxpxOKII=
|
||||
|
||||
@@ -27,6 +27,46 @@ type Config struct {
|
||||
ExtractManager *ExtractManagerConfig `yaml:"extract" json:"extract"`
|
||||
WebSearch *WebSearchConfig `yaml:"web_search" json:"web_search"`
|
||||
PromptTemplates *PromptTemplatesConfig `yaml:"prompt_templates" json:"prompt_templates"`
|
||||
IM *IMConfig `yaml:"im" json:"im"`
|
||||
}
|
||||
|
||||
// IMConfig IM 集成配置
|
||||
type IMConfig struct {
|
||||
WeCom *WeComIMConfig `yaml:"wecom" json:"wecom"`
|
||||
Feishu *FeishuIMConfig `yaml:"feishu" json:"feishu"`
|
||||
}
|
||||
|
||||
// WeComIMConfig 企业微信配置
|
||||
type WeComIMConfig struct {
|
||||
Enabled bool `yaml:"enabled" json:"enabled"`
|
||||
TenantID uint64 `yaml:"tenant_id" json:"tenant_id"`
|
||||
AgentID string `yaml:"agent_id" json:"agent_id"`
|
||||
KnowledgeBases []string `yaml:"knowledge_base_ids" json:"knowledge_base_ids"`
|
||||
// Mode: "webhook" (default, requires public domain) or "websocket" (long connection via intelligent bot, no public domain needed)
|
||||
Mode string `yaml:"mode" json:"mode"`
|
||||
// --- Webhook mode fields (self-built app callback) ---
|
||||
CorpID string `yaml:"corp_id" json:"corp_id"`
|
||||
AgentSecret string `yaml:"agent_secret" json:"agent_secret"`
|
||||
Token string `yaml:"token" json:"token"`
|
||||
EncodingAESKey string `yaml:"encoding_aes_key" json:"encoding_aes_key"`
|
||||
CorpAgentID int `yaml:"corp_agent_id" json:"corp_agent_id"`
|
||||
// --- WebSocket mode fields (intelligent bot long connection) ---
|
||||
BotID string `yaml:"bot_id" json:"bot_id"`
|
||||
BotSecret string `yaml:"bot_secret" json:"bot_secret"`
|
||||
}
|
||||
|
||||
// FeishuIMConfig 飞书配置
|
||||
type FeishuIMConfig struct {
|
||||
Enabled bool `yaml:"enabled" json:"enabled"`
|
||||
TenantID uint64 `yaml:"tenant_id" json:"tenant_id"`
|
||||
AgentID string `yaml:"agent_id" json:"agent_id"`
|
||||
KnowledgeBases []string `yaml:"knowledge_base_ids" json:"knowledge_base_ids"`
|
||||
AppID string `yaml:"app_id" json:"app_id"`
|
||||
AppSecret string `yaml:"app_secret" json:"app_secret"`
|
||||
VerificationToken string `yaml:"verification_token" json:"verification_token"`
|
||||
EncryptKey string `yaml:"encrypt_key" json:"encrypt_key"`
|
||||
// Mode: "websocket" (default, long connection, no public domain needed) or "webhook" (requires public domain)
|
||||
Mode string `yaml:"mode" json:"mode"`
|
||||
}
|
||||
|
||||
// DocReaderConfig configures the document parser client (gRPC or HTTP).
|
||||
|
||||
@@ -52,6 +52,9 @@ import (
|
||||
"github.com/Tencent/WeKnora/internal/event"
|
||||
"github.com/Tencent/WeKnora/internal/handler"
|
||||
"github.com/Tencent/WeKnora/internal/handler/session"
|
||||
imPkg "github.com/Tencent/WeKnora/internal/im"
|
||||
"github.com/Tencent/WeKnora/internal/im/feishu"
|
||||
"github.com/Tencent/WeKnora/internal/im/wecom"
|
||||
"github.com/Tencent/WeKnora/internal/infrastructure/docparser"
|
||||
"github.com/Tencent/WeKnora/internal/logger"
|
||||
"github.com/Tencent/WeKnora/internal/mcp"
|
||||
@@ -235,6 +238,12 @@ func BuildContainer(container *dig.Container) *dig.Container {
|
||||
must(container.Provide(service.NewSkillService))
|
||||
must(container.Provide(handler.NewSkillHandler))
|
||||
must(container.Provide(handler.NewOrganizationHandler))
|
||||
|
||||
// IM integration
|
||||
logger.Debugf(ctx, "[Container] Registering IM integration...")
|
||||
must(container.Provide(imPkg.NewService))
|
||||
must(container.Invoke(registerIMAdapters))
|
||||
must(container.Provide(handler.NewIMHandler))
|
||||
logger.Debugf(ctx, "[Container] HTTP handlers registered")
|
||||
|
||||
// Router configuration
|
||||
@@ -937,3 +946,115 @@ func registerWebSearchProviders(registry *web_search.Registry) {
|
||||
return web_search.NewBingProvider()
|
||||
})
|
||||
}
|
||||
|
||||
// registerIMAdapters registers IM platform adapters based on configuration.
|
||||
// For "websocket" mode, it also starts a long connection client in a goroutine.
|
||||
func registerIMAdapters(cfg *config.Config, imService *imPkg.Service) {
|
||||
if cfg.IM == nil {
|
||||
logger.Infof(context.Background(), "[IM] No IM configuration found, skipping adapter registration")
|
||||
return
|
||||
}
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
// Register WeCom
|
||||
if cfg.IM.WeCom != nil && cfg.IM.WeCom.Enabled {
|
||||
registerWeComAdapter(ctx, cfg.IM.WeCom, imService)
|
||||
}
|
||||
|
||||
// Register Feishu
|
||||
if cfg.IM.Feishu != nil && cfg.IM.Feishu.Enabled {
|
||||
registerFeishuAdapter(ctx, cfg.IM.Feishu, imService)
|
||||
}
|
||||
}
|
||||
|
||||
func registerWeComAdapter(ctx context.Context, cfg *config.WeComIMConfig, imService *imPkg.Service) {
|
||||
mode := cfg.Mode
|
||||
if mode == "" {
|
||||
mode = "websocket"
|
||||
}
|
||||
|
||||
switch mode {
|
||||
case "webhook":
|
||||
adapter, err := wecom.NewAdapter(
|
||||
cfg.CorpID,
|
||||
cfg.AgentSecret,
|
||||
cfg.Token,
|
||||
cfg.EncodingAESKey,
|
||||
cfg.CorpAgentID,
|
||||
)
|
||||
if err != nil {
|
||||
logger.Warnf(ctx, "[IM] Failed to create WeCom webhook adapter: %v", err)
|
||||
return
|
||||
}
|
||||
imService.RegisterAdapter(adapter)
|
||||
logger.Infof(ctx, "[IM] WeCom adapter registered (mode=webhook, corp_id=%s)", cfg.CorpID)
|
||||
|
||||
case "websocket":
|
||||
// Build the message handler that delegates to imService.HandleMessage
|
||||
handler := func(msgCtx context.Context, msg *imPkg.IncomingMessage) error {
|
||||
return imService.HandleMessage(msgCtx, msg, cfg.TenantID, cfg.AgentID, cfg.KnowledgeBases)
|
||||
}
|
||||
|
||||
client := wecom.NewLongConnClient(cfg.BotID, cfg.BotSecret, handler)
|
||||
|
||||
// Register a BotAdapter so the service can send replies via WebSocket
|
||||
imService.RegisterAdapter(wecom.NewBotAdapter(client))
|
||||
logger.Infof(ctx, "[IM] WeCom adapter registered (mode=websocket, bot_id=%s)", cfg.BotID)
|
||||
|
||||
// Start the long connection in a goroutine
|
||||
go func() {
|
||||
if err := client.Start(context.Background()); err != nil {
|
||||
logger.Errorf(context.Background(), "[IM] WeCom long connection stopped: %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
default:
|
||||
logger.Warnf(ctx, "[IM] Unknown WeCom mode: %s (expected 'webhook' or 'websocket')", mode)
|
||||
}
|
||||
}
|
||||
|
||||
func registerFeishuAdapter(ctx context.Context, cfg *config.FeishuIMConfig, imService *imPkg.Service) {
|
||||
mode := cfg.Mode
|
||||
if mode == "" {
|
||||
mode = "websocket"
|
||||
}
|
||||
|
||||
// Always register the HTTP adapter (needed for SendReply in both modes)
|
||||
adapter := feishu.NewAdapter(
|
||||
cfg.AppID,
|
||||
cfg.AppSecret,
|
||||
cfg.VerificationToken,
|
||||
cfg.EncryptKey,
|
||||
)
|
||||
imService.RegisterAdapter(adapter)
|
||||
|
||||
switch mode {
|
||||
case "webhook":
|
||||
logger.Infof(ctx, "[IM] Feishu adapter registered (mode=webhook, app_id=%s)", cfg.AppID)
|
||||
|
||||
case "websocket":
|
||||
logger.Infof(ctx, "[IM] Feishu adapter registered (mode=websocket, app_id=%s)", cfg.AppID)
|
||||
|
||||
// Build the message handler
|
||||
handler := func(msgCtx context.Context, msg *imPkg.IncomingMessage) error {
|
||||
return imService.HandleMessage(msgCtx, msg, cfg.TenantID, cfg.AgentID, cfg.KnowledgeBases)
|
||||
}
|
||||
|
||||
client := feishu.NewLongConnClient(
|
||||
cfg.AppID,
|
||||
cfg.AppSecret,
|
||||
handler,
|
||||
)
|
||||
|
||||
// Start the long connection in a goroutine
|
||||
go func() {
|
||||
if err := client.Start(context.Background()); err != nil {
|
||||
logger.Errorf(context.Background(), "[IM] Feishu long connection stopped: %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
default:
|
||||
logger.Warnf(ctx, "[IM] Unknown Feishu mode: %s (expected 'webhook' or 'websocket')", mode)
|
||||
}
|
||||
}
|
||||
|
||||
132
internal/handler/im.go
Normal file
132
internal/handler/im.go
Normal file
@@ -0,0 +1,132 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
|
||||
"github.com/Tencent/WeKnora/internal/config"
|
||||
"github.com/Tencent/WeKnora/internal/im"
|
||||
"github.com/Tencent/WeKnora/internal/logger"
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
// IMHandler handles IM platform callback requests.
|
||||
type IMHandler struct {
|
||||
imService *im.Service
|
||||
config *config.Config
|
||||
}
|
||||
|
||||
// NewIMHandler creates a new IM handler.
|
||||
func NewIMHandler(imService *im.Service, cfg *config.Config) *IMHandler {
|
||||
return &IMHandler{
|
||||
imService: imService,
|
||||
config: cfg,
|
||||
}
|
||||
}
|
||||
|
||||
// WeComCallback handles WeCom callback requests (both URL verification and message events).
|
||||
func (h *IMHandler) WeComCallback(c *gin.Context) {
|
||||
ctx := c.Request.Context()
|
||||
|
||||
adapter, ok := h.imService.GetAdapter(im.PlatformWeCom)
|
||||
if !ok {
|
||||
logger.Error(ctx, "[IM] WeCom adapter not registered")
|
||||
c.JSON(http.StatusServiceUnavailable, gin.H{"error": "WeCom integration not enabled"})
|
||||
return
|
||||
}
|
||||
|
||||
// Handle URL verification (GET request)
|
||||
if adapter.HandleURLVerification(c) {
|
||||
return
|
||||
}
|
||||
|
||||
// Verify callback signature
|
||||
if err := adapter.VerifyCallback(c); err != nil {
|
||||
logger.Errorf(ctx, "[IM] WeCom callback verification failed: %v", err)
|
||||
c.JSON(http.StatusForbidden, gin.H{"error": "verification failed"})
|
||||
return
|
||||
}
|
||||
|
||||
// Parse the callback message
|
||||
msg, err := adapter.ParseCallback(c)
|
||||
if err != nil {
|
||||
logger.Errorf(ctx, "[IM] WeCom parse callback failed: %v", err)
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "parse failed"})
|
||||
return
|
||||
}
|
||||
|
||||
// If nil, it's a non-message event (e.g., system event) - just acknowledge
|
||||
if msg == nil {
|
||||
c.JSON(http.StatusOK, gin.H{"success": true})
|
||||
return
|
||||
}
|
||||
|
||||
// Respond immediately to avoid WeCom timeout, process asynchronously
|
||||
c.JSON(http.StatusOK, gin.H{"success": true})
|
||||
|
||||
// Get config for this platform
|
||||
wecomCfg := h.config.IM.WeCom
|
||||
|
||||
// Detach from gin request context to prevent cancellation after HTTP response.
|
||||
asyncCtx := context.WithoutCancel(ctx)
|
||||
|
||||
// Process message asynchronously
|
||||
go func() {
|
||||
if err := h.imService.HandleMessage(asyncCtx, msg, wecomCfg.TenantID, wecomCfg.AgentID, wecomCfg.KnowledgeBases); err != nil {
|
||||
logger.Errorf(asyncCtx, "[IM] WeCom handle message error: %v", err)
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
// FeishuCallback handles Feishu callback requests (both URL verification and message events).
|
||||
func (h *IMHandler) FeishuCallback(c *gin.Context) {
|
||||
ctx := c.Request.Context()
|
||||
|
||||
adapter, ok := h.imService.GetAdapter(im.PlatformFeishu)
|
||||
if !ok {
|
||||
logger.Error(ctx, "[IM] Feishu adapter not registered")
|
||||
c.JSON(http.StatusServiceUnavailable, gin.H{"error": "Feishu integration not enabled"})
|
||||
return
|
||||
}
|
||||
|
||||
// Handle URL verification (challenge)
|
||||
if adapter.HandleURLVerification(c) {
|
||||
return
|
||||
}
|
||||
|
||||
// Verify callback
|
||||
if err := adapter.VerifyCallback(c); err != nil {
|
||||
logger.Errorf(ctx, "[IM] Feishu callback verification failed: %v", err)
|
||||
c.JSON(http.StatusForbidden, gin.H{"error": "verification failed"})
|
||||
return
|
||||
}
|
||||
|
||||
// Parse the callback message
|
||||
msg, err := adapter.ParseCallback(c)
|
||||
if err != nil {
|
||||
logger.Errorf(ctx, "[IM] Feishu parse callback failed: %v", err)
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "parse failed"})
|
||||
return
|
||||
}
|
||||
|
||||
if msg == nil {
|
||||
c.JSON(http.StatusOK, gin.H{"success": true})
|
||||
return
|
||||
}
|
||||
|
||||
// Respond immediately
|
||||
c.JSON(http.StatusOK, gin.H{"success": true})
|
||||
|
||||
// Get config
|
||||
feishuCfg := h.config.IM.Feishu
|
||||
|
||||
// Detach from gin request context to prevent cancellation after HTTP response.
|
||||
asyncCtx := context.WithoutCancel(ctx)
|
||||
|
||||
// Process asynchronously
|
||||
go func() {
|
||||
if err := h.imService.HandleMessage(asyncCtx, msg, feishuCfg.TenantID, feishuCfg.AgentID, feishuCfg.KnowledgeBases); err != nil {
|
||||
logger.Errorf(asyncCtx, "[IM] Feishu handle message error: %v", err)
|
||||
}
|
||||
}()
|
||||
}
|
||||
76
internal/im/adapter.go
Normal file
76
internal/im/adapter.go
Normal file
@@ -0,0 +1,76 @@
|
||||
package im
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
// Platform identifies an IM platform.
|
||||
type Platform string
|
||||
|
||||
const (
|
||||
PlatformWeCom Platform = "wecom"
|
||||
PlatformFeishu Platform = "feishu"
|
||||
)
|
||||
|
||||
// IncomingMessage is the unified message parsed from an IM callback.
|
||||
type IncomingMessage struct {
|
||||
// Platform identifies which IM platform the message comes from.
|
||||
Platform Platform
|
||||
// UserID is the IM-platform user identifier.
|
||||
UserID string
|
||||
// UserName is the display name of the user (optional).
|
||||
UserName string
|
||||
// ChatID is the group/channel ID (empty for direct messages).
|
||||
ChatID string
|
||||
// ChatType distinguishes direct message from group chat.
|
||||
ChatType ChatType
|
||||
// Content is the text content of the message.
|
||||
Content string
|
||||
// MessageID is the IM-platform message identifier (for dedup).
|
||||
MessageID string
|
||||
// Extra holds platform-specific fields (e.g., WeCom stream ID).
|
||||
Extra map[string]string
|
||||
}
|
||||
|
||||
// ChatType represents the IM chat type.
|
||||
type ChatType string
|
||||
|
||||
const (
|
||||
ChatTypeDirect ChatType = "direct"
|
||||
ChatTypeGroup ChatType = "group"
|
||||
)
|
||||
|
||||
// ReplyMessage is what WeKnora sends back to the IM platform.
|
||||
type ReplyMessage struct {
|
||||
// Content is the text content (Markdown).
|
||||
Content string
|
||||
// IsStreaming indicates whether this is a streaming chunk.
|
||||
IsStreaming bool
|
||||
// IsFinal marks the last chunk of a streaming reply.
|
||||
IsFinal bool
|
||||
// Extra holds platform-specific fields.
|
||||
Extra map[string]string
|
||||
}
|
||||
|
||||
// Adapter is the interface every IM platform must implement.
|
||||
type Adapter interface {
|
||||
// Platform returns the platform identifier.
|
||||
Platform() Platform
|
||||
|
||||
// VerifyCallback verifies the signature/token of an incoming callback request.
|
||||
// Returns nil if verification passes.
|
||||
VerifyCallback(c *gin.Context) error
|
||||
|
||||
// ParseCallback parses the raw IM callback request into a unified IncomingMessage.
|
||||
// Returns nil message for non-message events (e.g., URL verification).
|
||||
ParseCallback(c *gin.Context) (*IncomingMessage, error)
|
||||
|
||||
// SendReply sends a reply back to the IM platform.
|
||||
SendReply(ctx context.Context, incoming *IncomingMessage, reply *ReplyMessage) error
|
||||
|
||||
// HandleURLVerification handles the initial URL verification challenge from the IM platform.
|
||||
// Returns true if this request is a verification request and has been handled.
|
||||
HandleURLVerification(c *gin.Context) bool
|
||||
}
|
||||
424
internal/im/feishu/adapter.go
Normal file
424
internal/im/feishu/adapter.go
Normal file
@@ -0,0 +1,424 @@
|
||||
// Package feishu implements the Feishu (飞书/Lark) IM adapter for WeKnora.
|
||||
//
|
||||
// Feishu bot flow:
|
||||
// 1. User sends a message to the bot (direct or @mention in group)
|
||||
// 2. Feishu calls our event subscription URL with the message event
|
||||
// 3. We parse the event, run QA, then call Feishu API to send reply
|
||||
// 4. For streaming: create a card, then use CardKit streaming update API
|
||||
//
|
||||
// Reference: https://open.feishu.cn/document/server-docs/im-v1/message/create
|
||||
package feishu
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"crypto/aes"
|
||||
"crypto/cipher"
|
||||
"crypto/sha256"
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/Tencent/WeKnora/internal/im"
|
||||
"github.com/Tencent/WeKnora/internal/logger"
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
var httpClient = &http.Client{Timeout: 10 * time.Second}
|
||||
|
||||
// Adapter implements im.Adapter for Feishu/Lark.
|
||||
type Adapter struct {
|
||||
appID string
|
||||
appSecret string
|
||||
verificationToken string
|
||||
encryptKey string
|
||||
|
||||
// Token cache
|
||||
tokenMu sync.Mutex
|
||||
tokenCache string
|
||||
tokenExpAt time.Time
|
||||
}
|
||||
|
||||
// NewAdapter creates a new Feishu adapter.
|
||||
func NewAdapter(appID, appSecret, verificationToken, encryptKey string) *Adapter {
|
||||
return &Adapter{
|
||||
appID: appID,
|
||||
appSecret: appSecret,
|
||||
verificationToken: verificationToken,
|
||||
encryptKey: encryptKey,
|
||||
}
|
||||
}
|
||||
|
||||
// Platform returns the platform identifier.
|
||||
func (a *Adapter) Platform() im.Platform {
|
||||
return im.PlatformFeishu
|
||||
}
|
||||
|
||||
// VerifyCallback verifies the Feishu event callback by checking the verification token.
|
||||
// If no verification token is configured (e.g., WebSocket mode), skip verification.
|
||||
func (a *Adapter) VerifyCallback(c *gin.Context) error {
|
||||
if a.verificationToken == "" {
|
||||
return nil
|
||||
}
|
||||
|
||||
bodyBytes, err := io.ReadAll(c.Request.Body)
|
||||
if err != nil {
|
||||
return fmt.Errorf("read body: %w", err)
|
||||
}
|
||||
// Always restore body for subsequent reads (ParseCallback)
|
||||
defer func() { c.Request.Body = io.NopCloser(bytes.NewReader(bodyBytes)) }()
|
||||
|
||||
var raw []byte
|
||||
|
||||
// Handle encrypted events
|
||||
var encryptedBody struct {
|
||||
Encrypt string `json:"encrypt"`
|
||||
}
|
||||
if err := json.Unmarshal(bodyBytes, &encryptedBody); err == nil && encryptedBody.Encrypt != "" {
|
||||
decrypted, err := a.decrypt(encryptedBody.Encrypt)
|
||||
if err != nil {
|
||||
return fmt.Errorf("decrypt event for verification: %w", err)
|
||||
}
|
||||
raw = decrypted
|
||||
} else {
|
||||
raw = bodyBytes
|
||||
}
|
||||
|
||||
var eventBody struct {
|
||||
Header *feishuEventHeader `json:"header"`
|
||||
}
|
||||
if err := json.Unmarshal(raw, &eventBody); err != nil {
|
||||
return fmt.Errorf("unmarshal event header: %w", err)
|
||||
}
|
||||
|
||||
if eventBody.Header == nil || eventBody.Header.Token != a.verificationToken {
|
||||
return fmt.Errorf("invalid verification token")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// HandleURLVerification handles the Feishu URL verification challenge.
|
||||
func (a *Adapter) HandleURLVerification(c *gin.Context) bool {
|
||||
bodyBytes, err := io.ReadAll(c.Request.Body)
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
c.Request.Body = io.NopCloser(bytes.NewReader(bodyBytes))
|
||||
|
||||
// Try to parse as a challenge request
|
||||
var body map[string]interface{}
|
||||
|
||||
// If encrypted, try to decrypt first
|
||||
var encryptedBody struct {
|
||||
Encrypt string `json:"encrypt"`
|
||||
}
|
||||
if err := json.Unmarshal(bodyBytes, &encryptedBody); err == nil && encryptedBody.Encrypt != "" {
|
||||
decrypted, err := a.decrypt(encryptedBody.Encrypt)
|
||||
if err != nil {
|
||||
logger.Errorf(c.Request.Context(), "[Feishu] Failed to decrypt: %v", err)
|
||||
return false
|
||||
}
|
||||
if err := json.Unmarshal(decrypted, &body); err != nil {
|
||||
return false
|
||||
}
|
||||
} else {
|
||||
if err := json.Unmarshal(bodyBytes, &body); err != nil {
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
// Check if this is a URL verification challenge
|
||||
if challenge, ok := body["challenge"].(string); ok {
|
||||
c.JSON(http.StatusOK, gin.H{"challenge": challenge})
|
||||
return true
|
||||
}
|
||||
|
||||
// Reset body for subsequent reads
|
||||
c.Request.Body = io.NopCloser(bytes.NewReader(bodyBytes))
|
||||
return false
|
||||
}
|
||||
|
||||
// feishuEventBody is the typed structure of a Feishu event callback.
|
||||
type feishuEventBody struct {
|
||||
Header *feishuEventHeader `json:"header"`
|
||||
Event *feishuEvent `json:"event"`
|
||||
}
|
||||
|
||||
type feishuEventHeader struct {
|
||||
EventType string `json:"event_type"`
|
||||
Token string `json:"token"`
|
||||
}
|
||||
|
||||
type feishuEvent struct {
|
||||
Message *feishuMessage `json:"message"`
|
||||
Sender *feishuSender `json:"sender"`
|
||||
}
|
||||
|
||||
type feishuMessage struct {
|
||||
MessageID string `json:"message_id"`
|
||||
MessageType string `json:"message_type"`
|
||||
ChatType string `json:"chat_type"`
|
||||
ChatID string `json:"chat_id"`
|
||||
Content string `json:"content"`
|
||||
}
|
||||
|
||||
type feishuSender struct {
|
||||
SenderID *feishuSenderID `json:"sender_id"`
|
||||
}
|
||||
|
||||
type feishuSenderID struct {
|
||||
OpenID string `json:"open_id"`
|
||||
}
|
||||
|
||||
// ParseCallback parses a Feishu event callback into a unified IncomingMessage.
|
||||
func (a *Adapter) ParseCallback(c *gin.Context) (*im.IncomingMessage, error) {
|
||||
bodyBytes, err := io.ReadAll(c.Request.Body)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("read body: %w", err)
|
||||
}
|
||||
|
||||
var raw []byte
|
||||
|
||||
// Handle encrypted events
|
||||
var encryptedBody struct {
|
||||
Encrypt string `json:"encrypt"`
|
||||
}
|
||||
if err := json.Unmarshal(bodyBytes, &encryptedBody); err == nil && encryptedBody.Encrypt != "" {
|
||||
decrypted, err := a.decrypt(encryptedBody.Encrypt)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("decrypt event: %w", err)
|
||||
}
|
||||
raw = decrypted
|
||||
} else {
|
||||
raw = bodyBytes
|
||||
}
|
||||
|
||||
var eventBody feishuEventBody
|
||||
if err := json.Unmarshal(raw, &eventBody); err != nil {
|
||||
return nil, fmt.Errorf("unmarshal event: %w", err)
|
||||
}
|
||||
|
||||
// Token verification is handled by VerifyCallback; no need to re-check here.
|
||||
|
||||
// Check event type
|
||||
if eventBody.Header == nil || eventBody.Header.EventType != "im.message.receive_v1" {
|
||||
if eventBody.Header != nil {
|
||||
logger.Infof(c.Request.Context(), "[Feishu] Ignoring event type: %s", eventBody.Header.EventType)
|
||||
}
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
// Extract message info
|
||||
if eventBody.Event == nil || eventBody.Event.Message == nil {
|
||||
return nil, nil
|
||||
}
|
||||
msg := eventBody.Event.Message
|
||||
|
||||
if msg.MessageType != "text" {
|
||||
logger.Infof(c.Request.Context(), "[Feishu] Ignoring non-text message type: %s", msg.MessageType)
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
// Parse text content
|
||||
var textContent struct {
|
||||
Text string `json:"text"`
|
||||
}
|
||||
if err := json.Unmarshal([]byte(msg.Content), &textContent); err != nil {
|
||||
return nil, fmt.Errorf("unmarshal text content: %w", err)
|
||||
}
|
||||
|
||||
// Determine chat type
|
||||
chatType := im.ChatTypeDirect
|
||||
chatID := ""
|
||||
if msg.ChatType == "group" {
|
||||
chatType = im.ChatTypeGroup
|
||||
chatID = msg.ChatID
|
||||
}
|
||||
|
||||
// Get sender info
|
||||
openID := ""
|
||||
if eventBody.Event.Sender != nil && eventBody.Event.Sender.SenderID != nil {
|
||||
openID = eventBody.Event.Sender.SenderID.OpenID
|
||||
}
|
||||
|
||||
// Strip @bot mention from group messages
|
||||
content := textContent.Text
|
||||
if chatType == im.ChatTypeGroup {
|
||||
// Feishu @mentions are in the format @_user_xxx
|
||||
for strings.HasPrefix(content, "@_user_") {
|
||||
idx := strings.Index(content, " ")
|
||||
if idx >= 0 {
|
||||
content = content[idx+1:]
|
||||
} else {
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return &im.IncomingMessage{
|
||||
Platform: im.PlatformFeishu,
|
||||
UserID: openID,
|
||||
ChatID: chatID,
|
||||
ChatType: chatType,
|
||||
Content: strings.TrimSpace(content),
|
||||
MessageID: msg.MessageID,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// SendReply sends a reply message via Feishu API.
|
||||
func (a *Adapter) SendReply(ctx context.Context, incoming *im.IncomingMessage, reply *im.ReplyMessage) error {
|
||||
accessToken, err := a.getTenantAccessToken(ctx)
|
||||
if err != nil {
|
||||
return fmt.Errorf("get access token: %w", err)
|
||||
}
|
||||
|
||||
// Determine receive_id_type and receive_id
|
||||
receiveIDType := "open_id"
|
||||
receiveID := incoming.UserID
|
||||
if incoming.ChatType == im.ChatTypeGroup && incoming.ChatID != "" {
|
||||
receiveIDType = "chat_id"
|
||||
receiveID = incoming.ChatID
|
||||
}
|
||||
|
||||
// Build text message
|
||||
content, _ := json.Marshal(map[string]string{"text": reply.Content})
|
||||
payload := map[string]interface{}{
|
||||
"receive_id": receiveID,
|
||||
"msg_type": "text",
|
||||
"content": string(content),
|
||||
}
|
||||
|
||||
payloadBytes, _ := json.Marshal(payload)
|
||||
|
||||
url := fmt.Sprintf("https://open.feishu.cn/open-apis/im/v1/messages?receive_id_type=%s", receiveIDType)
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewReader(payloadBytes))
|
||||
if err != nil {
|
||||
return fmt.Errorf("create request: %w", err)
|
||||
}
|
||||
req.Header.Set("Content-Type", "application/json; charset=utf-8")
|
||||
req.Header.Set("Authorization", "Bearer "+accessToken)
|
||||
|
||||
resp, err := httpClient.Do(req)
|
||||
if err != nil {
|
||||
return fmt.Errorf("send message: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
var result struct {
|
||||
Code int `json:"code"`
|
||||
Msg string `json:"msg"`
|
||||
}
|
||||
if err := json.NewDecoder(resp.Body).Decode(&result); err != nil {
|
||||
return fmt.Errorf("decode response: %w", err)
|
||||
}
|
||||
if result.Code != 0 {
|
||||
return fmt.Errorf("feishu api error: code=%d msg=%s", result.Code, result.Msg)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// getTenantAccessToken retrieves the Feishu tenant access token with caching.
|
||||
// Feishu tokens expire in 2 hours; we cache with a safety margin.
|
||||
func (a *Adapter) getTenantAccessToken(ctx context.Context) (string, error) {
|
||||
a.tokenMu.Lock()
|
||||
defer a.tokenMu.Unlock()
|
||||
|
||||
if a.tokenCache != "" && time.Now().Before(a.tokenExpAt) {
|
||||
return a.tokenCache, nil
|
||||
}
|
||||
|
||||
payload, _ := json.Marshal(map[string]string{
|
||||
"app_id": a.appID,
|
||||
"app_secret": a.appSecret,
|
||||
})
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodPost,
|
||||
"https://open.feishu.cn/open-apis/auth/v3/tenant_access_token/internal",
|
||||
bytes.NewReader(payload))
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("create request: %w", err)
|
||||
}
|
||||
req.Header.Set("Content-Type", "application/json; charset=utf-8")
|
||||
|
||||
resp, err := httpClient.Do(req)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("request token: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
var result struct {
|
||||
Code int `json:"code"`
|
||||
Msg string `json:"msg"`
|
||||
TenantAccessToken string `json:"tenant_access_token"`
|
||||
Expire int `json:"expire"` // seconds
|
||||
}
|
||||
if err := json.NewDecoder(resp.Body).Decode(&result); err != nil {
|
||||
return "", fmt.Errorf("decode response: %w", err)
|
||||
}
|
||||
if result.Code != 0 {
|
||||
return "", fmt.Errorf("get token error: code=%d msg=%s", result.Code, result.Msg)
|
||||
}
|
||||
|
||||
a.tokenCache = result.TenantAccessToken
|
||||
// Cache with 5-minute safety margin
|
||||
ttl := time.Duration(result.Expire) * time.Second
|
||||
if ttl > 5*time.Minute {
|
||||
ttl -= 5 * time.Minute
|
||||
}
|
||||
a.tokenExpAt = time.Now().Add(ttl)
|
||||
|
||||
return a.tokenCache, nil
|
||||
}
|
||||
|
||||
// decrypt decrypts a Feishu encrypted event body.
|
||||
// Feishu uses AES-256-CBC with SHA-256 of the encrypt key as the AES key.
|
||||
func (a *Adapter) decrypt(encrypted string) ([]byte, error) {
|
||||
if a.encryptKey == "" {
|
||||
return nil, fmt.Errorf("encrypt_key not configured")
|
||||
}
|
||||
|
||||
ciphertext, err := base64.StdEncoding.DecodeString(encrypted)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("base64 decode: %w", err)
|
||||
}
|
||||
|
||||
// SHA-256 of encrypt key as AES key
|
||||
keyHash := sha256.Sum256([]byte(a.encryptKey))
|
||||
block, err := aes.NewCipher(keyHash[:])
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("new cipher: %w", err)
|
||||
}
|
||||
|
||||
if len(ciphertext) < aes.BlockSize {
|
||||
return nil, fmt.Errorf("ciphertext too short")
|
||||
}
|
||||
|
||||
iv := ciphertext[:aes.BlockSize]
|
||||
ciphertext = ciphertext[aes.BlockSize:]
|
||||
|
||||
mode := cipher.NewCBCDecrypter(block, iv)
|
||||
mode.CryptBlocks(ciphertext, ciphertext)
|
||||
|
||||
// Remove and verify PKCS#7 padding
|
||||
if len(ciphertext) == 0 {
|
||||
return nil, fmt.Errorf("empty plaintext")
|
||||
}
|
||||
padLen := int(ciphertext[len(ciphertext)-1])
|
||||
if padLen > aes.BlockSize || padLen == 0 || padLen > len(ciphertext) {
|
||||
return nil, fmt.Errorf("invalid padding")
|
||||
}
|
||||
for i := 0; i < padLen; i++ {
|
||||
if ciphertext[len(ciphertext)-1-i] != byte(padLen) {
|
||||
return nil, fmt.Errorf("invalid padding")
|
||||
}
|
||||
}
|
||||
|
||||
return ciphertext[:len(ciphertext)-padLen], nil
|
||||
}
|
||||
149
internal/im/feishu/longconn.go
Normal file
149
internal/im/feishu/longconn.go
Normal file
@@ -0,0 +1,149 @@
|
||||
package feishu
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/Tencent/WeKnora/internal/im"
|
||||
"github.com/Tencent/WeKnora/internal/logger"
|
||||
"github.com/larksuite/oapi-sdk-go/v3/event/dispatcher"
|
||||
larkim "github.com/larksuite/oapi-sdk-go/v3/service/im/v1"
|
||||
larkws "github.com/larksuite/oapi-sdk-go/v3/ws"
|
||||
)
|
||||
|
||||
// MessageHandler is called when an IM message is received via long connection.
|
||||
type MessageHandler func(ctx context.Context, msg *im.IncomingMessage) error
|
||||
|
||||
// LongConnClient manages a Feishu WebSocket long connection.
|
||||
type LongConnClient struct {
|
||||
appID string
|
||||
wsClient *larkws.Client
|
||||
}
|
||||
|
||||
// NewLongConnClient creates a Feishu long connection client.
|
||||
// When a text message arrives, it converts it to IncomingMessage and calls handler.
|
||||
func NewLongConnClient(appID, appSecret string, handler MessageHandler) *LongConnClient {
|
||||
// Long connection mode does not require verificationToken or encryptKey;
|
||||
// those are only used for webhook signature verification and decryption.
|
||||
eventHandler := dispatcher.NewEventDispatcher("", "").
|
||||
OnP2MessageReceiveV1(func(ctx context.Context, event *larkim.P2MessageReceiveV1) error {
|
||||
msg := convertEvent(event)
|
||||
if msg == nil {
|
||||
return nil
|
||||
}
|
||||
return handler(ctx, msg)
|
||||
})
|
||||
|
||||
sdkLogger := &feishuLoggerAdapter{appID: appID}
|
||||
|
||||
wsClient := larkws.NewClient(appID, appSecret,
|
||||
larkws.WithEventHandler(eventHandler),
|
||||
larkws.WithAutoReconnect(true),
|
||||
larkws.WithLogger(sdkLogger),
|
||||
)
|
||||
|
||||
return &LongConnClient{appID: appID, wsClient: wsClient}
|
||||
}
|
||||
|
||||
// Start begins the WebSocket long connection. It blocks until ctx is cancelled.
|
||||
func (c *LongConnClient) Start(ctx context.Context) error {
|
||||
logger.Infof(ctx, "[IM] Feishu WebSocket connecting (app_id=%s)...", c.appID)
|
||||
return c.wsClient.Start(ctx)
|
||||
}
|
||||
|
||||
// feishuLoggerAdapter bridges the Feishu SDK logger to our unified logger,
|
||||
// replacing raw SDK connection messages with a consistent format.
|
||||
type feishuLoggerAdapter struct {
|
||||
appID string
|
||||
}
|
||||
|
||||
func (l *feishuLoggerAdapter) Debug(ctx context.Context, args ...interface{}) {
|
||||
logger.Debugf(ctx, "[Feishu] %s", fmt.Sprint(args...))
|
||||
}
|
||||
|
||||
func (l *feishuLoggerAdapter) Info(ctx context.Context, args ...interface{}) {
|
||||
msg := fmt.Sprint(args...)
|
||||
if strings.HasPrefix(msg, "connected to ") {
|
||||
logger.Infof(ctx, "[IM] Feishu WebSocket connected successfully (app_id=%s)", l.appID)
|
||||
return
|
||||
}
|
||||
logger.Infof(ctx, "[Feishu] %s", msg)
|
||||
}
|
||||
|
||||
func (l *feishuLoggerAdapter) Warn(ctx context.Context, args ...interface{}) {
|
||||
logger.Warnf(ctx, "[Feishu] %s", fmt.Sprint(args...))
|
||||
}
|
||||
|
||||
func (l *feishuLoggerAdapter) Error(ctx context.Context, args ...interface{}) {
|
||||
logger.Errorf(ctx, "[Feishu] %s", fmt.Sprint(args...))
|
||||
}
|
||||
|
||||
// convertEvent converts a Feishu SDK event to a unified IncomingMessage.
|
||||
// Returns nil for non-text messages.
|
||||
func convertEvent(event *larkim.P2MessageReceiveV1) *im.IncomingMessage {
|
||||
if event == nil || event.Event == nil || event.Event.Message == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
msg := event.Event.Message
|
||||
if msg.MessageType == nil || *msg.MessageType != "text" {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Parse text content from JSON: {"text": "..."}
|
||||
var textContent struct {
|
||||
Text string `json:"text"`
|
||||
}
|
||||
if msg.Content == nil {
|
||||
return nil
|
||||
}
|
||||
if err := json.Unmarshal([]byte(*msg.Content), &textContent); err != nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Sender info
|
||||
openID := ""
|
||||
if event.Event.Sender != nil && event.Event.Sender.SenderId != nil && event.Event.Sender.SenderId.OpenId != nil {
|
||||
openID = *event.Event.Sender.SenderId.OpenId
|
||||
}
|
||||
|
||||
// Chat type
|
||||
chatType := im.ChatTypeDirect
|
||||
chatID := ""
|
||||
if msg.ChatType != nil && *msg.ChatType == "group" {
|
||||
chatType = im.ChatTypeGroup
|
||||
if msg.ChatId != nil {
|
||||
chatID = *msg.ChatId
|
||||
}
|
||||
}
|
||||
|
||||
// Message ID
|
||||
messageID := ""
|
||||
if msg.MessageId != nil {
|
||||
messageID = *msg.MessageId
|
||||
}
|
||||
|
||||
// Strip @bot mention from group messages
|
||||
content := textContent.Text
|
||||
if chatType == im.ChatTypeGroup {
|
||||
for strings.HasPrefix(content, "@_user_") {
|
||||
idx := strings.Index(content, " ")
|
||||
if idx >= 0 {
|
||||
content = content[idx+1:]
|
||||
} else {
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return &im.IncomingMessage{
|
||||
Platform: im.PlatformFeishu,
|
||||
UserID: openID,
|
||||
ChatID: chatID,
|
||||
ChatType: chatType,
|
||||
Content: strings.TrimSpace(content),
|
||||
MessageID: messageID,
|
||||
}
|
||||
}
|
||||
378
internal/im/service.go
Normal file
378
internal/im/service.go
Normal file
@@ -0,0 +1,378 @@
|
||||
package im
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/Tencent/WeKnora/internal/event"
|
||||
"github.com/Tencent/WeKnora/internal/logger"
|
||||
"github.com/Tencent/WeKnora/internal/types"
|
||||
"github.com/Tencent/WeKnora/internal/types/interfaces"
|
||||
"github.com/google/uuid"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
const (
|
||||
// qaTimeout is the maximum time to wait for the QA pipeline to complete.
|
||||
qaTimeout = 120 * time.Second
|
||||
// dedupTTL is how long processed message IDs are retained.
|
||||
dedupTTL = 5 * time.Minute
|
||||
// dedupCleanupInterval is how often the dedup map is cleaned.
|
||||
dedupCleanupInterval = 1 * time.Minute
|
||||
// maxContentLength is the maximum allowed message content length.
|
||||
maxContentLength = 4096
|
||||
)
|
||||
|
||||
// Service orchestrates IM message handling:
|
||||
// 1. Receives a unified IncomingMessage from an Adapter
|
||||
// 2. Resolves or creates a WeKnora session for the IM channel
|
||||
// 3. Calls the WeKnora QA pipeline
|
||||
// 4. Collects the streaming answer and sends it back via the Adapter
|
||||
type Service struct {
|
||||
db *gorm.DB
|
||||
sessionService interfaces.SessionService
|
||||
messageService interfaces.MessageService
|
||||
tenantService interfaces.TenantService
|
||||
agentService interfaces.CustomAgentService
|
||||
|
||||
adapters map[Platform]Adapter
|
||||
mu sync.RWMutex
|
||||
|
||||
// processedMsgs tracks recently processed message IDs to prevent duplicate handling.
|
||||
processedMsgs sync.Map
|
||||
|
||||
stopCh chan struct{}
|
||||
}
|
||||
|
||||
// NewService creates a new IM service.
|
||||
func NewService(
|
||||
db *gorm.DB,
|
||||
sessionService interfaces.SessionService,
|
||||
messageService interfaces.MessageService,
|
||||
tenantService interfaces.TenantService,
|
||||
agentService interfaces.CustomAgentService,
|
||||
) *Service {
|
||||
s := &Service{
|
||||
db: db,
|
||||
sessionService: sessionService,
|
||||
messageService: messageService,
|
||||
tenantService: tenantService,
|
||||
agentService: agentService,
|
||||
adapters: make(map[Platform]Adapter),
|
||||
stopCh: make(chan struct{}),
|
||||
}
|
||||
|
||||
// Start periodic dedup cleanup instead of per-message goroutines
|
||||
go s.dedupCleanupLoop()
|
||||
|
||||
return s
|
||||
}
|
||||
|
||||
// Stop gracefully shuts down the service, stopping background goroutines.
|
||||
func (s *Service) Stop() {
|
||||
close(s.stopCh)
|
||||
}
|
||||
|
||||
// dedupCleanupLoop periodically cleans up expired entries from the dedup map.
|
||||
func (s *Service) dedupCleanupLoop() {
|
||||
ticker := time.NewTicker(dedupCleanupInterval)
|
||||
defer ticker.Stop()
|
||||
for {
|
||||
select {
|
||||
case <-ticker.C:
|
||||
cutoff := time.Now().Add(-dedupTTL)
|
||||
s.processedMsgs.Range(func(key, value interface{}) bool {
|
||||
if t, ok := value.(time.Time); ok && t.Before(cutoff) {
|
||||
s.processedMsgs.Delete(key)
|
||||
}
|
||||
return true
|
||||
})
|
||||
case <-s.stopCh:
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// RegisterAdapter registers an IM platform adapter.
|
||||
func (s *Service) RegisterAdapter(adapter Adapter) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
s.adapters[adapter.Platform()] = adapter
|
||||
}
|
||||
|
||||
// GetAdapter returns the adapter for a given platform.
|
||||
func (s *Service) GetAdapter(platform Platform) (Adapter, bool) {
|
||||
s.mu.RLock()
|
||||
defer s.mu.RUnlock()
|
||||
a, ok := s.adapters[platform]
|
||||
return a, ok
|
||||
}
|
||||
|
||||
// HandleMessage processes an incoming IM message end-to-end:
|
||||
// resolves session, runs QA, sends reply.
|
||||
func (s *Service) HandleMessage(ctx context.Context, msg *IncomingMessage, tenantID uint64, agentID string, kbIDs []string) error {
|
||||
// Dedup: skip if this message was already processed (IM platforms may retry)
|
||||
if msg.MessageID != "" {
|
||||
if _, loaded := s.processedMsgs.LoadOrStore(msg.MessageID, time.Now()); loaded {
|
||||
logger.Infof(ctx, "[IM] Skipping duplicate message: %s", msg.MessageID)
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
// Reject overly long messages to protect the QA pipeline
|
||||
contentRunes := []rune(msg.Content)
|
||||
if len(contentRunes) > maxContentLength {
|
||||
logger.Warnf(ctx, "[IM] Message too long (%d runes), truncating to %d", len(contentRunes), maxContentLength)
|
||||
msg.Content = string(contentRunes[:maxContentLength])
|
||||
}
|
||||
|
||||
logger.Infof(ctx, "[IM] HandleMessage: platform=%s user=%s chat=%s content_len=%d",
|
||||
msg.Platform, msg.UserID, msg.ChatID, len(msg.Content))
|
||||
|
||||
// 1. Get tenant (once, shared across resolve + QA)
|
||||
tenant, err := s.tenantService.GetTenantByID(ctx, tenantID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("get tenant: %w", err)
|
||||
}
|
||||
sessionCtx := context.WithValue(ctx, types.TenantIDContextKey, tenantID)
|
||||
sessionCtx = context.WithValue(sessionCtx, types.TenantInfoContextKey, tenant)
|
||||
|
||||
// 2. Resolve or create a WeKnora session
|
||||
channelSession, err := s.resolveSession(sessionCtx, msg, tenantID, agentID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("resolve session: %w", err)
|
||||
}
|
||||
|
||||
// 3. Get the WeKnora session
|
||||
session, err := s.sessionService.GetSession(sessionCtx, channelSession.SessionID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("get session: %w", err)
|
||||
}
|
||||
|
||||
// 4. Resolve custom agent (optional)
|
||||
var customAgent *types.CustomAgent
|
||||
if agentID != "" {
|
||||
agent, err := s.agentService.GetAgentByID(sessionCtx, agentID)
|
||||
if err != nil {
|
||||
logger.Warnf(ctx, "[IM] Failed to get agent %s: %v, using default", agentID, err)
|
||||
} else {
|
||||
customAgent = agent
|
||||
}
|
||||
}
|
||||
|
||||
// 5. Run the QA pipeline and collect the full answer
|
||||
answer, err := s.runQA(sessionCtx, session, msg.Content, customAgent, kbIDs)
|
||||
if err != nil {
|
||||
logger.Errorf(ctx, "[IM] QA failed: %v, sending fallback reply", err)
|
||||
answer = "抱歉,处理您的问题时出现了异常,请稍后再试。"
|
||||
}
|
||||
|
||||
// 6. Send the reply back via the platform adapter
|
||||
adapter, ok := s.GetAdapter(msg.Platform)
|
||||
if !ok {
|
||||
return fmt.Errorf("no adapter for platform: %s", msg.Platform)
|
||||
}
|
||||
|
||||
reply := &ReplyMessage{
|
||||
Content: answer,
|
||||
IsFinal: true,
|
||||
}
|
||||
if err := adapter.SendReply(ctx, msg, reply); err != nil {
|
||||
return fmt.Errorf("send reply: %w", err)
|
||||
}
|
||||
|
||||
logger.Infof(ctx, "[IM] Reply sent: platform=%s user=%s answer_len=%d", msg.Platform, msg.UserID, len(answer))
|
||||
return nil
|
||||
}
|
||||
|
||||
// resolveSession finds or creates a ChannelSession for the given IM message.
|
||||
// ctx must already carry TenantIDContextKey and TenantInfoContextKey.
|
||||
func (s *Service) resolveSession(ctx context.Context, msg *IncomingMessage, tenantID uint64, agentID string) (*ChannelSession, error) {
|
||||
var cs ChannelSession
|
||||
result := s.db.Where("platform = ? AND user_id = ? AND chat_id = ? AND tenant_id = ? AND deleted_at IS NULL",
|
||||
string(msg.Platform), msg.UserID, msg.ChatID, tenantID).
|
||||
First(&cs)
|
||||
|
||||
if result.Error == nil {
|
||||
return &cs, nil
|
||||
}
|
||||
|
||||
if result.Error != gorm.ErrRecordNotFound {
|
||||
return nil, fmt.Errorf("query channel session: %w", result.Error)
|
||||
}
|
||||
|
||||
// Create a new WeKnora session
|
||||
title := fmt.Sprintf("IM-%s", msg.Platform)
|
||||
if msg.UserName != "" {
|
||||
title = fmt.Sprintf("IM-%s-%s", msg.Platform, msg.UserName)
|
||||
}
|
||||
|
||||
newSession := &types.Session{
|
||||
TenantID: tenantID,
|
||||
Title: title,
|
||||
Description: fmt.Sprintf("Auto-created from %s IM integration", msg.Platform),
|
||||
}
|
||||
|
||||
createdSession, err := s.sessionService.CreateSession(ctx, newSession)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("create session: %w", err)
|
||||
}
|
||||
|
||||
// Create the channel-session mapping; use a unique constraint fallback
|
||||
// to handle concurrent creation attempts for the same channel.
|
||||
cs = ChannelSession{
|
||||
Platform: string(msg.Platform),
|
||||
UserID: msg.UserID,
|
||||
ChatID: msg.ChatID,
|
||||
SessionID: createdSession.ID,
|
||||
TenantID: tenantID,
|
||||
AgentID: agentID,
|
||||
}
|
||||
if err := s.db.Create(&cs).Error; err != nil {
|
||||
// If the insert failed due to unique constraint (concurrent request),
|
||||
// fetch the existing record.
|
||||
var existing ChannelSession
|
||||
if findErr := s.db.Where("platform = ? AND user_id = ? AND chat_id = ? AND tenant_id = ? AND deleted_at IS NULL",
|
||||
string(msg.Platform), msg.UserID, msg.ChatID, tenantID).
|
||||
First(&existing).Error; findErr != nil {
|
||||
return nil, fmt.Errorf("create channel session: %w (lookup fallback: %v)", err, findErr)
|
||||
}
|
||||
return &existing, nil
|
||||
}
|
||||
|
||||
logger.Infof(ctx, "[IM] Created new session mapping: channel=%s/%s/%s -> session=%s",
|
||||
msg.Platform, msg.UserID, msg.ChatID, createdSession.ID)
|
||||
|
||||
return &cs, nil
|
||||
}
|
||||
|
||||
// runQA executes the WeKnora QA pipeline and returns the full answer text.
|
||||
func (s *Service) runQA(ctx context.Context, session *types.Session, query string, customAgent *types.CustomAgent, kbIDs []string) (string, error) {
|
||||
// Add timeout to prevent indefinite blocking
|
||||
ctx, cancel := context.WithTimeout(ctx, qaTimeout)
|
||||
defer cancel()
|
||||
|
||||
eventBus := event.NewEventBus()
|
||||
|
||||
// Thread-safe answer collection
|
||||
var answerMu sync.Mutex
|
||||
var answerBuilder strings.Builder
|
||||
var qaErr error
|
||||
done := make(chan struct{})
|
||||
var closeOnce sync.Once
|
||||
closeDone := func() { closeOnce.Do(func() { close(done) }) }
|
||||
|
||||
eventBus.On(event.EventAgentFinalAnswer, func(ctx context.Context, evt event.Event) error {
|
||||
data, ok := evt.Data.(event.AgentFinalAnswerData)
|
||||
if !ok {
|
||||
return nil
|
||||
}
|
||||
answerMu.Lock()
|
||||
answerBuilder.WriteString(data.Content)
|
||||
answerMu.Unlock()
|
||||
if data.Done {
|
||||
closeDone()
|
||||
}
|
||||
return nil
|
||||
})
|
||||
|
||||
eventBus.On(event.EventError, func(ctx context.Context, evt event.Event) error {
|
||||
data, ok := evt.Data.(event.ErrorData)
|
||||
if !ok {
|
||||
return nil
|
||||
}
|
||||
logger.Errorf(ctx, "[IM] QA error: %s", data.Error)
|
||||
answerMu.Lock()
|
||||
qaErr = fmt.Errorf("QA pipeline error: %s", data.Error)
|
||||
answerMu.Unlock()
|
||||
closeDone()
|
||||
return nil
|
||||
})
|
||||
|
||||
// Determine whether to use agent mode
|
||||
useAgent := customAgent != nil && customAgent.IsAgentMode()
|
||||
|
||||
// Generate a shared RequestID to pair user and assistant messages for history
|
||||
requestID := uuid.New().String()
|
||||
|
||||
// Create user message so it appears in conversation history
|
||||
userMsg, err := s.messageService.CreateMessage(ctx, &types.Message{
|
||||
SessionID: session.ID,
|
||||
Role: "user",
|
||||
Content: query,
|
||||
RequestID: requestID,
|
||||
CreatedAt: time.Now(),
|
||||
IsCompleted: true,
|
||||
})
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("create user message: %w", err)
|
||||
}
|
||||
_ = userMsg
|
||||
|
||||
// Create a placeholder assistant message
|
||||
assistantMsg, err := s.messageService.CreateMessage(ctx, &types.Message{
|
||||
SessionID: session.ID,
|
||||
Role: "assistant",
|
||||
RequestID: requestID,
|
||||
CreatedAt: time.Now(),
|
||||
IsCompleted: false,
|
||||
})
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("create assistant message: %w", err)
|
||||
}
|
||||
|
||||
// Run QA async
|
||||
go func() {
|
||||
var err error
|
||||
if useAgent {
|
||||
err = s.sessionService.AgentQA(ctx, session, query, assistantMsg.ID, "", eventBus, customAgent, kbIDs, nil)
|
||||
} else {
|
||||
err = s.sessionService.KnowledgeQA(ctx, session, query, kbIDs, nil, assistantMsg.ID, "", false, eventBus, customAgent, false)
|
||||
}
|
||||
if err != nil {
|
||||
logger.Errorf(ctx, "[IM] QA execution error: %v", err)
|
||||
answerMu.Lock()
|
||||
qaErr = fmt.Errorf("QA execution error: %w", err)
|
||||
answerMu.Unlock()
|
||||
closeDone()
|
||||
}
|
||||
}()
|
||||
|
||||
// Wait for completion or timeout
|
||||
select {
|
||||
case <-done:
|
||||
case <-ctx.Done():
|
||||
// Mark assistant message as completed to avoid dangling incomplete records
|
||||
assistantMsg.Content = "抱歉,回答超时,请稍后再试。"
|
||||
assistantMsg.IsCompleted = true
|
||||
// Use a fresh context since the original is cancelled
|
||||
if updateErr := s.messageService.UpdateMessage(context.WithoutCancel(ctx), assistantMsg); updateErr != nil {
|
||||
logger.Warnf(ctx, "[IM] Failed to update timed-out assistant message: %v", updateErr)
|
||||
}
|
||||
return "", fmt.Errorf("QA timed out after %v", qaTimeout)
|
||||
}
|
||||
|
||||
answerMu.Lock()
|
||||
answer := answerBuilder.String()
|
||||
qaError := qaErr
|
||||
answerMu.Unlock()
|
||||
|
||||
if answer == "" && qaError != nil {
|
||||
return "", qaError
|
||||
}
|
||||
if answer == "" {
|
||||
answer = "抱歉,我暂时无法回答这个问题。"
|
||||
}
|
||||
|
||||
// Update assistant message with the final answer content
|
||||
assistantMsg.Content = answer
|
||||
assistantMsg.IsCompleted = true
|
||||
if err := s.messageService.UpdateMessage(ctx, assistantMsg); err != nil {
|
||||
logger.Warnf(ctx, "[IM] Failed to update assistant message: %v", err)
|
||||
}
|
||||
|
||||
return answer, nil
|
||||
}
|
||||
40
internal/im/types.go
Normal file
40
internal/im/types.go
Normal file
@@ -0,0 +1,40 @@
|
||||
package im
|
||||
|
||||
import (
|
||||
"time"
|
||||
|
||||
"github.com/Tencent/WeKnora/internal/types"
|
||||
"github.com/google/uuid"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
// ChannelSession maps an IM channel (user+chat combination) to a WeKnora session.
|
||||
// This allows the IM integration to maintain conversation continuity.
|
||||
type ChannelSession struct {
|
||||
ID string `json:"id" gorm:"type:varchar(36);primaryKey;default:uuid_generate_v4()"`
|
||||
Platform string `json:"platform" gorm:"type:varchar(20);not null"`
|
||||
UserID string `json:"user_id" gorm:"type:varchar(128);not null"`
|
||||
ChatID string `json:"chat_id" gorm:"type:varchar(128);not null;default:''"`
|
||||
SessionID string `json:"session_id" gorm:"type:varchar(36);not null;index"`
|
||||
TenantID uint64 `json:"tenant_id" gorm:"not null;index"`
|
||||
AgentID string `json:"agent_id" gorm:"type:varchar(36);default:''"`
|
||||
Status string `json:"status" gorm:"type:varchar(20);not null;default:'active'"`
|
||||
Metadata types.JSON `json:"metadata" gorm:"type:jsonb;default:'{}'"`
|
||||
CreatedAt time.Time `json:"created_at"`
|
||||
UpdatedAt time.Time `json:"updated_at"`
|
||||
DeletedAt gorm.DeletedAt `json:"deleted_at" gorm:"index"`
|
||||
}
|
||||
|
||||
func (ChannelSession) TableName() string {
|
||||
return "im_channel_sessions"
|
||||
}
|
||||
|
||||
func (cs *ChannelSession) BeforeCreate(tx *gorm.DB) error {
|
||||
if cs.ID == "" {
|
||||
cs.ID = uuid.New().String()
|
||||
}
|
||||
if cs.Status == "" {
|
||||
cs.Status = "active"
|
||||
}
|
||||
return nil
|
||||
}
|
||||
369
internal/im/wecom/adapter.go
Normal file
369
internal/im/wecom/adapter.go
Normal file
@@ -0,0 +1,369 @@
|
||||
// Package wecom implements the WeCom (企业微信) IM adapter for WeKnora.
|
||||
//
|
||||
// WeCom Smart Bot flow:
|
||||
// 1. User sends a message to the bot (direct or @mention in group)
|
||||
// 2. WeCom calls our callback URL with the encrypted message
|
||||
// 3. We decrypt, parse, and return an immediate response (or stream response)
|
||||
// 4. For streaming: respond with msgtype="stream", WeCom pulls subsequent chunks via refresh callbacks
|
||||
//
|
||||
// Reference: https://developer.work.weixin.qq.com/document/path/101031
|
||||
package wecom
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"crypto/aes"
|
||||
"crypto/cipher"
|
||||
"crypto/hmac"
|
||||
"crypto/sha1"
|
||||
"encoding/base64"
|
||||
"encoding/binary"
|
||||
"encoding/json"
|
||||
"encoding/xml"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"sort"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/Tencent/WeKnora/internal/im"
|
||||
"github.com/Tencent/WeKnora/internal/logger"
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
var httpClient = &http.Client{Timeout: 10 * time.Second}
|
||||
|
||||
// Adapter implements im.Adapter for WeCom.
|
||||
type Adapter struct {
|
||||
corpID string
|
||||
token string
|
||||
encodingAESKey string
|
||||
aesKey []byte
|
||||
agentSecret string
|
||||
corpAgentID int
|
||||
|
||||
// Token cache
|
||||
tokenMu sync.Mutex
|
||||
tokenCache string
|
||||
tokenExpAt time.Time
|
||||
}
|
||||
|
||||
// NewAdapter creates a new WeCom adapter.
|
||||
func NewAdapter(corpID, agentSecret, token, encodingAESKey string, corpAgentID int) (*Adapter, error) {
|
||||
// Decode the AES key from base64
|
||||
aesKey, err := base64.StdEncoding.DecodeString(encodingAESKey + "=")
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("decode encoding_aes_key: %w", err)
|
||||
}
|
||||
|
||||
return &Adapter{
|
||||
corpID: corpID,
|
||||
token: token,
|
||||
encodingAESKey: encodingAESKey,
|
||||
aesKey: aesKey,
|
||||
agentSecret: agentSecret,
|
||||
corpAgentID: corpAgentID,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Platform returns the platform identifier.
|
||||
func (a *Adapter) Platform() im.Platform {
|
||||
return im.PlatformWeCom
|
||||
}
|
||||
|
||||
// VerifyCallback verifies the WeCom callback signature.
|
||||
func (a *Adapter) VerifyCallback(c *gin.Context) error {
|
||||
timestamp := c.Query("timestamp")
|
||||
nonce := c.Query("nonce")
|
||||
msgSignature := c.Query("msg_signature")
|
||||
|
||||
// For GET requests (URL verification), use echostr
|
||||
// For POST requests (message callback), use request body's Encrypt field
|
||||
var encrypt string
|
||||
if c.Request.Method == http.MethodGet {
|
||||
encrypt = c.Query("echostr")
|
||||
} else {
|
||||
var body callbackRequestBody
|
||||
bodyBytes, err := io.ReadAll(c.Request.Body)
|
||||
if err != nil {
|
||||
return fmt.Errorf("read request body: %w", err)
|
||||
}
|
||||
c.Request.Body = io.NopCloser(bytes.NewReader(bodyBytes))
|
||||
if err := xml.Unmarshal(bodyBytes, &body); err != nil {
|
||||
return fmt.Errorf("unmarshal xml body: %w", err)
|
||||
}
|
||||
encrypt = body.Encrypt
|
||||
}
|
||||
|
||||
if !a.verifySignature(msgSignature, timestamp, nonce, encrypt) {
|
||||
return fmt.Errorf("invalid signature")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// HandleURLVerification handles the WeCom URL verification (GET request).
|
||||
func (a *Adapter) HandleURLVerification(c *gin.Context) bool {
|
||||
if c.Request.Method != http.MethodGet {
|
||||
return false
|
||||
}
|
||||
|
||||
echoStr := c.Query("echostr")
|
||||
if echoStr == "" {
|
||||
return false
|
||||
}
|
||||
|
||||
// Decrypt the echostr and return it
|
||||
decrypted, err := a.decrypt(echoStr)
|
||||
if err != nil {
|
||||
logger.Errorf(c.Request.Context(), "[WeCom] Failed to decrypt echostr: %v", err)
|
||||
c.String(http.StatusBadRequest, "decrypt failed")
|
||||
return true
|
||||
}
|
||||
|
||||
c.String(http.StatusOK, string(decrypted))
|
||||
return true
|
||||
}
|
||||
|
||||
// ParseCallback parses a WeCom callback into a unified IncomingMessage.
|
||||
func (a *Adapter) ParseCallback(c *gin.Context) (*im.IncomingMessage, error) {
|
||||
bodyBytes, err := io.ReadAll(c.Request.Body)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("read body: %w", err)
|
||||
}
|
||||
|
||||
var body callbackRequestBody
|
||||
if err := xml.Unmarshal(bodyBytes, &body); err != nil {
|
||||
return nil, fmt.Errorf("unmarshal xml: %w", err)
|
||||
}
|
||||
|
||||
// Decrypt the message
|
||||
decrypted, err := a.decrypt(body.Encrypt)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("decrypt message: %w", err)
|
||||
}
|
||||
|
||||
var msg wecomMessage
|
||||
if err := xml.Unmarshal(decrypted, &msg); err != nil {
|
||||
return nil, fmt.Errorf("unmarshal decrypted message: %w", err)
|
||||
}
|
||||
|
||||
// Only handle text messages
|
||||
if msg.MsgType != "text" {
|
||||
logger.Infof(c.Request.Context(), "[WeCom] Ignoring non-text message type: %s", msg.MsgType)
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
chatType := im.ChatTypeDirect
|
||||
chatID := ""
|
||||
content := msg.Content
|
||||
|
||||
// Check if this is a group message (has ChatId field in WeCom smart bot callback)
|
||||
// In group chats, the content may contain @bot mention prefix that should be stripped
|
||||
if msg.ChatID != "" {
|
||||
chatType = im.ChatTypeGroup
|
||||
chatID = msg.ChatID
|
||||
}
|
||||
|
||||
return &im.IncomingMessage{
|
||||
Platform: im.PlatformWeCom,
|
||||
UserID: msg.FromUserName,
|
||||
UserName: msg.FromUserName,
|
||||
ChatID: chatID,
|
||||
ChatType: chatType,
|
||||
Content: strings.TrimSpace(content),
|
||||
MessageID: msg.MsgID,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// SendReply sends a reply message via WeCom API.
|
||||
// For simplicity, this uses the "应用消息" API to proactively send messages.
|
||||
// A production implementation would integrate with the streaming callback model.
|
||||
func (a *Adapter) SendReply(ctx context.Context, incoming *im.IncomingMessage, reply *im.ReplyMessage) error {
|
||||
// Get access token
|
||||
accessToken, err := a.getAccessToken(ctx)
|
||||
if err != nil {
|
||||
return fmt.Errorf("get access token: %w", err)
|
||||
}
|
||||
|
||||
// Build message payload.
|
||||
// Note: /cgi-bin/message/send only supports touser/toparty/totag — it does NOT
|
||||
// support regular group chat IDs (chatid is only for appchat-created groups).
|
||||
// So for both direct and group messages we reply to the user directly.
|
||||
payload := map[string]interface{}{
|
||||
"touser": incoming.UserID,
|
||||
"msgtype": "markdown",
|
||||
"agentid": a.corpAgentID,
|
||||
"markdown": map[string]string{
|
||||
"content": reply.Content,
|
||||
},
|
||||
}
|
||||
|
||||
payloadBytes, err := json.Marshal(payload)
|
||||
if err != nil {
|
||||
return fmt.Errorf("marshal payload: %w", err)
|
||||
}
|
||||
|
||||
sendURL := fmt.Sprintf("https://qyapi.weixin.qq.com/cgi-bin/message/send?access_token=%s", accessToken)
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodPost, sendURL, bytes.NewReader(payloadBytes))
|
||||
if err != nil {
|
||||
return fmt.Errorf("create request: %w", err)
|
||||
}
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
|
||||
resp, err := httpClient.Do(req)
|
||||
if err != nil {
|
||||
return fmt.Errorf("send message: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
var result struct {
|
||||
ErrCode int `json:"errcode"`
|
||||
ErrMsg string `json:"errmsg"`
|
||||
}
|
||||
if err := json.NewDecoder(resp.Body).Decode(&result); err != nil {
|
||||
return fmt.Errorf("decode response: %w", err)
|
||||
}
|
||||
if result.ErrCode != 0 {
|
||||
return fmt.Errorf("wecom api error: code=%d msg=%s", result.ErrCode, result.ErrMsg)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// getAccessToken retrieves the WeCom access token with caching.
|
||||
// WeCom tokens expire in 7200 seconds (2 hours); we cache with a safety margin.
|
||||
func (a *Adapter) getAccessToken(ctx context.Context) (string, error) {
|
||||
a.tokenMu.Lock()
|
||||
defer a.tokenMu.Unlock()
|
||||
|
||||
if a.tokenCache != "" && time.Now().Before(a.tokenExpAt) {
|
||||
return a.tokenCache, nil
|
||||
}
|
||||
|
||||
tokenURL := fmt.Sprintf("https://qyapi.weixin.qq.com/cgi-bin/gettoken?corpid=%s&corpsecret=%s",
|
||||
a.corpID, a.agentSecret)
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodGet, tokenURL, nil)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("create request: %w", err)
|
||||
}
|
||||
|
||||
resp, err := httpClient.Do(req)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("request access token: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
var result struct {
|
||||
ErrCode int `json:"errcode"`
|
||||
ErrMsg string `json:"errmsg"`
|
||||
AccessToken string `json:"access_token"`
|
||||
ExpiresIn int `json:"expires_in"` // seconds
|
||||
}
|
||||
if err := json.NewDecoder(resp.Body).Decode(&result); err != nil {
|
||||
return "", fmt.Errorf("decode token response: %w", err)
|
||||
}
|
||||
if result.ErrCode != 0 {
|
||||
return "", fmt.Errorf("get token error: code=%d msg=%s", result.ErrCode, result.ErrMsg)
|
||||
}
|
||||
|
||||
a.tokenCache = result.AccessToken
|
||||
// Cache with 5-minute safety margin
|
||||
ttl := time.Duration(result.ExpiresIn) * time.Second
|
||||
if ttl > 5*time.Minute {
|
||||
ttl -= 5 * time.Minute
|
||||
}
|
||||
a.tokenExpAt = time.Now().Add(ttl)
|
||||
|
||||
return a.tokenCache, nil
|
||||
}
|
||||
|
||||
// verifySignature verifies the WeCom callback signature using constant-time comparison.
|
||||
func (a *Adapter) verifySignature(signature, timestamp, nonce, encrypt string) bool {
|
||||
parts := []string{a.token, timestamp, nonce, encrypt}
|
||||
sort.Strings(parts)
|
||||
combined := strings.Join(parts, "")
|
||||
|
||||
hash := sha1.New()
|
||||
hash.Write([]byte(combined))
|
||||
computed := fmt.Sprintf("%x", hash.Sum(nil))
|
||||
|
||||
return hmac.Equal([]byte(computed), []byte(signature))
|
||||
}
|
||||
|
||||
// decrypt decrypts a WeCom AES-encrypted message.
|
||||
func (a *Adapter) decrypt(encrypted string) ([]byte, error) {
|
||||
ciphertext, err := base64.StdEncoding.DecodeString(encrypted)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("base64 decode: %w", err)
|
||||
}
|
||||
|
||||
block, err := aes.NewCipher(a.aesKey)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("new cipher: %w", err)
|
||||
}
|
||||
|
||||
if len(ciphertext) < aes.BlockSize {
|
||||
return nil, fmt.Errorf("ciphertext too short")
|
||||
}
|
||||
|
||||
iv := a.aesKey[:aes.BlockSize]
|
||||
mode := cipher.NewCBCDecrypter(block, iv)
|
||||
mode.CryptBlocks(ciphertext, ciphertext)
|
||||
|
||||
// Remove and verify PKCS#7 padding
|
||||
padLen := int(ciphertext[len(ciphertext)-1])
|
||||
if padLen > aes.BlockSize || padLen == 0 || padLen > len(ciphertext) {
|
||||
return nil, fmt.Errorf("invalid padding")
|
||||
}
|
||||
for i := 0; i < padLen; i++ {
|
||||
if ciphertext[len(ciphertext)-1-i] != byte(padLen) {
|
||||
return nil, fmt.Errorf("invalid padding")
|
||||
}
|
||||
}
|
||||
plaintext := ciphertext[:len(ciphertext)-padLen]
|
||||
|
||||
// WeCom format: random(16) + msg_len(4) + msg + corp_id
|
||||
if len(plaintext) < 20 {
|
||||
return nil, fmt.Errorf("plaintext too short")
|
||||
}
|
||||
|
||||
msgLen := binary.BigEndian.Uint32(plaintext[16:20])
|
||||
if uint32(len(plaintext)) < 20+msgLen {
|
||||
return nil, fmt.Errorf("message length mismatch")
|
||||
}
|
||||
|
||||
msgBytes := plaintext[20 : 20+msgLen]
|
||||
|
||||
// Verify corp_id from plaintext tail
|
||||
corpIDBytes := plaintext[20+msgLen:]
|
||||
if string(corpIDBytes) != a.corpID {
|
||||
return nil, fmt.Errorf("corp_id mismatch: expected %s, got %s", a.corpID, string(corpIDBytes))
|
||||
}
|
||||
|
||||
return msgBytes, nil
|
||||
}
|
||||
|
||||
// callbackRequestBody is the XML structure of a WeCom callback request body.
|
||||
type callbackRequestBody struct {
|
||||
XMLName xml.Name `xml:"xml"`
|
||||
ToUserName string `xml:"ToUserName"`
|
||||
Encrypt string `xml:"Encrypt"`
|
||||
AgentID string `xml:"AgentID"`
|
||||
}
|
||||
|
||||
// wecomMessage is the decrypted WeCom message structure.
|
||||
type wecomMessage struct {
|
||||
XMLName xml.Name `xml:"xml"`
|
||||
ToUserName string `xml:"ToUserName"`
|
||||
FromUserName string `xml:"FromUserName"`
|
||||
CreateTime int64 `xml:"CreateTime"`
|
||||
MsgType string `xml:"MsgType"`
|
||||
Content string `xml:"Content"`
|
||||
MsgID string `xml:"MsgId"`
|
||||
AgentID string `xml:"AgentID"`
|
||||
ChatID string `xml:"ChatId"`
|
||||
}
|
||||
42
internal/im/wecom/bot_adapter.go
Normal file
42
internal/im/wecom/bot_adapter.go
Normal file
@@ -0,0 +1,42 @@
|
||||
package wecom
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
|
||||
"github.com/Tencent/WeKnora/internal/im"
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
// BotAdapter implements im.Adapter for WeCom intelligent bot (long connection mode).
|
||||
// It delegates SendReply to the WebSocket LongConnClient.
|
||||
// The webhook methods (VerifyCallback, ParseCallback, HandleURLVerification) are no-ops
|
||||
// since messages arrive via WebSocket, not HTTP.
|
||||
type BotAdapter struct {
|
||||
client *LongConnClient
|
||||
}
|
||||
|
||||
// NewBotAdapter creates an adapter backed by a WeCom long connection client.
|
||||
func NewBotAdapter(client *LongConnClient) *BotAdapter {
|
||||
return &BotAdapter{client: client}
|
||||
}
|
||||
|
||||
func (a *BotAdapter) Platform() im.Platform {
|
||||
return im.PlatformWeCom
|
||||
}
|
||||
|
||||
func (a *BotAdapter) VerifyCallback(c *gin.Context) error {
|
||||
return fmt.Errorf("WeCom bot adapter does not support webhook callbacks")
|
||||
}
|
||||
|
||||
func (a *BotAdapter) ParseCallback(c *gin.Context) (*im.IncomingMessage, error) {
|
||||
return nil, fmt.Errorf("WeCom bot adapter does not support webhook callbacks")
|
||||
}
|
||||
|
||||
func (a *BotAdapter) HandleURLVerification(c *gin.Context) bool {
|
||||
return false
|
||||
}
|
||||
|
||||
func (a *BotAdapter) SendReply(ctx context.Context, incoming *im.IncomingMessage, reply *im.ReplyMessage) error {
|
||||
return a.client.SendReply(ctx, incoming, reply)
|
||||
}
|
||||
353
internal/im/wecom/longconn.go
Normal file
353
internal/im/wecom/longconn.go
Normal file
@@ -0,0 +1,353 @@
|
||||
// WeCom Intelligent Bot long connection client.
|
||||
//
|
||||
// Protocol reference: https://developer.work.weixin.qq.com/document/path/101463
|
||||
// Node.js SDK reference: https://github.com/WecomTeam/aibot-node-sdk
|
||||
//
|
||||
// Flow:
|
||||
// 1. Connect to wss://openws.work.weixin.qq.com
|
||||
// 2. Send aibot_subscribe with bot_id + secret
|
||||
// 3. Receive aibot_msg_callback / aibot_event_callback frames
|
||||
// 4. Reply via aibot_respond_msg on the same WebSocket
|
||||
// 5. Heartbeat via ping/pong every 30s
|
||||
package wecom
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"math"
|
||||
"strings"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/Tencent/WeKnora/internal/im"
|
||||
"github.com/Tencent/WeKnora/internal/logger"
|
||||
ws "github.com/gorilla/websocket"
|
||||
)
|
||||
|
||||
const (
|
||||
wecomWSEndpoint = "wss://openws.work.weixin.qq.com"
|
||||
|
||||
cmdSubscribe = "aibot_subscribe"
|
||||
cmdPing = "ping"
|
||||
cmdMsgCallback = "aibot_msg_callback"
|
||||
cmdEventCallback = "aibot_event_callback"
|
||||
cmdResponse = "aibot_respond_msg"
|
||||
|
||||
defaultHeartbeatInterval = 30 * time.Second
|
||||
defaultReconnectBaseDelay = 1 * time.Second
|
||||
defaultReconnectMaxDelay = 30 * time.Second
|
||||
defaultMaxReconnectAttempts = -1 // infinite
|
||||
)
|
||||
|
||||
// wsFrame is the JSON frame exchanged over the WeCom bot WebSocket.
|
||||
type wsFrame struct {
|
||||
Cmd string `json:"cmd,omitempty"`
|
||||
Headers map[string]string `json:"headers,omitempty"`
|
||||
Body json.RawMessage `json:"body,omitempty"`
|
||||
ErrCode int `json:"errcode,omitempty"`
|
||||
ErrMsg string `json:"errmsg,omitempty"`
|
||||
}
|
||||
|
||||
// botMessage is the body of an aibot_msg_callback frame.
|
||||
type botMessage struct {
|
||||
MsgID string `json:"msgid"`
|
||||
AiBotID string `json:"aibotid"`
|
||||
ChatID string `json:"chatid"`
|
||||
ChatType string `json:"chattype"` // "single" or "group"
|
||||
MsgType string `json:"msgtype"` // "text", "image", "event", ...
|
||||
CreateTime int64 `json:"create_time"`
|
||||
From struct {
|
||||
UserID string `json:"userid"`
|
||||
} `json:"from"`
|
||||
Text struct {
|
||||
Content string `json:"content"`
|
||||
} `json:"text"`
|
||||
}
|
||||
|
||||
// streamReplyBody is the body for a streaming text reply.
|
||||
type streamReplyBody struct {
|
||||
MsgType string `json:"msgtype"`
|
||||
Stream struct {
|
||||
ID string `json:"id"`
|
||||
Finish bool `json:"finish"`
|
||||
Content string `json:"content"`
|
||||
} `json:"stream"`
|
||||
}
|
||||
|
||||
// MessageHandler is called when an IM message is received via long connection.
|
||||
type MessageHandler func(ctx context.Context, msg *im.IncomingMessage) error
|
||||
|
||||
// LongConnClient manages a WeCom intelligent bot WebSocket long connection.
|
||||
type LongConnClient struct {
|
||||
botID string
|
||||
secret string
|
||||
handler MessageHandler
|
||||
|
||||
conn *ws.Conn
|
||||
mu sync.Mutex
|
||||
closed atomic.Bool
|
||||
reqSeq atomic.Int64
|
||||
}
|
||||
|
||||
// NewLongConnClient creates a WeCom long connection client.
|
||||
func NewLongConnClient(botID, secret string, handler MessageHandler) *LongConnClient {
|
||||
return &LongConnClient{
|
||||
botID: botID,
|
||||
secret: secret,
|
||||
handler: handler,
|
||||
}
|
||||
}
|
||||
|
||||
// Start connects and runs the long connection loop. It reconnects automatically on failure.
|
||||
func (c *LongConnClient) Start(ctx context.Context) error {
|
||||
logger.Infof(ctx, "[IM] WeCom WebSocket connecting (bot_id=%s)...", c.botID)
|
||||
|
||||
attempts := 0
|
||||
for {
|
||||
if ctx.Err() != nil {
|
||||
return ctx.Err()
|
||||
}
|
||||
|
||||
err := c.connectAndRun(ctx)
|
||||
if c.closed.Load() {
|
||||
return nil
|
||||
}
|
||||
if ctx.Err() != nil {
|
||||
return ctx.Err()
|
||||
}
|
||||
|
||||
attempts++
|
||||
if defaultMaxReconnectAttempts >= 0 && attempts >= defaultMaxReconnectAttempts {
|
||||
return fmt.Errorf("max reconnect attempts reached: %w", err)
|
||||
}
|
||||
|
||||
delay := reconnectDelay(attempts)
|
||||
logger.Warnf(ctx, "[WeCom] Connection lost (%v), reconnecting in %v (attempt %d)...", err, delay, attempts)
|
||||
|
||||
select {
|
||||
case <-time.After(delay):
|
||||
case <-ctx.Done():
|
||||
return ctx.Err()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Stop gracefully closes the connection.
|
||||
func (c *LongConnClient) Stop() {
|
||||
c.closed.Store(true)
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
if c.conn != nil {
|
||||
_ = c.conn.Close()
|
||||
c.conn = nil
|
||||
}
|
||||
}
|
||||
|
||||
// SendReply sends a text reply through the WebSocket connection.
|
||||
// This is used by the IM service to reply to messages in long connection mode.
|
||||
func (c *LongConnClient) SendReply(ctx context.Context, incoming *im.IncomingMessage, reply *im.ReplyMessage) error {
|
||||
reqID, ok := incoming.Extra["req_id"]
|
||||
if !ok || reqID == "" {
|
||||
return fmt.Errorf("missing req_id in incoming message extra")
|
||||
}
|
||||
|
||||
// Generate a unique stream ID for this reply
|
||||
streamID := fmt.Sprintf("stream_%d", c.reqSeq.Add(1))
|
||||
|
||||
body := streamReplyBody{MsgType: "stream"}
|
||||
body.Stream.ID = streamID
|
||||
body.Stream.Finish = true
|
||||
body.Stream.Content = reply.Content
|
||||
|
||||
bodyBytes, err := json.Marshal(body)
|
||||
if err != nil {
|
||||
return fmt.Errorf("marshal reply body: %w", err)
|
||||
}
|
||||
|
||||
frame := wsFrame{
|
||||
Cmd: cmdResponse,
|
||||
Headers: map[string]string{"req_id": reqID},
|
||||
Body: bodyBytes,
|
||||
}
|
||||
|
||||
return c.writeJSON(frame)
|
||||
}
|
||||
|
||||
func (c *LongConnClient) connectAndRun(ctx context.Context) error {
|
||||
conn, _, err := ws.DefaultDialer.DialContext(ctx, wecomWSEndpoint, nil)
|
||||
if err != nil {
|
||||
return fmt.Errorf("dial: %w", err)
|
||||
}
|
||||
|
||||
c.mu.Lock()
|
||||
c.conn = conn
|
||||
c.mu.Unlock()
|
||||
|
||||
defer func() {
|
||||
c.mu.Lock()
|
||||
c.conn = nil
|
||||
c.mu.Unlock()
|
||||
_ = conn.Close()
|
||||
}()
|
||||
|
||||
// Authenticate
|
||||
if err := c.authenticate(ctx); err != nil {
|
||||
return fmt.Errorf("authenticate: %w", err)
|
||||
}
|
||||
|
||||
logger.Infof(ctx, "[IM] WeCom WebSocket connected successfully (bot_id=%s)", c.botID)
|
||||
|
||||
// Start heartbeat
|
||||
heartbeatCtx, heartbeatCancel := context.WithCancel(ctx)
|
||||
defer heartbeatCancel()
|
||||
go c.heartbeatLoop(heartbeatCtx)
|
||||
|
||||
// Message receive loop
|
||||
for {
|
||||
_, message, err := conn.ReadMessage()
|
||||
if err != nil {
|
||||
return fmt.Errorf("read message: %w", err)
|
||||
}
|
||||
|
||||
var frame wsFrame
|
||||
if err := json.Unmarshal(message, &frame); err != nil {
|
||||
logger.Warnf(ctx, "[WeCom] Failed to unmarshal frame: %v", err)
|
||||
continue
|
||||
}
|
||||
|
||||
switch frame.Cmd {
|
||||
case cmdMsgCallback, cmdEventCallback:
|
||||
// Detach from connection ctx so in-flight messages survive reconnects.
|
||||
go c.handleCallback(context.WithoutCancel(ctx), frame)
|
||||
default:
|
||||
// pong or other control frames — ignore
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (c *LongConnClient) authenticate(ctx context.Context) error {
|
||||
authBody, _ := json.Marshal(map[string]string{
|
||||
"bot_id": c.botID,
|
||||
"secret": c.secret,
|
||||
})
|
||||
|
||||
reqID := fmt.Sprintf("%s_%d", cmdSubscribe, time.Now().UnixNano())
|
||||
frame := wsFrame{
|
||||
Cmd: cmdSubscribe,
|
||||
Headers: map[string]string{"req_id": reqID},
|
||||
Body: authBody,
|
||||
}
|
||||
|
||||
if err := c.writeJSON(frame); err != nil {
|
||||
return fmt.Errorf("send subscribe: %w", err)
|
||||
}
|
||||
|
||||
// Read auth response
|
||||
c.mu.Lock()
|
||||
conn := c.conn
|
||||
c.mu.Unlock()
|
||||
if conn == nil {
|
||||
return fmt.Errorf("connection closed")
|
||||
}
|
||||
|
||||
_ = conn.SetReadDeadline(time.Now().Add(10 * time.Second))
|
||||
_, msg, err := conn.ReadMessage()
|
||||
_ = conn.SetReadDeadline(time.Time{}) // clear deadline
|
||||
if err != nil {
|
||||
return fmt.Errorf("read auth response: %w", err)
|
||||
}
|
||||
|
||||
var resp wsFrame
|
||||
if err := json.Unmarshal(msg, &resp); err != nil {
|
||||
return fmt.Errorf("unmarshal auth response: %w", err)
|
||||
}
|
||||
|
||||
if resp.ErrCode != 0 {
|
||||
return fmt.Errorf("auth failed: code=%d msg=%s", resp.ErrCode, resp.ErrMsg)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *LongConnClient) heartbeatLoop(ctx context.Context) {
|
||||
ticker := time.NewTicker(defaultHeartbeatInterval)
|
||||
defer ticker.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case <-ticker.C:
|
||||
reqID := fmt.Sprintf("%s_%d", cmdPing, time.Now().UnixNano())
|
||||
frame := wsFrame{
|
||||
Cmd: cmdPing,
|
||||
Headers: map[string]string{"req_id": reqID},
|
||||
}
|
||||
if err := c.writeJSON(frame); err != nil {
|
||||
logger.Warnf(ctx, "[WeCom] Heartbeat failed: %v", err)
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (c *LongConnClient) handleCallback(ctx context.Context, frame wsFrame) {
|
||||
var msg botMessage
|
||||
if err := json.Unmarshal(frame.Body, &msg); err != nil {
|
||||
logger.Warnf(ctx, "[WeCom] Failed to unmarshal callback body: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
// Only handle text messages for now
|
||||
if msg.MsgType != "text" {
|
||||
logger.Infof(ctx, "[WeCom] Ignoring non-text message type: %s", msg.MsgType)
|
||||
return
|
||||
}
|
||||
|
||||
chatType := im.ChatTypeDirect
|
||||
chatID := ""
|
||||
if msg.ChatType == "group" {
|
||||
chatType = im.ChatTypeGroup
|
||||
chatID = msg.ChatID
|
||||
}
|
||||
|
||||
// Preserve req_id in Extra for reply routing
|
||||
reqID := ""
|
||||
if frame.Headers != nil {
|
||||
reqID = frame.Headers["req_id"]
|
||||
}
|
||||
|
||||
incoming := &im.IncomingMessage{
|
||||
Platform: im.PlatformWeCom,
|
||||
UserID: msg.From.UserID,
|
||||
UserName: msg.From.UserID,
|
||||
ChatID: chatID,
|
||||
ChatType: chatType,
|
||||
Content: strings.TrimSpace(msg.Text.Content),
|
||||
MessageID: msg.MsgID,
|
||||
Extra: map[string]string{"req_id": reqID},
|
||||
}
|
||||
|
||||
if err := c.handler(ctx, incoming); err != nil {
|
||||
logger.Errorf(ctx, "[WeCom] Handle message error: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func (c *LongConnClient) writeJSON(v interface{}) error {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
if c.conn == nil {
|
||||
return fmt.Errorf("connection closed")
|
||||
}
|
||||
return c.conn.WriteJSON(v)
|
||||
}
|
||||
|
||||
func reconnectDelay(attempt int) time.Duration {
|
||||
delay := defaultReconnectBaseDelay * time.Duration(math.Pow(2, float64(attempt-1)))
|
||||
if delay > defaultReconnectMaxDelay {
|
||||
delay = defaultReconnectMaxDelay
|
||||
}
|
||||
return delay
|
||||
}
|
||||
@@ -59,6 +59,7 @@ type RouterParams struct {
|
||||
CustomAgentHandler *handler.CustomAgentHandler
|
||||
SkillHandler *handler.SkillHandler
|
||||
OrganizationHandler *handler.OrganizationHandler
|
||||
IMHandler *handler.IMHandler
|
||||
}
|
||||
|
||||
// NewRouter 创建新的路由
|
||||
@@ -103,6 +104,9 @@ func NewRouter(params RouterParams) *gin.Engine {
|
||||
serveFrontendStatic(r)
|
||||
}
|
||||
|
||||
// IM 回调路由(在认证中间件之前注册,使用各平台自身的签名验证)
|
||||
RegisterIMRoutes(r, params.IMHandler)
|
||||
|
||||
// 认证中间件
|
||||
r.Use(middleware.Auth(params.TenantService, params.UserService, params.Config))
|
||||
|
||||
@@ -576,6 +580,19 @@ func RegisterOrganizationRoutes(r *gin.RouterGroup, orgHandler *handler.Organiza
|
||||
r.POST("/shared-agents/disabled", orgHandler.SetSharedAgentDisabledByMe)
|
||||
}
|
||||
|
||||
// RegisterIMRoutes registers IM callback routes.
|
||||
// These are registered BEFORE auth middleware since IM platforms use their own signature verification.
|
||||
func RegisterIMRoutes(r *gin.Engine, imHandler *handler.IMHandler) {
|
||||
im := r.Group("/api/v1/im")
|
||||
{
|
||||
// WeCom callback (supports both GET for URL verification and POST for message events)
|
||||
im.GET("/callback/wecom", imHandler.WeComCallback)
|
||||
im.POST("/callback/wecom", imHandler.WeComCallback)
|
||||
// Feishu callback (POST for both URL verification challenge and message events)
|
||||
im.POST("/callback/feishu", imHandler.FeishuCallback)
|
||||
}
|
||||
}
|
||||
|
||||
// serveFrontendStatic registers a middleware that serves the frontend SPA
|
||||
// from the ./web directory if it exists. Must be called BEFORE auth middleware
|
||||
// so static files are served without authentication.
|
||||
|
||||
1
migrations/versioned/000021_im_channel_sessions.down.sql
Normal file
1
migrations/versioned/000021_im_channel_sessions.down.sql
Normal file
@@ -0,0 +1 @@
|
||||
DROP TABLE IF EXISTS im_channel_sessions;
|
||||
44
migrations/versioned/000021_im_channel_sessions.up.sql
Normal file
44
migrations/versioned/000021_im_channel_sessions.up.sql
Normal file
@@ -0,0 +1,44 @@
|
||||
-- Migration: 000021_im_channel_sessions
|
||||
-- Description: Create IM channel-to-session mapping table
|
||||
DO $$ BEGIN RAISE NOTICE '[Migration 000021] Creating table: im_channel_sessions'; END $$;
|
||||
|
||||
CREATE TABLE IF NOT EXISTS im_channel_sessions (
|
||||
id VARCHAR(36) PRIMARY KEY DEFAULT uuid_generate_v4(),
|
||||
platform VARCHAR(20) NOT NULL,
|
||||
user_id VARCHAR(128) NOT NULL,
|
||||
chat_id VARCHAR(128) NOT NULL DEFAULT '',
|
||||
session_id VARCHAR(36) NOT NULL REFERENCES sessions(id) ON DELETE CASCADE,
|
||||
tenant_id BIGINT NOT NULL,
|
||||
agent_id VARCHAR(36) DEFAULT '',
|
||||
status VARCHAR(20) NOT NULL DEFAULT 'active',
|
||||
metadata JSONB DEFAULT '{}',
|
||||
created_at TIMESTAMP WITH TIME ZONE NOT NULL DEFAULT CURRENT_TIMESTAMP,
|
||||
updated_at TIMESTAMP WITH TIME ZONE NOT NULL DEFAULT CURRENT_TIMESTAMP,
|
||||
deleted_at TIMESTAMP WITH TIME ZONE
|
||||
);
|
||||
|
||||
-- Partial unique index: only enforce uniqueness for non-deleted rows
|
||||
CREATE UNIQUE INDEX IF NOT EXISTS idx_channel_lookup
|
||||
ON im_channel_sessions (platform, user_id, chat_id, tenant_id)
|
||||
WHERE deleted_at IS NULL;
|
||||
|
||||
-- Index for tenant-based queries
|
||||
CREATE INDEX IF NOT EXISTS idx_im_channel_tenant ON im_channel_sessions (tenant_id);
|
||||
|
||||
-- Index for session-based queries
|
||||
CREATE INDEX IF NOT EXISTS idx_im_channel_session ON im_channel_sessions (session_id);
|
||||
|
||||
-- Partial index for soft deletes (only index deleted rows)
|
||||
CREATE INDEX IF NOT EXISTS idx_im_channel_deleted ON im_channel_sessions (deleted_at) WHERE deleted_at IS NOT NULL;
|
||||
|
||||
COMMENT ON TABLE im_channel_sessions IS 'Maps IM platform channels to WeKnora conversation sessions';
|
||||
COMMENT ON COLUMN im_channel_sessions.platform IS 'IM platform identifier: wecom, feishu, etc.';
|
||||
COMMENT ON COLUMN im_channel_sessions.user_id IS 'Platform-specific user identifier';
|
||||
COMMENT ON COLUMN im_channel_sessions.chat_id IS 'Platform-specific chat/group identifier, empty for direct messages';
|
||||
COMMENT ON COLUMN im_channel_sessions.session_id IS 'Associated WeKnora session ID';
|
||||
COMMENT ON COLUMN im_channel_sessions.tenant_id IS 'Tenant that owns this channel mapping';
|
||||
COMMENT ON COLUMN im_channel_sessions.agent_id IS 'Custom agent ID used for this channel, empty for default';
|
||||
COMMENT ON COLUMN im_channel_sessions.status IS 'Channel status: active, paused, expired';
|
||||
COMMENT ON COLUMN im_channel_sessions.metadata IS 'Platform-specific extra data (JSON)';
|
||||
|
||||
DO $$ BEGIN RAISE NOTICE '[Migration 000021] im_channel_sessions setup completed successfully!'; END $$;
|
||||
Reference in New Issue
Block a user