fix(audit_log): prevent deadlock in Stop() when Start() is never called

Updated the AuditLogRetentionRunner to include a 'started' flag, allowing Stop() to return immediately if Start() was never invoked. This change prevents potential deadlocks during graceful shutdown. Additionally, added tests to ensure that Stop() behaves correctly when called before Start() and when the service is nil.

Refs: #1303
This commit is contained in:
wizardchen
2026-05-17 22:27:00 +08:00
committed by lyingbug
parent 884ac44283
commit 14f62a6e14
6 changed files with 173 additions and 33 deletions

View File

@@ -87,11 +87,26 @@ func (s *auditLogService) LogDenied(
actorUserID, actorRole string,
requiredRole types.TenantRole,
) error {
requestPath := ""
rawPath := ""
requestMethod := ""
// dedupPath keys both the sliding-window dedup AND the persisted
// request_path column. We prefer the route TEMPLATE (gin's
// c.FullPath, e.g. "/api/v1/knowledge-bases/:id") over the raw URL:
// without this, an attacker iterating UUIDs in the URL produces a
// fresh dedup key per request, defeating the window and ballooning
// audit_logs. The raw URL is preserved inside the Details JSON for
// forensics, so we don't lose "which resource was probed".
dedupPath := ""
if c != nil && c.Request != nil {
requestPath = c.Request.URL.Path
rawPath = c.Request.URL.Path
requestMethod = c.Request.Method
dedupPath = c.FullPath()
if dedupPath == "" {
// No matched route (e.g. 404 path that still hit the
// middleware via a catch-all). Fall back to the raw path so
// the row remains non-empty for the audit reader.
dedupPath = rawPath
}
}
// Dedup probe: skip the durable write if this exact tuple already
@@ -100,18 +115,22 @@ func (s *auditLogService) LogDenied(
// "skip the audit because the count failed".
since := s.now().Add(-denyDedupWindow)
if n, err := s.repo.CountSinceForDedup(
ctx, tenantID, actorUserID, types.AuditActionAccessDenied, requestPath, since,
ctx, tenantID, actorUserID, types.AuditActionAccessDenied, dedupPath, since,
); err == nil && n > 0 {
return nil
}
details, _ := json.Marshal(map[string]string{"required_role": string(requiredRole)})
detailsMap := map[string]string{"required_role": string(requiredRole)}
if rawPath != "" && rawPath != dedupPath {
detailsMap["raw_path"] = rawPath
}
details, _ := json.Marshal(detailsMap)
return s.Log(ctx, &types.AuditLog{
TenantID: tenantID,
ActorUserID: actorUserID,
ActorRole: actorRole,
Action: types.AuditActionAccessDenied,
RequestPath: requestPath,
RequestPath: dedupPath,
RequestMethod: requestMethod,
Outcome: types.AuditOutcomeDenied,
Details: types.JSON(details),

View File

@@ -3,6 +3,7 @@ package service
import (
"context"
"sync"
"sync/atomic"
"time"
"github.com/Tencent/WeKnora/internal/config"
@@ -29,6 +30,13 @@ type AuditLogRetentionRunner struct {
stopOnce sync.Once
stopCh chan struct{}
doneCh chan struct{}
// started is set inside startOnce.Do BEFORE doneCh is wired to a
// goroutine, so Stop() can tell "Start was never called" apart from
// "Start is running" without blocking on doneCh. Without this, a
// runner that was constructed but never Start()'d (early container
// init failure, test setup that skips Start) would deadlock Stop()
// on a doneCh nobody ever closes.
started atomic.Bool
}
// auditLogPurgeInterval is the gap between sweeps. 24h is enough for
@@ -76,6 +84,7 @@ func (r *AuditLogRetentionRunner) Start(ctx context.Context) {
return
}
r.startOnce.Do(func() {
r.started.Store(true)
if r.retentionDays <= 0 {
logger.Infof(ctx,
"[audit-retention] disabled (retention_days=%d)", r.retentionDays)
@@ -90,11 +99,15 @@ func (r *AuditLogRetentionRunner) Start(ctx context.Context) {
}
// Stop signals the loop to exit and blocks until it returns. Idempotent.
// If Start was never called, Stop returns immediately.
// If Start was never called, Stop returns immediately (no doneCh to
// wait on — see the `started` flag in the struct comment).
func (r *AuditLogRetentionRunner) Stop() {
if r == nil {
return
}
if !r.started.Load() {
return
}
r.stopOnce.Do(func() {
close(r.stopCh)
})

View File

@@ -166,9 +166,10 @@ func TestAuditLogRetentionRunner_StartIsIdempotent(t *testing.T) {
func TestAuditLogRetentionRunner_NilSvcShortCircuits(t *testing.T) {
// Defensive: a misconfigured container (audit service couldn't
// be constructed) must not crash the app. Start with nil svc is
// a no-op. We don't call Stop here because Start never closes
// doneCh on the nil-svc path (returns before the once block runs);
// blocking on doneCh inside Stop would hang the test.
// a no-op, and the subsequent Stop must NOT hang on a doneCh
// nobody ever closed (regression: when startOnce.Do returned early
// without closing doneCh, Stop would block forever, deadlocking
// graceful shutdown).
r := &AuditLogRetentionRunner{
retentionDays: 90,
interval: time.Millisecond,
@@ -176,6 +177,40 @@ func TestAuditLogRetentionRunner_NilSvcShortCircuits(t *testing.T) {
doneCh: make(chan struct{}),
}
r.Start(context.Background())
done := make(chan struct{})
go func() {
r.Stop()
close(done)
}()
select {
case <-done:
case <-time.After(2 * time.Second):
t.Fatal("Stop() hung after Start() short-circuited on nil svc")
}
}
func TestAuditLogRetentionRunner_StopBeforeStart(t *testing.T) {
// Container teardown ordering can run Stop before Start ever fires
// (early init failure, test cleanup). The runner must treat this
// as a no-op rather than blocking on doneCh.
r := &AuditLogRetentionRunner{
svc: &purgeCountingService{},
retentionDays: 90,
interval: time.Millisecond,
stopCh: make(chan struct{}),
doneCh: make(chan struct{}),
}
done := make(chan struct{})
go func() {
r.Stop()
close(done)
}()
select {
case <-done:
case <-time.After(2 * time.Second):
t.Fatal("Stop() hung when called before Start()")
}
}
// retentionRunnerWithImmediateStartup builds a runner whose startup

View File

@@ -11,8 +11,34 @@ import (
"github.com/Tencent/WeKnora/internal/logger"
"github.com/Tencent/WeKnora/internal/types"
"github.com/Tencent/WeKnora/internal/types/interfaces"
"gorm.io/gorm"
)
// isDuplicateMembership recognises the unique-constraint violation that
// the tenant_members partial unique index throws when two concurrent
// AddMember / EnsureOwner calls race past the in-service Get() check.
// We map this to ErrMembershipAlreadyExists so handlers can return 409
// instead of an opaque 500; the underlying DB still rejects the second
// insert, so this is purely about error-translation, not weakening any
// invariant.
//
// gorm.ErrDuplicatedKey covers the dialect-translated case (gorm ≥1.25
// with TranslateError enabled). The string match on "duplicate" /
// "unique" is the fallback for raw drivers that don't surface the
// sentinel — Postgres "duplicate key value violates unique constraint",
// SQLite "UNIQUE constraint failed", MySQL "Duplicate entry" all
// contain at least one of those tokens.
func isDuplicateMembership(err error) bool {
if err == nil {
return false
}
if errors.Is(err, gorm.ErrDuplicatedKey) {
return true
}
msg := strings.ToLower(err.Error())
return strings.Contains(msg, "duplicate") || strings.Contains(msg, "unique constraint")
}
// Sentinel errors returned by tenantMemberService. Callers compare with
// errors.Is to render appropriate HTTP responses (404 / 409 / 403).
var (
@@ -114,6 +140,14 @@ func (s *tenantMemberService) AddMember(
JoinedAt: time.Now(),
}
if err := s.repo.Create(ctx, member); err != nil {
// TOCTOU race: a concurrent AddMember / EnsureOwner slipped past
// the Get above. The DB's partial unique index on
// (user_id, tenant_id) WHERE deleted_at IS NULL caught it; map
// to the same sentinel the in-service check would have returned
// so callers get a clean 409 instead of an opaque 500.
if isDuplicateMembership(err) {
return nil, ErrMembershipAlreadyExists
}
return nil, err
}
s.emitAudit(ctx, &types.AuditLog{
@@ -152,6 +186,20 @@ func (s *tenantMemberService) EnsureOwner(
JoinedAt: time.Now(),
}
if err := s.repo.Create(ctx, member); err != nil {
// Idempotent contract: if a concurrent Ensure/AddMember beat us
// (two simultaneous registrations of the same user, or the
// orphan-tenant self-heal path firing on parallel JWTs), the
// partial unique index rejects the second insert. Re-read and
// return the winning row so EnsureOwner stays observably
// idempotent.
if isDuplicateMembership(err) {
if winner, getErr := s.repo.Get(ctx, userID, tenantID); getErr == nil && winner != nil {
logger.Infof(ctx,
"EnsureOwner lost race for user=%s tenant=%d, returning winning row (role=%s)",
userID, tenantID, winner.Role)
return winner, nil
}
}
return nil, err
}
logger.Infof(ctx, "Bootstrapped owner membership for user=%s tenant=%d", userID, tenantID)

View File

@@ -280,6 +280,21 @@ func TestTenantMemberService_AddMember_RejectsDuplicate(t *testing.T) {
}
}
func TestTenantMemberService_AddMember_MapsDuplicateKeyRace(t *testing.T) {
// Simulate the TOCTOU race: Get() saw no row, then a concurrent
// AddMember inserted before us, so our Create hits the partial
// unique index. The DB returns a duplicate-key error, which the
// service must translate into ErrMembershipAlreadyExists so the
// handler returns 409 rather than a generic 500.
svc, repo := newServiceWithRepo()
repo.failCreate = errors.New(
"ERROR: duplicate key value violates unique constraint \"idx_tenant_members_user_tenant_unique\"")
_, err := svc.AddMember(context.Background(), "u_race", 1, types.TenantRoleContributor, nil)
if !errors.Is(err, ErrMembershipAlreadyExists) {
t.Fatalf("want ErrMembershipAlreadyExists on duplicate-key race, got %v", err)
}
}
func TestTenantMemberService_EnsureOwner_Idempotent(t *testing.T) {
svc, repo := newServiceWithRepo()
ctx := context.Background()

View File

@@ -83,36 +83,46 @@ func Auth(
crossTenantSwitch := targetTenantID != user.TenantID
tenantHeader := c.GetHeader("X-Tenant-ID")
if tenantHeader != "" {
// 解析目标租户ID
// 解析目标租户ID。畸形 / 零值必须显式拒绝:静默忽略会让坏掉的
// 前端/SDK 悄悄写错租户,反而看不到问题。与 RequirePathTenantMatch
// 中对 :id 的校验保持一致(非空、可解析、>0
parsedTenantID, err := strconv.ParseUint(tenantHeader, 10, 64)
if err == nil {
// 检查用户是否有权限访问目标租户:自家租户、跨租户超管、或
// 有 active membership 行——三选一,由 IsTenantAccessible
// 统一判定。
if IsTenantAccessible(c.Request.Context(), user, parsedTenantID, memberService, cfg) {
// 验证目标租户是否存在
targetTenant, err := tenantService.GetTenantByID(c.Request.Context(), parsedTenantID)
if err == nil && targetTenant != nil {
targetTenantID = parsedTenantID
crossTenantSwitch = parsedTenantID != user.TenantID
log.Printf("User %s switching to tenant %d", user.ID, targetTenantID)
} else {
log.Printf("Error getting target tenant by ID: %v, tenantID: %d", err, parsedTenantID)
c.JSON(http.StatusBadRequest, gin.H{
"error": "Invalid target tenant ID",
})
c.Abort()
return
}
if err != nil || parsedTenantID == 0 {
logger.Warnf(c.Request.Context(),
"Invalid X-Tenant-ID header from user=%s: %q (err=%v)",
user.ID, tenantHeader, err)
c.JSON(http.StatusBadRequest, gin.H{
"error": "Invalid X-Tenant-ID header",
})
c.Abort()
return
}
// 检查用户是否有权限访问目标租户:自家租户、跨租户超管、或
// 有 active membership 行——三选一,由 IsTenantAccessible
// 统一判定。
if IsTenantAccessible(c.Request.Context(), user, parsedTenantID, memberService, cfg) {
// 验证目标租户是否存在
targetTenant, err := tenantService.GetTenantByID(c.Request.Context(), parsedTenantID)
if err == nil && targetTenant != nil {
targetTenantID = parsedTenantID
crossTenantSwitch = parsedTenantID != user.TenantID
log.Printf("User %s switching to tenant %d", user.ID, targetTenantID)
} else {
// 用户没有权限访问目标租户
log.Printf("User %s attempted to access tenant %d without permission", user.ID, parsedTenantID)
c.JSON(http.StatusForbidden, gin.H{
"error": "Forbidden: insufficient permissions to access target tenant",
log.Printf("Error getting target tenant by ID: %v, tenantID: %d", err, parsedTenantID)
c.JSON(http.StatusBadRequest, gin.H{
"error": "Invalid target tenant ID",
})
c.Abort()
return
}
} else {
// 用户没有权限访问目标租户
log.Printf("User %s attempted to access tenant %d without permission", user.ID, parsedTenantID)
c.JSON(http.StatusForbidden, gin.H{
"error": "Forbidden: insufficient permissions to access target tenant",
})
c.Abort()
return
}
}