fix(session): scope session access by user

This commit is contained in:
wolfkill
2026-05-13 16:22:06 +08:00
committed by lyingbug
parent 7478e3cddb
commit 8f4e5a459f
7 changed files with 418 additions and 70 deletions

View File

@@ -2,9 +2,11 @@ package repository
import (
"context"
stderrors "errors"
"strings"
"time"
apperrors "github.com/Tencent/WeKnora/internal/errors"
"github.com/Tencent/WeKnora/internal/types"
"github.com/Tencent/WeKnora/internal/types/interfaces"
"gorm.io/gorm"
@@ -15,6 +17,14 @@ type sessionRepository struct {
db *gorm.DB
}
func applySessionUserScope(db *gorm.DB, userID string) *gorm.DB {
if userID == "" {
return db
}
// Empty user_id rows are legacy/API-created tenant-level sessions.
return db.Where("(user_id = ? OR user_id IS NULL OR user_id = '')", userID)
}
// NewSessionRepository creates a new session repository instance
func NewSessionRepository(db *gorm.DB) interfaces.SessionRepository {
return &sessionRepository{db: db}
@@ -32,19 +42,28 @@ func (r *sessionRepository) Create(ctx context.Context, session *types.Session)
}
// Get retrieves a session by ID
func (r *sessionRepository) Get(ctx context.Context, tenantID uint64, id string) (*types.Session, error) {
func (r *sessionRepository) Get(ctx context.Context, tenantID uint64, userID string, id string) (*types.Session, error) {
var session types.Session
err := r.db.WithContext(ctx).Where("tenant_id = ?", tenantID).First(&session, "id = ?", id).Error
err := applySessionUserScope(
r.db.WithContext(ctx).Where("tenant_id = ? AND id = ?", tenantID, id),
userID,
).First(&session).Error
if err != nil {
if stderrors.Is(err, gorm.ErrRecordNotFound) {
return nil, apperrors.ErrSessionNotFound
}
return nil, err
}
return &session, nil
}
// GetByTenantID retrieves all sessions for a tenant
func (r *sessionRepository) GetByTenantID(ctx context.Context, tenantID uint64) ([]*types.Session, error) {
func (r *sessionRepository) GetByTenantID(ctx context.Context, tenantID uint64, userID string) ([]*types.Session, error) {
var sessions []*types.Session
err := r.db.WithContext(ctx).Where("tenant_id = ?", tenantID).Order("updated_at DESC").Find(&sessions).Error
err := applySessionUserScope(
r.db.WithContext(ctx).Where("tenant_id = ?", tenantID),
userID,
).Order("updated_at DESC").Find(&sessions).Error
if err != nil {
return nil, err
}
@@ -53,20 +72,26 @@ func (r *sessionRepository) GetByTenantID(ctx context.Context, tenantID uint64)
// GetPagedByTenantID retrieves sessions for a tenant with pagination
func (r *sessionRepository) GetPagedByTenantID(
ctx context.Context, tenantID uint64, page *types.Pagination,
ctx context.Context, tenantID uint64, userID string, page *types.Pagination,
) ([]*types.Session, int64, error) {
var sessions []*types.Session
var total int64
// First query the total count
err := r.db.WithContext(ctx).Model(&types.Session{}).Where("tenant_id = ?", tenantID).Count(&total).Error
baseQ := applySessionUserScope(
r.db.WithContext(ctx).Model(&types.Session{}).Where("tenant_id = ?", tenantID),
userID,
)
err := baseQ.Count(&total).Error
if err != nil {
return nil, 0, err
}
// Then query the paginated data
err = r.db.WithContext(ctx).
Where("tenant_id = ?", tenantID).
err = applySessionUserScope(
r.db.WithContext(ctx).Where("tenant_id = ?", tenantID),
userID,
).
Order("updated_at DESC").
Offset(page.Offset()).
Limit(page.Limit()).
@@ -201,32 +226,45 @@ func (r *sessionRepository) SetPinned(
}
// Update updates a session
func (r *sessionRepository) Update(ctx context.Context, session *types.Session) error {
func (r *sessionRepository) Update(ctx context.Context, session *types.Session, userID string) (int64, error) {
session.UpdatedAt = time.Now()
return r.db.WithContext(ctx).
res := applySessionUserScope(r.db.WithContext(ctx).
Model(&types.Session{}).
Where("tenant_id = ? AND id = ?", session.TenantID, session.ID).
Where("tenant_id = ? AND id = ?", session.TenantID, session.ID), userID).
Updates(map[string]interface{}{
"title": session.Title,
"description": session.Description,
"updated_at": session.UpdatedAt,
}).Error
})
return res.RowsAffected, res.Error
}
// Delete deletes a session
func (r *sessionRepository) Delete(ctx context.Context, tenantID uint64, id string) error {
return r.db.WithContext(ctx).Where("tenant_id = ?", tenantID).Delete(&types.Session{}, "id = ?", id).Error
func (r *sessionRepository) Delete(ctx context.Context, tenantID uint64, userID string, id string) (int64, error) {
res := applySessionUserScope(
r.db.WithContext(ctx).Where("tenant_id = ? AND id = ?", tenantID, id),
userID,
).Delete(&types.Session{})
return res.RowsAffected, res.Error
}
// BatchDelete deletes multiple sessions by IDs
func (r *sessionRepository) BatchDelete(ctx context.Context, tenantID uint64, ids []string) error {
func (r *sessionRepository) BatchDelete(ctx context.Context, tenantID uint64, userID string, ids []string) (int64, error) {
if len(ids) == 0 {
return nil
return 0, nil
}
return r.db.WithContext(ctx).Where("tenant_id = ? AND id IN ?", tenantID, ids).Delete(&types.Session{}).Error
res := applySessionUserScope(
r.db.WithContext(ctx).Where("tenant_id = ? AND id IN ?", tenantID, ids),
userID,
).Delete(&types.Session{})
return res.RowsAffected, res.Error
}
// DeleteAllByTenantID deletes all sessions for a tenant
func (r *sessionRepository) DeleteAllByTenantID(ctx context.Context, tenantID uint64) error {
return r.db.WithContext(ctx).Where("tenant_id = ?", tenantID).Delete(&types.Session{}).Error
func (r *sessionRepository) DeleteAllByTenantID(ctx context.Context, tenantID uint64, userID string) (int64, error) {
res := applySessionUserScope(
r.db.WithContext(ctx).Where("tenant_id = ?", tenantID),
userID,
).Delete(&types.Session{})
return res.RowsAffected, res.Error
}

View File

@@ -0,0 +1,163 @@
package repository
import (
"context"
"testing"
apperrors "github.com/Tencent/WeKnora/internal/errors"
"github.com/Tencent/WeKnora/internal/types"
"github.com/Tencent/WeKnora/internal/types/interfaces"
"github.com/stretchr/testify/require"
"gorm.io/driver/sqlite"
"gorm.io/gorm"
)
func newSessionRepositoryForTest(t *testing.T) (interfaces.SessionRepository, *gorm.DB) {
t.Helper()
db, err := gorm.Open(sqlite.Open(":memory:"), &gorm.Config{})
require.NoError(t, err)
require.NoError(t, db.AutoMigrate(&types.Session{}))
return NewSessionRepository(db), db
}
func createSessionForTest(t *testing.T, db *gorm.DB, tenantID uint64, userID string) *types.Session {
t.Helper()
session := &types.Session{
TenantID: tenantID,
UserID: userID,
Title: userID + " session",
}
if userID == "" {
session.Title = "legacy tenant session"
}
require.NoError(t, db.Create(session).Error)
return session
}
func countActiveSessionsForTest(t *testing.T, db *gorm.DB, id string) int64 {
t.Helper()
var count int64
require.NoError(t, db.Model(&types.Session{}).Where("id = ?", id).Count(&count).Error)
return count
}
func sessionIDsForTest(sessions []*types.Session) []string {
ids := make([]string, 0, len(sessions))
for _, session := range sessions {
ids = append(ids, session.ID)
}
return ids
}
func TestSessionRepositoryGetAndListHonorUserScope(t *testing.T) {
repo, db := newSessionRepositoryForTest(t)
ctx := context.Background()
aliceSession := createSessionForTest(t, db, 1, "alice")
bobSession := createSessionForTest(t, db, 1, "bob")
legacySession := createSessionForTest(t, db, 1, "")
_ = createSessionForTest(t, db, 2, "bob")
_, err := repo.Get(ctx, 1, "bob", aliceSession.ID)
require.ErrorIs(t, err, apperrors.ErrSessionNotFound)
got, err := repo.Get(ctx, 1, "bob", bobSession.ID)
require.NoError(t, err)
require.Equal(t, bobSession.ID, got.ID)
got, err = repo.Get(ctx, 1, "bob", legacySession.ID)
require.NoError(t, err)
require.Equal(t, legacySession.ID, got.ID)
sessions, err := repo.GetByTenantID(ctx, 1, "bob")
require.NoError(t, err)
require.ElementsMatch(t, []string{bobSession.ID, legacySession.ID}, sessionIDsForTest(sessions))
paged, total, err := repo.GetPagedByTenantID(ctx, 1, "bob", &types.Pagination{Page: 1, PageSize: 10})
require.NoError(t, err)
require.EqualValues(t, 2, total)
require.ElementsMatch(t, []string{bobSession.ID, legacySession.ID}, sessionIDsForTest(paged))
}
func TestSessionRepositoryUpdateHonorsUserScope(t *testing.T) {
repo, db := newSessionRepositoryForTest(t)
ctx := context.Background()
aliceSession := createSessionForTest(t, db, 1, "alice")
rows, err := repo.Update(ctx, &types.Session{
ID: aliceSession.ID,
TenantID: aliceSession.TenantID,
Title: "bob update attempt",
}, "bob")
require.NoError(t, err)
require.Zero(t, rows)
var unchanged types.Session
require.NoError(t, db.First(&unchanged, "id = ?", aliceSession.ID).Error)
require.Equal(t, aliceSession.Title, unchanged.Title)
rows, err = repo.Update(ctx, &types.Session{
ID: aliceSession.ID,
TenantID: aliceSession.TenantID,
Title: "alice updated session",
}, "alice")
require.NoError(t, err)
require.EqualValues(t, 1, rows)
var changed types.Session
require.NoError(t, db.First(&changed, "id = ?", aliceSession.ID).Error)
require.Equal(t, "alice updated session", changed.Title)
}
func TestSessionRepositoryDeleteHonorsUserScope(t *testing.T) {
repo, db := newSessionRepositoryForTest(t)
ctx := context.Background()
aliceSession := createSessionForTest(t, db, 1, "alice")
bobSession := createSessionForTest(t, db, 1, "bob")
rows, err := repo.Delete(ctx, 1, "bob", aliceSession.ID)
require.NoError(t, err)
require.Zero(t, rows)
require.EqualValues(t, 1, countActiveSessionsForTest(t, db, aliceSession.ID))
rows, err = repo.Delete(ctx, 1, "bob", bobSession.ID)
require.NoError(t, err)
require.EqualValues(t, 1, rows)
require.Zero(t, countActiveSessionsForTest(t, db, bobSession.ID))
}
func TestSessionRepositoryBatchDeleteHonorsUserScope(t *testing.T) {
repo, db := newSessionRepositoryForTest(t)
ctx := context.Background()
aliceSession := createSessionForTest(t, db, 1, "alice")
bobSession := createSessionForTest(t, db, 1, "bob")
legacySession := createSessionForTest(t, db, 1, "")
rows, err := repo.BatchDelete(ctx, 1, "bob", []string{aliceSession.ID, bobSession.ID, legacySession.ID})
require.NoError(t, err)
require.EqualValues(t, 2, rows)
require.EqualValues(t, 1, countActiveSessionsForTest(t, db, aliceSession.ID))
require.Zero(t, countActiveSessionsForTest(t, db, bobSession.ID))
require.Zero(t, countActiveSessionsForTest(t, db, legacySession.ID))
}
func TestSessionRepositoryDeleteAllHonorsUserScope(t *testing.T) {
repo, db := newSessionRepositoryForTest(t)
ctx := context.Background()
aliceSession := createSessionForTest(t, db, 1, "alice")
bobSession := createSessionForTest(t, db, 1, "bob")
legacySession := createSessionForTest(t, db, 1, "")
otherTenantSession := createSessionForTest(t, db, 2, "bob")
rows, err := repo.DeleteAllByTenantID(ctx, 1, "bob")
require.NoError(t, err)
require.EqualValues(t, 2, rows)
require.EqualValues(t, 1, countActiveSessionsForTest(t, db, aliceSession.ID))
require.Zero(t, countActiveSessionsForTest(t, db, bobSession.ID))
require.Zero(t, countActiveSessionsForTest(t, db, legacySession.ID))
require.EqualValues(t, 1, countActiveSessionsForTest(t, db, otherTenantSession.ID))
}

View File

@@ -22,12 +22,12 @@ var regThinkIndex = regexp.MustCompile(`(?s)<think>.*?</think>`)
// It reads the chat history knowledge base configuration from the tenant's ChatHistoryConfig,
// which is managed via the settings UI.
type messageService struct {
messageRepo interfaces.MessageRepository // Repository for message storage operations
sessionRepo interfaces.SessionRepository // Repository for session validation
tenantService interfaces.TenantService // Service for tenant operations (read ChatHistoryConfig)
kbService interfaces.KnowledgeBaseService // Service for knowledge base operations (search chat history KB)
knowService interfaces.KnowledgeService // Service for knowledge operations (index/delete passages)
modelService interfaces.ModelService // Service for model operations (rerank model)
messageRepo interfaces.MessageRepository // Repository for message storage operations
sessionRepo interfaces.SessionRepository // Repository for session validation
tenantService interfaces.TenantService // Service for tenant operations (read ChatHistoryConfig)
kbService interfaces.KnowledgeBaseService // Service for knowledge base operations (search chat history KB)
knowService interfaces.KnowledgeService // Service for knowledge operations (index/delete passages)
modelService interfaces.ModelService // Service for model operations (rerank model)
}
// NewMessageService creates a new message service instance with the required repositories
@@ -64,6 +64,15 @@ func sessionTenantIDForLookup(ctx context.Context) (uint64, bool) {
return 0, false
}
func sessionUserIDForLookup(ctx context.Context) string {
if ctx.Value(types.SessionTenantIDContextKey) != nil {
// Shared-agent pipelines resolve the session owner tenant first; keep that internal lookup tenant-scoped.
return ""
}
userID, _ := types.UserIDFromContext(ctx)
return userID
}
// CreateMessage creates a new message within an existing session
func (s *messageService) CreateMessage(ctx context.Context, message *types.Message) (*types.Message, error) {
logger.Info(ctx, "Start creating message")
@@ -71,7 +80,7 @@ func (s *messageService) CreateMessage(ctx context.Context, message *types.Messa
tenantID := types.MustTenantIDFromContext(ctx)
logger.Infof(ctx, "Checking if session exists, tenant ID: %d, session ID: %s", tenantID, message.SessionID)
_, err := s.sessionRepo.Get(ctx, tenantID, message.SessionID)
_, err := s.sessionRepo.Get(ctx, tenantID, sessionUserIDForLookup(ctx), message.SessionID)
if err != nil {
logger.Errorf(ctx, "Failed to get session: %v", err)
return nil, err
@@ -97,7 +106,7 @@ func (s *messageService) GetMessage(ctx context.Context, sessionID string, messa
tenantID := types.MustTenantIDFromContext(ctx)
logger.Infof(ctx, "Checking if session exists, tenant ID: %d", tenantID)
_, err := s.sessionRepo.Get(ctx, tenantID, sessionID)
_, err := s.sessionRepo.Get(ctx, tenantID, sessionUserIDForLookup(ctx), sessionID)
if err != nil {
logger.Errorf(ctx, "Failed to get session: %v", err)
return nil, err
@@ -126,7 +135,7 @@ func (s *messageService) GetMessagesBySession(ctx context.Context,
tenantID := types.MustTenantIDFromContext(ctx)
logger.Infof(ctx, "Checking if session exists, tenant ID: %d", tenantID)
_, err := s.sessionRepo.Get(ctx, tenantID, sessionID)
_, err := s.sessionRepo.Get(ctx, tenantID, sessionUserIDForLookup(ctx), sessionID)
if err != nil {
logger.Errorf(ctx, "Failed to get session: %v", err)
return nil, err
@@ -160,7 +169,7 @@ func (s *messageService) GetRecentMessagesBySession(ctx context.Context,
return nil, errors.New("tenant ID not found in context")
}
logger.Infof(ctx, "Checking if session exists, tenant ID: %d", tenantID)
_, err := s.sessionRepo.Get(ctx, tenantID, sessionID)
_, err := s.sessionRepo.Get(ctx, tenantID, sessionUserIDForLookup(ctx), sessionID)
if err != nil {
logger.Errorf(ctx, "Failed to get session: %v", err)
return nil, err
@@ -193,7 +202,7 @@ func (s *messageService) GetMessagesBySessionBeforeTime(ctx context.Context,
return nil, errors.New("tenant ID not found in context")
}
logger.Infof(ctx, "Checking if session exists, tenant ID: %d", tenantID)
_, err := s.sessionRepo.Get(ctx, tenantID, sessionID)
_, err := s.sessionRepo.Get(ctx, tenantID, sessionUserIDForLookup(ctx), sessionID)
if err != nil {
logger.Errorf(ctx, "Failed to get session: %v", err)
return nil, err
@@ -221,7 +230,7 @@ func (s *messageService) UpdateMessage(ctx context.Context, message *types.Messa
tenantID := types.MustTenantIDFromContext(ctx)
logger.Infof(ctx, "Checking if session exists, tenant ID: %d", tenantID)
_, err := s.sessionRepo.Get(ctx, tenantID, message.SessionID)
_, err := s.sessionRepo.Get(ctx, tenantID, sessionUserIDForLookup(ctx), message.SessionID)
if err != nil {
logger.Errorf(ctx, "Failed to get session: %v", err)
return err
@@ -258,7 +267,7 @@ func (s *messageService) DeleteMessage(ctx context.Context, sessionID string, me
tenantID := types.MustTenantIDFromContext(ctx)
logger.Infof(ctx, "Checking if session exists, tenant ID: %d", tenantID)
_, err := s.sessionRepo.Get(ctx, tenantID, sessionID)
_, err := s.sessionRepo.Get(ctx, tenantID, sessionUserIDForLookup(ctx), sessionID)
if err != nil {
logger.Errorf(ctx, "Failed to get session: %v", err)
return err
@@ -298,7 +307,7 @@ func (s *messageService) ClearSessionMessages(ctx context.Context, sessionID str
logger.Infof(ctx, "Start clearing all messages for session: %s", sessionID)
tenantID := types.MustTenantIDFromContext(ctx)
if _, err := s.sessionRepo.Get(ctx, tenantID, sessionID); err != nil {
if _, err := s.sessionRepo.Get(ctx, tenantID, sessionUserIDForLookup(ctx), sessionID); err != nil {
logger.Errorf(ctx, "Failed to get session: %v", err)
return err
}

View File

@@ -2,11 +2,12 @@ package service
import (
"context"
"errors"
stderrors "errors"
"fmt"
"strings"
"github.com/Tencent/WeKnora/internal/config"
apperrors "github.com/Tencent/WeKnora/internal/errors"
"github.com/Tencent/WeKnora/internal/event"
"github.com/Tencent/WeKnora/internal/logger"
"github.com/Tencent/WeKnora/internal/models/chat"
@@ -17,6 +18,11 @@ import (
chatpipeline "github.com/Tencent/WeKnora/internal/application/service/chat_pipeline"
)
func sessionUserIDFromContext(ctx context.Context) string {
userID, _ := types.UserIDFromContext(ctx)
return userID
}
// generateEventID generates a unique event ID with type suffix for better traceability
func generateEventID(suffix string) string {
return fmt.Sprintf("%s-%s", uuid.New().String()[:8], suffix)
@@ -84,7 +90,7 @@ func (s *sessionService) CreateSession(ctx context.Context, session *types.Sessi
// Validate tenant ID
if session.TenantID == 0 {
logger.Error(ctx, "Failed to create session: tenant ID cannot be empty")
return nil, errors.New("tenant ID is required")
return nil, stderrors.New("tenant ID is required")
}
logger.Infof(ctx, "Creating session, tenant ID: %d", session.TenantID)
@@ -106,15 +112,16 @@ func (s *sessionService) GetSession(ctx context.Context, id string) (*types.Sess
// Validate session ID
if id == "" {
logger.Error(ctx, "Failed to get session: session ID cannot be empty")
return nil, errors.New("session id is required")
return nil, stderrors.New("session id is required")
}
// Get tenant ID from context
tenantID := types.MustTenantIDFromContext(ctx)
userID := sessionUserIDFromContext(ctx)
logger.Infof(ctx, "Retrieving session, ID: %s, tenant ID: %d", id, tenantID)
// Get session from repository
session, err := s.sessionRepo.Get(ctx, tenantID, id)
session, err := s.sessionRepo.Get(ctx, tenantID, userID, id)
if err != nil {
logger.ErrorWithFields(ctx, err, map[string]interface{}{
"session_id": id,
@@ -131,10 +138,11 @@ func (s *sessionService) GetSession(ctx context.Context, id string) (*types.Sess
func (s *sessionService) GetSessionsByTenant(ctx context.Context) ([]*types.Session, error) {
// Get tenant ID from context
tenantID := types.MustTenantIDFromContext(ctx)
userID := sessionUserIDFromContext(ctx)
logger.Infof(ctx, "Retrieving all sessions for tenant, tenant ID: %d", tenantID)
// Get sessions from repository
sessions, err := s.sessionRepo.GetByTenantID(ctx, tenantID)
sessions, err := s.sessionRepo.GetByTenantID(ctx, tenantID, userID)
if err != nil {
logger.ErrorWithFields(ctx, err, map[string]interface{}{
"tenant_id": tenantID,
@@ -154,8 +162,9 @@ func (s *sessionService) GetPagedSessionsByTenant(ctx context.Context,
) (*types.PageResult, error) {
// Get tenant ID from context
tenantID := types.MustTenantIDFromContext(ctx)
userID := sessionUserIDFromContext(ctx)
// Get paged sessions from repository
sessions, total, err := s.sessionRepo.GetPagedByTenantID(ctx, tenantID, pagination)
sessions, total, err := s.sessionRepo.GetPagedByTenantID(ctx, tenantID, userID, pagination)
if err != nil {
logger.ErrorWithFields(ctx, err, map[string]interface{}{
"tenant_id": tenantID,
@@ -204,10 +213,10 @@ func (s *sessionService) SetSessionPinned(
ctx context.Context, sessionID string, pinned bool,
) (int64, error) {
if sessionID == "" {
return 0, errors.New("session id is required")
return 0, stderrors.New("session id is required")
}
tenantID := types.MustTenantIDFromContext(ctx)
userID, _ := types.UserIDFromContext(ctx)
userID := sessionUserIDFromContext(ctx)
return s.sessionRepo.SetPinned(ctx, tenantID, userID, sessionID, pinned)
}
@@ -216,11 +225,16 @@ func (s *sessionService) UpdateSession(ctx context.Context, session *types.Sessi
// Validate session ID
if session.ID == "" {
logger.Error(ctx, "Failed to update session: session ID cannot be empty")
return errors.New("session id is required")
return stderrors.New("session id is required")
}
// Update session in repository
err := s.sessionRepo.Update(ctx, session)
userID := sessionUserIDFromContext(ctx)
if _, err := s.sessionRepo.Get(ctx, session.TenantID, userID, session.ID); err != nil {
return err
}
_, err := s.sessionRepo.Update(ctx, session, userID)
if err != nil {
logger.ErrorWithFields(ctx, err, map[string]interface{}{
"session_id": session.ID,
@@ -238,11 +252,16 @@ func (s *sessionService) DeleteSession(ctx context.Context, id string) error {
// Validate session ID
if id == "" {
logger.Error(ctx, "Failed to delete session: session ID cannot be empty")
return errors.New("session id is required")
return stderrors.New("session id is required")
}
// Get tenant ID from context
tenantID := types.MustTenantIDFromContext(ctx)
userID := sessionUserIDFromContext(ctx)
if _, err := s.sessionRepo.Get(ctx, tenantID, userID, id); err != nil {
return err
}
// Cleanup chat history knowledge entries for this session (async, best-effort).
// Use WithoutCancel so the goroutine survives after the HTTP request context is done.
@@ -266,7 +285,7 @@ func (s *sessionService) DeleteSession(ctx context.Context, id string) error {
}
// Delete session from repository
err := s.sessionRepo.Delete(ctx, tenantID, id)
rows, err := s.sessionRepo.Delete(ctx, tenantID, userID, id)
if err != nil {
logger.ErrorWithFields(ctx, err, map[string]interface{}{
"session_id": id,
@@ -274,6 +293,9 @@ func (s *sessionService) DeleteSession(ctx context.Context, id string) error {
})
return err
}
if rows == 0 {
return apperrors.ErrSessionNotFound
}
return nil
}
@@ -282,15 +304,28 @@ func (s *sessionService) DeleteSession(ctx context.Context, id string) error {
func (s *sessionService) BatchDeleteSessions(ctx context.Context, ids []string) error {
if len(ids) == 0 {
logger.Error(ctx, "Failed to batch delete sessions: IDs list is empty")
return errors.New("session ids are required")
return stderrors.New("session ids are required")
}
// Get tenant ID from context
tenantID := types.MustTenantIDFromContext(ctx)
userID := sessionUserIDFromContext(ctx)
visibleIDs := make([]string, 0, len(ids))
for _, id := range ids {
if _, err := s.sessionRepo.Get(ctx, tenantID, userID, id); err == nil {
visibleIDs = append(visibleIDs, id)
} else if !stderrors.Is(err, apperrors.ErrSessionNotFound) {
return err
}
}
if len(visibleIDs) == 0 {
return apperrors.ErrSessionNotFound
}
// Cleanup associated resources for each session
bgCtx := context.WithoutCancel(ctx)
for _, id := range ids {
for _, id := range visibleIDs {
// Cleanup chat history knowledge entries (async, best-effort)
go func(sessionID string) {
knowledgeIDs, err := s.messageRepo.GetKnowledgeIDsBySessionID(bgCtx, sessionID)
@@ -311,9 +346,9 @@ func (s *sessionService) BatchDeleteSessions(ctx context.Context, ids []string)
}
// Batch delete sessions from repository
if err := s.sessionRepo.BatchDelete(ctx, tenantID, ids); err != nil {
if _, err := s.sessionRepo.BatchDelete(ctx, tenantID, userID, visibleIDs); err != nil {
logger.ErrorWithFields(ctx, err, map[string]interface{}{
"session_ids": ids,
"session_ids": visibleIDs,
"tenant_id": tenantID,
})
return err
@@ -325,9 +360,10 @@ func (s *sessionService) BatchDeleteSessions(ctx context.Context, ids []string)
// DeleteAllSessions deletes all sessions for the current tenant
func (s *sessionService) DeleteAllSessions(ctx context.Context) error {
tenantID := types.MustTenantIDFromContext(ctx)
userID := sessionUserIDFromContext(ctx)
logger.Infof(ctx, "Deleting all sessions for tenant %d", tenantID)
sessions, err := s.sessionRepo.GetByTenantID(ctx, tenantID)
sessions, err := s.sessionRepo.GetByTenantID(ctx, tenantID, userID)
if err != nil {
logger.Warnf(ctx, "Failed to list sessions for cleanup: %v", err)
} else {
@@ -353,7 +389,7 @@ func (s *sessionService) DeleteAllSessions(ctx context.Context) error {
}
}
if err := s.sessionRepo.DeleteAllByTenantID(ctx, tenantID); err != nil {
if _, err := s.sessionRepo.DeleteAllByTenantID(ctx, tenantID, userID); err != nil {
logger.ErrorWithFields(ctx, err, map[string]interface{}{
"tenant_id": tenantID,
})
@@ -371,7 +407,7 @@ func (s *sessionService) GenerateTitle(ctx context.Context,
) (string, error) {
if session == nil {
logger.Error(ctx, "Failed to generate title: session cannot be empty")
return "", errors.New("session cannot be empty")
return "", stderrors.New("session cannot be empty")
}
// Skip if title already exists
@@ -401,7 +437,7 @@ func (s *sessionService) GenerateTitle(ctx context.Context,
// Ensure a user message was found
if message == nil {
logger.Error(ctx, "No user message found, cannot generate title")
return "", errors.New("no user message found")
return "", stderrors.New("no user message found")
}
// Use provided modelID, or fallback to first available KnowledgeQA model
@@ -423,7 +459,7 @@ func (s *sessionService) GenerateTitle(ctx context.Context,
}
if modelID == "" {
logger.Error(ctx, "No KnowledgeQA model found")
return "", errors.New("no KnowledgeQA model available for title generation")
return "", stderrors.New("no KnowledgeQA model available for title generation")
}
} else {
logger.Infof(ctx, "Using specified model for title generation: %s", modelID)
@@ -464,7 +500,7 @@ func (s *sessionService) GenerateTitle(ctx context.Context,
session.Title = strings.TrimPrefix(response.Content, "<think>\n\n</think>")
// Update session with new title
err = s.sessionRepo.Update(ctx, session)
_, err = s.sessionRepo.Update(ctx, session, session.UserID)
if err != nil {
logger.ErrorWithFields(ctx, err, nil)
return "", err
@@ -485,7 +521,7 @@ func (s *sessionService) GenerateTitleAsync(
eventBus *event.EventBus,
) {
// Use context tenant (effective tenant when using shared agent) so ListModels/GetChatModel find the agent's model.
// sessionRepo.Update uses session.TenantID in WHERE, so the session row is updated correctly regardless of ctx.
// The session row itself is still updated by its persisted tenant/user owner scope.
tenantID := ctx.Value(types.TenantIDContextKey)
requestID := ctx.Value(types.RequestIDContextKey)
language := ctx.Value(types.LanguageContextKey)

View File

@@ -0,0 +1,97 @@
package service
import (
"context"
"testing"
"github.com/Tencent/WeKnora/internal/application/repository"
apperrors "github.com/Tencent/WeKnora/internal/errors"
"github.com/Tencent/WeKnora/internal/types"
"github.com/stretchr/testify/require"
"gorm.io/driver/sqlite"
"gorm.io/gorm"
)
func testSessionScopeContext(tenantID uint64, userID string) context.Context {
ctx := context.WithValue(context.Background(), types.TenantIDContextKey, tenantID)
if userID != "" {
ctx = context.WithValue(ctx, types.UserIDContextKey, userID)
}
return ctx
}
func newTestSessionService(t *testing.T) (*sessionService, *gorm.DB) {
t.Helper()
db, err := gorm.Open(sqlite.Open(":memory:"), &gorm.Config{})
require.NoError(t, err)
require.NoError(t, db.AutoMigrate(&types.Session{}))
return &sessionService{
sessionRepo: repository.NewSessionRepository(db),
}, db
}
func TestGetSessionIsScopedToCurrentUser(t *testing.T) {
svc, db := newTestSessionService(t)
aliceSession := &types.Session{
TenantID: 1,
UserID: "alice",
Title: "alice private session",
}
require.NoError(t, db.Create(aliceSession).Error)
bobSession := &types.Session{
TenantID: 1,
UserID: "bob",
Title: "bob private session",
}
require.NoError(t, db.Create(bobSession).Error)
legacySession := &types.Session{
TenantID: 1,
Title: "legacy tenant session",
}
require.NoError(t, db.Create(legacySession).Error)
_, err := svc.GetSession(testSessionScopeContext(1, "bob"), aliceSession.ID)
require.ErrorIs(t, err, apperrors.ErrSessionNotFound)
got, err := svc.GetSession(testSessionScopeContext(1, "bob"), bobSession.ID)
require.NoError(t, err)
require.Equal(t, bobSession.ID, got.ID)
got, err = svc.GetSession(testSessionScopeContext(1, "bob"), legacySession.ID)
require.NoError(t, err)
require.Equal(t, legacySession.ID, got.ID)
}
func TestUpdateSessionIsScopedToCurrentUserAndAllowsNoOp(t *testing.T) {
svc, db := newTestSessionService(t)
aliceSession := &types.Session{
TenantID: 1,
UserID: "alice",
Title: "alice private session",
Description: "original description",
}
require.NoError(t, db.Create(aliceSession).Error)
err := svc.UpdateSession(testSessionScopeContext(1, "bob"), &types.Session{
ID: aliceSession.ID,
TenantID: 1,
Title: "bob update attempt",
Description: "should not be saved",
})
require.ErrorIs(t, err, apperrors.ErrSessionNotFound)
var unchanged types.Session
require.NoError(t, db.First(&unchanged, "id = ?", aliceSession.ID).Error)
require.Equal(t, aliceSession.Title, unchanged.Title)
require.Equal(t, aliceSession.Description, unchanged.Description)
err = svc.UpdateSession(testSessionScopeContext(1, "alice"), &types.Session{
ID: aliceSession.ID,
TenantID: 1,
Title: aliceSession.Title,
Description: aliceSession.Description,
})
require.NoError(t, err)
}

View File

@@ -442,6 +442,11 @@ func (h *Handler) BatchDeleteSessions(c *gin.Context) {
}
if err := h.sessionService.BatchDeleteSessions(ctx, sanitizedIDs); err != nil {
if err == errors.ErrSessionNotFound {
logger.Warnf(ctx, "No visible sessions found for batch delete")
c.Error(errors.NewNotFoundError(err.Error()))
return
}
logger.ErrorWithFields(ctx, err, nil)
c.Error(errors.NewInternalServerError(err.Error()))
return

View File

@@ -55,25 +55,25 @@ type SessionService interface {
type SessionRepository interface {
// Create creates a session
Create(ctx context.Context, session *types.Session) (*types.Session, error)
// Get gets a session
Get(ctx context.Context, tenantID uint64, id string) (*types.Session, error)
// GetByTenantID gets all sessions of a tenant
GetByTenantID(ctx context.Context, tenantID uint64) ([]*types.Session, error)
// GetPagedByTenantID gets paged sessions of a tenant
GetPagedByTenantID(ctx context.Context, tenantID uint64, page *types.Pagination) ([]*types.Session, int64, error)
// Get gets a session visible to the tenant/user scope.
Get(ctx context.Context, tenantID uint64, userID string, id string) (*types.Session, error)
// GetByTenantID gets all sessions visible to the tenant/user scope.
GetByTenantID(ctx context.Context, tenantID uint64, userID string) ([]*types.Session, error)
// GetPagedByTenantID gets paged sessions visible to the tenant/user scope.
GetPagedByTenantID(ctx context.Context, tenantID uint64, userID string, page *types.Pagination) ([]*types.Session, int64, error)
// QueryPaged lists sessions with filters, user-scoped ownership and pin-aware ordering.
QueryPaged(ctx context.Context, q *types.SessionListQuery) ([]*types.SessionListItem, int64, error)
// Update updates a session
Update(ctx context.Context, session *types.Session) error
// Update updates a session visible to the tenant/user scope.
Update(ctx context.Context, session *types.Session, userID string) (int64, error)
// SetPinned pins or unpins a session row scoped by tenant.
// userID, when non-empty, is enforced so users cannot pin sessions they don't own.
// Returns the number of rows affected; 0 means the session doesn't exist or is
// not visible to this caller.
SetPinned(ctx context.Context, tenantID uint64, userID string, id string, pinned bool) (int64, error)
// Delete deletes a session
Delete(ctx context.Context, tenantID uint64, id string) error
// BatchDelete deletes multiple sessions by IDs
BatchDelete(ctx context.Context, tenantID uint64, ids []string) error
// DeleteAllByTenantID deletes all sessions for a tenant
DeleteAllByTenantID(ctx context.Context, tenantID uint64) error
// Delete deletes a session visible to the tenant/user scope.
Delete(ctx context.Context, tenantID uint64, userID string, id string) (int64, error)
// BatchDelete deletes multiple sessions visible to the tenant/user scope.
BatchDelete(ctx context.Context, tenantID uint64, userID string, ids []string) (int64, error)
// DeleteAllByTenantID deletes all sessions visible to the tenant/user scope.
DeleteAllByTenantID(ctx context.Context, tenantID uint64, userID string) (int64, error)
}