feat: add WeCom and Feishu IM bot integration

- support webhook and websocket modes for both platforms
- add im_channel_sessions migration for channel-session mapping
- register IM adapters and callback routes
- update config and docker-compose for IM env vars
This commit is contained in:
nullkey
2026-03-15 18:45:53 +08:00
committed by lyingbug
parent 0e8f1b1c81
commit 9fa969fb5c
18 changed files with 2239 additions and 4 deletions

View File

@@ -206,7 +206,12 @@ conversation:
- If the user asks in English, respond in English
- If the user asks in Chinese, respond in Chinese
context_template: "{{query}}"
context_template: |
The following is retrieved information:
{{contexts}}
Based on the above retrieved information, answer the following question:
{{query}}
extract_entities_prompt: |
## Task
Extract all entities from the user-provided text that match the following entity types:
@@ -536,3 +541,39 @@ extract:
tenant:
# Enable cross-tenant access (can be enabled for intranet environments)
enable_cross_tenant_access: false
# IM integration configuration (optional)
# Uncomment and configure to enable WeCom/Feishu bot integration
#
# mode:
# "webhook" — (default) platform pushes events to your callback URL, requires public domain
# "websocket" — long connection, no public domain needed, SDK maintains persistent connection
im:
wecom:
enabled: true
tenant_id: 10000 # WeKnora tenant ID to use
agent_id: "" # Default agent ID (optional)
knowledge_base_ids:
- "" # Default knowledge bases (optional)
mode: "websocket" # "webhook" or "websocket"
# --- websocket mode (智能机器人长连接) ---
bot_id: "${WECOM_BOT_ID}" # 智能机器人 Bot ID
bot_secret: "${WECOM_BOT_SECRET}" # 智能机器人 Secret
# --- webhook mode (自建应用回调) ---
# corp_id: "${WECOM_CORP_ID}" # 企业微信 Corp ID
# agent_secret: "${WECOM_AGENT_SECRET}" # 应用 Secret
# token: "${WECOM_TOKEN}" # 回调 Token
# encoding_aes_key: "${WECOM_AES_KEY}" # 回调 EncodingAESKey
# corp_agent_id: 1000001 # 应用 AgentID
feishu:
enabled: true
tenant_id: 10000 # WeKnora tenant ID to use
agent_id: "" # Default agent ID (optional)
knowledge_base_ids: # Default knowledge bases
- ""
mode: "websocket" # "webhook" or "websocket"
app_id: "${FEISHU_APP_ID}" # 飞书 App ID
app_secret: "${FEISHU_APP_SECRET}" # 飞书 App Secret
# --- 以下仅 webhook 模式需要websocket 模式可留空 ---
# verification_token: "${FEISHU_TOKEN}" # 事件订阅 Verification Token (webhook only)
# encrypt_key: "${FEISHU_ENCRYPT_KEY}" # 事件订阅 Encrypt Key (webhook only)

View File

@@ -38,8 +38,8 @@ services:
volumes:
- data-files:/data/files
- docreader-tmp:/tmp/docreader:ro
# Optional: mount custom config file
# - ./config/config.yaml:/app/config/config.yaml
# Mount custom config file (required for IM integration)
- ./config/config.yaml:/app/config/config.yaml
# Optional: mount custom skills directory (allows adding skills without rebuilding image)
- ./skills/preloaded:/app/skills/preloaded
healthcheck:
@@ -114,6 +114,11 @@ services:
- SYSTEM_AES_KEY=${SYSTEM_AES_KEY:-}
- CONCURRENCY_POOL_SIZE=${CONCURRENCY_POOL_SIZE:-5}
- JWT_SECRET=${JWT_SECRET:-}
# IM integration
- FEISHU_APP_ID=${FEISHU_APP_ID:-}
- FEISHU_APP_SECRET=${FEISHU_APP_SECRET:-}
- WECOM_BOT_ID=${WECOM_BOT_ID:-}
- WECOM_BOT_SECRET=${WECOM_BOT_SECRET:-}
# File size limit (in MB)
- MAX_FILE_SIZE_MB=${MAX_FILE_SIZE_MB:-50}
# Agent Skills Sandbox

3
go.mod
View File

