mirror of
https://github.com/Tencent/WeKnora.git
synced 2026-06-04 13:30:32 +08:00
* feat(rbac): add multi-use share-link invitations for invite_only mode When the auth.registration_mode toggle flips to invite_only, the public /auth/register endpoint returns 403 — but until now there was no channel for unregistered users to actually join a tenant. The existing invitation table required a registered invitee, so new emails could not enter the system at all. This change introduces a share-link model: - Owner generates a multi-use registration link with a per-tenant role - Recipient opens the link, registers with their own email, and is added to the tenant on the same request - Token is stored in plaintext (not hashed) so the management UI can re-display the URL on demand without a "copy now or revoke" trap. Threat model is bounded by 7-day TTL, revocability, and the fact that all the link grants is membership in one tenant. - accepted_count tracks how many users joined through the link so the Owner can tell whether a link is fresh or has already been spread Per-user invitations (registered email -> in-app inbox accept) are unchanged. Share-link rows use empty invitee_user_id as discriminator; the partial unique index on (tenant_id, invitee_user_id) was relaxed to skip empty values so multiple share links can coexist per tenant. Frontend reuses Login.vue for /register?token=xxx instead of a parallel page, and adds icon-only revoke / remove actions plus a distinct "active" status tag for share-link rows. Migration 000054 adds the token column, the accepted_count column, and the relaxed pending unique index in a single step. * feat(auth): update invitation token handling and add rate limiting - Refactor the invitation token retrieval to use POST instead of GET, enhancing security by preventing the plaintext token from appearing in logs and browser history. - Update the API endpoint from `/auth/invitations/:token` to `/auth/invitations/lookup` to reflect the new method. - Introduce a rate limiter for unauthenticated share-link endpoints to mitigate brute-force attacks and abuse, ensuring a maximum of 30 requests per minute per IP. - Adjust the registration flow to maintain user experience while securing the invitation process.
855 lines
30 KiB
Go
855 lines
30 KiB
Go
package handler
|
||
|
||
import (
|
||
"context"
|
||
"crypto/rand"
|
||
"encoding/base64"
|
||
"encoding/json"
|
||
"fmt"
|
||
"net/http"
|
||
"strings"
|
||
|
||
"github.com/gin-gonic/gin"
|
||
|
||
"github.com/Tencent/WeKnora/internal/config"
|
||
"github.com/Tencent/WeKnora/internal/errors"
|
||
"github.com/Tencent/WeKnora/internal/logger"
|
||
"github.com/Tencent/WeKnora/internal/types"
|
||
"github.com/Tencent/WeKnora/internal/types/interfaces"
|
||
secutils "github.com/Tencent/WeKnora/internal/utils"
|
||
)
|
||
|
||
// AuthHandler implements HTTP request handlers for user authentication
|
||
// Provides functionality for user registration, login, logout, and token management
|
||
// through the REST API endpoints
|
||
type AuthHandler struct {
|
||
userService interfaces.UserService
|
||
tenantService interfaces.TenantService
|
||
configInfo *config.Config
|
||
systemSettingSvc interfaces.SystemSettingService
|
||
// invitationSvc is required for the share-link registration path
|
||
// (POST /auth/register-by-invite). When nil — e.g. legacy test
|
||
// fixtures — the share-link endpoints respond 503 rather than
|
||
// blocking the rest of the auth surface.
|
||
invitationSvc interfaces.TenantInvitationService
|
||
}
|
||
|
||
// NewAuthHandler creates a new auth handler instance with the provided services
|
||
// Parameters:
|
||
// - userService: An implementation of the UserService interface for business logic
|
||
// - tenantService: An implementation of the TenantService interface for tenant management
|
||
// - systemSettingSvc: 3-tier resolver for runtime-tunable settings such as
|
||
// auth.registration_mode (P3). When DB has a row, it overrides cfg's
|
||
// startup value; otherwise we fall back to cfg.Auth.RegistrationMode
|
||
// (which already accounted for the legacy DISABLE_REGISTRATION env coerce
|
||
// during config load). Mismatch impossible by construction since the
|
||
// handler always passes cfg's value as the def parameter to GetString.
|
||
//
|
||
// Returns a pointer to the newly created AuthHandler
|
||
func NewAuthHandler(configInfo *config.Config,
|
||
userService interfaces.UserService, tenantService interfaces.TenantService,
|
||
systemSettingSvc interfaces.SystemSettingService,
|
||
invitationSvc interfaces.TenantInvitationService,
|
||
) *AuthHandler {
|
||
// Boot-time guard: a nil-or-empty Auth section silently disables the
|
||
// invite_only gate (see Register below). Emit a loud one-shot log
|
||
// pointing at the misconfiguration so operators notice on startup
|
||
// instead of discovering it the day someone hits /auth/register.
|
||
if configInfo == nil || configInfo.Auth == nil {
|
||
logger.Errorf(context.Background(),
|
||
"[auth] AuthHandler constructed with nil/incomplete config (cfg=%v); "+
|
||
"registration_mode enforcement is disabled. This is almost certainly a wiring bug.",
|
||
configInfo)
|
||
}
|
||
return &AuthHandler{
|
||
configInfo: configInfo,
|
||
userService: userService,
|
||
tenantService: tenantService,
|
||
systemSettingSvc: systemSettingSvc,
|
||
invitationSvc: invitationSvc,
|
||
}
|
||
}
|
||
|
||
// resolveRegistrationMode returns the currently active registration mode.
|
||
// Priority: DB system_settings > cfg (which already absorbed the legacy
|
||
// DISABLE_REGISTRATION env coerce at startup) > "self_serve" hard default.
|
||
//
|
||
// Centralised here so /auth/register and /auth/config stay in lock-step —
|
||
// otherwise a SystemAdmin's UI edit could affect one path and not the other.
|
||
func (h *AuthHandler) resolveRegistrationMode(ctx context.Context) string {
|
||
// cfg-derived default: empty is impossible after applyAuthAndTenantDefaults,
|
||
// but be defensive in case AuthHandler was constructed before that ran
|
||
// (the NewAuthHandler guard already logged in that case).
|
||
def := config.AuthRegistrationModeSelfServe
|
||
if h.configInfo != nil && h.configInfo.Auth != nil {
|
||
if m := strings.TrimSpace(h.configInfo.Auth.RegistrationMode); m != "" {
|
||
def = m
|
||
}
|
||
}
|
||
if h.systemSettingSvc == nil {
|
||
return def
|
||
}
|
||
// envName = "" because DISABLE_REGISTRATION is a boolean and
|
||
// auth.registration_mode is a string — the legacy env was already
|
||
// coerced into `def` above. Mixing the two semantics at the resolver
|
||
// layer would mean a UI delete (DB row absent) silently flipped to
|
||
// the legacy boolean read again, which is surprising.
|
||
return h.systemSettingSvc.GetString(ctx, "auth.registration_mode", "", def)
|
||
}
|
||
|
||
// Register godoc
|
||
// @Summary 用户注册
|
||
// @Description 注册新用户账号
|
||
// @Tags 认证
|
||
// @Accept json
|
||
// @Produce json
|
||
// @Param request body types.RegisterRequest true "注册请求参数"
|
||
// @Success 201 {object} types.RegisterResponse
|
||
// @Failure 400 {object} errors.AppError "请求参数错误"
|
||
// @Failure 403 {object} errors.AppError "注册功能已禁用"
|
||
// @Router /auth/register [post]
|
||
func (h *AuthHandler) Register(c *gin.Context) {
|
||
ctx := c.Request.Context()
|
||
|
||
logger.Info(ctx, "Start user registration")
|
||
|
||
// 当 auth.registration_mode=invite_only 时,public 注册被关闭。
|
||
// 优先级:DB system_settings > cfg.Auth.RegistrationMode > "self_serve"。
|
||
// SystemAdmin 通过「全局设置」UI 实时切换 self_serve / invite_only,立即
|
||
// 生效,不需要重启服务。历史变量 DISABLE_REGISTRATION=true 仍在 config
|
||
// 启动阶段被等价提升为 invite_only(applyAuthAndTenantDefaults),
|
||
// 作为 cfg-default 进入 resolveRegistrationMode。
|
||
if h.resolveRegistrationMode(ctx) == config.AuthRegistrationModeInviteOnly {
|
||
logger.Warn(ctx, "Registration rejected: auth.registration_mode=invite_only")
|
||
appErr := errors.NewForbiddenError("Registration is invite-only")
|
||
c.Error(appErr)
|
||
return
|
||
}
|
||
|
||
var req types.RegisterRequest
|
||
if err := c.ShouldBindJSON(&req); err != nil {
|
||
logger.Error(ctx, "Failed to parse registration request parameters", err)
|
||
appErr := errors.NewValidationError("Invalid registration parameters").WithDetails(err.Error())
|
||
c.Error(appErr)
|
||
return
|
||
}
|
||
req.Username = secutils.SanitizeForLog(req.Username)
|
||
req.Email = secutils.SanitizeForLog(req.Email)
|
||
req.Password = secutils.SanitizeForLog(req.Password)
|
||
|
||
// Validate required fields
|
||
if req.Username == "" || req.Email == "" || req.Password == "" {
|
||
logger.Error(ctx, "Missing required registration fields")
|
||
appErr := errors.NewValidationError("Username, email and password are required")
|
||
c.Error(appErr)
|
||
return
|
||
}
|
||
req.Username = secutils.SanitizeForLog(req.Username)
|
||
req.Email = secutils.SanitizeForLog(req.Email)
|
||
// Call service to register user
|
||
user, err := h.userService.Register(ctx, &req)
|
||
if err != nil {
|
||
logger.Errorf(ctx, "Failed to register user: %v", err)
|
||
appErr := errors.NewBadRequestError(err.Error())
|
||
c.Error(appErr)
|
||
return
|
||
}
|
||
|
||
// Return success response
|
||
response := &types.RegisterResponse{
|
||
Success: true,
|
||
Message: "Registration successful",
|
||
User: user,
|
||
}
|
||
|
||
logger.Infof(ctx, "User registered successfully: %s", secutils.SanitizeForLog(user.Email))
|
||
c.JSON(http.StatusCreated, response)
|
||
}
|
||
|
||
// Login godoc
|
||
// @Summary 用户登录
|
||
// @Description 用户登录并获取访问令牌
|
||
// @Tags 认证
|
||
// @Accept json
|
||
// @Produce json
|
||
// @Param request body types.LoginRequest true "登录请求参数"
|
||
// @Success 200 {object} types.LoginResponse
|
||
// @Failure 401 {object} errors.AppError "认证失败"
|
||
// @Router /auth/login [post]
|
||
func (h *AuthHandler) Login(c *gin.Context) {
|
||
ctx := c.Request.Context()
|
||
|
||
logger.Info(ctx, "Start user login")
|
||
|
||
var req types.LoginRequest
|
||
if err := c.ShouldBindJSON(&req); err != nil {
|
||
logger.Error(ctx, "Failed to parse login request parameters", err)
|
||
appErr := errors.NewValidationError("Invalid login parameters").WithDetails(err.Error())
|
||
c.Error(appErr)
|
||
return
|
||
}
|
||
email := secutils.SanitizeForLog(req.Email)
|
||
|
||
// Validate required fields
|
||
if req.Email == "" || req.Password == "" {
|
||
logger.Error(ctx, "Missing required login fields")
|
||
appErr := errors.NewValidationError("Email and password are required")
|
||
c.Error(appErr)
|
||
return
|
||
}
|
||
|
||
// Call service to authenticate user
|
||
response, err := h.userService.Login(ctx, &req)
|
||
if err != nil {
|
||
logger.Errorf(ctx, "Failed to login user: %v", err)
|
||
appErr := errors.NewUnauthorizedError("Login failed").WithDetails(err.Error())
|
||
c.Error(appErr)
|
||
return
|
||
}
|
||
|
||
// Check if login was successful
|
||
if !response.Success {
|
||
logger.Warnf(ctx, "Login failed: %s", response.Message)
|
||
c.JSON(http.StatusUnauthorized, response)
|
||
return
|
||
}
|
||
|
||
// User is already in the correct format from service
|
||
|
||
logger.Infof(ctx, "User logged in successfully, email: %s", email)
|
||
c.JSON(http.StatusOK, response)
|
||
}
|
||
|
||
// GetOIDCAuthorizationURL godoc
|
||
// @Summary 获取OIDC授权地址
|
||
// @Description 根据后端OIDC配置生成第三方登录跳转地址
|
||
// @Tags 认证
|
||
// @Accept json
|
||
// @Produce json
|
||
// @Param redirect_uri query string true "OIDC回调地址"
|
||
// @Success 200 {object} types.OIDCAuthURLResponse
|
||
// @Failure 400 {object} errors.AppError "请求参数错误"
|
||
// @Failure 403 {object} errors.AppError "OIDC未启用"
|
||
// @Router /auth/oidc/url [get]
|
||
func (h *AuthHandler) GetOIDCAuthorizationURL(c *gin.Context) {
|
||
ctx := c.Request.Context()
|
||
redirectURI := strings.TrimSpace(c.Query("redirect_uri"))
|
||
if redirectURI == "" {
|
||
appErr := errors.NewValidationError("redirect_uri is required")
|
||
c.Error(appErr)
|
||
return
|
||
}
|
||
|
||
resp, err := h.userService.GetOIDCAuthorizationURL(ctx, redirectURI)
|
||
if err != nil {
|
||
logger.Errorf(ctx, "Failed to generate OIDC authorization URL: %v", err)
|
||
appErr := errors.NewForbiddenError("OIDC authorization unavailable").WithDetails(err.Error())
|
||
c.Error(appErr)
|
||
return
|
||
}
|
||
|
||
c.JSON(http.StatusOK, resp)
|
||
}
|
||
|
||
// GetOIDCConfig godoc
|
||
// @Summary 获取OIDC登录配置
|
||
// @Description 返回OIDC是否启用以及provider展示名称,供前端决定是否展示OIDC登录入口
|
||
// @Tags 认证
|
||
// @Accept json
|
||
// @Produce json
|
||
// @Success 200 {object} types.OIDCConfigResponse
|
||
// @Router /auth/oidc/config [get]
|
||
func (h *AuthHandler) GetOIDCConfig(c *gin.Context) {
|
||
providerDisplayName := ""
|
||
enabled := false
|
||
|
||
if h.configInfo != nil && h.configInfo.OIDCAuth != nil {
|
||
enabled = h.configInfo.OIDCAuth.Enable
|
||
providerDisplayName = strings.TrimSpace(h.configInfo.OIDCAuth.ProviderDisplayName)
|
||
}
|
||
|
||
c.JSON(http.StatusOK, &types.OIDCConfigResponse{
|
||
Success: true,
|
||
Enabled: enabled,
|
||
ProviderDisplayName: providerDisplayName,
|
||
})
|
||
}
|
||
|
||
// OIDCRedirectCallback godoc
|
||
// @Summary OIDC登录重定向回调
|
||
// @Description 接收OIDC provider回调并由后端完成code交换,随后重定向回前端登录页
|
||
// @Tags 认证
|
||
// @Accept json
|
||
// @Produce json
|
||
// @Param code query string false "OIDC授权码"
|
||
// @Param state query string false "OIDC状态"
|
||
// @Param error query string false "OIDC错误码"
|
||
// @Success 302
|
||
// @Router /auth/oidc/callback [get]
|
||
func (h *AuthHandler) OIDCRedirectCallback(c *gin.Context) {
|
||
ctx := c.Request.Context()
|
||
frontendRedirectURI := "/"
|
||
|
||
if providerError := strings.TrimSpace(c.Query("error")); providerError != "" {
|
||
redirectURL := frontendRedirectURI + "#oidc_error=" + urlQueryEscape(providerError)
|
||
if description := strings.TrimSpace(c.Query("error_description")); description != "" {
|
||
redirectURL += "&oidc_error_description=" + urlQueryEscape(description)
|
||
}
|
||
c.Redirect(http.StatusFound, redirectURL)
|
||
return
|
||
}
|
||
|
||
state := strings.TrimSpace(c.Query("state"))
|
||
decodedState, err := decodeOIDCState(state)
|
||
if err != nil {
|
||
logger.Errorf(ctx, "Failed to decode OIDC state: %v", err)
|
||
c.Redirect(http.StatusFound, frontendRedirectURI+"#oidc_error="+urlQueryEscape("invalid_state"))
|
||
return
|
||
}
|
||
|
||
code := strings.TrimSpace(c.Query("code"))
|
||
if code == "" {
|
||
c.Redirect(http.StatusFound, frontendRedirectURI+"#oidc_error="+urlQueryEscape("missing_code"))
|
||
return
|
||
}
|
||
|
||
resp, err := h.userService.LoginWithOIDC(ctx, code, strings.TrimSpace(decodedState.RedirectURI))
|
||
if err != nil {
|
||
logger.Errorf(ctx, "Failed to complete OIDC login via redirect callback: %v", err)
|
||
c.Redirect(http.StatusFound, frontendRedirectURI+"#oidc_error="+urlQueryEscape("login_failed")+"&oidc_error_description="+urlQueryEscape(err.Error()))
|
||
return
|
||
}
|
||
if !resp.Success {
|
||
c.Redirect(http.StatusFound, frontendRedirectURI+"#oidc_error="+urlQueryEscape("login_failed")+"&oidc_error_description="+urlQueryEscape(resp.Message))
|
||
return
|
||
}
|
||
|
||
payload, err := encodeOIDCCallbackPayload(resp)
|
||
if err != nil {
|
||
logger.Errorf(ctx, "Failed to encode OIDC callback payload: %v", err)
|
||
c.Redirect(http.StatusFound, frontendRedirectURI+"#oidc_error="+urlQueryEscape("payload_encode_failed"))
|
||
return
|
||
}
|
||
|
||
c.Redirect(http.StatusFound, frontendRedirectURI+"#oidc_result="+urlQueryEscape(payload))
|
||
}
|
||
|
||
func encodeOIDCCallbackPayload(resp *types.OIDCCallbackResponse) (string, error) {
|
||
payload, err := json.Marshal(resp)
|
||
if err != nil {
|
||
return "", err
|
||
}
|
||
return base64.RawURLEncoding.EncodeToString(payload), nil
|
||
}
|
||
|
||
type oidcStatePayload struct {
|
||
Nonce string `json:"nonce"`
|
||
RedirectURI string `json:"redirect_uri,omitempty"`
|
||
}
|
||
|
||
func decodeOIDCState(raw string) (*oidcStatePayload, error) {
|
||
decoded, err := base64.RawURLEncoding.DecodeString(strings.TrimSpace(raw))
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
var payload oidcStatePayload
|
||
if err := json.Unmarshal(decoded, &payload); err != nil {
|
||
return nil, err
|
||
}
|
||
if strings.TrimSpace(payload.RedirectURI) == "" {
|
||
return nil, errors.NewValidationError("state.redirect_uri is required")
|
||
}
|
||
return &payload, nil
|
||
}
|
||
|
||
func urlQueryEscape(value string) string {
|
||
replacer := strings.NewReplacer(
|
||
"%", "%25",
|
||
" ", "%20",
|
||
"#", "%23",
|
||
"&", "%26",
|
||
"+", "%2B",
|
||
"=", "%3D",
|
||
"?", "%3F",
|
||
)
|
||
return replacer.Replace(value)
|
||
}
|
||
|
||
// Logout godoc
|
||
// @Summary 用户登出
|
||
// @Description 撤销当前访问令牌并登出
|
||
// @Tags 认证
|
||
// @Accept json
|
||
// @Produce json
|
||
// @Success 200 {object} map[string]interface{} "登出成功"
|
||
// @Failure 400 {object} errors.AppError "请求参数错误"
|
||
// @Security Bearer
|
||
// @Router /auth/logout [post]
|
||
func (h *AuthHandler) Logout(c *gin.Context) {
|
||
ctx := c.Request.Context()
|
||
|
||
logger.Info(ctx, "Start user logout")
|
||
|
||
// Extract token from Authorization header
|
||
authHeader := c.GetHeader("Authorization")
|
||
if authHeader == "" {
|
||
logger.Error(ctx, "Missing Authorization header")
|
||
appErr := errors.NewValidationError("Authorization header is required")
|
||
c.Error(appErr)
|
||
return
|
||
}
|
||
|
||
// Parse Bearer token
|
||
tokenParts := strings.Split(authHeader, " ")
|
||
if len(tokenParts) != 2 || tokenParts[0] != "Bearer" {
|
||
logger.Error(ctx, "Invalid Authorization header format")
|
||
appErr := errors.NewValidationError("Invalid Authorization header format")
|
||
c.Error(appErr)
|
||
return
|
||
}
|
||
|
||
token := tokenParts[1]
|
||
|
||
// Revoke token
|
||
err := h.userService.RevokeToken(ctx, token)
|
||
if err != nil {
|
||
logger.Errorf(ctx, "Failed to revoke token: %v", err)
|
||
appErr := errors.NewInternalServerError("Logout failed").WithDetails(err.Error())
|
||
c.Error(appErr)
|
||
return
|
||
}
|
||
|
||
logger.Info(ctx, "User logged out successfully")
|
||
c.JSON(http.StatusOK, gin.H{
|
||
"success": true,
|
||
"message": "Logout successful",
|
||
})
|
||
}
|
||
|
||
// RefreshToken godoc
|
||
// @Summary 刷新令牌
|
||
// @Description 使用刷新令牌获取新的访问令牌
|
||
// @Tags 认证
|
||
// @Accept json
|
||
// @Produce json
|
||
// @Param request body object{refreshToken=string} true "刷新令牌"
|
||
// @Success 200 {object} map[string]interface{} "新令牌"
|
||
// @Failure 401 {object} errors.AppError "令牌无效"
|
||
// @Router /auth/refresh [post]
|
||
func (h *AuthHandler) RefreshToken(c *gin.Context) {
|
||
ctx := c.Request.Context()
|
||
|
||
logger.Info(ctx, "Start token refresh")
|
||
|
||
var req struct {
|
||
RefreshToken string `json:"refreshToken" binding:"required"`
|
||
}
|
||
|
||
if err := c.ShouldBindJSON(&req); err != nil {
|
||
logger.Error(ctx, "Failed to parse refresh token request", err)
|
||
appErr := errors.NewValidationError("Invalid refresh token request").WithDetails(err.Error())
|
||
c.Error(appErr)
|
||
return
|
||
}
|
||
|
||
// Call service to refresh token
|
||
accessToken, newRefreshToken, err := h.userService.RefreshToken(ctx, req.RefreshToken)
|
||
if err != nil {
|
||
logger.Errorf(ctx, "Failed to refresh token: %v", err)
|
||
appErr := errors.NewUnauthorizedError("Token refresh failed").WithDetails(err.Error())
|
||
c.Error(appErr)
|
||
return
|
||
}
|
||
|
||
logger.Info(ctx, "Token refreshed successfully")
|
||
c.JSON(http.StatusOK, gin.H{
|
||
"success": true,
|
||
"message": "Token refreshed successfully",
|
||
"access_token": accessToken,
|
||
"refresh_token": newRefreshToken,
|
||
})
|
||
}
|
||
|
||
// GetCurrentUser godoc
|
||
// @Summary 获取当前用户信息
|
||
// @Description 获取当前登录用户的详细信息
|
||
// @Tags 认证
|
||
// @Accept json
|
||
// @Produce json
|
||
// @Success 200 {object} map[string]interface{} "用户信息"
|
||
// @Failure 401 {object} errors.AppError "未授权"
|
||
// @Security Bearer
|
||
// @Router /auth/me [get]
|
||
func (h *AuthHandler) GetCurrentUser(c *gin.Context) {
|
||
ctx := c.Request.Context()
|
||
|
||
// Get current user from service (which extracts from context)
|
||
user, err := h.userService.GetCurrentUser(ctx)
|
||
if err != nil {
|
||
logger.Errorf(ctx, "Failed to get current user: %v", err)
|
||
appErr := errors.NewUnauthorizedError("Failed to get user information").WithDetails(err.Error())
|
||
c.Error(appErr)
|
||
return
|
||
}
|
||
|
||
// Get tenant information for the *active* tenant (the one the
|
||
// auth middleware resolved against the X-Tenant-ID header), not
|
||
// the user's home tenant. user.TenantID is the row stored on the
|
||
// users table at signup time and never changes; reading it here
|
||
// would make /auth/me always return the home tenant even after
|
||
// the user switched into a peer tenant. The frontend then re-keys
|
||
// `authStore.tenant.id` to the home tenant, and every UI gate
|
||
// computed against it (currentTenantRole, isOwner, ...) leaks
|
||
// the wrong role. Pull the active tenant id from context instead.
|
||
var tenant *types.Tenant
|
||
activeTenantID, _ := types.TenantIDFromContext(ctx)
|
||
if activeTenantID == 0 {
|
||
activeTenantID = user.TenantID
|
||
}
|
||
if activeTenantID > 0 {
|
||
tenant, err = h.tenantService.GetTenantByID(ctx, activeTenantID)
|
||
if err != nil {
|
||
logger.Warnf(ctx, "Failed to get tenant info for user %s, tenant ID %d: %v", user.Email, activeTenantID, err)
|
||
// Don't fail the request if tenant info is not available
|
||
}
|
||
}
|
||
userInfo := user.ToUserInfo()
|
||
userInfo.CanAccessAllTenants = user.CanAccessAllTenants && h.configInfo.Tenant.EnableCrossTenantAccess
|
||
// 同步返回当前用户的 memberships,让前端在页面刷新(仅命中 /auth/me)
|
||
// 后也能恢复 currentTenantRole,避免角色信息只在 login 那一刻可用。
|
||
memberships := h.userService.BuildLoginMemberships(ctx, user, tenant)
|
||
c.JSON(http.StatusOK, gin.H{
|
||
"success": true,
|
||
"data": gin.H{
|
||
"user": userInfo,
|
||
"tenant": tenant,
|
||
"memberships": memberships,
|
||
},
|
||
})
|
||
}
|
||
|
||
// updateMyPreferencesRequest is the body for PUT /auth/me/preferences.
|
||
// Fields are pointers so the handler can distinguish "key not present"
|
||
// (preserve existing value) from "explicit false". See
|
||
// types.UserPreferences for the persistence-layer counterpart.
|
||
type updateMyPreferencesRequest struct {
|
||
EnableMemory *bool `json:"enable_memory"`
|
||
// LastActiveTenantID lets the SPA persist "after a fresh login,
|
||
// drop me back into this workspace" across devices. Send a positive
|
||
// tenant id to set / replace, or 0 to clear. Membership is validated
|
||
// at next login, not here. Nil = field omitted from the PATCH and
|
||
// stays untouched.
|
||
LastActiveTenantID *uint64 `json:"last_active_tenant_id"`
|
||
}
|
||
|
||
// UpdateMyPreferences godoc
|
||
// @Summary 更新当前用户的个性化设置
|
||
// @Description 按 PATCH 语义合并用户偏好(仅覆盖请求体里出现的字段,其余字段保持不变),
|
||
// @Description 数据存放在 users.preferences (JSON),跨设备/浏览器自动同步。
|
||
// @Tags 认证
|
||
// @Accept json
|
||
// @Produce json
|
||
// @Param request body updateMyPreferencesRequest true "Preferences patch"
|
||
// @Success 200 {object} map[string]interface{} "更新后的偏好"
|
||
// @Failure 400 {object} errors.AppError "请求参数错误"
|
||
// @Failure 401 {object} errors.AppError "未授权"
|
||
// @Security Bearer
|
||
// @Router /auth/me/preferences [put]
|
||
func (h *AuthHandler) UpdateMyPreferences(c *gin.Context) {
|
||
ctx := c.Request.Context()
|
||
|
||
user, err := h.userService.GetCurrentUser(ctx)
|
||
if err != nil {
|
||
appErr := errors.NewUnauthorizedError("Failed to get user information").WithDetails(err.Error())
|
||
c.Error(appErr)
|
||
return
|
||
}
|
||
|
||
var req updateMyPreferencesRequest
|
||
if err := c.ShouldBindJSON(&req); err != nil {
|
||
appErr := errors.NewValidationError("Invalid preferences request").WithDetails(err.Error())
|
||
c.Error(appErr)
|
||
return
|
||
}
|
||
|
||
patch := types.UserPreferences{
|
||
EnableMemory: req.EnableMemory,
|
||
LastActiveTenantID: req.LastActiveTenantID,
|
||
}
|
||
prefs, err := h.userService.UpdateUserPreferences(ctx, user.ID, patch)
|
||
if err != nil {
|
||
logger.Errorf(ctx, "Failed to update preferences for user %s: %v", user.Email, err)
|
||
appErr := errors.NewBadRequestError("Failed to update preferences").WithDetails(err.Error())
|
||
c.Error(appErr)
|
||
return
|
||
}
|
||
|
||
c.JSON(http.StatusOK, gin.H{
|
||
"success": true,
|
||
"data": prefs,
|
||
})
|
||
}
|
||
|
||
// ChangePassword godoc
|
||
// @Summary 修改密码
|
||
// @Description 修改当前用户的登录密码
|
||
// @Tags 认证
|
||
// @Accept json
|
||
// @Produce json
|
||
// @Param request body object{old_password=string,new_password=string} true "密码修改请求"
|
||
// @Success 200 {object} map[string]interface{} "修改成功"
|
||
// @Failure 400 {object} errors.AppError "请求参数错误"
|
||
// @Security Bearer
|
||
// @Router /auth/change-password [post]
|
||
func (h *AuthHandler) ChangePassword(c *gin.Context) {
|
||
ctx := c.Request.Context()
|
||
|
||
logger.Info(ctx, "Start password change")
|
||
|
||
var req struct {
|
||
OldPassword string `json:"old_password" binding:"required"`
|
||
NewPassword string `json:"new_password" binding:"required,min=6"`
|
||
}
|
||
|
||
if err := c.ShouldBindJSON(&req); err != nil {
|
||
logger.Error(ctx, "Failed to parse password change request", err)
|
||
appErr := errors.NewValidationError("Invalid password change request").WithDetails(err.Error())
|
||
c.Error(appErr)
|
||
return
|
||
}
|
||
|
||
// Get current user
|
||
user, err := h.userService.GetCurrentUser(ctx)
|
||
if err != nil {
|
||
logger.Errorf(ctx, "Failed to get current user: %v", err)
|
||
appErr := errors.NewUnauthorizedError("Failed to get user information").WithDetails(err.Error())
|
||
c.Error(appErr)
|
||
return
|
||
}
|
||
|
||
// Change password
|
||
err = h.userService.ChangePassword(ctx, user.ID, req.OldPassword, req.NewPassword)
|
||
if err != nil {
|
||
logger.Errorf(ctx, "Failed to change password: %v", err)
|
||
appErr := errors.NewBadRequestError("Password change failed").WithDetails(err.Error())
|
||
c.Error(appErr)
|
||
return
|
||
}
|
||
|
||
logger.Infof(ctx, "Password changed successfully for user: %s", user.Email)
|
||
c.JSON(http.StatusOK, gin.H{
|
||
"success": true,
|
||
"message": "Password changed successfully",
|
||
})
|
||
}
|
||
|
||
// GetAuthConfig godoc
|
||
// @Summary 获取认证配置
|
||
// @Description 返回当前部署的注册模式等公开认证配置,供前端决定是否展示注册入口
|
||
// @Tags 认证
|
||
// @Accept json
|
||
// @Produce json
|
||
// @Success 200 {object} map[string]interface{} "认证配置"
|
||
// @Router /auth/config [get]
|
||
//
|
||
// GetAuthConfig is intentionally a no-auth endpoint: the frontend reads
|
||
// it on app load to decide whether to show the Register tab. We expose
|
||
// only what the UI strictly needs (registration_mode); other config
|
||
// stays internal.
|
||
func (h *AuthHandler) GetAuthConfig(c *gin.Context) {
|
||
// Same source-of-truth as Register's gate, so the UI hide-the-button
|
||
// signal can never disagree with the API enforcement signal.
|
||
mode := h.resolveRegistrationMode(c.Request.Context())
|
||
c.JSON(http.StatusOK, gin.H{
|
||
"success": true,
|
||
"registration_mode": mode,
|
||
})
|
||
}
|
||
|
||
// SwitchTenant godoc
|
||
// @Summary 切换激活租户
|
||
// @Description 为当前用户在目标租户重新签发访问令牌;要求该用户在目标租户存在 active 成员关系
|
||
// @Tags 认证
|
||
// @Accept json
|
||
// @Produce json
|
||
// @Param request body object{tenant_id=integer,refresh_token=string} true "切换请求"
|
||
// @Success 200 {object} types.LoginResponse
|
||
// @Failure 400 {object} errors.AppError "参数错误"
|
||
// @Failure 403 {object} errors.AppError "无该租户成员关系"
|
||
// @Security Bearer
|
||
// @Router /auth/switch-tenant [post]
|
||
//
|
||
// SwitchTenant is the v1 backend hook for the tenant-switcher UI added
|
||
// in PR 3. The current PR ships the endpoint so multi-tenant tests can
|
||
// exercise the membership flow end-to-end before the frontend lands.
|
||
func (h *AuthHandler) SwitchTenant(c *gin.Context) {
|
||
ctx := c.Request.Context()
|
||
|
||
var req struct {
|
||
TenantID uint64 `json:"tenant_id" binding:"required"`
|
||
RefreshToken string `json:"refresh_token"`
|
||
}
|
||
if err := c.ShouldBindJSON(&req); err != nil {
|
||
appErr := errors.NewValidationError("Invalid switch-tenant request").WithDetails(err.Error())
|
||
c.Error(appErr)
|
||
return
|
||
}
|
||
|
||
user, err := h.userService.GetCurrentUser(ctx)
|
||
if err != nil || user == nil {
|
||
appErr := errors.NewUnauthorizedError("not authenticated")
|
||
c.Error(appErr)
|
||
return
|
||
}
|
||
|
||
resp, err := h.userService.SwitchTenant(ctx, user, req.TenantID, req.RefreshToken)
|
||
if err != nil {
|
||
logger.Errorf(ctx, "SwitchTenant failed user=%s target=%d: %v", user.ID, req.TenantID, err)
|
||
appErr := errors.NewForbiddenError("switch tenant failed").WithDetails(err.Error())
|
||
c.Error(appErr)
|
||
return
|
||
}
|
||
|
||
c.JSON(http.StatusOK, resp)
|
||
}
|
||
|
||
// AutoSetup godoc
|
||
// @Summary 自动初始化(Lite 桌面版)
|
||
// @Description Lite 版专用:首次启动时自动创建默认用户和租户并返回令牌,后续启动直接签发令牌,免除手动注册/登录流程
|
||
// @Tags 认证
|
||
// @Accept json
|
||
// @Produce json
|
||
// @Success 200 {object} types.LoginResponse
|
||
// @Failure 403 {object} errors.AppError "非 Lite 版本"
|
||
// @Router /auth/auto-setup [post]
|
||
func (h *AuthHandler) AutoSetup(c *gin.Context) {
|
||
ctx := c.Request.Context()
|
||
|
||
if Edition != "lite" {
|
||
appErr := errors.NewForbiddenError("auto-setup is only available in lite edition")
|
||
c.Error(appErr)
|
||
return
|
||
}
|
||
|
||
const defaultEmail = "admin@weknora.local"
|
||
|
||
user, _ := h.userService.GetUserByEmail(ctx, defaultEmail)
|
||
if user == nil {
|
||
logger.Info(ctx, "Auto-setup: creating default user and tenant for lite edition")
|
||
|
||
randomBytes := make([]byte, 24)
|
||
if _, err := rand.Read(randomBytes); err != nil {
|
||
appErr := errors.NewInternalServerError("auto-setup failed: unable to generate credentials")
|
||
c.Error(appErr)
|
||
return
|
||
}
|
||
randomPassword := base64.RawURLEncoding.EncodeToString(randomBytes)
|
||
randomUsername := fmt.Sprintf("user_%s", base64.RawURLEncoding.EncodeToString(randomBytes[:6]))
|
||
|
||
_, err := h.userService.Register(ctx, &types.RegisterRequest{
|
||
Username: randomUsername,
|
||
Email: defaultEmail,
|
||
Password: randomPassword,
|
||
})
|
||
if err != nil {
|
||
logger.Errorf(ctx, "Auto-setup: failed to register default user: %v", err)
|
||
appErr := errors.NewInternalServerError("auto-setup failed").WithDetails(err.Error())
|
||
c.Error(appErr)
|
||
return
|
||
}
|
||
user, _ = h.userService.GetUserByEmail(ctx, defaultEmail)
|
||
if user == nil {
|
||
appErr := errors.NewInternalServerError("auto-setup failed: user not found after registration")
|
||
c.Error(appErr)
|
||
return
|
||
}
|
||
}
|
||
|
||
accessToken, refreshToken, err := h.userService.GenerateTokens(ctx, user)
|
||
if err != nil {
|
||
logger.Errorf(ctx, "Auto-setup: failed to generate tokens: %v", err)
|
||
appErr := errors.NewInternalServerError("auto-setup failed").WithDetails(err.Error())
|
||
c.Error(appErr)
|
||
return
|
||
}
|
||
|
||
tenant, _ := h.tenantService.GetTenantByID(ctx, user.TenantID)
|
||
|
||
logger.Info(ctx, "Auto-setup: completed successfully")
|
||
c.JSON(http.StatusOK, &types.LoginResponse{
|
||
Success: true,
|
||
Message: "Auto-setup successful",
|
||
User: user,
|
||
ActiveTenant: tenant,
|
||
Memberships: []types.Membership{{
|
||
TenantID: user.TenantID,
|
||
TenantName: tenantNameOrEmpty(tenant),
|
||
Role: types.TenantRoleOwner,
|
||
}},
|
||
Token: accessToken,
|
||
RefreshToken: refreshToken,
|
||
})
|
||
}
|
||
|
||
// tenantNameOrEmpty returns t.Name when t is non-nil, "" otherwise.
|
||
// Used by AutoSetup to populate Membership.TenantName without crashing
|
||
// if the tenant lookup failed.
|
||
func tenantNameOrEmpty(t *types.Tenant) string {
|
||
if t == nil {
|
||
return ""
|
||
}
|
||
return t.Name
|
||
}
|
||
|
||
// ValidateToken godoc
|
||
// @Summary 验证令牌
|
||
// @Description 验证访问令牌是否有效
|
||
// @Tags 认证
|
||
// @Accept json
|
||
// @Produce json
|
||
// @Success 200 {object} map[string]interface{} "令牌有效"
|
||
// @Failure 401 {object} errors.AppError "令牌无效"
|
||
// @Security Bearer
|
||
// @Router /auth/validate [get]
|
||
func (h *AuthHandler) ValidateToken(c *gin.Context) {
|
||
ctx := c.Request.Context()
|
||
|
||
logger.Info(ctx, "Start token validation")
|
||
|
||
// Extract token from Authorization header
|
||
authHeader := c.GetHeader("Authorization")
|
||
if authHeader == "" {
|
||
logger.Error(ctx, "Missing Authorization header")
|
||
appErr := errors.NewValidationError("Authorization header is required")
|
||
c.Error(appErr)
|
||
return
|
||
}
|
||
|
||
// Parse Bearer token
|
||
tokenParts := strings.Split(authHeader, " ")
|
||
if len(tokenParts) != 2 || tokenParts[0] != "Bearer" {
|
||
logger.Error(ctx, "Invalid Authorization header format")
|
||
appErr := errors.NewValidationError("Invalid Authorization header format")
|
||
c.Error(appErr)
|
||
return
|
||
}
|
||
|
||
token := tokenParts[1]
|
||
|
||
// Validate token
|
||
user, _, err := h.userService.ValidateToken(ctx, token)
|
||
if err != nil {
|
||
logger.Errorf(ctx, "Failed to validate token: %v", err)
|
||
appErr := errors.NewUnauthorizedError("Token validation failed").WithDetails(err.Error())
|
||
c.Error(appErr)
|
||
return
|
||
}
|
||
|
||
logger.Infof(ctx, "Token validated successfully for user: %s", user.Email)
|
||
c.JSON(http.StatusOK, gin.H{
|
||
"success": true,
|
||
"message": "Token is valid",
|
||
"user": user.ToUserInfo(),
|
||
})
|
||
}
|