Files
WeKnora/internal/handler/auth.go
lyingbug 3611a8607d feat(rbac): add multi-use share-link invitations for invite_only mode (#1504)
* feat(rbac): add multi-use share-link invitations for invite_only mode

When the auth.registration_mode toggle flips to invite_only, the public
/auth/register endpoint returns 403 — but until now there was no
channel for unregistered users to actually join a tenant. The existing
invitation table required a registered invitee, so new emails could
not enter the system at all.

This change introduces a share-link model:

- Owner generates a multi-use registration link with a per-tenant role
- Recipient opens the link, registers with their own email, and is
  added to the tenant on the same request
- Token is stored in plaintext (not hashed) so the management UI can
  re-display the URL on demand without a "copy now or revoke" trap.
  Threat model is bounded by 7-day TTL, revocability, and the fact
  that all the link grants is membership in one tenant.
- accepted_count tracks how many users joined through the link so the
  Owner can tell whether a link is fresh or has already been spread

Per-user invitations (registered email -> in-app inbox accept) are
unchanged. Share-link rows use empty invitee_user_id as discriminator;
the partial unique index on (tenant_id, invitee_user_id) was relaxed
to skip empty values so multiple share links can coexist per tenant.

Frontend reuses Login.vue for /register?token=xxx instead of a
parallel page, and adds icon-only revoke / remove actions plus a
distinct "active" status tag for share-link rows.

Migration 000054 adds the token column, the accepted_count column,
and the relaxed pending unique index in a single step.

* feat(auth): update invitation token handling and add rate limiting

- Refactor the invitation token retrieval to use POST instead of GET, enhancing security by preventing the plaintext token from appearing in logs and browser history.
- Update the API endpoint from `/auth/invitations/:token` to `/auth/invitations/lookup` to reflect the new method.
- Introduce a rate limiter for unauthenticated share-link endpoints to mitigate brute-force attacks and abuse, ensuring a maximum of 30 requests per minute per IP.
- Adjust the registration flow to maintain user experience while securing the invitation process.
2026-05-27 19:42:01 +08:00

855 lines
30 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
package handler
import (
"context"
"crypto/rand"
"encoding/base64"
"encoding/json"
"fmt"
"net/http"
"strings"
"github.com/gin-gonic/gin"
"github.com/Tencent/WeKnora/internal/config"
"github.com/Tencent/WeKnora/internal/errors"
"github.com/Tencent/WeKnora/internal/logger"
"github.com/Tencent/WeKnora/internal/types"
"github.com/Tencent/WeKnora/internal/types/interfaces"
secutils "github.com/Tencent/WeKnora/internal/utils"
)
// AuthHandler implements HTTP request handlers for user authentication
// Provides functionality for user registration, login, logout, and token management
// through the REST API endpoints
type AuthHandler struct {
userService interfaces.UserService
tenantService interfaces.TenantService
configInfo *config.Config
systemSettingSvc interfaces.SystemSettingService
// invitationSvc is required for the share-link registration path
// (POST /auth/register-by-invite). When nil — e.g. legacy test
// fixtures — the share-link endpoints respond 503 rather than
// blocking the rest of the auth surface.
invitationSvc interfaces.TenantInvitationService
}
// NewAuthHandler creates a new auth handler instance with the provided services
// Parameters:
// - userService: An implementation of the UserService interface for business logic
// - tenantService: An implementation of the TenantService interface for tenant management
// - systemSettingSvc: 3-tier resolver for runtime-tunable settings such as
// auth.registration_mode (P3). When DB has a row, it overrides cfg's
// startup value; otherwise we fall back to cfg.Auth.RegistrationMode
// (which already accounted for the legacy DISABLE_REGISTRATION env coerce
// during config load). Mismatch impossible by construction since the
// handler always passes cfg's value as the def parameter to GetString.
//
// Returns a pointer to the newly created AuthHandler
func NewAuthHandler(configInfo *config.Config,
userService interfaces.UserService, tenantService interfaces.TenantService,
systemSettingSvc interfaces.SystemSettingService,
invitationSvc interfaces.TenantInvitationService,
) *AuthHandler {
// Boot-time guard: a nil-or-empty Auth section silently disables the
// invite_only gate (see Register below). Emit a loud one-shot log
// pointing at the misconfiguration so operators notice on startup
// instead of discovering it the day someone hits /auth/register.
if configInfo == nil || configInfo.Auth == nil {
logger.Errorf(context.Background(),
"[auth] AuthHandler constructed with nil/incomplete config (cfg=%v); "+
"registration_mode enforcement is disabled. This is almost certainly a wiring bug.",
configInfo)
}
return &AuthHandler{
configInfo: configInfo,
userService: userService,
tenantService: tenantService,
systemSettingSvc: systemSettingSvc,
invitationSvc: invitationSvc,
}
}
// resolveRegistrationMode returns the currently active registration mode.
// Priority: DB system_settings > cfg (which already absorbed the legacy
// DISABLE_REGISTRATION env coerce at startup) > "self_serve" hard default.
//
// Centralised here so /auth/register and /auth/config stay in lock-step —
// otherwise a SystemAdmin's UI edit could affect one path and not the other.
func (h *AuthHandler) resolveRegistrationMode(ctx context.Context) string {
// cfg-derived default: empty is impossible after applyAuthAndTenantDefaults,
// but be defensive in case AuthHandler was constructed before that ran
// (the NewAuthHandler guard already logged in that case).
def := config.AuthRegistrationModeSelfServe
if h.configInfo != nil && h.configInfo.Auth != nil {
if m := strings.TrimSpace(h.configInfo.Auth.RegistrationMode); m != "" {
def = m
}
}
if h.systemSettingSvc == nil {
return def
}
// envName = "" because DISABLE_REGISTRATION is a boolean and
// auth.registration_mode is a string — the legacy env was already
// coerced into `def` above. Mixing the two semantics at the resolver
// layer would mean a UI delete (DB row absent) silently flipped to
// the legacy boolean read again, which is surprising.
return h.systemSettingSvc.GetString(ctx, "auth.registration_mode", "", def)
}
// Register godoc
// @Summary 用户注册
// @Description 注册新用户账号
// @Tags 认证
// @Accept json
// @Produce json
// @Param request body types.RegisterRequest true "注册请求参数"
// @Success 201 {object} types.RegisterResponse
// @Failure 400 {object} errors.AppError "请求参数错误"
// @Failure 403 {object} errors.AppError "注册功能已禁用"
// @Router /auth/register [post]
func (h *AuthHandler) Register(c *gin.Context) {
ctx := c.Request.Context()
logger.Info(ctx, "Start user registration")
// 当 auth.registration_mode=invite_only 时public 注册被关闭。
// 优先级DB system_settings > cfg.Auth.RegistrationMode > "self_serve"。
// SystemAdmin 通过「全局设置」UI 实时切换 self_serve / invite_only立即
// 生效,不需要重启服务。历史变量 DISABLE_REGISTRATION=true 仍在 config
// 启动阶段被等价提升为 invite_onlyapplyAuthAndTenantDefaults
// 作为 cfg-default 进入 resolveRegistrationMode。
if h.resolveRegistrationMode(ctx) == config.AuthRegistrationModeInviteOnly {
logger.Warn(ctx, "Registration rejected: auth.registration_mode=invite_only")
appErr := errors.NewForbiddenError("Registration is invite-only")
c.Error(appErr)
return
}
var req types.RegisterRequest
if err := c.ShouldBindJSON(&req); err != nil {
logger.Error(ctx, "Failed to parse registration request parameters", err)
appErr := errors.NewValidationError("Invalid registration parameters").WithDetails(err.Error())
c.Error(appErr)
return
}
req.Username = secutils.SanitizeForLog(req.Username)
req.Email = secutils.SanitizeForLog(req.Email)
req.Password = secutils.SanitizeForLog(req.Password)
// Validate required fields
if req.Username == "" || req.Email == "" || req.Password == "" {
logger.Error(ctx, "Missing required registration fields")
appErr := errors.NewValidationError("Username, email and password are required")
c.Error(appErr)
return
}
req.Username = secutils.SanitizeForLog(req.Username)
req.Email = secutils.SanitizeForLog(req.Email)
// Call service to register user
user, err := h.userService.Register(ctx, &req)
if err != nil {
logger.Errorf(ctx, "Failed to register user: %v", err)
appErr := errors.NewBadRequestError(err.Error())
c.Error(appErr)
return
}
// Return success response
response := &types.RegisterResponse{
Success: true,
Message: "Registration successful",
User: user,
}
logger.Infof(ctx, "User registered successfully: %s", secutils.SanitizeForLog(user.Email))
c.JSON(http.StatusCreated, response)
}
// Login godoc
// @Summary 用户登录
// @Description 用户登录并获取访问令牌
// @Tags 认证
// @Accept json
// @Produce json
// @Param request body types.LoginRequest true "登录请求参数"
// @Success 200 {object} types.LoginResponse
// @Failure 401 {object} errors.AppError "认证失败"
// @Router /auth/login [post]
func (h *AuthHandler) Login(c *gin.Context) {
ctx := c.Request.Context()
logger.Info(ctx, "Start user login")
var req types.LoginRequest
if err := c.ShouldBindJSON(&req); err != nil {
logger.Error(ctx, "Failed to parse login request parameters", err)
appErr := errors.NewValidationError("Invalid login parameters").WithDetails(err.Error())
c.Error(appErr)
return
}
email := secutils.SanitizeForLog(req.Email)
// Validate required fields
if req.Email == "" || req.Password == "" {
logger.Error(ctx, "Missing required login fields")
appErr := errors.NewValidationError("Email and password are required")
c.Error(appErr)
return
}
// Call service to authenticate user
response, err := h.userService.Login(ctx, &req)
if err != nil {
logger.Errorf(ctx, "Failed to login user: %v", err)
appErr := errors.NewUnauthorizedError("Login failed").WithDetails(err.Error())
c.Error(appErr)
return
}
// Check if login was successful
if !response.Success {
logger.Warnf(ctx, "Login failed: %s", response.Message)
c.JSON(http.StatusUnauthorized, response)
return
}
// User is already in the correct format from service
logger.Infof(ctx, "User logged in successfully, email: %s", email)
c.JSON(http.StatusOK, response)
}
// GetOIDCAuthorizationURL godoc
// @Summary 获取OIDC授权地址
// @Description 根据后端OIDC配置生成第三方登录跳转地址
// @Tags 认证
// @Accept json
// @Produce json
// @Param redirect_uri query string true "OIDC回调地址"
// @Success 200 {object} types.OIDCAuthURLResponse
// @Failure 400 {object} errors.AppError "请求参数错误"
// @Failure 403 {object} errors.AppError "OIDC未启用"
// @Router /auth/oidc/url [get]
func (h *AuthHandler) GetOIDCAuthorizationURL(c *gin.Context) {
ctx := c.Request.Context()
redirectURI := strings.TrimSpace(c.Query("redirect_uri"))
if redirectURI == "" {
appErr := errors.NewValidationError("redirect_uri is required")
c.Error(appErr)
return
}
resp, err := h.userService.GetOIDCAuthorizationURL(ctx, redirectURI)
if err != nil {
logger.Errorf(ctx, "Failed to generate OIDC authorization URL: %v", err)
appErr := errors.NewForbiddenError("OIDC authorization unavailable").WithDetails(err.Error())
c.Error(appErr)
return
}
c.JSON(http.StatusOK, resp)
}
// GetOIDCConfig godoc
// @Summary 获取OIDC登录配置
// @Description 返回OIDC是否启用以及provider展示名称供前端决定是否展示OIDC登录入口
// @Tags 认证
// @Accept json
// @Produce json
// @Success 200 {object} types.OIDCConfigResponse
// @Router /auth/oidc/config [get]
func (h *AuthHandler) GetOIDCConfig(c *gin.Context) {
providerDisplayName := ""
enabled := false
if h.configInfo != nil && h.configInfo.OIDCAuth != nil {
enabled = h.configInfo.OIDCAuth.Enable
providerDisplayName = strings.TrimSpace(h.configInfo.OIDCAuth.ProviderDisplayName)
}
c.JSON(http.StatusOK, &types.OIDCConfigResponse{
Success: true,
Enabled: enabled,
ProviderDisplayName: providerDisplayName,
})
}
// OIDCRedirectCallback godoc
// @Summary OIDC登录重定向回调
// @Description 接收OIDC provider回调并由后端完成code交换随后重定向回前端登录页
// @Tags 认证
// @Accept json
// @Produce json
// @Param code query string false "OIDC授权码"
// @Param state query string false "OIDC状态"
// @Param error query string false "OIDC错误码"
// @Success 302
// @Router /auth/oidc/callback [get]
func (h *AuthHandler) OIDCRedirectCallback(c *gin.Context) {
ctx := c.Request.Context()
frontendRedirectURI := "/"
if providerError := strings.TrimSpace(c.Query("error")); providerError != "" {
redirectURL := frontendRedirectURI + "#oidc_error=" + urlQueryEscape(providerError)
if description := strings.TrimSpace(c.Query("error_description")); description != "" {
redirectURL += "&oidc_error_description=" + urlQueryEscape(description)
}
c.Redirect(http.StatusFound, redirectURL)
return
}
state := strings.TrimSpace(c.Query("state"))
decodedState, err := decodeOIDCState(state)
if err != nil {
logger.Errorf(ctx, "Failed to decode OIDC state: %v", err)
c.Redirect(http.StatusFound, frontendRedirectURI+"#oidc_error="+urlQueryEscape("invalid_state"))
return
}
code := strings.TrimSpace(c.Query("code"))
if code == "" {
c.Redirect(http.StatusFound, frontendRedirectURI+"#oidc_error="+urlQueryEscape("missing_code"))
return
}
resp, err := h.userService.LoginWithOIDC(ctx, code, strings.TrimSpace(decodedState.RedirectURI))
if err != nil {
logger.Errorf(ctx, "Failed to complete OIDC login via redirect callback: %v", err)
c.Redirect(http.StatusFound, frontendRedirectURI+"#oidc_error="+urlQueryEscape("login_failed")+"&oidc_error_description="+urlQueryEscape(err.Error()))
return
}
if !resp.Success {
c.Redirect(http.StatusFound, frontendRedirectURI+"#oidc_error="+urlQueryEscape("login_failed")+"&oidc_error_description="+urlQueryEscape(resp.Message))
return
}
payload, err := encodeOIDCCallbackPayload(resp)
if err != nil {
logger.Errorf(ctx, "Failed to encode OIDC callback payload: %v", err)
c.Redirect(http.StatusFound, frontendRedirectURI+"#oidc_error="+urlQueryEscape("payload_encode_failed"))
return
}
c.Redirect(http.StatusFound, frontendRedirectURI+"#oidc_result="+urlQueryEscape(payload))
}
func encodeOIDCCallbackPayload(resp *types.OIDCCallbackResponse) (string, error) {
payload, err := json.Marshal(resp)
if err != nil {
return "", err
}
return base64.RawURLEncoding.EncodeToString(payload), nil
}
type oidcStatePayload struct {
Nonce string `json:"nonce"`
RedirectURI string `json:"redirect_uri,omitempty"`
}
func decodeOIDCState(raw string) (*oidcStatePayload, error) {
decoded, err := base64.RawURLEncoding.DecodeString(strings.TrimSpace(raw))
if err != nil {
return nil, err
}
var payload oidcStatePayload
if err := json.Unmarshal(decoded, &payload); err != nil {
return nil, err
}
if strings.TrimSpace(payload.RedirectURI) == "" {
return nil, errors.NewValidationError("state.redirect_uri is required")
}
return &payload, nil
}
func urlQueryEscape(value string) string {
replacer := strings.NewReplacer(
"%", "%25",
" ", "%20",
"#", "%23",
"&", "%26",
"+", "%2B",
"=", "%3D",
"?", "%3F",
)
return replacer.Replace(value)
}
// Logout godoc
// @Summary 用户登出
// @Description 撤销当前访问令牌并登出
// @Tags 认证
// @Accept json
// @Produce json
// @Success 200 {object} map[string]interface{} "登出成功"
// @Failure 400 {object} errors.AppError "请求参数错误"
// @Security Bearer
// @Router /auth/logout [post]
func (h *AuthHandler) Logout(c *gin.Context) {
ctx := c.Request.Context()
logger.Info(ctx, "Start user logout")
// Extract token from Authorization header
authHeader := c.GetHeader("Authorization")
if authHeader == "" {
logger.Error(ctx, "Missing Authorization header")
appErr := errors.NewValidationError("Authorization header is required")
c.Error(appErr)
return
}
// Parse Bearer token
tokenParts := strings.Split(authHeader, " ")
if len(tokenParts) != 2 || tokenParts[0] != "Bearer" {
logger.Error(ctx, "Invalid Authorization header format")
appErr := errors.NewValidationError("Invalid Authorization header format")
c.Error(appErr)
return
}
token := tokenParts[1]
// Revoke token
err := h.userService.RevokeToken(ctx, token)
if err != nil {
logger.Errorf(ctx, "Failed to revoke token: %v", err)
appErr := errors.NewInternalServerError("Logout failed").WithDetails(err.Error())
c.Error(appErr)
return
}
logger.Info(ctx, "User logged out successfully")
c.JSON(http.StatusOK, gin.H{
"success": true,
"message": "Logout successful",
})
}
// RefreshToken godoc
// @Summary 刷新令牌
// @Description 使用刷新令牌获取新的访问令牌
// @Tags 认证
// @Accept json
// @Produce json
// @Param request body object{refreshToken=string} true "刷新令牌"
// @Success 200 {object} map[string]interface{} "新令牌"
// @Failure 401 {object} errors.AppError "令牌无效"
// @Router /auth/refresh [post]
func (h *AuthHandler) RefreshToken(c *gin.Context) {
ctx := c.Request.Context()
logger.Info(ctx, "Start token refresh")
var req struct {
RefreshToken string `json:"refreshToken" binding:"required"`
}
if err := c.ShouldBindJSON(&req); err != nil {
logger.Error(ctx, "Failed to parse refresh token request", err)
appErr := errors.NewValidationError("Invalid refresh token request").WithDetails(err.Error())
c.Error(appErr)
return
}
// Call service to refresh token
accessToken, newRefreshToken, err := h.userService.RefreshToken(ctx, req.RefreshToken)
if err != nil {
logger.Errorf(ctx, "Failed to refresh token: %v", err)
appErr := errors.NewUnauthorizedError("Token refresh failed").WithDetails(err.Error())
c.Error(appErr)
return
}
logger.Info(ctx, "Token refreshed successfully")
c.JSON(http.StatusOK, gin.H{
"success": true,
"message": "Token refreshed successfully",
"access_token": accessToken,
"refresh_token": newRefreshToken,
})
}
// GetCurrentUser godoc
// @Summary 获取当前用户信息
// @Description 获取当前登录用户的详细信息
// @Tags 认证
// @Accept json
// @Produce json
// @Success 200 {object} map[string]interface{} "用户信息"
// @Failure 401 {object} errors.AppError "未授权"
// @Security Bearer
// @Router /auth/me [get]
func (h *AuthHandler) GetCurrentUser(c *gin.Context) {
ctx := c.Request.Context()
// Get current user from service (which extracts from context)
user, err := h.userService.GetCurrentUser(ctx)
if err != nil {
logger.Errorf(ctx, "Failed to get current user: %v", err)
appErr := errors.NewUnauthorizedError("Failed to get user information").WithDetails(err.Error())
c.Error(appErr)
return
}
// Get tenant information for the *active* tenant (the one the
// auth middleware resolved against the X-Tenant-ID header), not
// the user's home tenant. user.TenantID is the row stored on the
// users table at signup time and never changes; reading it here
// would make /auth/me always return the home tenant even after
// the user switched into a peer tenant. The frontend then re-keys
// `authStore.tenant.id` to the home tenant, and every UI gate
// computed against it (currentTenantRole, isOwner, ...) leaks
// the wrong role. Pull the active tenant id from context instead.
var tenant *types.Tenant
activeTenantID, _ := types.TenantIDFromContext(ctx)
if activeTenantID == 0 {
activeTenantID = user.TenantID
}
if activeTenantID > 0 {
tenant, err = h.tenantService.GetTenantByID(ctx, activeTenantID)
if err != nil {
logger.Warnf(ctx, "Failed to get tenant info for user %s, tenant ID %d: %v", user.Email, activeTenantID, err)
// Don't fail the request if tenant info is not available
}
}
userInfo := user.ToUserInfo()
userInfo.CanAccessAllTenants = user.CanAccessAllTenants && h.configInfo.Tenant.EnableCrossTenantAccess
// 同步返回当前用户的 memberships让前端在页面刷新仅命中 /auth/me
// 后也能恢复 currentTenantRole避免角色信息只在 login 那一刻可用。
memberships := h.userService.BuildLoginMemberships(ctx, user, tenant)
c.JSON(http.StatusOK, gin.H{
"success": true,
"data": gin.H{
"user": userInfo,
"tenant": tenant,
"memberships": memberships,
},
})
}
// updateMyPreferencesRequest is the body for PUT /auth/me/preferences.
// Fields are pointers so the handler can distinguish "key not present"
// (preserve existing value) from "explicit false". See
// types.UserPreferences for the persistence-layer counterpart.
type updateMyPreferencesRequest struct {
EnableMemory *bool `json:"enable_memory"`
// LastActiveTenantID lets the SPA persist "after a fresh login,
// drop me back into this workspace" across devices. Send a positive
// tenant id to set / replace, or 0 to clear. Membership is validated
// at next login, not here. Nil = field omitted from the PATCH and
// stays untouched.
LastActiveTenantID *uint64 `json:"last_active_tenant_id"`
}
// UpdateMyPreferences godoc
// @Summary 更新当前用户的个性化设置
// @Description 按 PATCH 语义合并用户偏好(仅覆盖请求体里出现的字段,其余字段保持不变),
// @Description 数据存放在 users.preferences (JSON),跨设备/浏览器自动同步。
// @Tags 认证
// @Accept json
// @Produce json
// @Param request body updateMyPreferencesRequest true "Preferences patch"
// @Success 200 {object} map[string]interface{} "更新后的偏好"
// @Failure 400 {object} errors.AppError "请求参数错误"
// @Failure 401 {object} errors.AppError "未授权"
// @Security Bearer
// @Router /auth/me/preferences [put]
func (h *AuthHandler) UpdateMyPreferences(c *gin.Context) {
ctx := c.Request.Context()
user, err := h.userService.GetCurrentUser(ctx)
if err != nil {
appErr := errors.NewUnauthorizedError("Failed to get user information").WithDetails(err.Error())
c.Error(appErr)
return
}
var req updateMyPreferencesRequest
if err := c.ShouldBindJSON(&req); err != nil {
appErr := errors.NewValidationError("Invalid preferences request").WithDetails(err.Error())
c.Error(appErr)
return
}
patch := types.UserPreferences{
EnableMemory: req.EnableMemory,
LastActiveTenantID: req.LastActiveTenantID,
}
prefs, err := h.userService.UpdateUserPreferences(ctx, user.ID, patch)
if err != nil {
logger.Errorf(ctx, "Failed to update preferences for user %s: %v", user.Email, err)
appErr := errors.NewBadRequestError("Failed to update preferences").WithDetails(err.Error())
c.Error(appErr)
return
}
c.JSON(http.StatusOK, gin.H{
"success": true,
"data": prefs,
})
}
// ChangePassword godoc
// @Summary 修改密码
// @Description 修改当前用户的登录密码
// @Tags 认证
// @Accept json
// @Produce json
// @Param request body object{old_password=string,new_password=string} true "密码修改请求"
// @Success 200 {object} map[string]interface{} "修改成功"
// @Failure 400 {object} errors.AppError "请求参数错误"
// @Security Bearer
// @Router /auth/change-password [post]
func (h *AuthHandler) ChangePassword(c *gin.Context) {
ctx := c.Request.Context()
logger.Info(ctx, "Start password change")
var req struct {
OldPassword string `json:"old_password" binding:"required"`
NewPassword string `json:"new_password" binding:"required,min=6"`
}
if err := c.ShouldBindJSON(&req); err != nil {
logger.Error(ctx, "Failed to parse password change request", err)
appErr := errors.NewValidationError("Invalid password change request").WithDetails(err.Error())
c.Error(appErr)
return
}
// Get current user
user, err := h.userService.GetCurrentUser(ctx)
if err != nil {
logger.Errorf(ctx, "Failed to get current user: %v", err)
appErr := errors.NewUnauthorizedError("Failed to get user information").WithDetails(err.Error())
c.Error(appErr)
return
}
// Change password
err = h.userService.ChangePassword(ctx, user.ID, req.OldPassword, req.NewPassword)
if err != nil {
logger.Errorf(ctx, "Failed to change password: %v", err)
appErr := errors.NewBadRequestError("Password change failed").WithDetails(err.Error())
c.Error(appErr)
return
}
logger.Infof(ctx, "Password changed successfully for user: %s", user.Email)
c.JSON(http.StatusOK, gin.H{
"success": true,
"message": "Password changed successfully",
})
}
// GetAuthConfig godoc
// @Summary 获取认证配置
// @Description 返回当前部署的注册模式等公开认证配置,供前端决定是否展示注册入口
// @Tags 认证
// @Accept json
// @Produce json
// @Success 200 {object} map[string]interface{} "认证配置"
// @Router /auth/config [get]
//
// GetAuthConfig is intentionally a no-auth endpoint: the frontend reads
// it on app load to decide whether to show the Register tab. We expose
// only what the UI strictly needs (registration_mode); other config
// stays internal.
func (h *AuthHandler) GetAuthConfig(c *gin.Context) {
// Same source-of-truth as Register's gate, so the UI hide-the-button
// signal can never disagree with the API enforcement signal.
mode := h.resolveRegistrationMode(c.Request.Context())
c.JSON(http.StatusOK, gin.H{
"success": true,
"registration_mode": mode,
})
}
// SwitchTenant godoc
// @Summary 切换激活租户
// @Description 为当前用户在目标租户重新签发访问令牌;要求该用户在目标租户存在 active 成员关系
// @Tags 认证
// @Accept json
// @Produce json
// @Param request body object{tenant_id=integer,refresh_token=string} true "切换请求"
// @Success 200 {object} types.LoginResponse
// @Failure 400 {object} errors.AppError "参数错误"
// @Failure 403 {object} errors.AppError "无该租户成员关系"
// @Security Bearer
// @Router /auth/switch-tenant [post]
//
// SwitchTenant is the v1 backend hook for the tenant-switcher UI added
// in PR 3. The current PR ships the endpoint so multi-tenant tests can
// exercise the membership flow end-to-end before the frontend lands.
func (h *AuthHandler) SwitchTenant(c *gin.Context) {
ctx := c.Request.Context()
var req struct {
TenantID uint64 `json:"tenant_id" binding:"required"`
RefreshToken string `json:"refresh_token"`
}
if err := c.ShouldBindJSON(&req); err != nil {
appErr := errors.NewValidationError("Invalid switch-tenant request").WithDetails(err.Error())
c.Error(appErr)
return
}
user, err := h.userService.GetCurrentUser(ctx)
if err != nil || user == nil {
appErr := errors.NewUnauthorizedError("not authenticated")
c.Error(appErr)
return
}
resp, err := h.userService.SwitchTenant(ctx, user, req.TenantID, req.RefreshToken)
if err != nil {
logger.Errorf(ctx, "SwitchTenant failed user=%s target=%d: %v", user.ID, req.TenantID, err)
appErr := errors.NewForbiddenError("switch tenant failed").WithDetails(err.Error())
c.Error(appErr)
return
}
c.JSON(http.StatusOK, resp)
}
// AutoSetup godoc
// @Summary 自动初始化Lite 桌面版)
// @Description Lite 版专用:首次启动时自动创建默认用户和租户并返回令牌,后续启动直接签发令牌,免除手动注册/登录流程
// @Tags 认证
// @Accept json
// @Produce json
// @Success 200 {object} types.LoginResponse
// @Failure 403 {object} errors.AppError "非 Lite 版本"
// @Router /auth/auto-setup [post]
func (h *AuthHandler) AutoSetup(c *gin.Context) {
ctx := c.Request.Context()
if Edition != "lite" {
appErr := errors.NewForbiddenError("auto-setup is only available in lite edition")
c.Error(appErr)
return
}
const defaultEmail = "admin@weknora.local"
user, _ := h.userService.GetUserByEmail(ctx, defaultEmail)
if user == nil {
logger.Info(ctx, "Auto-setup: creating default user and tenant for lite edition")
randomBytes := make([]byte, 24)
if _, err := rand.Read(randomBytes); err != nil {
appErr := errors.NewInternalServerError("auto-setup failed: unable to generate credentials")
c.Error(appErr)
return
}
randomPassword := base64.RawURLEncoding.EncodeToString(randomBytes)
randomUsername := fmt.Sprintf("user_%s", base64.RawURLEncoding.EncodeToString(randomBytes[:6]))
_, err := h.userService.Register(ctx, &types.RegisterRequest{
Username: randomUsername,
Email: defaultEmail,
Password: randomPassword,
})
if err != nil {
logger.Errorf(ctx, "Auto-setup: failed to register default user: %v", err)
appErr := errors.NewInternalServerError("auto-setup failed").WithDetails(err.Error())
c.Error(appErr)
return
}
user, _ = h.userService.GetUserByEmail(ctx, defaultEmail)
if user == nil {
appErr := errors.NewInternalServerError("auto-setup failed: user not found after registration")
c.Error(appErr)
return
}
}
accessToken, refreshToken, err := h.userService.GenerateTokens(ctx, user)
if err != nil {
logger.Errorf(ctx, "Auto-setup: failed to generate tokens: %v", err)
appErr := errors.NewInternalServerError("auto-setup failed").WithDetails(err.Error())
c.Error(appErr)
return
}
tenant, _ := h.tenantService.GetTenantByID(ctx, user.TenantID)
logger.Info(ctx, "Auto-setup: completed successfully")
c.JSON(http.StatusOK, &types.LoginResponse{
Success: true,
Message: "Auto-setup successful",
User: user,
ActiveTenant: tenant,
Memberships: []types.Membership{{
TenantID: user.TenantID,
TenantName: tenantNameOrEmpty(tenant),
Role: types.TenantRoleOwner,
}},
Token: accessToken,
RefreshToken: refreshToken,
})
}
// tenantNameOrEmpty returns t.Name when t is non-nil, "" otherwise.
// Used by AutoSetup to populate Membership.TenantName without crashing
// if the tenant lookup failed.
func tenantNameOrEmpty(t *types.Tenant) string {
if t == nil {
return ""
}
return t.Name
}
// ValidateToken godoc
// @Summary 验证令牌
// @Description 验证访问令牌是否有效
// @Tags 认证
// @Accept json
// @Produce json
// @Success 200 {object} map[string]interface{} "令牌有效"
// @Failure 401 {object} errors.AppError "令牌无效"
// @Security Bearer
// @Router /auth/validate [get]
func (h *AuthHandler) ValidateToken(c *gin.Context) {
ctx := c.Request.Context()
logger.Info(ctx, "Start token validation")
// Extract token from Authorization header
authHeader := c.GetHeader("Authorization")
if authHeader == "" {
logger.Error(ctx, "Missing Authorization header")
appErr := errors.NewValidationError("Authorization header is required")
c.Error(appErr)
return
}
// Parse Bearer token
tokenParts := strings.Split(authHeader, " ")
if len(tokenParts) != 2 || tokenParts[0] != "Bearer" {
logger.Error(ctx, "Invalid Authorization header format")
appErr := errors.NewValidationError("Invalid Authorization header format")
c.Error(appErr)
return
}
token := tokenParts[1]
// Validate token
user, _, err := h.userService.ValidateToken(ctx, token)
if err != nil {
logger.Errorf(ctx, "Failed to validate token: %v", err)
appErr := errors.NewUnauthorizedError("Token validation failed").WithDetails(err.Error())
c.Error(appErr)
return
}
logger.Infof(ctx, "Token validated successfully for user: %s", user.Email)
c.JSON(http.StatusOK, gin.H{
"success": true,
"message": "Token is valid",
"user": user.ToUserInfo(),
})
}