mirror of
https://github.com/Tencent/WeKnora.git
synced 2026-06-04 13:30:32 +08:00
167 lines
4.5 KiB
Go
167 lines
4.5 KiB
Go
package chatpipline
|
|
|
|
import (
|
|
"context"
|
|
"encoding/json"
|
|
"regexp"
|
|
"strings"
|
|
|
|
"github.com/Tencent/WeKnora/internal/config"
|
|
"github.com/Tencent/WeKnora/internal/models/chat"
|
|
"github.com/Tencent/WeKnora/internal/types"
|
|
"github.com/Tencent/WeKnora/internal/types/interfaces"
|
|
)
|
|
|
|
// PluginPreprocess Query preprocessing plugin
|
|
type PluginPreprocess struct {
|
|
config *config.Config
|
|
modelService interfaces.ModelService
|
|
}
|
|
|
|
// Regular expressions for text cleaning
|
|
var (
|
|
multiSpaceRegex = regexp.MustCompile(`\s+`) // Multiple spaces
|
|
)
|
|
|
|
// NewPluginPreprocess Creates a new query preprocessing plugin
|
|
func NewPluginPreprocess(
|
|
eventManager *EventManager,
|
|
config *config.Config,
|
|
cleaner interfaces.ResourceCleaner,
|
|
modelService interfaces.ModelService,
|
|
) *PluginPreprocess {
|
|
res := &PluginPreprocess{
|
|
config: config,
|
|
modelService: modelService,
|
|
}
|
|
|
|
eventManager.Register(res)
|
|
return res
|
|
}
|
|
|
|
// ActivationEvents Register activation events
|
|
func (p *PluginPreprocess) ActivationEvents() []types.EventType {
|
|
return []types.EventType{types.PREPROCESS_QUERY}
|
|
}
|
|
|
|
// OnEvent Process events
|
|
func (p *PluginPreprocess) OnEvent(
|
|
ctx context.Context,
|
|
eventType types.EventType,
|
|
chatManage *types.ChatManage,
|
|
next func() *PluginError,
|
|
) *PluginError {
|
|
rawQuery := strings.TrimSpace(chatManage.RewriteQuery)
|
|
if rawQuery == "" {
|
|
return next()
|
|
}
|
|
|
|
pipelineInfo(ctx, "Preprocess", "input", map[string]interface{}{
|
|
"session_id": chatManage.SessionID,
|
|
"rewrite_query": rawQuery,
|
|
})
|
|
|
|
// Lightweight normalization: just collapse multiple spaces
|
|
processed := multiSpaceRegex.ReplaceAllString(rawQuery, " ")
|
|
processed = strings.TrimSpace(processed)
|
|
|
|
chatManage.ProcessedQuery = processed
|
|
chatManage.QueryIntent = p.detectIntentLLM(ctx, chatManage, processed)
|
|
|
|
pipelineInfo(ctx, "Preprocess", "output", map[string]interface{}{
|
|
"session_id": chatManage.SessionID,
|
|
"processed_query": processed,
|
|
"query_intent": chatManage.QueryIntent,
|
|
})
|
|
|
|
return next()
|
|
}
|
|
|
|
// intentResp is a response for intent detection
|
|
type intentResp struct {
|
|
Intent string `json:"intent"`
|
|
Confidence float64 `json:"confidence"`
|
|
}
|
|
|
|
// detectIntentLLM detects the intent of a query using an LLM
|
|
func (p *PluginPreprocess) detectIntentLLM(ctx context.Context, chatManage *types.ChatManage, text string) string {
|
|
if p.modelService == nil || chatManage.ChatModelID == "" {
|
|
pipelineWarn(
|
|
ctx,
|
|
"IntentDetect",
|
|
"skip",
|
|
map[string]interface{}{"reason": "no_model", "session_id": chatManage.SessionID},
|
|
)
|
|
return "general"
|
|
}
|
|
chatModel, err := p.modelService.GetChatModel(ctx, chatManage.ChatModelID)
|
|
if err != nil {
|
|
pipelineWarn(
|
|
ctx,
|
|
"IntentDetect",
|
|
"get_model_failed",
|
|
map[string]interface{}{"error": err.Error(), "model_id": chatManage.ChatModelID},
|
|
)
|
|
return "general"
|
|
}
|
|
pipelineInfo(
|
|
ctx,
|
|
"IntentDetect",
|
|
"start",
|
|
map[string]interface{}{"session_id": chatManage.SessionID, "model_id": chatManage.ChatModelID},
|
|
)
|
|
sys := "You are a query intent classifier. Classify the user's query into one of: definition, howto, compare, qa, general. Respond ONLY with a JSON object {\"intent\": \"...\", \"confidence\": 0.0 } inside a markdown fenced block."
|
|
usr := text
|
|
think := false
|
|
resp, err := chatModel.Chat(ctx, []chat.Message{
|
|
{Role: "system", Content: sys},
|
|
{Role: "user", Content: usr},
|
|
}, &chat.ChatOptions{Temperature: 0.0, MaxCompletionTokens: 64, Thinking: &think})
|
|
if err != nil || resp.Content == "" {
|
|
pipelineWarn(ctx, "IntentDetect", "model_call_failed", map[string]interface{}{"error": err})
|
|
return "general"
|
|
}
|
|
body := extractJSONBody(resp.Content)
|
|
var ir intentResp
|
|
if err := json.Unmarshal([]byte(body), &ir); err != nil {
|
|
pipelineWarn(ctx, "IntentDetect", "parse_failed", map[string]interface{}{"body": body, "error": err.Error()})
|
|
return "general"
|
|
}
|
|
pipelineInfo(
|
|
ctx,
|
|
"IntentDetect",
|
|
"result",
|
|
map[string]interface{}{"intent": ir.Intent, "confidence": ir.Confidence},
|
|
)
|
|
switch strings.ToLower(strings.TrimSpace(ir.Intent)) {
|
|
case "definition", "howto", "compare", "qa", "general":
|
|
return strings.ToLower(ir.Intent)
|
|
default:
|
|
return "general"
|
|
}
|
|
}
|
|
|
|
// extractJSONBody extracts a JSON body from a string
|
|
func extractJSONBody(text string) string {
|
|
t := strings.TrimSpace(text)
|
|
// Try fenced block first
|
|
if i := strings.Index(t, "{"); i >= 0 {
|
|
j := strings.LastIndex(t, "}")
|
|
if j > i {
|
|
return t[i : j+1]
|
|
}
|
|
}
|
|
return "{}"
|
|
}
|
|
|
|
// Close Releases resources
|
|
func (p *PluginPreprocess) Close() {
|
|
}
|
|
|
|
// ShutdownHandler Returns shutdown function
|
|
func (p *PluginPreprocess) ShutdownHandler() func() {
|
|
return func() {
|
|
p.Close()
|
|
}
|
|
}
|