Files
WeKnora/internal/handler/model.go
2026-05-28 19:22:17 +08:00

466 lines
15 KiB
Go

package handler
import (
"net/http"
"github.com/Tencent/WeKnora/internal/application/service"
"github.com/Tencent/WeKnora/internal/errors"
"github.com/Tencent/WeKnora/internal/handler/dto"
"github.com/Tencent/WeKnora/internal/logger"
"github.com/Tencent/WeKnora/internal/models/provider"
"github.com/Tencent/WeKnora/internal/types"
"github.com/Tencent/WeKnora/internal/types/interfaces"
secutils "github.com/Tencent/WeKnora/internal/utils"
"github.com/gin-gonic/gin"
)
// ModelHandler handles HTTP requests for model-related operations
// It implements the necessary methods to create, retrieve, update, and delete models
type ModelHandler struct {
service interfaces.ModelService
}
// NewModelHandler creates a new instance of ModelHandler
// It requires a model service implementation that handles business logic
// Parameters:
// - service: An implementation of the ModelService interface
//
// Returns a pointer to the newly created ModelHandler
func NewModelHandler(service interfaces.ModelService) *ModelHandler {
return &ModelHandler{service: service}
}
// Per-response redaction/stripping for Model now lives in
// dto.NewModelResponse — handlers must use it for every body that contains a
// model. The previous hideSensitiveInfo helper has been removed.
// CreateModelRequest defines the structure for model creation requests
// Contains all fields required to create a new model in the system
type CreateModelRequest struct {
Name string `json:"name" binding:"required"`
DisplayName string `json:"display_name"`
Type types.ModelType `json:"type" binding:"required"`
Source types.ModelSource `json:"source" binding:"required"`
Description string `json:"description"`
Parameters types.ModelParameters `json:"parameters" binding:"required"`
}
// CreateModel godoc
// @Summary 创建模型
// @Description 创建新的模型配置
// @Tags 模型管理
// @Accept json
// @Produce json
// @Param request body CreateModelRequest true "模型信息"
// @Success 201 {object} map[string]interface{} "创建的模型"
// @Failure 400 {object} errors.AppError "请求参数错误"
// @Security Bearer
// @Security ApiKeyAuth
// @Router /models [post]
func (h *ModelHandler) CreateModel(c *gin.Context) {
ctx := c.Request.Context()
logger.Info(ctx, "Start creating model")
var req CreateModelRequest
if err := c.ShouldBindJSON(&req); err != nil {
logger.Error(ctx, "Failed to parse request parameters", err)
c.Error(errors.NewBadRequestError(err.Error()))
return
}
tenantID := c.GetUint64(types.TenantIDContextKey.String())
if tenantID == 0 {
logger.Error(ctx, "Tenant ID is empty")
c.Error(errors.NewBadRequestError("Tenant ID cannot be empty"))
return
}
logger.Infof(ctx, "Creating model, Tenant ID: %d, Model name: %s, Model type: %s",
tenantID, secutils.SanitizeForLog(req.Name), secutils.SanitizeForLog(string(req.Type)))
// SSRF validation for model BaseURL
if req.Parameters.BaseURL != "" {
if err := secutils.ValidateURLForSSRF(req.Parameters.BaseURL); err != nil {
logger.Warnf(ctx, "SSRF validation failed for model BaseURL: %v", err)
c.Error(errors.NewBadRequestError(secutils.FormatSSRFError("Base URL", req.Parameters.BaseURL, err)))
return
}
}
model := &types.Model{
TenantID: tenantID,
Name: secutils.SanitizeForLog(req.Name),
DisplayName: secutils.SanitizeForLog(req.DisplayName),
Type: types.ModelType(secutils.SanitizeForLog(string(req.Type))),
Source: req.Source,
Description: secutils.SanitizeForLog(req.Description),
Parameters: req.Parameters,
}
if err := h.service.CreateModel(ctx, model); err != nil {
logger.ErrorWithFields(ctx, err, nil)
c.Error(errors.NewInternalServerError(err.Error()))
return
}
logger.Infof(
ctx,
"Model created successfully, ID: %s, Name: %s",
secutils.SanitizeForLog(model.ID),
secutils.SanitizeForLog(model.Name),
)
c.JSON(http.StatusCreated, gin.H{
"success": true,
"data": dto.NewModelResponse(model),
})
}
// GetModel godoc
// @Summary 获取模型详情
// @Description 根据ID获取模型详情
// @Tags 模型管理
// @Accept json
// @Produce json
// @Param id path string true "模型ID"
// @Success 200 {object} map[string]interface{} "模型详情"
// @Failure 404 {object} errors.AppError "模型不存在"
// @Security Bearer
// @Security ApiKeyAuth
// @Router /models/{id} [get]
func (h *ModelHandler) GetModel(c *gin.Context) {
ctx := c.Request.Context()
logger.Info(ctx, "Start retrieving model")
id := secutils.SanitizeForLog(c.Param("id"))
if id == "" {
logger.Error(ctx, "Model ID is empty")
c.Error(errors.NewBadRequestError("Model ID cannot be empty"))
return
}
logger.Infof(ctx, "Retrieving model, ID: %s", id)
model, err := h.service.GetModelByID(ctx, id)
if err != nil {
if err == service.ErrModelNotFound {
logger.Warnf(ctx, "Model not found, ID: %s", id)
c.Error(errors.NewNotFoundError("Model not found"))
return
}
logger.ErrorWithFields(ctx, err, nil)
c.Error(errors.NewInternalServerError(err.Error()))
return
}
logger.Infof(ctx, "Retrieved model successfully, ID: %s, Name: %s", model.ID, model.Name)
c.JSON(http.StatusOK, gin.H{
"success": true,
"data": dto.NewModelResponse(model),
})
}
// ListModels godoc
// @Summary 获取模型列表
// @Description 获取当前租户的所有模型
// @Tags 模型管理
// @Accept json
// @Produce json
// @Success 200 {object} map[string]interface{} "模型列表"
// @Failure 400 {object} errors.AppError "请求参数错误"
// @Security Bearer
// @Security ApiKeyAuth
// @Router /models [get]
func (h *ModelHandler) ListModels(c *gin.Context) {
ctx := c.Request.Context()
logger.Info(ctx, "Start retrieving model list")
tenantID := c.GetUint64(types.TenantIDContextKey.String())
if tenantID == 0 {
logger.Error(ctx, "Tenant ID is empty")
c.Error(errors.NewBadRequestError("Tenant ID cannot be empty"))
return
}
models, err := h.service.ListModels(ctx)
if err != nil {
logger.ErrorWithFields(ctx, err, nil)
c.Error(errors.NewInternalServerError(err.Error()))
return
}
logger.Infof(ctx, "Retrieved model list successfully, Tenant ID: %d, Total: %d models", tenantID, len(models))
c.JSON(http.StatusOK, gin.H{
"success": true,
"data": dto.NewModelResponses(models),
})
}
// UpdateModelRequest defines the structure for model update requests
// Contains fields that can be updated for an existing model
type UpdateModelRequest struct {
Name string `json:"name"`
DisplayName *string `json:"display_name"`
Description string `json:"description"`
Parameters types.ModelParameters `json:"parameters"`
Source types.ModelSource `json:"source"`
Type types.ModelType `json:"type"`
}
// UpdateModel godoc
// @Summary 更新模型
// @Description 更新模型配置信息
// @Tags 模型管理
// @Accept json
// @Produce json
// @Param id path string true "模型ID"
// @Param request body UpdateModelRequest true "更新信息"
// @Success 200 {object} map[string]interface{} "更新后的模型"
// @Failure 404 {object} errors.AppError "模型不存在"
// @Security Bearer
// @Security ApiKeyAuth
// @Router /models/{id} [put]
func (h *ModelHandler) UpdateModel(c *gin.Context) {
ctx := c.Request.Context()
logger.Info(ctx, "Start updating model")
id := secutils.SanitizeForLog(c.Param("id"))
if id == "" {
logger.Error(ctx, "Model ID is empty")
c.Error(errors.NewBadRequestError("Model ID cannot be empty"))
return
}
var req UpdateModelRequest
if err := c.ShouldBindJSON(&req); err != nil {
logger.Error(ctx, "Failed to parse request parameters", err)
c.Error(errors.NewBadRequestError(err.Error()))
return
}
logger.Infof(ctx, "Retrieving model information, ID: %s", id)
model, err := h.service.GetModelByID(ctx, id)
if err != nil {
if err == service.ErrModelNotFound {
logger.Warnf(ctx, "Model not found, ID: %s", id)
c.Error(errors.NewNotFoundError("Model not found"))
return
}
logger.ErrorWithFields(ctx, err, nil)
c.Error(errors.NewInternalServerError(err.Error()))
return
}
// Update model fields if they are provided in the request
if req.Name != "" {
model.Name = req.Name
}
if req.DisplayName != nil {
model.DisplayName = secutils.SanitizeForLog(*req.DisplayName)
}
model.Description = req.Description
// SSRF validation for updated model BaseURL
if req.Parameters.BaseURL != "" {
if err := secutils.ValidateURLForSSRF(req.Parameters.BaseURL); err != nil {
logger.Warnf(ctx, "SSRF validation failed for model BaseURL: %v", err)
c.Error(errors.NewBadRequestError(secutils.FormatSSRFError("Base URL", req.Parameters.BaseURL, err)))
return
}
}
// Credentials (api_key, app_secret) NEVER flow through this endpoint —
// they live behind the /credentials subresource. Force-preserve them by
// snapshotting the stored values before copying request fields in, so
// that even a misbehaving caller that puts api_key in the body cannot
// clobber a stored credential. Log a warning to spot stale callers.
storedAPIKey := model.Parameters.APIKey
storedAppSecret := model.Parameters.AppSecret
if req.Parameters.APIKey != "" && req.Parameters.APIKey != storedAPIKey {
logger.Warnf(ctx,
"deprecated: api_key in PUT /models/%s body is ignored; use PUT /credentials instead", id)
}
if req.Parameters.AppSecret != "" && req.Parameters.AppSecret != storedAppSecret {
logger.Warnf(ctx,
"deprecated: app_secret in PUT /models/%s body is ignored; use PUT /credentials instead", id)
}
newParams := req.Parameters
newParams.APIKey = storedAPIKey
newParams.AppSecret = storedAppSecret
// Preserve backend-managed fields not sent by the frontend either.
newParams.ParameterSize = model.Parameters.ParameterSize
if newParams.ExtraConfig == nil {
newParams.ExtraConfig = model.Parameters.ExtraConfig
}
model.Parameters = newParams
model.Source = req.Source
model.Type = req.Type
logger.Infof(ctx, "Updating model, ID: %s, Name: %s", id, model.Name)
if err := h.service.UpdateModel(ctx, model); err != nil {
logger.ErrorWithFields(ctx, err, nil)
c.Error(errors.NewInternalServerError(err.Error()))
return
}
logger.Infof(ctx, "Model updated successfully, ID: %s", id)
c.JSON(http.StatusOK, gin.H{
"success": true,
"data": dto.NewModelResponse(model),
})
}
// DeleteModel godoc
// @Summary 删除模型
// @Description 删除指定的模型
// @Tags 模型管理
// @Accept json
// @Produce json
// @Param id path string true "模型ID"
// @Success 200 {object} map[string]interface{} "删除成功"
// @Failure 404 {object} errors.AppError "模型不存在"
// @Security Bearer
// @Security ApiKeyAuth
// @Router /models/{id} [delete]
func (h *ModelHandler) DeleteModel(c *gin.Context) {
ctx := c.Request.Context()
logger.Info(ctx, "Start deleting model")
id := secutils.SanitizeForLog(c.Param("id"))
if id == "" {
logger.Error(ctx, "Model ID is empty")
c.Error(errors.NewBadRequestError("Model ID cannot be empty"))
return
}
logger.Infof(ctx, "Deleting model, ID: %s", id)
if err := h.service.DeleteModel(ctx, id); err != nil {
if err == service.ErrModelNotFound {
logger.Warnf(ctx, "Model not found, ID: %s", id)
c.Error(errors.NewNotFoundError("Model not found"))
return
}
logger.ErrorWithFields(ctx, err, nil)
c.Error(errors.NewInternalServerError(err.Error()))
return
}
logger.Infof(ctx, "Model deleted successfully, ID: %s", id)
c.JSON(http.StatusOK, gin.H{
"success": true,
"message": "Model deleted",
})
}
// ModelProviderDTO 模型厂商信息 DTO
type ModelProviderDTO struct {
Value string `json:"value"` // provider 标识符
Label string `json:"label"` // 显示名称
Description string `json:"description"` // 描述
DefaultURLs map[string]string `json:"defaultUrls"` // 按模型类型区分的默认 URL
ModelTypes []string `json:"modelTypes"` // 支持的模型类型
}
// modelTypeToFrontend 将后端 ModelType 转换为前端兼容的字符串
// KnowledgeQA -> chat, Embedding -> embedding, Rerank -> rerank, VLLM -> vllm
func modelTypeToFrontend(mt types.ModelType) string {
switch mt {
case types.ModelTypeKnowledgeQA:
return "chat"
case types.ModelTypeEmbedding:
return "embedding"
case types.ModelTypeRerank:
return "rerank"
case types.ModelTypeVLLM:
return "vllm"
case types.ModelTypeASR:
return "asr"
default:
return string(mt)
}
}
// ListModelProviders godoc
// @Summary 获取模型厂商列表
// @Description 根据模型类型获取支持的厂商列表及配置信息
// @Tags 模型管理
// @Accept json
// @Produce json
// @Param model_type query string false "模型类型 (chat, embedding, rerank, vllm)"
// @Success 200 {object} map[string]interface{} "厂商列表"
// @Security Bearer
// @Security ApiKeyAuth
// @Router /models/providers [get]
func (h *ModelHandler) ListModelProviders(c *gin.Context) {
ctx := c.Request.Context()
modelType := c.Query("model_type")
logger.Infof(ctx, "Listing model providers for type: %s", secutils.SanitizeForLog(modelType))
// 将前端类型映射到后端类型
// 前端: chat, embedding, rerank, vllm
// 后端: KnowledgeQA, Embedding, Rerank, VLLM
var backendModelType types.ModelType
switch modelType {
case "chat":
backendModelType = types.ModelTypeKnowledgeQA
case "embedding":
backendModelType = types.ModelTypeEmbedding
case "rerank":
backendModelType = types.ModelTypeRerank
case "vllm":
backendModelType = types.ModelTypeVLLM
case "asr":
backendModelType = types.ModelTypeASR
default:
backendModelType = types.ModelType(modelType)
}
var providers []provider.ProviderInfo
if modelType != "" {
// 按模型类型过滤
providers = provider.ListByModelType(backendModelType)
} else {
// 返回所有 provider
providers = provider.List()
}
// 转换为 DTO
result := make([]ModelProviderDTO, 0, len(providers))
for _, p := range providers {
// 转换 DefaultURLs map[types.ModelType]string -> map[string]string
// 使用前端兼容的 key (chat 而不是 KnowledgeQA)
defaultURLs := make(map[string]string)
for mt, url := range p.DefaultURLs {
frontendType := modelTypeToFrontend(mt)
defaultURLs[frontendType] = url
}
// 转换 ModelTypes 为前端兼容格式
modelTypes := make([]string, 0, len(p.ModelTypes))
for _, mt := range p.ModelTypes {
modelTypes = append(modelTypes, modelTypeToFrontend(mt))
}
result = append(result, ModelProviderDTO{
Value: string(p.Name),
Label: p.DisplayName,
Description: p.Description,
DefaultURLs: defaultURLs,
ModelTypes: modelTypes,
})
}
logger.Infof(ctx, "Retrieved %d providers", len(result))
c.JSON(http.StatusOK, gin.H{
"success": true,
"data": result,
})
}