@@ -22,7 +22,9 @@ require (
github.com/golang-migrate/migrate/v4 v4.19.1
github.com/google/jsonschema-go v0.4.2
github.com/google/uuid v1.6.0
github.com/gorilla/websocket v1.5.0
github.com/hibiken/asynq v0.25.1
github.com/larksuite/oapi-sdk-go/v3 v3.5.3
github.com/mark3labs/mcp-go v0.43.0
github.com/milvus-io/milvus/client/v2 v2.6.2
github.com/minio/minio-go/v7 v7.0.91
@@ -167,7 +169,6 @@ require (
github.com/google/s2a-go v0.1.9 // indirect
github.com/googleapis/enterprise-certificate-proxy v0.3.7 // indirect
github.com/googleapis/gax-go/v2 v2.16.0 // indirect
github.com/gorilla/websocket v1.5.0 // indirect
github.com/grpc-ecosystem/go-grpc-middleware v1.4.0 // indirect
github.com/grpc-ecosystem/go-grpc-prometheus v1.2.0 // indirect
github.com/grpc-ecosystem/grpc-gateway v1.16.0 // indirect

2
go.sum
View File

@@ -2085,6 +2085,8 @@ github.com/kylelemons/godebug v1.1.0 h1:RPNrshWIDI6G2gRW9EHilWtl7Z6Sb1BR0xunSBf0
github.com/kylelemons/godebug v1.1.0/go.mod h1:9/0rRGxNHcop5bhtWyNeEfOS8JIWk580+fNqagV/RAw=
github.com/labstack/echo/v4 v4.5.0/go.mod h1:czIriw4a0C1dFun+ObrXp7ok03xON0N1awStJ6ArI7Y=
github.com/labstack/gommon v0.3.0/go.mod h1:MULnywXg0yavhxWKc+lOruYdAhDwPK9wf0OL7NoOu+k=
github.com/larksuite/oapi-sdk-go/v3 v3.5.3 h1:xvf8Dv29kBXC5/DNDCLhHkAFW8l/0LlQJimO5Zn+JUk=
github.com/larksuite/oapi-sdk-go/v3 v3.5.3/go.mod h1:ZEplY+kwuIrj/nqw5uSCINNATcH3KdxSN7y+UxYY5fI=
github.com/ledongthuc/pdf v0.0.0-20220302134840-0c2507a12d80 h1:6Yzfa6GP0rIo/kULo2bwGEkFvCePZ3qHDDTC3/J9Swo=
github.com/ledongthuc/pdf v0.0.0-20220302134840-0c2507a12d80/go.mod h1:imJHygn/1yfhB7XSJJKlFZKl/J+dCPAknuiaGOshXAs=
github.com/leodido/go-urn v1.2.0/go.mod h1:+8+nEpDfqqsY+g338gtMEUOtuK+4dEMhiQEgxpxOKII=

View File

@@ -27,6 +27,46 @@ type Config struct {
ExtractManager *ExtractManagerConfig `yaml:"extract" json:"extract"`
WebSearch *WebSearchConfig `yaml:"web_search" json:"web_search"`
PromptTemplates *PromptTemplatesConfig `yaml:"prompt_templates" json:"prompt_templates"`
IM *IMConfig `yaml:"im" json:"im"`
}
// IMConfig IM 集成配置
type IMConfig struct {
WeCom *WeComIMConfig `yaml:"wecom" json:"wecom"`
Feishu *FeishuIMConfig `yaml:"feishu" json:"feishu"`
}
// WeComIMConfig 企业微信配置
type WeComIMConfig struct {
Enabled bool `yaml:"enabled" json:"enabled"`
TenantID uint64 `yaml:"tenant_id" json:"tenant_id"`
AgentID string `yaml:"agent_id" json:"agent_id"`
KnowledgeBases []string `yaml:"knowledge_base_ids" json:"knowledge_base_ids"`
// Mode: "webhook" (default, requires public domain) or "websocket" (long connection via intelligent bot, no public domain needed)
Mode string `yaml:"mode" json:"mode"`
// --- Webhook mode fields (self-built app callback) ---
CorpID string `yaml:"corp_id" json:"corp_id"`
AgentSecret string `yaml:"agent_secret" json:"agent_secret"`
Token string `yaml:"token" json:"token"`
EncodingAESKey string `yaml:"encoding_aes_key" json:"encoding_aes_key"`
CorpAgentID int `yaml:"corp_agent_id" json:"corp_agent_id"`
// --- WebSocket mode fields (intelligent bot long connection) ---
BotID string `yaml:"bot_id" json:"bot_id"`
BotSecret string `yaml:"bot_secret" json:"bot_secret"`
}
// FeishuIMConfig 飞书配置
type FeishuIMConfig struct {
Enabled bool `yaml:"enabled" json:"enabled"`
TenantID uint64 `yaml:"tenant_id" json:"tenant_id"`
AgentID string `yaml:"agent_id" json:"agent_id"`
KnowledgeBases []string `yaml:"knowledge_base_ids" json:"knowledge_base_ids"`
AppID string `yaml:"app_id" json:"app_id"`
AppSecret string `yaml:"app_secret" json:"app_secret"`
VerificationToken string `yaml:"verification_token" json:"verification_token"`
EncryptKey string `yaml:"encrypt_key" json:"encrypt_key"`
// Mode: "websocket" (default, long connection, no public domain needed) or "webhook" (requires public domain)
Mode string `yaml:"mode" json:"mode"`
}
// DocReaderConfig configures the document parser client (gRPC or HTTP).

View File

@@ -52,6 +52,9 @@ import (
"github.com/Tencent/WeKnora/internal/event"
"github.com/Tencent/WeKnora/internal/handler"
"github.com/Tencent/WeKnora/internal/handler/session"
imPkg "github.com/Tencent/WeKnora/internal/im"
"github.com/Tencent/WeKnora/internal/im/feishu"
"github.com/Tencent/WeKnora/internal/im/wecom"
"github.com/Tencent/WeKnora/internal/infrastructure/docparser"
"github.com/Tencent/WeKnora/internal/logger"
"github.com/Tencent/WeKnora/internal/mcp"
@@ -235,6 +238,12 @@ func BuildContainer(container *dig.Container) *dig.Container {
must(container.Provide(service.NewSkillService))
must(container.Provide(handler.NewSkillHandler))
must(container.Provide(handler.NewOrganizationHandler))
// IM integration
logger.Debugf(ctx, "[Container] Registering IM integration...")
must(container.Provide(imPkg.NewService))
must(container.Invoke(registerIMAdapters))
must(container.Provide(handler.NewIMHandler))
logger.Debugf(ctx, "[Container] HTTP handlers registered")
// Router configuration
@@ -937,3 +946,115 @@ func registerWebSearchProviders(registry *web_search.Registry) {
return web_search.NewBingProvider()
})
}
// registerIMAdapters registers IM platform adapters based on configuration.
// For "websocket" mode, it also starts a long connection client in a goroutine.
func registerIMAdapters(cfg *config.Config, imService *imPkg.Service) {
if cfg.IM == nil {
logger.Infof(context.Background(), "[IM] No IM configuration found, skipping adapter registration")
return
}
ctx := context.Background()
// Register WeCom
if cfg.IM.WeCom != nil && cfg.IM.WeCom.Enabled {
registerWeComAdapter(ctx, cfg.IM.WeCom, imService)
}
// Register Feishu
if cfg.IM.Feishu != nil && cfg.IM.Feishu.Enabled {
registerFeishuAdapter(ctx, cfg.IM.Feishu, imService)
}
}
func registerWeComAdapter(ctx context.Context, cfg *config.WeComIMConfig, imService *imPkg.Service) {
mode := cfg.Mode
if mode == "" {
mode = "websocket"
}
switch mode {
case "webhook":
adapter, err := wecom.NewAdapter(
cfg.CorpID,
cfg.AgentSecret,
cfg.Token,
cfg.EncodingAESKey,
cfg.CorpAgentID,
)
if err != nil {
logger.Warnf(ctx, "[IM] Failed to create WeCom webhook adapter: %v", err)
return
}
imService.RegisterAdapter(adapter)
logger.Infof(ctx, "[IM] WeCom adapter registered (mode=webhook, corp_id=%s)", cfg.CorpID)
case "websocket":
// Build the message handler that delegates to imService.HandleMessage
handler := func(msgCtx context.Context, msg *imPkg.IncomingMessage) error {
return imService.HandleMessage(msgCtx, msg, cfg.TenantID, cfg.AgentID, cfg.KnowledgeBases)
}
client := wecom.NewLongConnClient(cfg.BotID, cfg.BotSecret, handler)
// Register a BotAdapter so the service can send replies via WebSocket
imService.RegisterAdapter(wecom.NewBotAdapter(client))
logger.Infof(ctx, "[IM] WeCom adapter registered (mode=websocket, bot_id=%s)", cfg.BotID)
// Start the long connection in a goroutine
go func() {
if err := client.Start(context.Background()); err != nil {
logger.Errorf(context.Background(), "[IM] WeCom long connection stopped: %v", err)
}
}()
default:
logger.Warnf(ctx, "[IM] Unknown WeCom mode: %s (expected 'webhook' or 'websocket')", mode)
}
}
func registerFeishuAdapter(ctx context.Context, cfg *config.FeishuIMConfig, imService *imPkg.Service) {
mode := cfg.Mode
if mode == "" {
mode = "websocket"
}
// Always register the HTTP adapter (needed for SendReply in both modes)
adapter := feishu.NewAdapter(
cfg.AppID,
cfg.AppSecret,
cfg.VerificationToken,
cfg.EncryptKey,
)
imService.RegisterAdapter(adapter)
switch mode {
case "webhook":
logger.Infof(ctx, "[IM] Feishu adapter registered (mode=webhook, app_id=%s)", cfg.AppID)
case "websocket":
logger.Infof(ctx, "[IM] Feishu adapter registered (mode=websocket, app_id=%s)", cfg.AppID)
// Build the message handler
handler := func(msgCtx context.Context, msg *imPkg.IncomingMessage) error {
return imService.HandleMessage(msgCtx, msg, cfg.TenantID, cfg.AgentID, cfg.KnowledgeBases)
}
client := feishu.NewLongConnClient(
cfg.AppID,
cfg.AppSecret,
handler,
)
// Start the long connection in a goroutine
go func() {
if err := client.Start(context.Background()); err != nil {
logger.Errorf(context.Background(), "[IM] Feishu long connection stopped: %v", err)
}
}()
default:
logger.Warnf(ctx, "[IM] Unknown Feishu mode: %s (expected 'webhook' or 'websocket')", mode)
}
}

132
internal/handler/im.go Normal file
View File

@@ -0,0 +1,132 @@
package handler
import (
"context"
"net/http"
"github.com/Tencent/WeKnora/internal/config"
"github.com/Tencent/WeKnora/internal/im"
"github.com/Tencent/WeKnora/internal/logger"
"github.com/gin-gonic/gin"
)
// IMHandler handles IM platform callback requests.
type IMHandler struct {
imService *im.Service
config *config.Config
}
// NewIMHandler creates a new IM handler.
func NewIMHandler(imService *im.Service, cfg *config.Config) *IMHandler {
return &IMHandler{
imService: imService,
config: cfg,
}
}
// WeComCallback handles WeCom callback requests (both URL verification and message events).
func (h *IMHandler) WeComCallback(c *gin.Context) {
ctx := c.Request.Context()
adapter, ok := h.imService.GetAdapter(im.PlatformWeCom)
if !ok {
logger.Error(ctx, "[IM] WeCom adapter not registered")
c.JSON(http.StatusServiceUnavailable, gin.H{"error": "WeCom integration not enabled"})
return
}
// Handle URL verification (GET request)
if adapter.HandleURLVerification(c) {
return
}
// Verify callback signature
if err := adapter.VerifyCallback(c); err != nil {
logger.Errorf(ctx, "[IM] WeCom callback verification failed: %v", err)
c.JSON(http.StatusForbidden, gin.H{"error": "verification failed"})
return
}
// Parse the callback message
msg, err := adapter.ParseCallback(c)
if err != nil {
logger.Errorf(ctx, "[IM] WeCom parse callback failed: %v", err)
c.JSON(http.StatusBadRequest, gin.H{"error": "parse failed"})
return
}
// If nil, it's a non-message event (e.g., system event) - just acknowledge
if msg == nil {
c.JSON(http.StatusOK, gin.H{"success": true})
return
}
// Respond immediately to avoid WeCom timeout, process asynchronously
c.JSON(http.StatusOK, gin.H{"success": true})
// Get config for this platform
wecomCfg := h.config.IM.WeCom
// Detach from gin request context to prevent cancellation after HTTP response.
asyncCtx := context.WithoutCancel(ctx)
// Process message asynchronously
go func() {
if err := h.imService.HandleMessage(asyncCtx, msg, wecomCfg.TenantID, wecomCfg.AgentID, wecomCfg.KnowledgeBases); err != nil {
logger.Errorf(asyncCtx, "[IM] WeCom handle message error: %v", err)
}
}()
}
// FeishuCallback handles Feishu callback requests (both URL verification and message events).
func (h *IMHandler) FeishuCallback(c *gin.Context) {
ctx := c.Request.Context()
adapter, ok := h.imService.GetAdapter(im.PlatformFeishu)
if !ok {
logger.Error(ctx, "[IM] Feishu adapter not registered")
c.JSON(http.StatusServiceUnavailable, gin.H{"error": "Feishu integration not enabled"})
return
}
// Handle URL verification (challenge)
if adapter.HandleURLVerification(c) {
return
}
// Verify callback
if err := adapter.VerifyCallback(c); err != nil {
logger.Errorf(ctx, "[IM] Feishu callback verification failed: %v", err)
c.JSON(http.StatusForbidden, gin.H{"error": "verification failed"})
return
}
// Parse the callback message
msg, err := adapter.ParseCallback(c)
if err != nil {
logger.Errorf(ctx, "[IM] Feishu parse callback failed: %v", err)
c.JSON(http.StatusBadRequest, gin.H{"error": "parse failed"})
return
}
if msg == nil {
c.JSON(http.StatusOK, gin.H{"success": true})
return
}
// Respond immediately
c.JSON(http.StatusOK, gin.H{"success": true})
// Get config
feishuCfg := h.config.IM.Feishu
// Detach from gin request context to prevent cancellation after HTTP response.
asyncCtx := context.WithoutCancel(ctx)
// Process asynchronously
go func() {
if err := h.imService.HandleMessage(asyncCtx, msg, feishuCfg.TenantID, feishuCfg.AgentID, feishuCfg.KnowledgeBases); err != nil {
logger.Errorf(asyncCtx, "[IM] Feishu handle message error: %v", err)
}
}()
}

76
internal/im/adapter.go Normal file
View File

@@ -0,0 +1,76 @@
package im
import (
"context"
"github.com/gin-gonic/gin"
)
// Platform identifies an IM platform.
type Platform string
const (
PlatformWeCom Platform = "wecom"
PlatformFeishu Platform = "feishu"
)
// IncomingMessage is the unified message parsed from an IM callback.
type IncomingMessage struct {
// Platform identifies which IM platform the message comes from.
Platform Platform
// UserID is the IM-platform user identifier.
UserID string
// UserName is the display name of the user (optional).
UserName string
// ChatID is the group/channel ID (empty for direct messages).
ChatID string
// ChatType distinguishes direct message from group chat.
ChatType ChatType
// Content is the text content of the message.
Content string
// MessageID is the IM-platform message identifier (for dedup).
MessageID string
// Extra holds platform-specific fields (e.g., WeCom stream ID).
Extra map[string]string
}
// ChatType represents the IM chat type.
type ChatType string
const (
ChatTypeDirect ChatType = "direct"
ChatTypeGroup ChatType = "group"
)
// ReplyMessage is what WeKnora sends back to the IM platform.
type ReplyMessage struct {
// Content is the text content (Markdown).
Content string
// IsStreaming indicates whether this is a streaming chunk.
IsStreaming bool
// IsFinal marks the last chunk of a streaming reply.
IsFinal bool
// Extra holds platform-specific fields.
Extra map[string]string
}
// Adapter is the interface every IM platform must implement.
type Adapter interface {
// Platform returns the platform identifier.
Platform() Platform
// VerifyCallback verifies the signature/token of an incoming callback request.
// Returns nil if verification passes.
VerifyCallback(c *gin.Context) error
// ParseCallback parses the raw IM callback request into a unified IncomingMessage.
// Returns nil message for non-message events (e.g., URL verification).
ParseCallback(c *gin.Context) (*IncomingMessage, error)
// SendReply sends a reply back to the IM platform.
SendReply(ctx context.Context, incoming *IncomingMessage, reply *ReplyMessage) error
// HandleURLVerification handles the initial URL verification challenge from the IM platform.
// Returns true if this request is a verification request and has been handled.
HandleURLVerification(c *gin.Context) bool
}

View File

@@ -0,0 +1,424 @@
// Package feishu implements the Feishu (飞书/Lark) IM adapter for WeKnora.
//
// Feishu bot flow:
// 1. User sends a message to the bot (direct or @mention in group)
// 2. Feishu calls our event subscription URL with the message event
// 3. We parse the event, run QA, then call Feishu API to send reply
// 4. For streaming: create a card, then use CardKit streaming update API
//
// Reference: https://open.feishu.cn/document/server-docs/im-v1/message/create
package feishu
import (
"bytes"
"context"
"crypto/aes"
"crypto/cipher"
"crypto/sha256"
"encoding/base64"
"encoding/json"
"fmt"
"io"
"net/http"
"strings"
"sync"
"time"
"github.com/Tencent/WeKnora/internal/im"
"github.com/Tencent/WeKnora/internal/logger"
"github.com/gin-gonic/gin"
)
var httpClient = &http.Client{Timeout: 10 * time.Second}
// Adapter implements im.Adapter for Feishu/Lark.
type Adapter struct {
appID string
appSecret string
verificationToken string
encryptKey string
// Token cache
tokenMu sync.Mutex
tokenCache string
tokenExpAt time.Time
}
// NewAdapter creates a new Feishu adapter.
func NewAdapter(appID, appSecret, verificationToken, encryptKey string) *Adapter {
return &Adapter{
appID: appID,
appSecret: appSecret,
verificationToken: verificationToken,
encryptKey: encryptKey,
}
}
// Platform returns the platform identifier.
func (a *Adapter) Platform() im.Platform {
return im.PlatformFeishu
}
// VerifyCallback verifies the Feishu event callback by checking the verification token.
// If no verification token is configured (e.g., WebSocket mode), skip verification.
func (a *Adapter) VerifyCallback(c *gin.Context) error {
if a.verificationToken == "" {
return nil
}
bodyBytes, err := io.ReadAll(c.Request.Body)
if err != nil {
return fmt.Errorf("read body: %w", err)
}
// Always restore body for subsequent reads (ParseCallback)
defer func() { c.Request.Body = io.NopCloser(bytes.NewReader(bodyBytes)) }()
var raw []byte
// Handle encrypted events
var encryptedBody struct {
Encrypt string `json:"encrypt"`
}
if err := json.Unmarshal(bodyBytes, &encryptedBody); err == nil && encryptedBody.Encrypt != "" {
decrypted, err := a.decrypt(encryptedBody.Encrypt)
if err != nil {
return fmt.Errorf("decrypt event for verification: %w", err)
}
raw = decrypted
} else {
raw = bodyBytes
}
var eventBody struct {
Header *feishuEventHeader `json:"header"`
}
if err := json.Unmarshal(raw, &eventBody); err != nil {
return fmt.Errorf("unmarshal event header: %w", err)
}
if eventBody.Header == nil || eventBody.Header.Token != a.verificationToken {
return fmt.Errorf("invalid verification token")
}
return nil
}
// HandleURLVerification handles the Feishu URL verification challenge.
func (a *Adapter) HandleURLVerification(c *gin.Context) bool {
bodyBytes, err := io.ReadAll(c.Request.Body)
if err != nil {
return false
}
c.Request.Body = io.NopCloser(bytes.NewReader(bodyBytes))
// Try to parse as a challenge request
var body map[string]interface{}
// If encrypted, try to decrypt first
var encryptedBody struct {
Encrypt string `json:"encrypt"`
}
if err := json.Unmarshal(bodyBytes, &encryptedBody); err == nil && encryptedBody.Encrypt != "" {
decrypted, err := a.decrypt(encryptedBody.Encrypt)
if err != nil {
logger.Errorf(c.Request.Context(), "[Feishu] Failed to decrypt: %v", err)
return false
}
if err := json.Unmarshal(decrypted, &body); err != nil {
return false
}
} else {
if err := json.Unmarshal(bodyBytes, &body); err != nil {
return false
}
}
// Check if this is a URL verification challenge
if challenge, ok := body["challenge"].(string); ok {
c.JSON(http.StatusOK, gin.H{"challenge": challenge})
return true
}
// Reset body for subsequent reads
c.Request.Body = io.NopCloser(bytes.NewReader(bodyBytes))
return false
}
// feishuEventBody is the typed structure of a Feishu event callback.
type feishuEventBody struct {
Header *feishuEventHeader `json:"header"`
Event *feishuEvent `json:"event"`
}
type feishuEventHeader struct {
EventType string `json:"event_type"`
Token string `json:"token"`
}
type feishuEvent struct {
Message *feishuMessage `json:"message"`
Sender *feishuSender `json:"sender"`
}
type feishuMessage struct {
MessageID string `json:"message_id"`
MessageType string `json:"message_type"`
ChatType string `json:"chat_type"`
ChatID string `json:"chat_id"`
Content string `json:"content"`
}
type feishuSender struct {
SenderID *feishuSenderID `json:"sender_id"`
}
type feishuSenderID struct {
OpenID string `json:"open_id"`
}
// ParseCallback parses a Feishu event callback into a unified IncomingMessage.
func (a *Adapter) ParseCallback(c *gin.Context) (*im.IncomingMessage, error) {
bodyBytes, err := io.ReadAll(c.Request.Body)
if err != nil {
return nil, fmt.Errorf("read body: %w", err)
}
var raw []byte
// Handle encrypted events
var encryptedBody struct {
Encrypt string `json:"encrypt"`
}
if err := json.Unmarshal(bodyBytes, &encryptedBody); err == nil && encryptedBody.Encrypt != "" {
decrypted, err := a.decrypt(encryptedBody.Encrypt)
if err != nil {
return nil, fmt.Errorf("decrypt event: %w", err)
}
raw = decrypted
} else {
raw = bodyBytes
}
var eventBody feishuEventBody
if err := json.Unmarshal(raw, &eventBody); err != nil {
return nil, fmt.Errorf("unmarshal event: %w", err)
}
// Token verification is handled by VerifyCallback; no need to re-check here.
// Check event type
if eventBody.Header == nil || eventBody.Header.EventType != "im.message.receive_v1" {
if eventBody.Header != nil {
logger.Infof(c.Request.Context(), "[Feishu] Ignoring event type: %s", eventBody.Header.EventType)
}
return nil, nil
}
// Extract message info
if eventBody.Event == nil || eventBody.Event.Message == nil {
return nil, nil
}
msg := eventBody.Event.Message
if msg.MessageType != "text" {
logger.Infof(c.Request.Context(), "[Feishu] Ignoring non-text message type: %s", msg.MessageType)
return nil, nil
}
// Parse text content
var textContent struct {
Text string `json:"text"`
}
if err := json.Unmarshal([]byte(msg.Content), &textContent); err != nil {
return nil, fmt.Errorf("unmarshal text content: %w", err)
}
// Determine chat type
chatType := im.ChatTypeDirect
chatID := ""
if msg.ChatType == "group" {
chatType = im.ChatTypeGroup
chatID = msg.ChatID
}
// Get sender info
openID := ""
if eventBody.Event.Sender != nil && eventBody.Event.Sender.SenderID != nil {
openID = eventBody.Event.Sender.SenderID.OpenID
}
// Strip @bot mention from group messages
content := textContent.Text
if chatType == im.ChatTypeGroup {
// Feishu @mentions are in the format @_user_xxx
for strings.HasPrefix(content, "@_user_") {
idx := strings.Index(content, " ")
if idx >= 0 {
content = content[idx+1:]
} else {
break
}
}
}
return &im.IncomingMessage{
Platform: im.PlatformFeishu,
UserID: openID,
ChatID: chatID,
ChatType: chatType,
Content: strings.TrimSpace(content),
MessageID: msg.MessageID,
}, nil
}
// SendReply sends a reply message via Feishu API.
func (a *Adapter) SendReply(ctx context.Context, incoming *im.IncomingMessage, reply *im.ReplyMessage) error {
accessToken, err := a.getTenantAccessToken(ctx)
if err != nil {
return fmt.Errorf("get access token: %w", err)
}
// Determine receive_id_type and receive_id
receiveIDType := "open_id"
receiveID := incoming.UserID
if incoming.ChatType == im.ChatTypeGroup && incoming.ChatID != "" {
receiveIDType = "chat_id"
receiveID = incoming.ChatID
}
// Build text message
content, _ := json.Marshal(map[string]string{"text": reply.Content})
payload := map[string]interface{}{
"receive_id": receiveID,
"msg_type": "text",
"content": string(content),
}
payloadBytes, _ := json.Marshal(payload)
url := fmt.Sprintf("https://open.feishu.cn/open-apis/im/v1/messages?receive_id_type=%s", receiveIDType)
req, err := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewReader(payloadBytes))
if err != nil {
return fmt.Errorf("create request: %w", err)
}
req.Header.Set("Content-Type", "application/json; charset=utf-8")
req.Header.Set("Authorization", "Bearer "+accessToken)
resp, err := httpClient.Do(req)
if err != nil {
return fmt.Errorf("send message: %w", err)
}
defer resp.Body.Close()
var result struct {
Code int `json:"code"`
Msg string `json:"msg"`
}
if err := json.NewDecoder(resp.Body).Decode(&result); err != nil {
return fmt.Errorf("decode response: %w", err)
}
if result.Code != 0 {
return fmt.Errorf("feishu api error: code=%d msg=%s", result.Code, result.Msg)
}
return nil
}
// getTenantAccessToken retrieves the Feishu tenant access token with caching.
// Feishu tokens expire in 2 hours; we cache with a safety margin.
func (a *Adapter) getTenantAccessToken(ctx context.Context) (string, error) {
a.tokenMu.Lock()
defer a.tokenMu.Unlock()
if a.tokenCache != "" && time.Now().Before(a.tokenExpAt) {
return a.tokenCache, nil
}
payload, _ := json.Marshal(map[string]string{
"app_id": a.appID,
"app_secret": a.appSecret,
})
req, err := http.NewRequestWithContext(ctx, http.MethodPost,
"https://open.feishu.cn/open-apis/auth/v3/tenant_access_token/internal",
bytes.NewReader(payload))
if err != nil {
return "", fmt.Errorf("create request: %w", err)
}
req.Header.Set("Content-Type", "application/json; charset=utf-8")
resp, err := httpClient.Do(req)
if err != nil {
return "", fmt.Errorf("request token: %w", err)
}
defer resp.Body.Close()
var result struct {
Code int `json:"code"`
Msg string `json:"msg"`
TenantAccessToken string `json:"tenant_access_token"`
Expire int `json:"expire"` // seconds
}
if err := json.NewDecoder(resp.Body).Decode(&result); err != nil {
return "", fmt.Errorf("decode response: %w", err)
}
if result.Code != 0 {
return "", fmt.Errorf("get token error: code=%d msg=%s", result.Code, result.Msg)
}
a.tokenCache = result.TenantAccessToken
// Cache with 5-minute safety margin
ttl := time.Duration(result.Expire) * time.Second
if ttl > 5*time.Minute {
ttl -= 5 * time.Minute
}
a.tokenExpAt = time.Now().Add(ttl)
return a.tokenCache, nil
}
// decrypt decrypts a Feishu encrypted event body.
// Feishu uses AES-256-CBC with SHA-256 of the encrypt key as the AES key.
func (a *Adapter) decrypt(encrypted string) ([]byte, error) {
if a.encryptKey == "" {
return nil, fmt.Errorf("encrypt_key not configured")
}
ciphertext, err := base64.StdEncoding.DecodeString(encrypted)
if err != nil {
return nil, fmt.Errorf("base64 decode: %w", err)
}
// SHA-256 of encrypt key as AES key
keyHash := sha256.Sum256([]byte(a.encryptKey))
block, err := aes.NewCipher(keyHash[:])
if err != nil {
return nil, fmt.Errorf("new cipher: %w", err)
}
if len(ciphertext) < aes.BlockSize {
return nil, fmt.Errorf("ciphertext too short")
}
iv := ciphertext[:aes.BlockSize]
ciphertext = ciphertext[aes.BlockSize:]
mode := cipher.NewCBCDecrypter(block, iv)
mode.CryptBlocks(ciphertext, ciphertext)
// Remove and verify PKCS#7 padding
if len(ciphertext) == 0 {
return nil, fmt.Errorf("empty plaintext")
}
padLen := int(ciphertext[len(ciphertext)-1])
if padLen > aes.BlockSize || padLen == 0 || padLen > len(ciphertext) {
return nil, fmt.Errorf("invalid padding")
}
for i := 0; i < padLen; i++ {
if ciphertext[len(ciphertext)-1-i] != byte(padLen) {
return nil, fmt.Errorf("invalid padding")
}
}
return ciphertext[:len(ciphertext)-padLen], nil
}

View File

@@ -0,0 +1,149 @@
package feishu
import (
"context"
"encoding/json"
"fmt"
"strings"
"github.com/Tencent/WeKnora/internal/im"
"github.com/Tencent/WeKnora/internal/logger"
"github.com/larksuite/oapi-sdk-go/v3/event/dispatcher"
larkim "github.com/larksuite/oapi-sdk-go/v3/service/im/v1"
larkws "github.com/larksuite/oapi-sdk-go/v3/ws"
)
// MessageHandler is called when an IM message is received via long connection.
type MessageHandler func(ctx context.Context, msg *im.IncomingMessage) error
// LongConnClient manages a Feishu WebSocket long connection.
type LongConnClient struct {
appID string
wsClient *larkws.Client
}
// NewLongConnClient creates a Feishu long connection client.
// When a text message arrives, it converts it to IncomingMessage and calls handler.
func NewLongConnClient(appID, appSecret string, handler MessageHandler) *LongConnClient {
// Long connection mode does not require verificationToken or encryptKey;
// those are only used for webhook signature verification and decryption.
eventHandler := dispatcher.NewEventDispatcher("", "").
OnP2MessageReceiveV1(func(ctx context.Context, event *larkim.P2MessageReceiveV1) error {
msg := convertEvent(event)
if msg == nil {
return nil
}
return handler(ctx, msg)
})
sdkLogger := &feishuLoggerAdapter{appID: appID}
wsClient := larkws.NewClient(appID, appSecret,
larkws.WithEventHandler(eventHandler),
larkws.WithAutoReconnect(true),
larkws.WithLogger(sdkLogger),
)
return &LongConnClient{appID: appID, wsClient: wsClient}
}
// Start begins the WebSocket long connection. It blocks until ctx is cancelled.
func (c *LongConnClient) Start(ctx context.Context) error {
logger.Infof(ctx, "[IM] Feishu WebSocket connecting (app_id=%s)...", c.appID)
return c.wsClient.Start(ctx)
}
// feishuLoggerAdapter bridges the Feishu SDK logger to our unified logger,
// replacing raw SDK connection messages with a consistent format.
type feishuLoggerAdapter struct {
appID string
}
func (l *feishuLoggerAdapter) Debug(ctx context.Context, args ...interface{}) {
logger.Debugf(ctx, "[Feishu] %s", fmt.Sprint(args...))
}
func (l *feishuLoggerAdapter) Info(ctx context.Context, args ...interface{}) {
msg := fmt.Sprint(args...)
if strings.HasPrefix(msg, "connected to ") {
logger.Infof(ctx, "[IM] Feishu WebSocket connected successfully (app_id=%s)", l.appID)
return
}
logger.Infof(ctx, "[Feishu] %s", msg)
}
func (l *feishuLoggerAdapter) Warn(ctx context.Context, args ...interface{}) {
logger.Warnf(ctx, "[Feishu] %s", fmt.Sprint(args...))
}
func (l *feishuLoggerAdapter) Error(ctx context.Context, args ...interface{}) {
logger.Errorf(ctx, "[Feishu] %s", fmt.Sprint(args...))
}
// convertEvent converts a Feishu SDK event to a unified IncomingMessage.
// Returns nil for non-text messages.
func convertEvent(event *larkim.P2MessageReceiveV1) *im.IncomingMessage {
if event == nil || event.Event == nil || event.Event.Message == nil {
return nil
}
msg := event.Event.Message
if msg.MessageType == nil || *msg.MessageType != "text" {
return nil
}
// Parse text content from JSON: {"text": "..."}
var textContent struct {
Text string `json:"text"`
}
if msg.Content == nil {
return nil
}
if err := json.Unmarshal([]byte(*msg.Content), &textContent); err != nil {
return nil
}
// Sender info
openID := ""
if event.Event.Sender != nil && event.Event.Sender.SenderId != nil && event.Event.Sender.SenderId.OpenId != nil {
openID = *event.Event.Sender.SenderId.OpenId
}
// Chat type
chatType := im.ChatTypeDirect
chatID := ""
if msg.ChatType != nil && *msg.ChatType == "group" {
chatType = im.ChatTypeGroup
if msg.ChatId != nil {
chatID = *msg.ChatId
}
}
// Message ID
messageID := ""
if msg.MessageId != nil {
messageID = *msg.MessageId
}
// Strip @bot mention from group messages
content := textContent.Text
if chatType == im.ChatTypeGroup {
for strings.HasPrefix(content, "@_user_") {
idx := strings.Index(content, " ")
if idx >= 0 {
content = content[idx+1:]
} else {
break
}
}
}
return &im.IncomingMessage{
Platform: im.PlatformFeishu,
UserID: openID,
ChatID: chatID,
ChatType: chatType,
Content: strings.TrimSpace(content),
MessageID: messageID,
}
}

378
internal/im/service.go Normal file
View File

@@ -0,0 +1,378 @@
package im
import (
"context"
"fmt"
"strings"
"sync"
"time"
"github.com/Tencent/WeKnora/internal/event"
"github.com/Tencent/WeKnora/internal/logger"
"github.com/Tencent/WeKnora/internal/types"
"github.com/Tencent/WeKnora/internal/types/interfaces"
"github.com/google/uuid"
"gorm.io/gorm"
)
const (
// qaTimeout is the maximum time to wait for the QA pipeline to complete.
qaTimeout = 120 * time.Second
// dedupTTL is how long processed message IDs are retained.
dedupTTL = 5 * time.Minute
// dedupCleanupInterval is how often the dedup map is cleaned.
dedupCleanupInterval = 1 * time.Minute
// maxContentLength is the maximum allowed message content length.
maxContentLength = 4096
)
// Service orchestrates IM message handling:
// 1. Receives a unified IncomingMessage from an Adapter
// 2. Resolves or creates a WeKnora session for the IM channel
// 3. Calls the WeKnora QA pipeline
// 4. Collects the streaming answer and sends it back via the Adapter
type Service struct {
db *gorm.DB
sessionService interfaces.SessionService
messageService interfaces.MessageService
tenantService interfaces.TenantService
agentService interfaces.CustomAgentService
adapters map[Platform]Adapter
mu sync.RWMutex
// processedMsgs tracks recently processed message IDs to prevent duplicate handling.
processedMsgs sync.Map
stopCh chan struct{}
}
// NewService creates a new IM service.
func NewService(
db *gorm.DB,
sessionService interfaces.SessionService,
messageService interfaces.MessageService,
tenantService interfaces.TenantService,
agentService interfaces.CustomAgentService,
) *Service {
s := &Service{
db: db,
sessionService: sessionService,
messageService: messageService,
tenantService: tenantService,
agentService: agentService,
adapters: make(map[Platform]Adapter),
stopCh: make(chan struct{}),
}
// Start periodic dedup cleanup instead of per-message goroutines
go s.dedupCleanupLoop()
return s
}
// Stop gracefully shuts down the service, stopping background goroutines.
func (s *Service) Stop() {
close(s.stopCh)
}
// dedupCleanupLoop periodically cleans up expired entries from the dedup map.
func (s *Service) dedupCleanupLoop() {
ticker := time.NewTicker(dedupCleanupInterval)
defer ticker.Stop()
for {
select {
case <-ticker.C:
cutoff := time.Now().Add(-dedupTTL)
s.processedMsgs.Range(func(key, value interface{}) bool {
if t, ok := value.(time.Time); ok && t.Before(cutoff) {
s.processedMsgs.Delete(key)
}
return true
})
case <-s.stopCh:
return
}
}
}
// RegisterAdapter registers an IM platform adapter.
func (s *Service) RegisterAdapter(adapter Adapter) {
s.mu.Lock()
defer s.mu.Unlock()
s.adapters[adapter.Platform()] = adapter
}
// GetAdapter returns the adapter for a given platform.
func (s *Service) GetAdapter(platform Platform) (Adapter, bool) {
s.mu.RLock()
defer s.mu.RUnlock()
a, ok := s.adapters[platform]
return a, ok
}
// HandleMessage processes an incoming IM message end-to-end:
// resolves session, runs QA, sends reply.
func (s *Service) HandleMessage(ctx context.Context, msg *IncomingMessage, tenantID uint64, agentID string, kbIDs []string) error {
// Dedup: skip if this message was already processed (IM platforms may retry)
if msg.MessageID != "" {
if _, loaded := s.processedMsgs.LoadOrStore(msg.MessageID, time.Now()); loaded {
logger.Infof(ctx, "[IM] Skipping duplicate message: %s", msg.MessageID)
return nil
}
}
// Reject overly long messages to protect the QA pipeline
contentRunes := []rune(msg.Content)
if len(contentRunes) > maxContentLength {
logger.Warnf(ctx, "[IM] Message too long (%d runes), truncating to %d", len(contentRunes), maxContentLength)
msg.Content = string(contentRunes[:maxContentLength])
}
logger.Infof(ctx, "[IM] HandleMessage: platform=%s user=%s chat=%s content_len=%d",
msg.Platform, msg.UserID, msg.ChatID, len(msg.Content))
// 1. Get tenant (once, shared across resolve + QA)
tenant, err := s.tenantService.GetTenantByID(ctx, tenantID)
if err != nil {
return fmt.Errorf("get tenant: %w", err)
}
sessionCtx := context.WithValue(ctx, types.TenantIDContextKey, tenantID)
sessionCtx = context.WithValue(sessionCtx, types.TenantInfoContextKey, tenant)
// 2. Resolve or create a WeKnora session
channelSession, err := s.resolveSession(sessionCtx, msg, tenantID, agentID)
if err != nil {
return fmt.Errorf("resolve session: %w", err)
}
// 3. Get the WeKnora session
session, err := s.sessionService.GetSession(sessionCtx, channelSession.SessionID)
if err != nil {
return fmt.Errorf("get session: %w", err)
}
// 4. Resolve custom agent (optional)
var customAgent *types.CustomAgent
if agentID != "" {
agent, err := s.agentService.GetAgentByID(sessionCtx, agentID)
if err != nil {
logger.Warnf(ctx, "[IM] Failed to get agent %s: %v, using default", agentID, err)
} else {
customAgent = agent
}
}
// 5. Run the QA pipeline and collect the full answer
answer, err := s.runQA(sessionCtx, session, msg.Content, customAgent, kbIDs)
if err != nil {
logger.Errorf(ctx, "[IM] QA failed: %v, sending fallback reply", err)
answer = "抱歉,处理您的问题时出现了异常,请稍后再试。"
}
// 6. Send the reply back via the platform adapter
adapter, ok := s.GetAdapter(msg.Platform)
if !ok {
return fmt.Errorf("no adapter for platform: %s", msg.Platform)
}
reply := &ReplyMessage{
Content: answer,
IsFinal: true,
}
if err := adapter.SendReply(ctx, msg, reply); err != nil {
return fmt.Errorf("send reply: %w", err)
}
logger.Infof(ctx, "[IM] Reply sent: platform=%s user=%s answer_len=%d", msg.Platform, msg.UserID, len(answer))
return nil
}
// resolveSession finds or creates a ChannelSession for the given IM message.
// ctx must already carry TenantIDContextKey and TenantInfoContextKey.
func (s *Service) resolveSession(ctx context.Context, msg *IncomingMessage, tenantID uint64, agentID string) (*ChannelSession, error) {
var cs ChannelSession
result := s.db.Where("platform = ? AND user_id = ? AND chat_id = ? AND tenant_id = ? AND deleted_at IS NULL",
string(msg.Platform), msg.UserID, msg.ChatID, tenantID).
First(&cs)
if result.Error == nil {
return &cs, nil
}
if result.Error != gorm.ErrRecordNotFound {
return nil, fmt.Errorf("query channel session: %w", result.Error)
}
// Create a new WeKnora session
title := fmt.Sprintf("IM-%s", msg.Platform)
if msg.UserName != "" {
title = fmt.Sprintf("IM-%s-%s", msg.Platform, msg.UserName)
}
newSession := &types.Session{
TenantID: tenantID,
Title: title,
Description: fmt.Sprintf("Auto-created from %s IM integration", msg.Platform),
}
createdSession, err := s.sessionService.CreateSession(ctx, newSession)
if err != nil {
return nil, fmt.Errorf("create session: %w", err)
}
// Create the channel-session mapping; use a unique constraint fallback
// to handle concurrent creation attempts for the same channel.
cs = ChannelSession{
Platform: string(msg.Platform),
UserID: msg.UserID,
ChatID: msg.ChatID,
SessionID: createdSession.ID,
TenantID: tenantID,
AgentID: agentID,
}
if err := s.db.Create(&cs).Error; err != nil {
// If the insert failed due to unique constraint (concurrent request),
// fetch the existing record.
var existing ChannelSession
if findErr := s.db.Where("platform = ? AND user_id = ? AND chat_id = ? AND tenant_id = ? AND deleted_at IS NULL",
string(msg.Platform), msg.UserID, msg.ChatID, tenantID).
First(&existing).Error; findErr != nil {
return nil, fmt.Errorf("create channel session: %w (lookup fallback: %v)", err, findErr)
}
return &existing, nil
}
logger.Infof(ctx, "[IM] Created new session mapping: channel=%s/%s/%s -> session=%s",
msg.Platform, msg.UserID, msg.ChatID, createdSession.ID)
return &cs, nil
}
// runQA executes the WeKnora QA pipeline and returns the full answer text.
func (s *Service) runQA(ctx context.Context, session *types.Session, query string, customAgent *types.CustomAgent, kbIDs []string) (string, error) {
// Add timeout to prevent indefinite blocking
ctx, cancel := context.WithTimeout(ctx, qaTimeout)
defer cancel()
eventBus := event.NewEventBus()
// Thread-safe answer collection
var answerMu sync.Mutex
var answerBuilder strings.Builder
var qaErr error
done := make(chan struct{})
var closeOnce sync.Once
closeDone := func() { closeOnce.Do(func() { close(done) }) }
eventBus.On(event.EventAgentFinalAnswer, func(ctx context.Context, evt event.Event) error {
data, ok := evt.Data.(event.AgentFinalAnswerData)
if !ok {
return nil
}
answerMu.Lock()
answerBuilder.WriteString(data.Content)
answerMu.Unlock()
if data.Done {
closeDone()
}
return nil
})
eventBus.On(event.EventError, func(ctx context.Context, evt event.Event) error {
data, ok := evt.Data.(event.ErrorData)
if !ok {
return nil
}
logger.Errorf(ctx, "[IM] QA error: %s", data.Error)
answerMu.Lock()
qaErr = fmt.Errorf("QA pipeline error: %s", data.Error)
answerMu.Unlock()
closeDone()
return nil
})
// Determine whether to use agent mode
useAgent := customAgent != nil && customAgent.IsAgentMode()
// Generate a shared RequestID to pair user and assistant messages for history
requestID := uuid.New().String()
// Create user message so it appears in conversation history
userMsg, err := s.messageService.CreateMessage(ctx, &types.Message{
SessionID: session.ID,
Role: "user",
Content: query,
RequestID: requestID,
CreatedAt: time.Now(),
IsCompleted: true,
})
if err != nil {
return "", fmt.Errorf("create user message: %w", err)
}
_ = userMsg
// Create a placeholder assistant message
assistantMsg, err := s.messageService.CreateMessage(ctx, &types.Message{
SessionID: session.ID,
Role: "assistant",
RequestID: requestID,
CreatedAt: time.Now(),
IsCompleted: false,
})
if err != nil {
return "", fmt.Errorf("create assistant message: %w", err)
}
// Run QA async
go func() {
var err error
if useAgent {
err = s.sessionService.AgentQA(ctx, session, query, assistantMsg.ID, "", eventBus, customAgent, kbIDs, nil)
} else {
err = s.sessionService.KnowledgeQA(ctx, session, query, kbIDs, nil, assistantMsg.ID, "", false, eventBus, customAgent, false)
}
if err != nil {
logger.Errorf(ctx, "[IM] QA execution error: %v", err)
answerMu.Lock()
qaErr = fmt.Errorf("QA execution error: %w", err)
answerMu.Unlock()
closeDone()
}
}()
// Wait for completion or timeout
select {
case <-done:
case <-ctx.Done():
// Mark assistant message as completed to avoid dangling incomplete records
assistantMsg.Content = "抱歉,回答超时,请稍后再试。"
assistantMsg.IsCompleted = true
// Use a fresh context since the original is cancelled
if updateErr := s.messageService.UpdateMessage(context.WithoutCancel(ctx), assistantMsg); updateErr != nil {
logger.Warnf(ctx, "[IM] Failed to update timed-out assistant message: %v", updateErr)
}
return "", fmt.Errorf("QA timed out after %v", qaTimeout)
}
answerMu.Lock()
answer := answerBuilder.String()
qaError := qaErr
answerMu.Unlock()
if answer == "" && qaError != nil {
return "", qaError
}
if answer == "" {
answer = "抱歉,我暂时无法回答这个问题。"
}
// Update assistant message with the final answer content
assistantMsg.Content = answer
assistantMsg.IsCompleted = true
if err := s.messageService.UpdateMessage(ctx, assistantMsg); err != nil {
logger.Warnf(ctx, "[IM] Failed to update assistant message: %v", err)
}
return answer, nil
}

40
internal/im/types.go Normal file
View File

@@ -0,0 +1,40 @@
package im
import (
"time"
"github.com/Tencent/WeKnora/internal/types"
"github.com/google/uuid"
"gorm.io/gorm"
)
// ChannelSession maps an IM channel (user+chat combination) to a WeKnora session.
// This allows the IM integration to maintain conversation continuity.
type ChannelSession struct {
ID string `json:"id" gorm:"type:varchar(36);primaryKey;default:uuid_generate_v4()"`
Platform string `json:"platform" gorm:"type:varchar(20);not null"`
UserID string `json:"user_id" gorm:"type:varchar(128);not null"`
ChatID string `json:"chat_id" gorm:"type:varchar(128);not null;default:''"`
SessionID string `json:"session_id" gorm:"type:varchar(36);not null;index"`
TenantID uint64 `json:"tenant_id" gorm:"not null;index"`
AgentID string `json:"agent_id" gorm:"type:varchar(36);default:''"`
Status string `json:"status" gorm:"type:varchar(20);not null;default:'active'"`
Metadata types.JSON `json:"metadata" gorm:"type:jsonb;default:'{}'"`
CreatedAt time.Time `json:"created_at"`
UpdatedAt time.Time `json:"updated_at"`
DeletedAt gorm.DeletedAt `json:"deleted_at" gorm:"index"`
}
func (ChannelSession) TableName() string {
return "im_channel_sessions"
}
func (cs *ChannelSession) BeforeCreate(tx *gorm.DB) error {
if cs.ID == "" {
cs.ID = uuid.New().String()
}
if cs.Status == "" {
cs.Status = "active"
}
return nil
}

View File

@@ -0,0 +1,369 @@
// Package wecom implements the WeCom (企业微信) IM adapter for WeKnora.
//
// WeCom Smart Bot flow:
// 1. User sends a message to the bot (direct or @mention in group)
// 2. WeCom calls our callback URL with the encrypted message
// 3. We decrypt, parse, and return an immediate response (or stream response)
// 4. For streaming: respond with msgtype="stream", WeCom pulls subsequent chunks via refresh callbacks
//
// Reference: https://developer.work.weixin.qq.com/document/path/101031
package wecom
import (
"bytes"
"context"
"crypto/aes"
"crypto/cipher"
"crypto/hmac"
"crypto/sha1"
"encoding/base64"
"encoding/binary"
"encoding/json"
"encoding/xml"
"fmt"
"io"
"net/http"
"sort"
"strings"
"sync"
"time"
"github.com/Tencent/WeKnora/internal/im"
"github.com/Tencent/WeKnora/internal/logger"
"github.com/gin-gonic/gin"
)
var httpClient = &http.Client{Timeout: 10 * time.Second}
// Adapter implements im.Adapter for WeCom.
type Adapter struct {
corpID string
token string
encodingAESKey string
aesKey []byte
agentSecret string
corpAgentID int
// Token cache
tokenMu sync.Mutex
tokenCache string
tokenExpAt time.Time
}
// NewAdapter creates a new WeCom adapter.
func NewAdapter(corpID, agentSecret, token, encodingAESKey string, corpAgentID int) (*Adapter, error) {
// Decode the AES key from base64
aesKey, err := base64.StdEncoding.DecodeString(encodingAESKey + "=")
if err != nil {
return nil, fmt.Errorf("decode encoding_aes_key: %w", err)
}
return &Adapter{
corpID: corpID,
token: token,
encodingAESKey: encodingAESKey,
aesKey: aesKey,
agentSecret: agentSecret,
corpAgentID: corpAgentID,
}, nil
}
// Platform returns the platform identifier.
func (a *Adapter) Platform() im.Platform {
return im.PlatformWeCom
}
// VerifyCallback verifies the WeCom callback signature.
func (a *Adapter) VerifyCallback(c *gin.Context) error {
timestamp := c.Query("timestamp")
nonce := c.Query("nonce")
msgSignature := c.Query("msg_signature")
// For GET requests (URL verification), use echostr
// For POST requests (message callback), use request body's Encrypt field
var encrypt string
if c.Request.Method == http.MethodGet {
encrypt = c.Query("echostr")
} else {
var body callbackRequestBody
bodyBytes, err := io.ReadAll(c.Request.Body)
if err != nil {
return fmt.Errorf("read request body: %w", err)
}
c.Request.Body = io.NopCloser(bytes.NewReader(bodyBytes))
if err := xml.Unmarshal(bodyBytes, &body); err != nil {
return fmt.Errorf("unmarshal xml body: %w", err)
}
encrypt = body.Encrypt
}
if !a.verifySignature(msgSignature, timestamp, nonce, encrypt) {
return fmt.Errorf("invalid signature")
}
return nil
}
// HandleURLVerification handles the WeCom URL verification (GET request).
func (a *Adapter) HandleURLVerification(c *gin.Context) bool {
if c.Request.Method != http.MethodGet {
return false
}
echoStr := c.Query("echostr")
if echoStr == "" {
return false
}
// Decrypt the echostr and return it
decrypted, err := a.decrypt(echoStr)
if err != nil {
logger.Errorf(c.Request.Context(), "[WeCom] Failed to decrypt echostr: %v", err)
c.String(http.StatusBadRequest, "decrypt failed")
return true
}
c.String(http.StatusOK, string(decrypted))
return true
}
// ParseCallback parses a WeCom callback into a unified IncomingMessage.
func (a *Adapter) ParseCallback(c *gin.Context) (*im.IncomingMessage, error) {
bodyBytes, err := io.ReadAll(c.Request.Body)
if err != nil {
return nil, fmt.Errorf("read body: %w", err)
}
var body callbackRequestBody
if err := xml.Unmarshal(bodyBytes, &body); err != nil {
return nil, fmt.Errorf("unmarshal xml: %w", err)
}
// Decrypt the message
decrypted, err := a.decrypt(body.Encrypt)
if err != nil {
return nil, fmt.Errorf("decrypt message: %w", err)
}
var msg wecomMessage
if err := xml.Unmarshal(decrypted, &msg); err != nil {
return nil, fmt.Errorf("unmarshal decrypted message: %w", err)
}
// Only handle text messages
if msg.MsgType != "text" {
logger.Infof(c.Request.Context(), "[WeCom] Ignoring non-text message type: %s", msg.MsgType)
return nil, nil
}
chatType := im.ChatTypeDirect
chatID := ""
content := msg.Content
// Check if this is a group message (has ChatId field in WeCom smart bot callback)
// In group chats, the content may contain @bot mention prefix that should be stripped
if msg.ChatID != "" {
chatType = im.ChatTypeGroup
chatID = msg.ChatID
}
return &im.IncomingMessage{
Platform: im.PlatformWeCom,
UserID: msg.FromUserName,
UserName: msg.FromUserName,
ChatID: chatID,
ChatType: chatType,
Content: strings.TrimSpace(content),
MessageID: msg.MsgID,
}, nil
}
// SendReply sends a reply message via WeCom API.
// For simplicity, this uses the "应用消息" API to proactively send messages.
// A production implementation would integrate with the streaming callback model.
func (a *Adapter) SendReply(ctx context.Context, incoming *im.IncomingMessage, reply *im.ReplyMessage) error {
// Get access token
accessToken, err := a.getAccessToken(ctx)
if err != nil {
return fmt.Errorf("get access token: %w", err)
}
// Build message payload.
// Note: /cgi-bin/message/send only supports touser/toparty/totag — it does NOT
// support regular group chat IDs (chatid is only for appchat-created groups).
// So for both direct and group messages we reply to the user directly.
payload := map[string]interface{}{
"touser": incoming.UserID,
"msgtype": "markdown",
"agentid": a.corpAgentID,
"markdown": map[string]string{
"content": reply.Content,
},
}
payloadBytes, err := json.Marshal(payload)
if err != nil {
return fmt.Errorf("marshal payload: %w", err)
}
sendURL := fmt.Sprintf("https://qyapi.weixin.qq.com/cgi-bin/message/send?access_token=%s", accessToken)
req, err := http.NewRequestWithContext(ctx, http.MethodPost, sendURL, bytes.NewReader(payloadBytes))
if err != nil {
return fmt.Errorf("create request: %w", err)
}
req.Header.Set("Content-Type", "application/json")
resp, err := httpClient.Do(req)
if err != nil {
return fmt.Errorf("send message: %w", err)
}
defer resp.Body.Close()
var result struct {
ErrCode int `json:"errcode"`
ErrMsg string `json:"errmsg"`
}
if err := json.NewDecoder(resp.Body).Decode(&result); err != nil {
return fmt.Errorf("decode response: %w", err)
}
if result.ErrCode != 0 {
return fmt.Errorf("wecom api error: code=%d msg=%s", result.ErrCode, result.ErrMsg)
}
return nil
}
// getAccessToken retrieves the WeCom access token with caching.
// WeCom tokens expire in 7200 seconds (2 hours); we cache with a safety margin.
func (a *Adapter) getAccessToken(ctx context.Context) (string, error) {
a.tokenMu.Lock()
defer a.tokenMu.Unlock()
if a.tokenCache != "" && time.Now().Before(a.tokenExpAt) {
return a.tokenCache, nil
}
tokenURL := fmt.Sprintf("https://qyapi.weixin.qq.com/cgi-bin/gettoken?corpid=%s&corpsecret=%s",
a.corpID, a.agentSecret)
req, err := http.NewRequestWithContext(ctx, http.MethodGet, tokenURL, nil)
if err != nil {
return "", fmt.Errorf("create request: %w", err)
}
resp, err := httpClient.Do(req)
if err != nil {
return "", fmt.Errorf("request access token: %w", err)
}
defer resp.Body.Close()
var result struct {
ErrCode int `json:"errcode"`
ErrMsg string `json:"errmsg"`
AccessToken string `json:"access_token"`
ExpiresIn int `json:"expires_in"` // seconds
}
if err := json.NewDecoder(resp.Body).Decode(&result); err != nil {
return "", fmt.Errorf("decode token response: %w", err)
}
if result.ErrCode != 0 {
return "", fmt.Errorf("get token error: code=%d msg=%s", result.ErrCode, result.ErrMsg)
}
a.tokenCache = result.AccessToken
// Cache with 5-minute safety margin
ttl := time.Duration(result.ExpiresIn) * time.Second
if ttl > 5*time.Minute {
ttl -= 5 * time.Minute
}
a.tokenExpAt = time.Now().Add(ttl)
return a.tokenCache, nil
}
// verifySignature verifies the WeCom callback signature using constant-time comparison.
func (a *Adapter) verifySignature(signature, timestamp, nonce, encrypt string) bool {
parts := []string{a.token, timestamp, nonce, encrypt}
sort.Strings(parts)
combined := strings.Join(parts, "")
hash := sha1.New()
hash.Write([]byte(combined))
computed := fmt.Sprintf("%x", hash.Sum(nil))
return hmac.Equal([]byte(computed), []byte(signature))
}
// decrypt decrypts a WeCom AES-encrypted message.
func (a *Adapter) decrypt(encrypted string) ([]byte, error) {
ciphertext, err := base64.StdEncoding.DecodeString(encrypted)
if err != nil {
return nil, fmt.Errorf("base64 decode: %w", err)
}
block, err := aes.NewCipher(a.aesKey)
if err != nil {
return nil, fmt.Errorf("new cipher: %w", err)
}
if len(ciphertext) < aes.BlockSize {
return nil, fmt.Errorf("ciphertext too short")
}
iv := a.aesKey[:aes.BlockSize]
mode := cipher.NewCBCDecrypter(block, iv)
mode.CryptBlocks(ciphertext, ciphertext)
// Remove and verify PKCS#7 padding
padLen := int(ciphertext[len(ciphertext)-1])
if padLen > aes.BlockSize || padLen == 0 || padLen > len(ciphertext) {
return nil, fmt.Errorf("invalid padding")
}
for i := 0; i < padLen; i++ {
if ciphertext[len(ciphertext)-1-i] != byte(padLen) {
return nil, fmt.Errorf("invalid padding")
}
}
plaintext := ciphertext[:len(ciphertext)-padLen]
// WeCom format: random(16) + msg_len(4) + msg + corp_id
if len(plaintext) < 20 {
return nil, fmt.Errorf("plaintext too short")
}
msgLen := binary.BigEndian.Uint32(plaintext[16:20])
if uint32(len(plaintext)) < 20+msgLen {
return nil, fmt.Errorf("message length mismatch")
}
msgBytes := plaintext[20 : 20+msgLen]
// Verify corp_id from plaintext tail
corpIDBytes := plaintext[20+msgLen:]
if string(corpIDBytes) != a.corpID {
return nil, fmt.Errorf("corp_id mismatch: expected %s, got %s", a.corpID, string(corpIDBytes))
}
return msgBytes, nil
}
// callbackRequestBody is the XML structure of a WeCom callback request body.
type callbackRequestBody struct {
XMLName xml.Name `xml:"xml"`
ToUserName string `xml:"ToUserName"`
Encrypt string `xml:"Encrypt"`
AgentID string `xml:"AgentID"`
}
// wecomMessage is the decrypted WeCom message structure.
type wecomMessage struct {
XMLName xml.Name `xml:"xml"`
ToUserName string `xml:"ToUserName"`
FromUserName string `xml:"FromUserName"`
CreateTime int64 `xml:"CreateTime"`
MsgType string `xml:"MsgType"`
Content string `xml:"Content"`
MsgID string `xml:"MsgId"`
AgentID string `xml:"AgentID"`
ChatID string `xml:"ChatId"`
}

View File

@@ -0,0 +1,42 @@
package wecom
import (
"context"
"fmt"
"github.com/Tencent/WeKnora/internal/im"
"github.com/gin-gonic/gin"
)
// BotAdapter implements im.Adapter for WeCom intelligent bot (long connection mode).
// It delegates SendReply to the WebSocket LongConnClient.
// The webhook methods (VerifyCallback, ParseCallback, HandleURLVerification) are no-ops
// since messages arrive via WebSocket, not HTTP.
type BotAdapter struct {
client *LongConnClient
}
// NewBotAdapter creates an adapter backed by a WeCom long connection client.
func NewBotAdapter(client *LongConnClient) *BotAdapter {
return &BotAdapter{client: client}
}
func (a *BotAdapter) Platform() im.Platform {
return im.PlatformWeCom
}
func (a *BotAdapter) VerifyCallback(c *gin.Context) error {
return fmt.Errorf("WeCom bot adapter does not support webhook callbacks")
}
func (a *BotAdapter) ParseCallback(c *gin.Context) (*im.IncomingMessage, error) {
return nil, fmt.Errorf("WeCom bot adapter does not support webhook callbacks")
}
func (a *BotAdapter) HandleURLVerification(c *gin.Context) bool {
return false
}
func (a *BotAdapter) SendReply(ctx context.Context, incoming *im.IncomingMessage, reply *im.ReplyMessage) error {
return a.client.SendReply(ctx, incoming, reply)
}

View File

@@ -0,0 +1,353 @@
// WeCom Intelligent Bot long connection client.
//
// Protocol reference: https://developer.work.weixin.qq.com/document/path/101463
// Node.js SDK reference: https://github.com/WecomTeam/aibot-node-sdk
//
// Flow:
// 1. Connect to wss://openws.work.weixin.qq.com
// 2. Send aibot_subscribe with bot_id + secret
// 3. Receive aibot_msg_callback / aibot_event_callback frames
// 4. Reply via aibot_respond_msg on the same WebSocket
// 5. Heartbeat via ping/pong every 30s
package wecom
import (
"context"
"encoding/json"
"fmt"
"math"
"strings"
"sync"
"sync/atomic"
"time"
"github.com/Tencent/WeKnora/internal/im"
"github.com/Tencent/WeKnora/internal/logger"
ws "github.com/gorilla/websocket"
)
const (
wecomWSEndpoint = "wss://openws.work.weixin.qq.com"
cmdSubscribe = "aibot_subscribe"
cmdPing = "ping"
cmdMsgCallback = "aibot_msg_callback"
cmdEventCallback = "aibot_event_callback"
cmdResponse = "aibot_respond_msg"
defaultHeartbeatInterval = 30 * time.Second
defaultReconnectBaseDelay = 1 * time.Second
defaultReconnectMaxDelay = 30 * time.Second
defaultMaxReconnectAttempts = -1 // infinite
)
// wsFrame is the JSON frame exchanged over the WeCom bot WebSocket.
type wsFrame struct {
Cmd string `json:"cmd,omitempty"`
Headers map[string]string `json:"headers,omitempty"`
Body json.RawMessage `json:"body,omitempty"`
ErrCode int `json:"errcode,omitempty"`
ErrMsg string `json:"errmsg,omitempty"`
}
// botMessage is the body of an aibot_msg_callback frame.
type botMessage struct {
MsgID string `json:"msgid"`
AiBotID string `json:"aibotid"`
ChatID string `json:"chatid"`
ChatType string `json:"chattype"` // "single" or "group"
MsgType string `json:"msgtype"` // "text", "image", "event", ...
CreateTime int64 `json:"create_time"`
From struct {
UserID string `json:"userid"`
} `json:"from"`
Text struct {
Content string `json:"content"`
} `json:"text"`
}
// streamReplyBody is the body for a streaming text reply.
type streamReplyBody struct {
MsgType string `json:"msgtype"`
Stream struct {
ID string `json:"id"`
Finish bool `json:"finish"`
Content string `json:"content"`
} `json:"stream"`
}
// MessageHandler is called when an IM message is received via long connection.
type MessageHandler func(ctx context.Context, msg *im.IncomingMessage) error
// LongConnClient manages a WeCom intelligent bot WebSocket long connection.
type LongConnClient struct {
botID string
secret string
handler MessageHandler
conn *ws.Conn
mu sync.Mutex
closed atomic.Bool
reqSeq atomic.Int64
}
// NewLongConnClient creates a WeCom long connection client.
func NewLongConnClient(botID, secret string, handler MessageHandler) *LongConnClient {
return &LongConnClient{
botID: botID,
secret: secret,
handler: handler,
}
}
// Start connects and runs the long connection loop. It reconnects automatically on failure.
func (c *LongConnClient) Start(ctx context.Context) error {
logger.Infof(ctx, "[IM] WeCom WebSocket connecting (bot_id=%s)...", c.botID)
attempts := 0
for {
if ctx.Err() != nil {
return ctx.Err()
}
err := c.connectAndRun(ctx)
if c.closed.Load() {
return nil
}
if ctx.Err() != nil {
return ctx.Err()
}
attempts++
if defaultMaxReconnectAttempts >= 0 && attempts >= defaultMaxReconnectAttempts {
return fmt.Errorf("max reconnect attempts reached: %w", err)
}
delay := reconnectDelay(attempts)
logger.Warnf(ctx, "[WeCom] Connection lost (%v), reconnecting in %v (attempt %d)...", err, delay, attempts)
select {
case <-time.After(delay):
case <-ctx.Done():
return ctx.Err()
}
}
}
// Stop gracefully closes the connection.
func (c *LongConnClient) Stop() {
c.closed.Store(true)
c.mu.Lock()
defer c.mu.Unlock()
if c.conn != nil {
_ = c.conn.Close()
c.conn = nil
}
}
// SendReply sends a text reply through the WebSocket connection.
// This is used by the IM service to reply to messages in long connection mode.
func (c *LongConnClient) SendReply(ctx context.Context, incoming *im.IncomingMessage, reply *im.ReplyMessage) error {
reqID, ok := incoming.Extra["req_id"]
if !ok || reqID == "" {
return fmt.Errorf("missing req_id in incoming message extra")
}
// Generate a unique stream ID for this reply
streamID := fmt.Sprintf("stream_%d", c.reqSeq.Add(1))
body := streamReplyBody{MsgType: "stream"}
body.Stream.ID = streamID
body.Stream.Finish = true
body.Stream.Content = reply.Content
bodyBytes, err := json.Marshal(body)
if err != nil {
return fmt.Errorf("marshal reply body: %w", err)
}
frame := wsFrame{
Cmd: cmdResponse,
Headers: map[string]string{"req_id": reqID},
Body: bodyBytes,
}
return c.writeJSON(frame)
}
func (c *LongConnClient) connectAndRun(ctx context.Context) error {
conn, _, err := ws.DefaultDialer.DialContext(ctx, wecomWSEndpoint, nil)
if err != nil {
return fmt.Errorf("dial: %w", err)
}
c.mu.Lock()
c.conn = conn
c.mu.Unlock()
defer func() {
c.mu.Lock()
c.conn = nil
c.mu.Unlock()
_ = conn.Close()
}()
// Authenticate
if err := c.authenticate(ctx); err != nil {
return fmt.Errorf("authenticate: %w", err)
}
logger.Infof(ctx, "[IM] WeCom WebSocket connected successfully (bot_id=%s)", c.botID)
// Start heartbeat
heartbeatCtx, heartbeatCancel := context.WithCancel(ctx)
defer heartbeatCancel()
go c.heartbeatLoop(heartbeatCtx)
// Message receive loop
for {
_, message, err := conn.ReadMessage()
if err != nil {
return fmt.Errorf("read message: %w", err)
}
var frame wsFrame
if err := json.Unmarshal(message, &frame); err != nil {
logger.Warnf(ctx, "[WeCom] Failed to unmarshal frame: %v", err)
continue
}
switch frame.Cmd {
case cmdMsgCallback, cmdEventCallback:
// Detach from connection ctx so in-flight messages survive reconnects.
go c.handleCallback(context.WithoutCancel(ctx), frame)
default:
// pong or other control frames — ignore
}
}
}
func (c *LongConnClient) authenticate(ctx context.Context) error {
authBody, _ := json.Marshal(map[string]string{
"bot_id": c.botID,
"secret": c.secret,
})
reqID := fmt.Sprintf("%s_%d", cmdSubscribe, time.Now().UnixNano())
frame := wsFrame{
Cmd: cmdSubscribe,
Headers: map[string]string{"req_id": reqID},
Body: authBody,
}
if err := c.writeJSON(frame); err != nil {
return fmt.Errorf("send subscribe: %w", err)
}
// Read auth response
c.mu.Lock()
conn := c.conn
c.mu.Unlock()
if conn == nil {
return fmt.Errorf("connection closed")
}
_ = conn.SetReadDeadline(time.Now().Add(10 * time.Second))
_, msg, err := conn.ReadMessage()
_ = conn.SetReadDeadline(time.Time{}) // clear deadline
if err != nil {
return fmt.Errorf("read auth response: %w", err)
}
var resp wsFrame
if err := json.Unmarshal(msg, &resp); err != nil {
return fmt.Errorf("unmarshal auth response: %w", err)
}
if resp.ErrCode != 0 {
return fmt.Errorf("auth failed: code=%d msg=%s", resp.ErrCode, resp.ErrMsg)
}
return nil
}
func (c *LongConnClient) heartbeatLoop(ctx context.Context) {
ticker := time.NewTicker(defaultHeartbeatInterval)
defer ticker.Stop()
for {
select {
case <-ctx.Done():
return
case <-ticker.C:
reqID := fmt.Sprintf("%s_%d", cmdPing, time.Now().UnixNano())
frame := wsFrame{
Cmd: cmdPing,
Headers: map[string]string{"req_id": reqID},
}
if err := c.writeJSON(frame); err != nil {
logger.Warnf(ctx, "[WeCom] Heartbeat failed: %v", err)
return
}
}
}
}
func (c *LongConnClient) handleCallback(ctx context.Context, frame wsFrame) {
var msg botMessage
if err := json.Unmarshal(frame.Body, &msg); err != nil {
logger.Warnf(ctx, "[WeCom] Failed to unmarshal callback body: %v", err)
return
}
// Only handle text messages for now
if msg.MsgType != "text" {
logger.Infof(ctx, "[WeCom] Ignoring non-text message type: %s", msg.MsgType)
return
}
chatType := im.ChatTypeDirect
chatID := ""
if msg.ChatType == "group" {
chatType = im.ChatTypeGroup
chatID = msg.ChatID
}
// Preserve req_id in Extra for reply routing
reqID := ""
if frame.Headers != nil {
reqID = frame.Headers["req_id"]
}
incoming := &im.IncomingMessage{
Platform: im.PlatformWeCom,
UserID: msg.From.UserID,
UserName: msg.From.UserID,
ChatID: chatID,
ChatType: chatType,
Content: strings.TrimSpace(msg.Text.Content),
MessageID: msg.MsgID,
Extra: map[string]string{"req_id": reqID},
}
if err := c.handler(ctx, incoming); err != nil {
logger.Errorf(ctx, "[WeCom] Handle message error: %v", err)
}
}
func (c *LongConnClient) writeJSON(v interface{}) error {
c.mu.Lock()
defer c.mu.Unlock()
if c.conn == nil {
return fmt.Errorf("connection closed")
}
return c.conn.WriteJSON(v)
}
func reconnectDelay(attempt int) time.Duration {
delay := defaultReconnectBaseDelay * time.Duration(math.Pow(2, float64(attempt-1)))
if delay > defaultReconnectMaxDelay {
delay = defaultReconnectMaxDelay
}
return delay
}

View File

@@ -59,6 +59,7 @@ type RouterParams struct {
CustomAgentHandler *handler.CustomAgentHandler
SkillHandler *handler.SkillHandler
OrganizationHandler *handler.OrganizationHandler
IMHandler *handler.IMHandler
}
// NewRouter 创建新的路由
@@ -103,6 +104,9 @@ func NewRouter(params RouterParams) *gin.Engine {
serveFrontendStatic(r)
}
// IM 回调路由(在认证中间件之前注册,使用各平台自身的签名验证)
RegisterIMRoutes(r, params.IMHandler)
// 认证中间件
r.Use(middleware.Auth(params.TenantService, params.UserService, params.Config))
@@ -576,6 +580,19 @@ func RegisterOrganizationRoutes(r *gin.RouterGroup, orgHandler *handler.Organiza
r.POST("/shared-agents/disabled", orgHandler.SetSharedAgentDisabledByMe)
}
// RegisterIMRoutes registers IM callback routes.
// These are registered BEFORE auth middleware since IM platforms use their own signature verification.
func RegisterIMRoutes(r *gin.Engine, imHandler *handler.IMHandler) {
im := r.Group("/api/v1/im")
{
// WeCom callback (supports both GET for URL verification and POST for message events)
im.GET("/callback/wecom", imHandler.WeComCallback)
im.POST("/callback/wecom", imHandler.WeComCallback)
// Feishu callback (POST for both URL verification challenge and message events)
im.POST("/callback/feishu", imHandler.FeishuCallback)
}
}
// serveFrontendStatic registers a middleware that serves the frontend SPA
// from the ./web directory if it exists. Must be called BEFORE auth middleware
// so static files are served without authentication.

View File

@@ -0,0 +1 @@
DROP TABLE IF EXISTS im_channel_sessions;

View File

@@ -0,0 +1,44 @@
-- Migration: 000021_im_channel_sessions
-- Description: Create IM channel-to-session mapping table
DO $$ BEGIN RAISE NOTICE '[Migration 000021] Creating table: im_channel_sessions'; END $$;
CREATE TABLE IF NOT EXISTS im_channel_sessions (
id VARCHAR(36) PRIMARY KEY DEFAULT uuid_generate_v4(),
platform VARCHAR(20) NOT NULL,
user_id VARCHAR(128) NOT NULL,
chat_id VARCHAR(128) NOT NULL DEFAULT '',
session_id VARCHAR(36) NOT NULL REFERENCES sessions(id) ON DELETE CASCADE,
tenant_id BIGINT NOT NULL,
agent_id VARCHAR(36) DEFAULT '',
status VARCHAR(20) NOT NULL DEFAULT 'active',
metadata JSONB DEFAULT '{}',
created_at TIMESTAMP WITH TIME ZONE NOT NULL DEFAULT CURRENT_TIMESTAMP,
updated_at TIMESTAMP WITH TIME ZONE NOT NULL DEFAULT CURRENT_TIMESTAMP,
deleted_at TIMESTAMP WITH TIME ZONE
);
-- Partial unique index: only enforce uniqueness for non-deleted rows
CREATE UNIQUE INDEX IF NOT EXISTS idx_channel_lookup
ON im_channel_sessions (platform, user_id, chat_id, tenant_id)
WHERE deleted_at IS NULL;
-- Index for tenant-based queries
CREATE INDEX IF NOT EXISTS idx_im_channel_tenant ON im_channel_sessions (tenant_id);
-- Index for session-based queries
CREATE INDEX IF NOT EXISTS idx_im_channel_session ON im_channel_sessions (session_id);
-- Partial index for soft deletes (only index deleted rows)
CREATE INDEX IF NOT EXISTS idx_im_channel_deleted ON im_channel_sessions (deleted_at) WHERE deleted_at IS NOT NULL;
COMMENT ON TABLE im_channel_sessions IS 'Maps IM platform channels to WeKnora conversation sessions';
COMMENT ON COLUMN im_channel_sessions.platform IS 'IM platform identifier: wecom, feishu, etc.';
COMMENT ON COLUMN im_channel_sessions.user_id IS 'Platform-specific user identifier';
COMMENT ON COLUMN im_channel_sessions.chat_id IS 'Platform-specific chat/group identifier, empty for direct messages';
COMMENT ON COLUMN im_channel_sessions.session_id IS 'Associated WeKnora session ID';
COMMENT ON COLUMN im_channel_sessions.tenant_id IS 'Tenant that owns this channel mapping';
COMMENT ON COLUMN im_channel_sessions.agent_id IS 'Custom agent ID used for this channel, empty for default';
COMMENT ON COLUMN im_channel_sessions.status IS 'Channel status: active, paused, expired';
COMMENT ON COLUMN im_channel_sessions.metadata IS 'Platform-specific extra data (JSON)';
DO $$ BEGIN RAISE NOTICE '[Migration 000021] im_channel_sessions setup completed successfully!'; END $$;