mirror of
https://github.com/Tencent/WeKnora.git
synced 2026-06-04 13:30:32 +08:00
859 lines
25 KiB
Go
859 lines
25 KiB
Go
package service
|
|
|
|
import (
|
|
"context"
|
|
"crypto/rand"
|
|
"encoding/base64"
|
|
"encoding/json"
|
|
"errors"
|
|
"fmt"
|
|
"io"
|
|
"net/http"
|
|
"net/url"
|
|
"os"
|
|
"strings"
|
|
"sync"
|
|
"time"
|
|
|
|
"github.com/golang-jwt/jwt/v5"
|
|
"github.com/google/uuid"
|
|
"golang.org/x/crypto/bcrypt"
|
|
|
|
apprepo "github.com/Tencent/WeKnora/internal/application/repository"
|
|
"github.com/Tencent/WeKnora/internal/config"
|
|
"github.com/Tencent/WeKnora/internal/logger"
|
|
"github.com/Tencent/WeKnora/internal/types"
|
|
"github.com/Tencent/WeKnora/internal/types/interfaces"
|
|
secutils "github.com/Tencent/WeKnora/internal/utils"
|
|
)
|
|
|
|
type oidcAuthorizationState struct {
|
|
Nonce string `json:"nonce"`
|
|
RedirectURI string `json:"redirect_uri,omitempty"`
|
|
}
|
|
|
|
var (
|
|
jwtSecretOnce sync.Once
|
|
jwtSecret string
|
|
)
|
|
|
|
// getJwtSecret retrieves the JWT secret from the environment, falling back to a securely generated random secret.
|
|
func getJwtSecret() string {
|
|
jwtSecretOnce.Do(func() {
|
|
if envSecret := strings.TrimSpace(os.Getenv("JWT_SECRET")); envSecret != "" {
|
|
jwtSecret = envSecret
|
|
return
|
|
}
|
|
|
|
randomBytes := make([]byte, 32)
|
|
if _, err := rand.Read(randomBytes); err != nil {
|
|
panic(fmt.Sprintf("failed to generate JWT secret: %v", err))
|
|
}
|
|
jwtSecret = base64.StdEncoding.EncodeToString(randomBytes)
|
|
})
|
|
|
|
return jwtSecret
|
|
}
|
|
|
|
// userService implements the UserService interface
|
|
type userService struct {
|
|
userRepo interfaces.UserRepository
|
|
tokenRepo interfaces.AuthTokenRepository
|
|
tenantService interfaces.TenantService
|
|
config *config.Config
|
|
}
|
|
|
|
// NewUserService creates a new user service instance
|
|
func NewUserService(
|
|
configInfo *config.Config,
|
|
userRepo interfaces.UserRepository,
|
|
tokenRepo interfaces.AuthTokenRepository,
|
|
tenantService interfaces.TenantService,
|
|
) interfaces.UserService {
|
|
return &userService{
|
|
userRepo: userRepo,
|
|
tokenRepo: tokenRepo,
|
|
tenantService: tenantService,
|
|
config: configInfo,
|
|
}
|
|
}
|
|
|
|
// Register creates a new user account
|
|
func (s *userService) Register(ctx context.Context, req *types.RegisterRequest) (*types.User, error) {
|
|
logger.Info(ctx, "Start user registration")
|
|
|
|
// Validate input
|
|
if req.Username == "" || req.Email == "" || req.Password == "" {
|
|
return nil, errors.New("username, email and password are required")
|
|
}
|
|
|
|
// Check if user already exists
|
|
existingUser, _ := s.userRepo.GetUserByEmail(ctx, req.Email)
|
|
if existingUser != nil {
|
|
return nil, errors.New("user with this email already exists")
|
|
}
|
|
|
|
existingUser, _ = s.userRepo.GetUserByUsername(ctx, req.Username)
|
|
if existingUser != nil {
|
|
return nil, errors.New("user with this username already exists")
|
|
}
|
|
|
|
// Hash password
|
|
hashedPassword, err := bcrypt.GenerateFromPassword([]byte(req.Password), bcrypt.DefaultCost)
|
|
if err != nil {
|
|
logger.Errorf(ctx, "Failed to hash password: %v", err)
|
|
return nil, errors.New("failed to process password")
|
|
}
|
|
|
|
// Create default tenant for the user
|
|
// Note: RetrieverEngines is left empty - system will use defaults from RETRIEVE_DRIVER env
|
|
tenant := &types.Tenant{
|
|
Name: fmt.Sprintf("%s's Workspace", secutils.SanitizeForLog(req.Username)),
|
|
Description: "Default workspace",
|
|
Status: "active",
|
|
}
|
|
|
|
createdTenant, err := s.tenantService.CreateTenant(ctx, tenant)
|
|
if err != nil {
|
|
logger.Errorf(ctx, "Failed to create tenant")
|
|
return nil, errors.New("failed to create workspace")
|
|
}
|
|
|
|
// Create user
|
|
user := &types.User{
|
|
ID: uuid.New().String(),
|
|
Username: req.Username,
|
|
Email: req.Email,
|
|
PasswordHash: string(hashedPassword),
|
|
TenantID: createdTenant.ID,
|
|
IsActive: true,
|
|
CreatedAt: time.Now(),
|
|
UpdatedAt: time.Now(),
|
|
}
|
|
|
|
err = s.userRepo.CreateUser(ctx, user)
|
|
if err != nil {
|
|
logger.Errorf(ctx, "Failed to create user: %v", err)
|
|
return nil, errors.New("failed to create user")
|
|
}
|
|
|
|
logger.Info(ctx, "User registered successfully")
|
|
return user, nil
|
|
}
|
|
|
|
// Login authenticates a user and returns tokens
|
|
func (s *userService) Login(ctx context.Context, req *types.LoginRequest) (*types.LoginResponse, error) {
|
|
logger.Info(ctx, "Start user login")
|
|
// Get user by email
|
|
user, err := s.userRepo.GetUserByEmail(ctx, req.Email)
|
|
if err != nil {
|
|
logger.Errorf(ctx, "Failed to get user by email: %v", err)
|
|
return &types.LoginResponse{
|
|
Success: false,
|
|
Message: "Invalid email or password",
|
|
}, nil
|
|
}
|
|
if user == nil {
|
|
logger.Warn(ctx, "User not found for email")
|
|
return &types.LoginResponse{
|
|
Success: false,
|
|
Message: "Invalid email or password",
|
|
}, nil
|
|
}
|
|
|
|
// Check if user is active
|
|
if !user.IsActive {
|
|
logger.Warn(ctx, "User account is disabled")
|
|
return &types.LoginResponse{
|
|
Success: false,
|
|
Message: "Account is disabled",
|
|
}, nil
|
|
}
|
|
|
|
// Verify password
|
|
err = bcrypt.CompareHashAndPassword([]byte(user.PasswordHash), []byte(req.Password))
|
|
if err != nil {
|
|
logger.Warn(ctx, "Password verification failed")
|
|
return &types.LoginResponse{
|
|
Success: false,
|
|
Message: "Invalid email or password",
|
|
}, nil
|
|
}
|
|
logger.Info(ctx, "Password verification successful")
|
|
|
|
// Generate tokens
|
|
logger.Info(ctx, "Generating tokens")
|
|
accessToken, refreshToken, err := s.GenerateTokens(ctx, user)
|
|
if err != nil {
|
|
logger.Errorf(ctx, "Failed to generate tokens: %v", err)
|
|
return &types.LoginResponse{
|
|
Success: false,
|
|
Message: "Login failed",
|
|
}, nil
|
|
}
|
|
logger.Info(ctx, "Tokens generated successfully")
|
|
|
|
// Get tenant information
|
|
tenant, err := s.tenantService.GetTenantByID(ctx, user.TenantID)
|
|
if err != nil {
|
|
logger.Warn(ctx, "Failed to get tenant info")
|
|
} else {
|
|
logger.Info(ctx, "Tenant information retrieved successfully")
|
|
}
|
|
|
|
logger.Info(ctx, "User logged in successfully")
|
|
return &types.LoginResponse{
|
|
Success: true,
|
|
Message: "Login successful",
|
|
User: user,
|
|
Tenant: tenant,
|
|
Token: accessToken,
|
|
RefreshToken: refreshToken,
|
|
}, nil
|
|
}
|
|
|
|
// GetOIDCAuthorizationURL builds the OIDC authorization URL.
|
|
func (s *userService) GetOIDCAuthorizationURL(ctx context.Context, redirectURI string) (*types.OIDCAuthURLResponse, error) {
|
|
cfg, err := s.getOIDCConfig(ctx)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
if strings.TrimSpace(redirectURI) == "" {
|
|
return nil, errors.New("redirect_uri is required")
|
|
}
|
|
|
|
nonce, err := generateRandomString(24)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to generate state: %w", err)
|
|
}
|
|
|
|
state, err := encodeOIDCAuthorizationState(&oidcAuthorizationState{
|
|
Nonce: nonce,
|
|
RedirectURI: strings.TrimSpace(redirectURI),
|
|
})
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to encode OIDC state: %w", err)
|
|
}
|
|
|
|
query := url.Values{}
|
|
query.Set("response_type", "code")
|
|
query.Set("client_id", cfg.ClientID)
|
|
query.Set("redirect_uri", redirectURI)
|
|
query.Set("scope", strings.Join(cfg.Scopes, " "))
|
|
query.Set("state", state)
|
|
|
|
authURL := cfg.AuthorizationEndpoint
|
|
if strings.Contains(authURL, "?") {
|
|
authURL += "&" + query.Encode()
|
|
} else {
|
|
authURL += "?" + query.Encode()
|
|
}
|
|
|
|
return &types.OIDCAuthURLResponse{
|
|
Success: true,
|
|
ProviderDisplayName: cfg.ProviderDisplayName,
|
|
AuthorizationURL: authURL,
|
|
State: state,
|
|
}, nil
|
|
}
|
|
|
|
// LoginWithOIDC exchanges code for tokens, loads user info, provisions user if needed, and returns local login tokens.
|
|
func (s *userService) LoginWithOIDC(ctx context.Context, code, redirectURI string) (*types.OIDCCallbackResponse, error) {
|
|
if strings.TrimSpace(code) == "" {
|
|
return nil, errors.New("code is required")
|
|
}
|
|
if strings.TrimSpace(redirectURI) == "" {
|
|
return nil, errors.New("redirect_uri is required")
|
|
}
|
|
|
|
cfg, err := s.getOIDCConfig(ctx)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
tokenResp, err := s.exchangeOIDCCode(ctx, cfg, code, redirectURI)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
userInfo, err := s.resolveOIDCUserInfo(ctx, cfg, tokenResp)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
if strings.TrimSpace(userInfo.Email) == "" {
|
|
return nil, errors.New("OIDC provider did not return email")
|
|
}
|
|
|
|
user, err := s.userRepo.GetUserByEmail(ctx, userInfo.Email)
|
|
if err != nil && !isUserLookupNotFound(err) {
|
|
return nil, fmt.Errorf("failed to query user by email: %w", err)
|
|
}
|
|
isNewUser := false
|
|
if isUserLookupNotFound(err) || user == nil {
|
|
user, err = s.provisionOIDCUser(ctx, userInfo)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
isNewUser = true
|
|
}
|
|
|
|
if !user.IsActive {
|
|
return &types.OIDCCallbackResponse{Success: false, Message: "Account is disabled"}, nil
|
|
}
|
|
|
|
accessToken, refreshToken, err := s.GenerateTokens(ctx, user)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to generate local tokens: %w", err)
|
|
}
|
|
|
|
tenant, err := s.tenantService.GetTenantByID(ctx, user.TenantID)
|
|
if err != nil {
|
|
logger.Warnf(ctx, "Failed to get tenant info after OIDC login: %v", err)
|
|
}
|
|
|
|
return &types.OIDCCallbackResponse{
|
|
Success: true,
|
|
Message: "Login successful",
|
|
User: user,
|
|
Tenant: tenant,
|
|
Token: accessToken,
|
|
RefreshToken: refreshToken,
|
|
IsNewUser: isNewUser,
|
|
}, nil
|
|
}
|
|
|
|
// GetUserByID gets a user by ID
|
|
func (s *userService) GetUserByID(ctx context.Context, id string) (*types.User, error) {
|
|
return s.userRepo.GetUserByID(ctx, id)
|
|
}
|
|
|
|
// GetUserByEmail gets a user by email
|
|
func (s *userService) GetUserByEmail(ctx context.Context, email string) (*types.User, error) {
|
|
return s.userRepo.GetUserByEmail(ctx, email)
|
|
}
|
|
|
|
// GetUserByUsername gets a user by username
|
|
func (s *userService) GetUserByUsername(ctx context.Context, username string) (*types.User, error) {
|
|
return s.userRepo.GetUserByUsername(ctx, username)
|
|
}
|
|
|
|
// GetUserByTenantID gets the first user (owner) of a tenant
|
|
func (s *userService) GetUserByTenantID(ctx context.Context, tenantID uint64) (*types.User, error) {
|
|
return s.userRepo.GetUserByTenantID(ctx, tenantID)
|
|
}
|
|
|
|
// UpdateUser updates user information
|
|
func (s *userService) UpdateUser(ctx context.Context, user *types.User) error {
|
|
user.UpdatedAt = time.Now()
|
|
return s.userRepo.UpdateUser(ctx, user)
|
|
}
|
|
|
|
// DeleteUser deletes a user
|
|
func (s *userService) DeleteUser(ctx context.Context, id string) error {
|
|
return s.userRepo.DeleteUser(ctx, id)
|
|
}
|
|
|
|
// ChangePassword changes user password
|
|
func (s *userService) ChangePassword(ctx context.Context, userID string, oldPassword, newPassword string) error {
|
|
user, err := s.userRepo.GetUserByID(ctx, userID)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
// Verify old password
|
|
err = bcrypt.CompareHashAndPassword([]byte(user.PasswordHash), []byte(oldPassword))
|
|
if err != nil {
|
|
return errors.New("invalid old password")
|
|
}
|
|
|
|
// Hash new password
|
|
hashedPassword, err := bcrypt.GenerateFromPassword([]byte(newPassword), bcrypt.DefaultCost)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
user.PasswordHash = string(hashedPassword)
|
|
user.UpdatedAt = time.Now()
|
|
|
|
return s.userRepo.UpdateUser(ctx, user)
|
|
}
|
|
|
|
// ValidatePassword validates user password
|
|
func (s *userService) ValidatePassword(ctx context.Context, userID string, password string) error {
|
|
user, err := s.userRepo.GetUserByID(ctx, userID)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
return bcrypt.CompareHashAndPassword([]byte(user.PasswordHash), []byte(password))
|
|
}
|
|
|
|
// GenerateTokens generates access and refresh tokens for user
|
|
func (s *userService) GenerateTokens(
|
|
ctx context.Context,
|
|
user *types.User,
|
|
) (accessToken, refreshToken string, err error) {
|
|
// Generate access token (expires in 24 hours)
|
|
accessClaims := jwt.MapClaims{
|
|
"user_id": user.ID,
|
|
"email": user.Email,
|
|
"tenant_id": user.TenantID,
|
|
"exp": time.Now().Add(24 * time.Hour).Unix(),
|
|
"iat": time.Now().Unix(),
|
|
"type": "access",
|
|
}
|
|
|
|
accessTokenObj := jwt.NewWithClaims(jwt.SigningMethodHS256, accessClaims)
|
|
accessToken, err = accessTokenObj.SignedString([]byte(getJwtSecret()))
|
|
if err != nil {
|
|
return "", "", err
|
|
}
|
|
|
|
// Generate refresh token (expires in 7 days)
|
|
refreshClaims := jwt.MapClaims{
|
|
"user_id": user.ID,
|
|
"exp": time.Now().Add(7 * 24 * time.Hour).Unix(),
|
|
"iat": time.Now().Unix(),
|
|
"type": "refresh",
|
|
}
|
|
|
|
refreshTokenObj := jwt.NewWithClaims(jwt.SigningMethodHS256, refreshClaims)
|
|
refreshToken, err = refreshTokenObj.SignedString([]byte(getJwtSecret()))
|
|
if err != nil {
|
|
return "", "", err
|
|
}
|
|
|
|
// Store tokens in database
|
|
accessTokenRecord := &types.AuthToken{
|
|
ID: uuid.New().String(),
|
|
UserID: user.ID,
|
|
Token: accessToken,
|
|
TokenType: "access_token",
|
|
ExpiresAt: time.Now().Add(24 * time.Hour),
|
|
CreatedAt: time.Now(),
|
|
UpdatedAt: time.Now(),
|
|
}
|
|
|
|
refreshTokenRecord := &types.AuthToken{
|
|
ID: uuid.New().String(),
|
|
UserID: user.ID,
|
|
Token: refreshToken,
|
|
TokenType: "refresh_token",
|
|
ExpiresAt: time.Now().Add(7 * 24 * time.Hour),
|
|
CreatedAt: time.Now(),
|
|
UpdatedAt: time.Now(),
|
|
}
|
|
|
|
_ = s.tokenRepo.CreateToken(ctx, accessTokenRecord)
|
|
_ = s.tokenRepo.CreateToken(ctx, refreshTokenRecord)
|
|
|
|
return accessToken, refreshToken, nil
|
|
}
|
|
|
|
// ValidateToken validates an access token
|
|
func (s *userService) ValidateToken(ctx context.Context, tokenString string) (*types.User, error) {
|
|
token, err := jwt.Parse(tokenString, func(token *jwt.Token) (interface{}, error) {
|
|
if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok {
|
|
return nil, fmt.Errorf("unexpected signing method: %v", token.Header["alg"])
|
|
}
|
|
return []byte(getJwtSecret()), nil
|
|
})
|
|
|
|
if err != nil || !token.Valid {
|
|
return nil, errors.New("invalid token")
|
|
}
|
|
|
|
claims, ok := token.Claims.(jwt.MapClaims)
|
|
if !ok {
|
|
return nil, errors.New("invalid token claims")
|
|
}
|
|
|
|
userID, ok := claims["user_id"].(string)
|
|
if !ok {
|
|
return nil, errors.New("invalid user ID in token")
|
|
}
|
|
|
|
// Check if token is revoked
|
|
tokenRecord, err := s.tokenRepo.GetTokenByValue(ctx, tokenString)
|
|
if err != nil || tokenRecord == nil || tokenRecord.IsRevoked {
|
|
return nil, errors.New("token is revoked")
|
|
}
|
|
|
|
return s.userRepo.GetUserByID(ctx, userID)
|
|
}
|
|
|
|
// RefreshToken refreshes access token using refresh token
|
|
func (s *userService) RefreshToken(
|
|
ctx context.Context,
|
|
refreshTokenString string,
|
|
) (accessToken, newRefreshToken string, err error) {
|
|
token, err := jwt.Parse(refreshTokenString, func(token *jwt.Token) (interface{}, error) {
|
|
if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok {
|
|
return nil, fmt.Errorf("unexpected signing method: %v", token.Header["alg"])
|
|
}
|
|
return []byte(getJwtSecret()), nil
|
|
})
|
|
|
|
if err != nil || !token.Valid {
|
|
return "", "", errors.New("invalid refresh token")
|
|
}
|
|
|
|
claims, ok := token.Claims.(jwt.MapClaims)
|
|
if !ok {
|
|
return "", "", errors.New("invalid token claims")
|
|
}
|
|
|
|
tokenType, ok := claims["type"].(string)
|
|
if !ok || tokenType != "refresh" {
|
|
return "", "", errors.New("not a refresh token")
|
|
}
|
|
|
|
userID, ok := claims["user_id"].(string)
|
|
if !ok {
|
|
return "", "", errors.New("invalid user ID in token")
|
|
}
|
|
|
|
// Check if token is revoked
|
|
tokenRecord, err := s.tokenRepo.GetTokenByValue(ctx, refreshTokenString)
|
|
if err != nil || tokenRecord == nil || tokenRecord.IsRevoked {
|
|
return "", "", errors.New("refresh token is revoked")
|
|
}
|
|
|
|
// Get user
|
|
user, err := s.userRepo.GetUserByID(ctx, userID)
|
|
if err != nil {
|
|
return "", "", err
|
|
}
|
|
|
|
// Revoke old refresh token
|
|
tokenRecord.IsRevoked = true
|
|
_ = s.tokenRepo.UpdateToken(ctx, tokenRecord)
|
|
|
|
// Generate new tokens
|
|
return s.GenerateTokens(ctx, user)
|
|
}
|
|
|
|
// RevokeToken revokes a token
|
|
func (s *userService) RevokeToken(ctx context.Context, tokenString string) error {
|
|
tokenRecord, err := s.tokenRepo.GetTokenByValue(ctx, tokenString)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
tokenRecord.IsRevoked = true
|
|
tokenRecord.UpdatedAt = time.Now()
|
|
|
|
return s.tokenRepo.UpdateToken(ctx, tokenRecord)
|
|
}
|
|
|
|
// GetCurrentUser gets current user from context
|
|
func (s *userService) GetCurrentUser(ctx context.Context) (*types.User, error) {
|
|
user, ok := ctx.Value(types.UserContextKey).(*types.User)
|
|
if !ok {
|
|
return nil, errors.New("user not found in context")
|
|
}
|
|
|
|
return user, nil
|
|
}
|
|
|
|
// SearchUsers searches users by username or email
|
|
func (s *userService) SearchUsers(ctx context.Context, query string, limit int) ([]*types.User, error) {
|
|
if query == "" {
|
|
return []*types.User{}, nil
|
|
}
|
|
return s.userRepo.SearchUsers(ctx, query, limit)
|
|
}
|
|
|
|
type oidcDiscoveryDocument struct {
|
|
AuthorizationEndpoint string `json:"authorization_endpoint"`
|
|
TokenEndpoint string `json:"token_endpoint"`
|
|
UserInfoEndpoint string `json:"userinfo_endpoint"`
|
|
}
|
|
|
|
type oidcTokenResponse struct {
|
|
AccessToken string `json:"access_token"`
|
|
IDToken string `json:"id_token"`
|
|
TokenType string `json:"token_type"`
|
|
}
|
|
|
|
func (s *userService) getOIDCConfig(ctx context.Context) (*config.OIDCAuthConfig, error) {
|
|
if s.config == nil || s.config.OIDCAuth == nil || !s.config.OIDCAuth.Enable {
|
|
return nil, errors.New("OIDC login is disabled")
|
|
}
|
|
cfg := *s.config.OIDCAuth
|
|
if cfg.UserInfoMapping == nil {
|
|
cfg.UserInfoMapping = &config.OIDCUserInfoMapping{Username: "name", Email: "email"}
|
|
}
|
|
if err := s.populateOIDCEndpoints(ctx, &cfg); err != nil {
|
|
return nil, err
|
|
}
|
|
return &cfg, nil
|
|
}
|
|
|
|
func (s *userService) populateOIDCEndpoints(ctx context.Context, cfg *config.OIDCAuthConfig) error {
|
|
if strings.TrimSpace(cfg.AuthorizationEndpoint) != "" && strings.TrimSpace(cfg.TokenEndpoint) != "" {
|
|
return nil
|
|
}
|
|
if strings.TrimSpace(cfg.DiscoveryURL) == "" {
|
|
return errors.New("OIDC discovery_url or explicit endpoints are required")
|
|
}
|
|
|
|
req, err := http.NewRequestWithContext(ctx, http.MethodGet, cfg.DiscoveryURL, nil)
|
|
if err != nil {
|
|
return fmt.Errorf("failed to create OIDC discovery request: %w", err)
|
|
}
|
|
|
|
resp, err := http.DefaultClient.Do(req)
|
|
if err != nil {
|
|
return fmt.Errorf("failed to load OIDC discovery document: %w", err)
|
|
}
|
|
defer resp.Body.Close()
|
|
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
|
|
body, _ := io.ReadAll(io.LimitReader(resp.Body, 2048))
|
|
return fmt.Errorf("OIDC discovery request failed: status=%d body=%s", resp.StatusCode, strings.TrimSpace(string(body)))
|
|
}
|
|
|
|
var doc oidcDiscoveryDocument
|
|
if err := json.NewDecoder(resp.Body).Decode(&doc); err != nil {
|
|
return fmt.Errorf("failed to decode OIDC discovery document: %w", err)
|
|
}
|
|
if cfg.AuthorizationEndpoint == "" {
|
|
cfg.AuthorizationEndpoint = doc.AuthorizationEndpoint
|
|
}
|
|
if cfg.TokenEndpoint == "" {
|
|
cfg.TokenEndpoint = doc.TokenEndpoint
|
|
}
|
|
if cfg.UserInfoEndpoint == "" {
|
|
cfg.UserInfoEndpoint = doc.UserInfoEndpoint
|
|
}
|
|
if cfg.AuthorizationEndpoint == "" || cfg.TokenEndpoint == "" {
|
|
return errors.New("OIDC discovery document missing required endpoints")
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func (s *userService) exchangeOIDCCode(ctx context.Context, cfg *config.OIDCAuthConfig, code, redirectURI string) (*oidcTokenResponse, error) {
|
|
form := url.Values{}
|
|
form.Set("grant_type", "authorization_code")
|
|
form.Set("code", code)
|
|
form.Set("redirect_uri", redirectURI)
|
|
form.Set("client_id", cfg.ClientID)
|
|
form.Set("client_secret", cfg.ClientSecret)
|
|
|
|
req, err := http.NewRequestWithContext(ctx, http.MethodPost, cfg.TokenEndpoint, strings.NewReader(form.Encode()))
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to create OIDC token request: %w", err)
|
|
}
|
|
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
|
|
req.Header.Set("Accept", "application/json")
|
|
|
|
resp, err := http.DefaultClient.Do(req)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to exchange OIDC code: %w", err)
|
|
}
|
|
defer resp.Body.Close()
|
|
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
|
|
body, _ := io.ReadAll(io.LimitReader(resp.Body, 2048))
|
|
return nil, fmt.Errorf("OIDC token exchange failed: status=%d body=%s", resp.StatusCode, strings.TrimSpace(string(body)))
|
|
}
|
|
|
|
var tokenResp oidcTokenResponse
|
|
if err := json.NewDecoder(resp.Body).Decode(&tokenResp); err != nil {
|
|
return nil, fmt.Errorf("failed to decode OIDC token response: %w", err)
|
|
}
|
|
if strings.TrimSpace(tokenResp.AccessToken) == "" && strings.TrimSpace(tokenResp.IDToken) == "" {
|
|
return nil, errors.New("OIDC token response missing access_token and id_token")
|
|
}
|
|
return &tokenResp, nil
|
|
}
|
|
|
|
func (s *userService) resolveOIDCUserInfo(ctx context.Context, cfg *config.OIDCAuthConfig, tokenResp *oidcTokenResponse) (*types.OIDCUserInfo, error) {
|
|
claims := map[string]interface{}{}
|
|
|
|
if strings.TrimSpace(tokenResp.IDToken) != "" {
|
|
idTokenClaims, err := decodeJWTClaims(tokenResp.IDToken)
|
|
if err != nil {
|
|
logger.Warnf(ctx, "Failed to decode OIDC id_token claims: %v", err)
|
|
} else {
|
|
for k, v := range idTokenClaims {
|
|
claims[k] = v
|
|
}
|
|
}
|
|
}
|
|
|
|
if strings.TrimSpace(cfg.UserInfoEndpoint) != "" && strings.TrimSpace(tokenResp.AccessToken) != "" {
|
|
userInfoClaims, err := s.fetchOIDCUserInfo(ctx, cfg.UserInfoEndpoint, tokenResp.AccessToken)
|
|
if err != nil {
|
|
logger.Warnf(ctx, "Failed to fetch OIDC userinfo, fallback to id_token claims: %v", err)
|
|
} else {
|
|
for k, v := range userInfoClaims {
|
|
claims[k] = v
|
|
}
|
|
}
|
|
}
|
|
|
|
info := &types.OIDCUserInfo{Claims: claims}
|
|
if sub, _ := claims["sub"].(string); sub != "" {
|
|
info.Subject = sub
|
|
}
|
|
info.Username = extractClaimAsString(claims, cfg.UserInfoMapping.Username)
|
|
info.Email = extractClaimAsString(claims, cfg.UserInfoMapping.Email)
|
|
if info.Username == "" {
|
|
info.Username = extractClaimAsString(claims, "preferred_username")
|
|
}
|
|
if info.Username == "" {
|
|
info.Username = extractClaimAsString(claims, "name")
|
|
}
|
|
if info.Username == "" && info.Email != "" {
|
|
info.Username = strings.Split(info.Email, "@")[0]
|
|
}
|
|
return info, nil
|
|
}
|
|
|
|
func (s *userService) fetchOIDCUserInfo(ctx context.Context, endpoint, accessToken string) (map[string]interface{}, error) {
|
|
req, err := http.NewRequestWithContext(ctx, http.MethodGet, endpoint, nil)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
req.Header.Set("Authorization", "Bearer "+accessToken)
|
|
req.Header.Set("Accept", "application/json")
|
|
|
|
resp, err := http.DefaultClient.Do(req)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
defer resp.Body.Close()
|
|
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
|
|
body, _ := io.ReadAll(io.LimitReader(resp.Body, 2048))
|
|
return nil, fmt.Errorf("userinfo request failed: status=%d body=%s", resp.StatusCode, strings.TrimSpace(string(body)))
|
|
}
|
|
|
|
var claims map[string]interface{}
|
|
if err := json.NewDecoder(resp.Body).Decode(&claims); err != nil {
|
|
return nil, err
|
|
}
|
|
return claims, nil
|
|
}
|
|
|
|
func (s *userService) provisionOIDCUser(ctx context.Context, info *types.OIDCUserInfo) (*types.User, error) {
|
|
username := s.generateOIDCUsername(ctx, info)
|
|
randomPassword, err := generateRandomString(32)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to generate password for OIDC user: %w", err)
|
|
}
|
|
|
|
user, err := s.Register(ctx, &types.RegisterRequest{
|
|
Username: username,
|
|
Email: info.Email,
|
|
Password: randomPassword,
|
|
})
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to auto-provision OIDC user: %w", err)
|
|
}
|
|
return user, nil
|
|
}
|
|
|
|
func (s *userService) generateOIDCUsername(ctx context.Context, info *types.OIDCUserInfo) string {
|
|
base := sanitizeUsernameCandidate(info.Username)
|
|
if base == "" {
|
|
base = sanitizeUsernameCandidate(strings.Split(info.Email, "@")[0])
|
|
}
|
|
if base == "" {
|
|
base = "oidc-user"
|
|
}
|
|
|
|
candidate := base
|
|
for i := 0; i < 20; i++ {
|
|
existing, err := s.userRepo.GetUserByUsername(ctx, candidate)
|
|
if isUserLookupNotFound(err) || (err == nil && existing == nil) {
|
|
return candidate
|
|
}
|
|
if err != nil && !isUserLookupNotFound(err) {
|
|
logger.Warnf(ctx, "Failed to check existing OIDC username %q: %v", candidate, err)
|
|
}
|
|
candidate = fmt.Sprintf("%s-%d", base, i+1)
|
|
}
|
|
return fmt.Sprintf("%s-%d", base, time.Now().Unix())
|
|
}
|
|
|
|
func generateRandomString(length int) (string, error) {
|
|
buffer := make([]byte, length)
|
|
if _, err := rand.Read(buffer); err != nil {
|
|
return "", err
|
|
}
|
|
return base64.RawURLEncoding.EncodeToString(buffer), nil
|
|
}
|
|
|
|
func encodeOIDCAuthorizationState(state *oidcAuthorizationState) (string, error) {
|
|
payload, err := json.Marshal(state)
|
|
if err != nil {
|
|
return "", err
|
|
}
|
|
return base64.RawURLEncoding.EncodeToString(payload), nil
|
|
}
|
|
|
|
func decodeJWTClaims(token string) (map[string]interface{}, error) {
|
|
parts := strings.Split(token, ".")
|
|
if len(parts) < 2 {
|
|
return nil, errors.New("invalid JWT format")
|
|
}
|
|
payload, err := base64.RawURLEncoding.DecodeString(parts[1])
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
var claims map[string]interface{}
|
|
if err := json.Unmarshal(payload, &claims); err != nil {
|
|
return nil, err
|
|
}
|
|
return claims, nil
|
|
}
|
|
|
|
func extractClaimAsString(claims map[string]interface{}, key string) string {
|
|
key = strings.TrimSpace(key)
|
|
if key == "" {
|
|
return ""
|
|
}
|
|
value, ok := claims[key]
|
|
if !ok || value == nil {
|
|
return ""
|
|
}
|
|
switch v := value.(type) {
|
|
case string:
|
|
return strings.TrimSpace(v)
|
|
default:
|
|
return strings.TrimSpace(fmt.Sprint(v))
|
|
}
|
|
}
|
|
|
|
func sanitizeUsernameCandidate(value string) string {
|
|
value = strings.TrimSpace(strings.ToLower(value))
|
|
if value == "" {
|
|
return ""
|
|
}
|
|
var b strings.Builder
|
|
lastDash := false
|
|
for _, r := range value {
|
|
if (r >= 'a' && r <= 'z') || (r >= '0' && r <= '9') || r == '_' || r == '.' {
|
|
b.WriteRune(r)
|
|
lastDash = false
|
|
continue
|
|
}
|
|
if !lastDash {
|
|
b.WriteByte('-')
|
|
lastDash = true
|
|
}
|
|
}
|
|
result := strings.Trim(b.String(), "-._")
|
|
if len(result) > 50 {
|
|
result = strings.Trim(result[:50], "-._")
|
|
}
|
|
return result
|
|
}
|
|
|
|
func isUserLookupNotFound(err error) bool {
|
|
if err == nil {
|
|
return false
|
|
}
|
|
return errors.Is(err, apprepo.ErrUserNotFound) || strings.Contains(strings.ToLower(err.Error()), "user not found")
|
|
}
|