Files
WeKnora/internal/application/service/user.go
Windfarer c1816fe6d6 add oidc
2026-03-30 11:13:44 +08:00

859 lines
25 KiB
Go

package service
import (
"context"
"crypto/rand"
"encoding/base64"
"encoding/json"
"errors"
"fmt"
"io"
"net/http"
"net/url"
"os"
"strings"
"sync"
"time"
"github.com/golang-jwt/jwt/v5"
"github.com/google/uuid"
"golang.org/x/crypto/bcrypt"
apprepo "github.com/Tencent/WeKnora/internal/application/repository"
"github.com/Tencent/WeKnora/internal/config"
"github.com/Tencent/WeKnora/internal/logger"
"github.com/Tencent/WeKnora/internal/types"
"github.com/Tencent/WeKnora/internal/types/interfaces"
secutils "github.com/Tencent/WeKnora/internal/utils"
)
type oidcAuthorizationState struct {
Nonce string `json:"nonce"`
RedirectURI string `json:"redirect_uri,omitempty"`
}
var (
jwtSecretOnce sync.Once
jwtSecret string
)
// getJwtSecret retrieves the JWT secret from the environment, falling back to a securely generated random secret.
func getJwtSecret() string {
jwtSecretOnce.Do(func() {
if envSecret := strings.TrimSpace(os.Getenv("JWT_SECRET")); envSecret != "" {
jwtSecret = envSecret
return
}
randomBytes := make([]byte, 32)
if _, err := rand.Read(randomBytes); err != nil {
panic(fmt.Sprintf("failed to generate JWT secret: %v", err))
}
jwtSecret = base64.StdEncoding.EncodeToString(randomBytes)
})
return jwtSecret
}
// userService implements the UserService interface
type userService struct {
userRepo interfaces.UserRepository
tokenRepo interfaces.AuthTokenRepository
tenantService interfaces.TenantService
config *config.Config
}
// NewUserService creates a new user service instance
func NewUserService(
configInfo *config.Config,
userRepo interfaces.UserRepository,
tokenRepo interfaces.AuthTokenRepository,
tenantService interfaces.TenantService,
) interfaces.UserService {
return &userService{
userRepo: userRepo,
tokenRepo: tokenRepo,
tenantService: tenantService,
config: configInfo,
}
}
// Register creates a new user account
func (s *userService) Register(ctx context.Context, req *types.RegisterRequest) (*types.User, error) {
logger.Info(ctx, "Start user registration")
// Validate input
if req.Username == "" || req.Email == "" || req.Password == "" {
return nil, errors.New("username, email and password are required")
}
// Check if user already exists
existingUser, _ := s.userRepo.GetUserByEmail(ctx, req.Email)
if existingUser != nil {
return nil, errors.New("user with this email already exists")
}
existingUser, _ = s.userRepo.GetUserByUsername(ctx, req.Username)
if existingUser != nil {
return nil, errors.New("user with this username already exists")
}
// Hash password
hashedPassword, err := bcrypt.GenerateFromPassword([]byte(req.Password), bcrypt.DefaultCost)
if err != nil {
logger.Errorf(ctx, "Failed to hash password: %v", err)
return nil, errors.New("failed to process password")
}
// Create default tenant for the user
// Note: RetrieverEngines is left empty - system will use defaults from RETRIEVE_DRIVER env
tenant := &types.Tenant{
Name: fmt.Sprintf("%s's Workspace", secutils.SanitizeForLog(req.Username)),
Description: "Default workspace",
Status: "active",
}
createdTenant, err := s.tenantService.CreateTenant(ctx, tenant)
if err != nil {
logger.Errorf(ctx, "Failed to create tenant")
return nil, errors.New("failed to create workspace")
}
// Create user
user := &types.User{
ID: uuid.New().String(),
Username: req.Username,
Email: req.Email,
PasswordHash: string(hashedPassword),
TenantID: createdTenant.ID,
IsActive: true,
CreatedAt: time.Now(),
UpdatedAt: time.Now(),
}
err = s.userRepo.CreateUser(ctx, user)
if err != nil {
logger.Errorf(ctx, "Failed to create user: %v", err)
return nil, errors.New("failed to create user")
}
logger.Info(ctx, "User registered successfully")
return user, nil
}
// Login authenticates a user and returns tokens
func (s *userService) Login(ctx context.Context, req *types.LoginRequest) (*types.LoginResponse, error) {
logger.Info(ctx, "Start user login")
// Get user by email
user, err := s.userRepo.GetUserByEmail(ctx, req.Email)
if err != nil {
logger.Errorf(ctx, "Failed to get user by email: %v", err)
return &types.LoginResponse{
Success: false,
Message: "Invalid email or password",
}, nil
}
if user == nil {
logger.Warn(ctx, "User not found for email")
return &types.LoginResponse{
Success: false,
Message: "Invalid email or password",
}, nil
}
// Check if user is active
if !user.IsActive {
logger.Warn(ctx, "User account is disabled")
return &types.LoginResponse{
Success: false,
Message: "Account is disabled",
}, nil
}
// Verify password
err = bcrypt.CompareHashAndPassword([]byte(user.PasswordHash), []byte(req.Password))
if err != nil {
logger.Warn(ctx, "Password verification failed")
return &types.LoginResponse{
Success: false,
Message: "Invalid email or password",
}, nil
}
logger.Info(ctx, "Password verification successful")
// Generate tokens
logger.Info(ctx, "Generating tokens")
accessToken, refreshToken, err := s.GenerateTokens(ctx, user)
if err != nil {
logger.Errorf(ctx, "Failed to generate tokens: %v", err)
return &types.LoginResponse{
Success: false,
Message: "Login failed",
}, nil
}
logger.Info(ctx, "Tokens generated successfully")
// Get tenant information
tenant, err := s.tenantService.GetTenantByID(ctx, user.TenantID)
if err != nil {
logger.Warn(ctx, "Failed to get tenant info")
} else {
logger.Info(ctx, "Tenant information retrieved successfully")
}
logger.Info(ctx, "User logged in successfully")
return &types.LoginResponse{
Success: true,
Message: "Login successful",
User: user,
Tenant: tenant,
Token: accessToken,
RefreshToken: refreshToken,
}, nil
}
// GetOIDCAuthorizationURL builds the OIDC authorization URL.
func (s *userService) GetOIDCAuthorizationURL(ctx context.Context, redirectURI string) (*types.OIDCAuthURLResponse, error) {
cfg, err := s.getOIDCConfig(ctx)
if err != nil {
return nil, err
}
if strings.TrimSpace(redirectURI) == "" {
return nil, errors.New("redirect_uri is required")
}
nonce, err := generateRandomString(24)
if err != nil {
return nil, fmt.Errorf("failed to generate state: %w", err)
}
state, err := encodeOIDCAuthorizationState(&oidcAuthorizationState{
Nonce: nonce,
RedirectURI: strings.TrimSpace(redirectURI),
})
if err != nil {
return nil, fmt.Errorf("failed to encode OIDC state: %w", err)
}
query := url.Values{}
query.Set("response_type", "code")
query.Set("client_id", cfg.ClientID)
query.Set("redirect_uri", redirectURI)
query.Set("scope", strings.Join(cfg.Scopes, " "))
query.Set("state", state)
authURL := cfg.AuthorizationEndpoint
if strings.Contains(authURL, "?") {
authURL += "&" + query.Encode()
} else {
authURL += "?" + query.Encode()
}
return &types.OIDCAuthURLResponse{
Success: true,
ProviderDisplayName: cfg.ProviderDisplayName,
AuthorizationURL: authURL,
State: state,
}, nil
}
// LoginWithOIDC exchanges code for tokens, loads user info, provisions user if needed, and returns local login tokens.
func (s *userService) LoginWithOIDC(ctx context.Context, code, redirectURI string) (*types.OIDCCallbackResponse, error) {
if strings.TrimSpace(code) == "" {
return nil, errors.New("code is required")
}
if strings.TrimSpace(redirectURI) == "" {
return nil, errors.New("redirect_uri is required")
}
cfg, err := s.getOIDCConfig(ctx)
if err != nil {
return nil, err
}
tokenResp, err := s.exchangeOIDCCode(ctx, cfg, code, redirectURI)
if err != nil {
return nil, err
}
userInfo, err := s.resolveOIDCUserInfo(ctx, cfg, tokenResp)
if err != nil {
return nil, err
}
if strings.TrimSpace(userInfo.Email) == "" {
return nil, errors.New("OIDC provider did not return email")
}
user, err := s.userRepo.GetUserByEmail(ctx, userInfo.Email)
if err != nil && !isUserLookupNotFound(err) {
return nil, fmt.Errorf("failed to query user by email: %w", err)
}
isNewUser := false
if isUserLookupNotFound(err) || user == nil {
user, err = s.provisionOIDCUser(ctx, userInfo)
if err != nil {
return nil, err
}
isNewUser = true
}
if !user.IsActive {
return &types.OIDCCallbackResponse{Success: false, Message: "Account is disabled"}, nil
}
accessToken, refreshToken, err := s.GenerateTokens(ctx, user)
if err != nil {
return nil, fmt.Errorf("failed to generate local tokens: %w", err)
}
tenant, err := s.tenantService.GetTenantByID(ctx, user.TenantID)
if err != nil {
logger.Warnf(ctx, "Failed to get tenant info after OIDC login: %v", err)
}
return &types.OIDCCallbackResponse{
Success: true,
Message: "Login successful",
User: user,
Tenant: tenant,
Token: accessToken,
RefreshToken: refreshToken,
IsNewUser: isNewUser,
}, nil
}
// GetUserByID gets a user by ID
func (s *userService) GetUserByID(ctx context.Context, id string) (*types.User, error) {
return s.userRepo.GetUserByID(ctx, id)
}
// GetUserByEmail gets a user by email
func (s *userService) GetUserByEmail(ctx context.Context, email string) (*types.User, error) {
return s.userRepo.GetUserByEmail(ctx, email)
}
// GetUserByUsername gets a user by username
func (s *userService) GetUserByUsername(ctx context.Context, username string) (*types.User, error) {
return s.userRepo.GetUserByUsername(ctx, username)
}
// GetUserByTenantID gets the first user (owner) of a tenant
func (s *userService) GetUserByTenantID(ctx context.Context, tenantID uint64) (*types.User, error) {
return s.userRepo.GetUserByTenantID(ctx, tenantID)
}
// UpdateUser updates user information
func (s *userService) UpdateUser(ctx context.Context, user *types.User) error {
user.UpdatedAt = time.Now()
return s.userRepo.UpdateUser(ctx, user)
}
// DeleteUser deletes a user
func (s *userService) DeleteUser(ctx context.Context, id string) error {
return s.userRepo.DeleteUser(ctx, id)
}
// ChangePassword changes user password
func (s *userService) ChangePassword(ctx context.Context, userID string, oldPassword, newPassword string) error {
user, err := s.userRepo.GetUserByID(ctx, userID)
if err != nil {
return err
}
// Verify old password
err = bcrypt.CompareHashAndPassword([]byte(user.PasswordHash), []byte(oldPassword))
if err != nil {
return errors.New("invalid old password")
}
// Hash new password
hashedPassword, err := bcrypt.GenerateFromPassword([]byte(newPassword), bcrypt.DefaultCost)
if err != nil {
return err
}
user.PasswordHash = string(hashedPassword)
user.UpdatedAt = time.Now()
return s.userRepo.UpdateUser(ctx, user)
}
// ValidatePassword validates user password
func (s *userService) ValidatePassword(ctx context.Context, userID string, password string) error {
user, err := s.userRepo.GetUserByID(ctx, userID)
if err != nil {
return err
}
return bcrypt.CompareHashAndPassword([]byte(user.PasswordHash), []byte(password))
}
// GenerateTokens generates access and refresh tokens for user
func (s *userService) GenerateTokens(
ctx context.Context,
user *types.User,
) (accessToken, refreshToken string, err error) {
// Generate access token (expires in 24 hours)
accessClaims := jwt.MapClaims{
"user_id": user.ID,
"email": user.Email,
"tenant_id": user.TenantID,
"exp": time.Now().Add(24 * time.Hour).Unix(),
"iat": time.Now().Unix(),
"type": "access",
}
accessTokenObj := jwt.NewWithClaims(jwt.SigningMethodHS256, accessClaims)
accessToken, err = accessTokenObj.SignedString([]byte(getJwtSecret()))
if err != nil {
return "", "", err
}
// Generate refresh token (expires in 7 days)
refreshClaims := jwt.MapClaims{
"user_id": user.ID,
"exp": time.Now().Add(7 * 24 * time.Hour).Unix(),
"iat": time.Now().Unix(),
"type": "refresh",
}
refreshTokenObj := jwt.NewWithClaims(jwt.SigningMethodHS256, refreshClaims)
refreshToken, err = refreshTokenObj.SignedString([]byte(getJwtSecret()))
if err != nil {
return "", "", err
}
// Store tokens in database
accessTokenRecord := &types.AuthToken{
ID: uuid.New().String(),
UserID: user.ID,
Token: accessToken,
TokenType: "access_token",
ExpiresAt: time.Now().Add(24 * time.Hour),
CreatedAt: time.Now(),
UpdatedAt: time.Now(),
}
refreshTokenRecord := &types.AuthToken{
ID: uuid.New().String(),
UserID: user.ID,
Token: refreshToken,
TokenType: "refresh_token",
ExpiresAt: time.Now().Add(7 * 24 * time.Hour),
CreatedAt: time.Now(),
UpdatedAt: time.Now(),
}
_ = s.tokenRepo.CreateToken(ctx, accessTokenRecord)
_ = s.tokenRepo.CreateToken(ctx, refreshTokenRecord)
return accessToken, refreshToken, nil
}
// ValidateToken validates an access token
func (s *userService) ValidateToken(ctx context.Context, tokenString string) (*types.User, error) {
token, err := jwt.Parse(tokenString, func(token *jwt.Token) (interface{}, error) {
if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok {
return nil, fmt.Errorf("unexpected signing method: %v", token.Header["alg"])
}
return []byte(getJwtSecret()), nil
})
if err != nil || !token.Valid {
return nil, errors.New("invalid token")
}
claims, ok := token.Claims.(jwt.MapClaims)
if !ok {
return nil, errors.New("invalid token claims")
}
userID, ok := claims["user_id"].(string)
if !ok {
return nil, errors.New("invalid user ID in token")
}
// Check if token is revoked
tokenRecord, err := s.tokenRepo.GetTokenByValue(ctx, tokenString)
if err != nil || tokenRecord == nil || tokenRecord.IsRevoked {
return nil, errors.New("token is revoked")
}
return s.userRepo.GetUserByID(ctx, userID)
}
// RefreshToken refreshes access token using refresh token
func (s *userService) RefreshToken(
ctx context.Context,
refreshTokenString string,
) (accessToken, newRefreshToken string, err error) {
token, err := jwt.Parse(refreshTokenString, func(token *jwt.Token) (interface{}, error) {
if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok {
return nil, fmt.Errorf("unexpected signing method: %v", token.Header["alg"])
}
return []byte(getJwtSecret()), nil
})
if err != nil || !token.Valid {
return "", "", errors.New("invalid refresh token")
}
claims, ok := token.Claims.(jwt.MapClaims)
if !ok {
return "", "", errors.New("invalid token claims")
}
tokenType, ok := claims["type"].(string)
if !ok || tokenType != "refresh" {
return "", "", errors.New("not a refresh token")
}
userID, ok := claims["user_id"].(string)
if !ok {
return "", "", errors.New("invalid user ID in token")
}
// Check if token is revoked
tokenRecord, err := s.tokenRepo.GetTokenByValue(ctx, refreshTokenString)
if err != nil || tokenRecord == nil || tokenRecord.IsRevoked {
return "", "", errors.New("refresh token is revoked")
}
// Get user
user, err := s.userRepo.GetUserByID(ctx, userID)
if err != nil {
return "", "", err
}
// Revoke old refresh token
tokenRecord.IsRevoked = true
_ = s.tokenRepo.UpdateToken(ctx, tokenRecord)
// Generate new tokens
return s.GenerateTokens(ctx, user)
}
// RevokeToken revokes a token
func (s *userService) RevokeToken(ctx context.Context, tokenString string) error {
tokenRecord, err := s.tokenRepo.GetTokenByValue(ctx, tokenString)
if err != nil {
return err
}
tokenRecord.IsRevoked = true
tokenRecord.UpdatedAt = time.Now()
return s.tokenRepo.UpdateToken(ctx, tokenRecord)
}
// GetCurrentUser gets current user from context
func (s *userService) GetCurrentUser(ctx context.Context) (*types.User, error) {
user, ok := ctx.Value(types.UserContextKey).(*types.User)
if !ok {
return nil, errors.New("user not found in context")
}
return user, nil
}
// SearchUsers searches users by username or email
func (s *userService) SearchUsers(ctx context.Context, query string, limit int) ([]*types.User, error) {
if query == "" {
return []*types.User{}, nil
}
return s.userRepo.SearchUsers(ctx, query, limit)
}
type oidcDiscoveryDocument struct {
AuthorizationEndpoint string `json:"authorization_endpoint"`
TokenEndpoint string `json:"token_endpoint"`
UserInfoEndpoint string `json:"userinfo_endpoint"`
}
type oidcTokenResponse struct {
AccessToken string `json:"access_token"`
IDToken string `json:"id_token"`
TokenType string `json:"token_type"`
}
func (s *userService) getOIDCConfig(ctx context.Context) (*config.OIDCAuthConfig, error) {
if s.config == nil || s.config.OIDCAuth == nil || !s.config.OIDCAuth.Enable {
return nil, errors.New("OIDC login is disabled")
}
cfg := *s.config.OIDCAuth
if cfg.UserInfoMapping == nil {
cfg.UserInfoMapping = &config.OIDCUserInfoMapping{Username: "name", Email: "email"}
}
if err := s.populateOIDCEndpoints(ctx, &cfg); err != nil {
return nil, err
}
return &cfg, nil
}
func (s *userService) populateOIDCEndpoints(ctx context.Context, cfg *config.OIDCAuthConfig) error {
if strings.TrimSpace(cfg.AuthorizationEndpoint) != "" && strings.TrimSpace(cfg.TokenEndpoint) != "" {
return nil
}
if strings.TrimSpace(cfg.DiscoveryURL) == "" {
return errors.New("OIDC discovery_url or explicit endpoints are required")
}
req, err := http.NewRequestWithContext(ctx, http.MethodGet, cfg.DiscoveryURL, nil)
if err != nil {
return fmt.Errorf("failed to create OIDC discovery request: %w", err)
}
resp, err := http.DefaultClient.Do(req)
if err != nil {
return fmt.Errorf("failed to load OIDC discovery document: %w", err)
}
defer resp.Body.Close()
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
body, _ := io.ReadAll(io.LimitReader(resp.Body, 2048))
return fmt.Errorf("OIDC discovery request failed: status=%d body=%s", resp.StatusCode, strings.TrimSpace(string(body)))
}
var doc oidcDiscoveryDocument
if err := json.NewDecoder(resp.Body).Decode(&doc); err != nil {
return fmt.Errorf("failed to decode OIDC discovery document: %w", err)
}
if cfg.AuthorizationEndpoint == "" {
cfg.AuthorizationEndpoint = doc.AuthorizationEndpoint
}
if cfg.TokenEndpoint == "" {
cfg.TokenEndpoint = doc.TokenEndpoint
}
if cfg.UserInfoEndpoint == "" {
cfg.UserInfoEndpoint = doc.UserInfoEndpoint
}
if cfg.AuthorizationEndpoint == "" || cfg.TokenEndpoint == "" {
return errors.New("OIDC discovery document missing required endpoints")
}
return nil
}
func (s *userService) exchangeOIDCCode(ctx context.Context, cfg *config.OIDCAuthConfig, code, redirectURI string) (*oidcTokenResponse, error) {
form := url.Values{}
form.Set("grant_type", "authorization_code")
form.Set("code", code)
form.Set("redirect_uri", redirectURI)
form.Set("client_id", cfg.ClientID)
form.Set("client_secret", cfg.ClientSecret)
req, err := http.NewRequestWithContext(ctx, http.MethodPost, cfg.TokenEndpoint, strings.NewReader(form.Encode()))
if err != nil {
return nil, fmt.Errorf("failed to create OIDC token request: %w", err)
}
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
req.Header.Set("Accept", "application/json")
resp, err := http.DefaultClient.Do(req)
if err != nil {
return nil, fmt.Errorf("failed to exchange OIDC code: %w", err)
}
defer resp.Body.Close()
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
body, _ := io.ReadAll(io.LimitReader(resp.Body, 2048))
return nil, fmt.Errorf("OIDC token exchange failed: status=%d body=%s", resp.StatusCode, strings.TrimSpace(string(body)))
}
var tokenResp oidcTokenResponse
if err := json.NewDecoder(resp.Body).Decode(&tokenResp); err != nil {
return nil, fmt.Errorf("failed to decode OIDC token response: %w", err)
}
if strings.TrimSpace(tokenResp.AccessToken) == "" && strings.TrimSpace(tokenResp.IDToken) == "" {
return nil, errors.New("OIDC token response missing access_token and id_token")
}
return &tokenResp, nil
}
func (s *userService) resolveOIDCUserInfo(ctx context.Context, cfg *config.OIDCAuthConfig, tokenResp *oidcTokenResponse) (*types.OIDCUserInfo, error) {
claims := map[string]interface{}{}
if strings.TrimSpace(tokenResp.IDToken) != "" {
idTokenClaims, err := decodeJWTClaims(tokenResp.IDToken)
if err != nil {
logger.Warnf(ctx, "Failed to decode OIDC id_token claims: %v", err)
} else {
for k, v := range idTokenClaims {
claims[k] = v
}
}
}
if strings.TrimSpace(cfg.UserInfoEndpoint) != "" && strings.TrimSpace(tokenResp.AccessToken) != "" {
userInfoClaims, err := s.fetchOIDCUserInfo(ctx, cfg.UserInfoEndpoint, tokenResp.AccessToken)
if err != nil {
logger.Warnf(ctx, "Failed to fetch OIDC userinfo, fallback to id_token claims: %v", err)
} else {
for k, v := range userInfoClaims {
claims[k] = v
}
}
}
info := &types.OIDCUserInfo{Claims: claims}
if sub, _ := claims["sub"].(string); sub != "" {
info.Subject = sub
}
info.Username = extractClaimAsString(claims, cfg.UserInfoMapping.Username)
info.Email = extractClaimAsString(claims, cfg.UserInfoMapping.Email)
if info.Username == "" {
info.Username = extractClaimAsString(claims, "preferred_username")
}
if info.Username == "" {
info.Username = extractClaimAsString(claims, "name")
}
if info.Username == "" && info.Email != "" {
info.Username = strings.Split(info.Email, "@")[0]
}
return info, nil
}
func (s *userService) fetchOIDCUserInfo(ctx context.Context, endpoint, accessToken string) (map[string]interface{}, error) {
req, err := http.NewRequestWithContext(ctx, http.MethodGet, endpoint, nil)
if err != nil {
return nil, err
}
req.Header.Set("Authorization", "Bearer "+accessToken)
req.Header.Set("Accept", "application/json")
resp, err := http.DefaultClient.Do(req)
if err != nil {
return nil, err
}
defer resp.Body.Close()
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
body, _ := io.ReadAll(io.LimitReader(resp.Body, 2048))
return nil, fmt.Errorf("userinfo request failed: status=%d body=%s", resp.StatusCode, strings.TrimSpace(string(body)))
}
var claims map[string]interface{}
if err := json.NewDecoder(resp.Body).Decode(&claims); err != nil {
return nil, err
}
return claims, nil
}
func (s *userService) provisionOIDCUser(ctx context.Context, info *types.OIDCUserInfo) (*types.User, error) {
username := s.generateOIDCUsername(ctx, info)
randomPassword, err := generateRandomString(32)
if err != nil {
return nil, fmt.Errorf("failed to generate password for OIDC user: %w", err)
}
user, err := s.Register(ctx, &types.RegisterRequest{
Username: username,
Email: info.Email,
Password: randomPassword,
})
if err != nil {
return nil, fmt.Errorf("failed to auto-provision OIDC user: %w", err)
}
return user, nil
}
func (s *userService) generateOIDCUsername(ctx context.Context, info *types.OIDCUserInfo) string {
base := sanitizeUsernameCandidate(info.Username)
if base == "" {
base = sanitizeUsernameCandidate(strings.Split(info.Email, "@")[0])
}
if base == "" {
base = "oidc-user"
}
candidate := base
for i := 0; i < 20; i++ {
existing, err := s.userRepo.GetUserByUsername(ctx, candidate)
if isUserLookupNotFound(err) || (err == nil && existing == nil) {
return candidate
}
if err != nil && !isUserLookupNotFound(err) {
logger.Warnf(ctx, "Failed to check existing OIDC username %q: %v", candidate, err)
}
candidate = fmt.Sprintf("%s-%d", base, i+1)
}
return fmt.Sprintf("%s-%d", base, time.Now().Unix())
}
func generateRandomString(length int) (string, error) {
buffer := make([]byte, length)
if _, err := rand.Read(buffer); err != nil {
return "", err
}
return base64.RawURLEncoding.EncodeToString(buffer), nil
}
func encodeOIDCAuthorizationState(state *oidcAuthorizationState) (string, error) {
payload, err := json.Marshal(state)
if err != nil {
return "", err
}
return base64.RawURLEncoding.EncodeToString(payload), nil
}
func decodeJWTClaims(token string) (map[string]interface{}, error) {
parts := strings.Split(token, ".")
if len(parts) < 2 {
return nil, errors.New("invalid JWT format")
}
payload, err := base64.RawURLEncoding.DecodeString(parts[1])
if err != nil {
return nil, err
}
var claims map[string]interface{}
if err := json.Unmarshal(payload, &claims); err != nil {
return nil, err
}
return claims, nil
}
func extractClaimAsString(claims map[string]interface{}, key string) string {
key = strings.TrimSpace(key)
if key == "" {
return ""
}
value, ok := claims[key]
if !ok || value == nil {
return ""
}
switch v := value.(type) {
case string:
return strings.TrimSpace(v)
default:
return strings.TrimSpace(fmt.Sprint(v))
}
}
func sanitizeUsernameCandidate(value string) string {
value = strings.TrimSpace(strings.ToLower(value))
if value == "" {
return ""
}
var b strings.Builder
lastDash := false
for _, r := range value {
if (r >= 'a' && r <= 'z') || (r >= '0' && r <= '9') || r == '_' || r == '.' {
b.WriteRune(r)
lastDash = false
continue
}
if !lastDash {
b.WriteByte('-')
lastDash = true
}
}
result := strings.Trim(b.String(), "-._")
if len(result) > 50 {
result = strings.Trim(result[:50], "-._")
}
return result
}
func isUserLookupNotFound(err error) bool {
if err == nil {
return false
}
return errors.Is(err, apprepo.ErrUserNotFound) || strings.Contains(strings.ToLower(err.Error()), "user not found")
}