package handler import ( "encoding/base64" "encoding/json" "net/http" "os" "strings" "github.com/gin-gonic/gin" "github.com/Tencent/WeKnora/internal/config" "github.com/Tencent/WeKnora/internal/errors" "github.com/Tencent/WeKnora/internal/logger" "github.com/Tencent/WeKnora/internal/types" "github.com/Tencent/WeKnora/internal/types/interfaces" secutils "github.com/Tencent/WeKnora/internal/utils" ) // AuthHandler implements HTTP request handlers for user authentication // Provides functionality for user registration, login, logout, and token management // through the REST API endpoints type AuthHandler struct { userService interfaces.UserService tenantService interfaces.TenantService configInfo *config.Config } // NewAuthHandler creates a new auth handler instance with the provided services // Parameters: // - userService: An implementation of the UserService interface for business logic // - tenantService: An implementation of the TenantService interface for tenant management // // Returns a pointer to the newly created AuthHandler func NewAuthHandler(configInfo *config.Config, userService interfaces.UserService, tenantService interfaces.TenantService) *AuthHandler { return &AuthHandler{ configInfo: configInfo, userService: userService, tenantService: tenantService, } } // Register godoc // @Summary 用户注册 // @Description 注册新用户账号 // @Tags 认证 // @Accept json // @Produce json // @Param request body types.RegisterRequest true "注册请求参数" // @Success 201 {object} types.RegisterResponse // @Failure 400 {object} errors.AppError "请求参数错误" // @Failure 403 {object} errors.AppError "注册功能已禁用" // @Router /auth/register [post] func (h *AuthHandler) Register(c *gin.Context) { ctx := c.Request.Context() logger.Info(ctx, "Start user registration") // 通过环境变量 DISABLE_REGISTRATION=true 禁止注册 if os.Getenv("DISABLE_REGISTRATION") == "true" { logger.Warn(ctx, "Registration is disabled by DISABLE_REGISTRATION env") appErr := errors.NewForbiddenError("Registration is disabled") c.Error(appErr) return } var req types.RegisterRequest if err := c.ShouldBindJSON(&req); err != nil { logger.Error(ctx, "Failed to parse registration request parameters", err) appErr := errors.NewValidationError("Invalid registration parameters").WithDetails(err.Error()) c.Error(appErr) return } req.Username = secutils.SanitizeForLog(req.Username) req.Email = secutils.SanitizeForLog(req.Email) req.Password = secutils.SanitizeForLog(req.Password) // Validate required fields if req.Username == "" || req.Email == "" || req.Password == "" { logger.Error(ctx, "Missing required registration fields") appErr := errors.NewValidationError("Username, email and password are required") c.Error(appErr) return } req.Username = secutils.SanitizeForLog(req.Username) req.Email = secutils.SanitizeForLog(req.Email) // Call service to register user user, err := h.userService.Register(ctx, &req) if err != nil { logger.Errorf(ctx, "Failed to register user: %v", err) appErr := errors.NewBadRequestError(err.Error()) c.Error(appErr) return } // Return success response response := &types.RegisterResponse{ Success: true, Message: "Registration successful", User: user, } logger.Infof(ctx, "User registered successfully: %s", secutils.SanitizeForLog(user.Email)) c.JSON(http.StatusCreated, response) } // Login godoc // @Summary 用户登录 // @Description 用户登录并获取访问令牌 // @Tags 认证 // @Accept json // @Produce json // @Param request body types.LoginRequest true "登录请求参数" // @Success 200 {object} types.LoginResponse // @Failure 401 {object} errors.AppError "认证失败" // @Router /auth/login [post] func (h *AuthHandler) Login(c *gin.Context) { ctx := c.Request.Context() logger.Info(ctx, "Start user login") var req types.LoginRequest if err := c.ShouldBindJSON(&req); err != nil { logger.Error(ctx, "Failed to parse login request parameters", err) appErr := errors.NewValidationError("Invalid login parameters").WithDetails(err.Error()) c.Error(appErr) return } email := secutils.SanitizeForLog(req.Email) // Validate required fields if req.Email == "" || req.Password == "" { logger.Error(ctx, "Missing required login fields") appErr := errors.NewValidationError("Email and password are required") c.Error(appErr) return } // Call service to authenticate user response, err := h.userService.Login(ctx, &req) if err != nil { logger.Errorf(ctx, "Failed to login user: %v", err) appErr := errors.NewUnauthorizedError("Login failed").WithDetails(err.Error()) c.Error(appErr) return } // Check if login was successful if !response.Success { logger.Warnf(ctx, "Login failed: %s", response.Message) c.JSON(http.StatusUnauthorized, response) return } // User is already in the correct format from service logger.Infof(ctx, "User logged in successfully, email: %s", email) c.JSON(http.StatusOK, response) } // GetOIDCAuthorizationURL godoc // @Summary 获取OIDC授权地址 // @Description 根据后端OIDC配置生成第三方登录跳转地址 // @Tags 认证 // @Accept json // @Produce json // @Param redirect_uri query string true "OIDC回调地址" // @Success 200 {object} types.OIDCAuthURLResponse // @Failure 400 {object} errors.AppError "请求参数错误" // @Failure 403 {object} errors.AppError "OIDC未启用" // @Router /auth/oidc/url [get] func (h *AuthHandler) GetOIDCAuthorizationURL(c *gin.Context) { ctx := c.Request.Context() redirectURI := strings.TrimSpace(c.Query("redirect_uri")) if redirectURI == "" { appErr := errors.NewValidationError("redirect_uri is required") c.Error(appErr) return } resp, err := h.userService.GetOIDCAuthorizationURL(ctx, redirectURI) if err != nil { logger.Errorf(ctx, "Failed to generate OIDC authorization URL: %v", err) appErr := errors.NewForbiddenError("OIDC authorization unavailable").WithDetails(err.Error()) c.Error(appErr) return } c.JSON(http.StatusOK, resp) } // GetOIDCConfig godoc // @Summary 获取OIDC登录配置 // @Description 返回OIDC是否启用以及provider展示名称,供前端决定是否展示OIDC登录入口 // @Tags 认证 // @Accept json // @Produce json // @Success 200 {object} types.OIDCConfigResponse // @Router /auth/oidc/config [get] func (h *AuthHandler) GetOIDCConfig(c *gin.Context) { providerDisplayName := "" enabled := false if h.configInfo != nil && h.configInfo.OIDCAuth != nil { enabled = h.configInfo.OIDCAuth.Enable providerDisplayName = strings.TrimSpace(h.configInfo.OIDCAuth.ProviderDisplayName) } c.JSON(http.StatusOK, &types.OIDCConfigResponse{ Success: true, Enabled: enabled, ProviderDisplayName: providerDisplayName, }) } // OIDCRedirectCallback godoc // @Summary OIDC登录重定向回调 // @Description 接收OIDC provider回调并由后端完成code交换,随后重定向回前端登录页 // @Tags 认证 // @Accept json // @Produce json // @Param code query string false "OIDC授权码" // @Param state query string false "OIDC状态" // @Param error query string false "OIDC错误码" // @Success 302 // @Router /auth/oidc/callback [get] func (h *AuthHandler) OIDCRedirectCallback(c *gin.Context) { ctx := c.Request.Context() frontendRedirectURI := "/" if providerError := strings.TrimSpace(c.Query("error")); providerError != "" { redirectURL := frontendRedirectURI + "#oidc_error=" + urlQueryEscape(providerError) if description := strings.TrimSpace(c.Query("error_description")); description != "" { redirectURL += "&oidc_error_description=" + urlQueryEscape(description) } c.Redirect(http.StatusFound, redirectURL) return } state := strings.TrimSpace(c.Query("state")) decodedState, err := decodeOIDCState(state) if err != nil { logger.Errorf(ctx, "Failed to decode OIDC state: %v", err) c.Redirect(http.StatusFound, frontendRedirectURI+"#oidc_error="+urlQueryEscape("invalid_state")) return } code := strings.TrimSpace(c.Query("code")) if code == "" { c.Redirect(http.StatusFound, frontendRedirectURI+"#oidc_error="+urlQueryEscape("missing_code")) return } resp, err := h.userService.LoginWithOIDC(ctx, code, strings.TrimSpace(decodedState.RedirectURI)) if err != nil { logger.Errorf(ctx, "Failed to complete OIDC login via redirect callback: %v", err) c.Redirect(http.StatusFound, frontendRedirectURI+"#oidc_error="+urlQueryEscape("login_failed")+"&oidc_error_description="+urlQueryEscape(err.Error())) return } if !resp.Success { c.Redirect(http.StatusFound, frontendRedirectURI+"#oidc_error="+urlQueryEscape("login_failed")+"&oidc_error_description="+urlQueryEscape(resp.Message)) return } payload, err := encodeOIDCCallbackPayload(resp) if err != nil { logger.Errorf(ctx, "Failed to encode OIDC callback payload: %v", err) c.Redirect(http.StatusFound, frontendRedirectURI+"#oidc_error="+urlQueryEscape("payload_encode_failed")) return } c.Redirect(http.StatusFound, frontendRedirectURI+"#oidc_result="+urlQueryEscape(payload)) } func encodeOIDCCallbackPayload(resp *types.OIDCCallbackResponse) (string, error) { payload, err := json.Marshal(resp) if err != nil { return "", err } return base64.RawURLEncoding.EncodeToString(payload), nil } type oidcStatePayload struct { Nonce string `json:"nonce"` RedirectURI string `json:"redirect_uri,omitempty"` } func decodeOIDCState(raw string) (*oidcStatePayload, error) { decoded, err := base64.RawURLEncoding.DecodeString(strings.TrimSpace(raw)) if err != nil { return nil, err } var payload oidcStatePayload if err := json.Unmarshal(decoded, &payload); err != nil { return nil, err } if strings.TrimSpace(payload.RedirectURI) == "" { return nil, errors.NewValidationError("state.redirect_uri is required") } return &payload, nil } func urlQueryEscape(value string) string { replacer := strings.NewReplacer( "%", "%25", " ", "%20", "#", "%23", "&", "%26", "+", "%2B", "=", "%3D", "?", "%3F", ) return replacer.Replace(value) } // Logout godoc // @Summary 用户登出 // @Description 撤销当前访问令牌并登出 // @Tags 认证 // @Accept json // @Produce json // @Success 200 {object} map[string]interface{} "登出成功" // @Failure 400 {object} errors.AppError "请求参数错误" // @Security Bearer // @Router /auth/logout [post] func (h *AuthHandler) Logout(c *gin.Context) { ctx := c.Request.Context() logger.Info(ctx, "Start user logout") // Extract token from Authorization header authHeader := c.GetHeader("Authorization") if authHeader == "" { logger.Error(ctx, "Missing Authorization header") appErr := errors.NewValidationError("Authorization header is required") c.Error(appErr) return } // Parse Bearer token tokenParts := strings.Split(authHeader, " ") if len(tokenParts) != 2 || tokenParts[0] != "Bearer" { logger.Error(ctx, "Invalid Authorization header format") appErr := errors.NewValidationError("Invalid Authorization header format") c.Error(appErr) return } token := tokenParts[1] // Revoke token err := h.userService.RevokeToken(ctx, token) if err != nil { logger.Errorf(ctx, "Failed to revoke token: %v", err) appErr := errors.NewInternalServerError("Logout failed").WithDetails(err.Error()) c.Error(appErr) return } logger.Info(ctx, "User logged out successfully") c.JSON(http.StatusOK, gin.H{ "success": true, "message": "Logout successful", }) } // RefreshToken godoc // @Summary 刷新令牌 // @Description 使用刷新令牌获取新的访问令牌 // @Tags 认证 // @Accept json // @Produce json // @Param request body object{refreshToken=string} true "刷新令牌" // @Success 200 {object} map[string]interface{} "新令牌" // @Failure 401 {object} errors.AppError "令牌无效" // @Router /auth/refresh [post] func (h *AuthHandler) RefreshToken(c *gin.Context) { ctx := c.Request.Context() logger.Info(ctx, "Start token refresh") var req struct { RefreshToken string `json:"refreshToken" binding:"required"` } if err := c.ShouldBindJSON(&req); err != nil { logger.Error(ctx, "Failed to parse refresh token request", err) appErr := errors.NewValidationError("Invalid refresh token request").WithDetails(err.Error()) c.Error(appErr) return } // Call service to refresh token accessToken, newRefreshToken, err := h.userService.RefreshToken(ctx, req.RefreshToken) if err != nil { logger.Errorf(ctx, "Failed to refresh token: %v", err) appErr := errors.NewUnauthorizedError("Token refresh failed").WithDetails(err.Error()) c.Error(appErr) return } logger.Info(ctx, "Token refreshed successfully") c.JSON(http.StatusOK, gin.H{ "success": true, "message": "Token refreshed successfully", "access_token": accessToken, "refresh_token": newRefreshToken, }) } // GetCurrentUser godoc // @Summary 获取当前用户信息 // @Description 获取当前登录用户的详细信息 // @Tags 认证 // @Accept json // @Produce json // @Success 200 {object} map[string]interface{} "用户信息" // @Failure 401 {object} errors.AppError "未授权" // @Security Bearer // @Router /auth/me [get] func (h *AuthHandler) GetCurrentUser(c *gin.Context) { ctx := c.Request.Context() // Get current user from service (which extracts from context) user, err := h.userService.GetCurrentUser(ctx) if err != nil { logger.Errorf(ctx, "Failed to get current user: %v", err) appErr := errors.NewUnauthorizedError("Failed to get user information").WithDetails(err.Error()) c.Error(appErr) return } // Get tenant information var tenant *types.Tenant if user.TenantID > 0 { tenant, err = h.tenantService.GetTenantByID(ctx, user.TenantID) if err != nil { logger.Warnf(ctx, "Failed to get tenant info for user %s, tenant ID %d: %v", user.Email, user.TenantID, err) // Don't fail the request if tenant info is not available } } userInfo := user.ToUserInfo() userInfo.CanAccessAllTenants = user.CanAccessAllTenants && h.configInfo.Tenant.EnableCrossTenantAccess c.JSON(http.StatusOK, gin.H{ "success": true, "data": gin.H{ "user": userInfo, "tenant": tenant, }, }) } // ChangePassword godoc // @Summary 修改密码 // @Description 修改当前用户的登录密码 // @Tags 认证 // @Accept json // @Produce json // @Param request body object{old_password=string,new_password=string} true "密码修改请求" // @Success 200 {object} map[string]interface{} "修改成功" // @Failure 400 {object} errors.AppError "请求参数错误" // @Security Bearer // @Router /auth/change-password [post] func (h *AuthHandler) ChangePassword(c *gin.Context) { ctx := c.Request.Context() logger.Info(ctx, "Start password change") var req struct { OldPassword string `json:"old_password" binding:"required"` NewPassword string `json:"new_password" binding:"required,min=6"` } if err := c.ShouldBindJSON(&req); err != nil { logger.Error(ctx, "Failed to parse password change request", err) appErr := errors.NewValidationError("Invalid password change request").WithDetails(err.Error()) c.Error(appErr) return } // Get current user user, err := h.userService.GetCurrentUser(ctx) if err != nil { logger.Errorf(ctx, "Failed to get current user: %v", err) appErr := errors.NewUnauthorizedError("Failed to get user information").WithDetails(err.Error()) c.Error(appErr) return } // Change password err = h.userService.ChangePassword(ctx, user.ID, req.OldPassword, req.NewPassword) if err != nil { logger.Errorf(ctx, "Failed to change password: %v", err) appErr := errors.NewBadRequestError("Password change failed").WithDetails(err.Error()) c.Error(appErr) return } logger.Infof(ctx, "Password changed successfully for user: %s", user.Email) c.JSON(http.StatusOK, gin.H{ "success": true, "message": "Password changed successfully", }) } // ValidateToken godoc // @Summary 验证令牌 // @Description 验证访问令牌是否有效 // @Tags 认证 // @Accept json // @Produce json // @Success 200 {object} map[string]interface{} "令牌有效" // @Failure 401 {object} errors.AppError "令牌无效" // @Security Bearer // @Router /auth/validate [get] func (h *AuthHandler) ValidateToken(c *gin.Context) { ctx := c.Request.Context() logger.Info(ctx, "Start token validation") // Extract token from Authorization header authHeader := c.GetHeader("Authorization") if authHeader == "" { logger.Error(ctx, "Missing Authorization header") appErr := errors.NewValidationError("Authorization header is required") c.Error(appErr) return } // Parse Bearer token tokenParts := strings.Split(authHeader, " ") if len(tokenParts) != 2 || tokenParts[0] != "Bearer" { logger.Error(ctx, "Invalid Authorization header format") appErr := errors.NewValidationError("Invalid Authorization header format") c.Error(appErr) return } token := tokenParts[1] // Validate token user, err := h.userService.ValidateToken(ctx, token) if err != nil { logger.Errorf(ctx, "Failed to validate token: %v", err) appErr := errors.NewUnauthorizedError("Token validation failed").WithDetails(err.Error()) c.Error(appErr) return } logger.Infof(ctx, "Token validated successfully for user: %s", user.Email) c.JSON(http.StatusOK, gin.H{ "success": true, "message": "Token is valid", "user": user.ToUserInfo(), }) }