mirror of
https://github.com/Tencent/WeKnora.git
synced 2026-06-04 13:30:32 +08:00
fix(session): scope session access by user
This commit is contained in:
@@ -2,9 +2,11 @@ package repository
|
||||
|
||||
import (
|
||||
"context"
|
||||
stderrors "errors"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
apperrors "github.com/Tencent/WeKnora/internal/errors"
|
||||
"github.com/Tencent/WeKnora/internal/types"
|
||||
"github.com/Tencent/WeKnora/internal/types/interfaces"
|
||||
"gorm.io/gorm"
|
||||
@@ -15,6 +17,14 @@ type sessionRepository struct {
|
||||
db *gorm.DB
|
||||
}
|
||||
|
||||
func applySessionUserScope(db *gorm.DB, userID string) *gorm.DB {
|
||||
if userID == "" {
|
||||
return db
|
||||
}
|
||||
// Empty user_id rows are legacy/API-created tenant-level sessions.
|
||||
return db.Where("(user_id = ? OR user_id IS NULL OR user_id = '')", userID)
|
||||
}
|
||||
|
||||
// NewSessionRepository creates a new session repository instance
|
||||
func NewSessionRepository(db *gorm.DB) interfaces.SessionRepository {
|
||||
return &sessionRepository{db: db}
|
||||
@@ -32,19 +42,28 @@ func (r *sessionRepository) Create(ctx context.Context, session *types.Session)
|
||||
}
|
||||
|
||||
// Get retrieves a session by ID
|
||||
func (r *sessionRepository) Get(ctx context.Context, tenantID uint64, id string) (*types.Session, error) {
|
||||
func (r *sessionRepository) Get(ctx context.Context, tenantID uint64, userID string, id string) (*types.Session, error) {
|
||||
var session types.Session
|
||||
err := r.db.WithContext(ctx).Where("tenant_id = ?", tenantID).First(&session, "id = ?", id).Error
|
||||
err := applySessionUserScope(
|
||||
r.db.WithContext(ctx).Where("tenant_id = ? AND id = ?", tenantID, id),
|
||||
userID,
|
||||
).First(&session).Error
|
||||
if err != nil {
|
||||
if stderrors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return nil, apperrors.ErrSessionNotFound
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
return &session, nil
|
||||
}
|
||||
|
||||
// GetByTenantID retrieves all sessions for a tenant
|
||||
func (r *sessionRepository) GetByTenantID(ctx context.Context, tenantID uint64) ([]*types.Session, error) {
|
||||
func (r *sessionRepository) GetByTenantID(ctx context.Context, tenantID uint64, userID string) ([]*types.Session, error) {
|
||||
var sessions []*types.Session
|
||||
err := r.db.WithContext(ctx).Where("tenant_id = ?", tenantID).Order("updated_at DESC").Find(&sessions).Error
|
||||
err := applySessionUserScope(
|
||||
r.db.WithContext(ctx).Where("tenant_id = ?", tenantID),
|
||||
userID,
|
||||
).Order("updated_at DESC").Find(&sessions).Error
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -53,20 +72,26 @@ func (r *sessionRepository) GetByTenantID(ctx context.Context, tenantID uint64)
|
||||
|
||||
// GetPagedByTenantID retrieves sessions for a tenant with pagination
|
||||
func (r *sessionRepository) GetPagedByTenantID(
|
||||
ctx context.Context, tenantID uint64, page *types.Pagination,
|
||||
ctx context.Context, tenantID uint64, userID string, page *types.Pagination,
|
||||
) ([]*types.Session, int64, error) {
|
||||
var sessions []*types.Session
|
||||
var total int64
|
||||
|
||||
// First query the total count
|
||||
err := r.db.WithContext(ctx).Model(&types.Session{}).Where("tenant_id = ?", tenantID).Count(&total).Error
|
||||
baseQ := applySessionUserScope(
|
||||
r.db.WithContext(ctx).Model(&types.Session{}).Where("tenant_id = ?", tenantID),
|
||||
userID,
|
||||
)
|
||||
err := baseQ.Count(&total).Error
|
||||
if err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
|
||||
// Then query the paginated data
|
||||
err = r.db.WithContext(ctx).
|
||||
Where("tenant_id = ?", tenantID).
|
||||
err = applySessionUserScope(
|
||||
r.db.WithContext(ctx).Where("tenant_id = ?", tenantID),
|
||||
userID,
|
||||
).
|
||||
Order("updated_at DESC").
|
||||
Offset(page.Offset()).
|
||||
Limit(page.Limit()).
|
||||
@@ -201,32 +226,45 @@ func (r *sessionRepository) SetPinned(
|
||||
}
|
||||
|
||||
// Update updates a session
|
||||
func (r *sessionRepository) Update(ctx context.Context, session *types.Session) error {
|
||||
func (r *sessionRepository) Update(ctx context.Context, session *types.Session, userID string) (int64, error) {
|
||||
session.UpdatedAt = time.Now()
|
||||
return r.db.WithContext(ctx).
|
||||
res := applySessionUserScope(r.db.WithContext(ctx).
|
||||
Model(&types.Session{}).
|
||||
Where("tenant_id = ? AND id = ?", session.TenantID, session.ID).
|
||||
Where("tenant_id = ? AND id = ?", session.TenantID, session.ID), userID).
|
||||
Updates(map[string]interface{}{
|
||||
"title": session.Title,
|
||||
"description": session.Description,
|
||||
"updated_at": session.UpdatedAt,
|
||||
}).Error
|
||||
})
|
||||
return res.RowsAffected, res.Error
|
||||
}
|
||||
|
||||
// Delete deletes a session
|
||||
func (r *sessionRepository) Delete(ctx context.Context, tenantID uint64, id string) error {
|
||||
return r.db.WithContext(ctx).Where("tenant_id = ?", tenantID).Delete(&types.Session{}, "id = ?", id).Error
|
||||
func (r *sessionRepository) Delete(ctx context.Context, tenantID uint64, userID string, id string) (int64, error) {
|
||||
res := applySessionUserScope(
|
||||
r.db.WithContext(ctx).Where("tenant_id = ? AND id = ?", tenantID, id),
|
||||
userID,
|
||||
).Delete(&types.Session{})
|
||||
return res.RowsAffected, res.Error
|
||||
}
|
||||
|
||||
// BatchDelete deletes multiple sessions by IDs
|
||||
func (r *sessionRepository) BatchDelete(ctx context.Context, tenantID uint64, ids []string) error {
|
||||
func (r *sessionRepository) BatchDelete(ctx context.Context, tenantID uint64, userID string, ids []string) (int64, error) {
|
||||
if len(ids) == 0 {
|
||||
return nil
|
||||
return 0, nil
|
||||
}
|
||||
return r.db.WithContext(ctx).Where("tenant_id = ? AND id IN ?", tenantID, ids).Delete(&types.Session{}).Error
|
||||
res := applySessionUserScope(
|
||||
r.db.WithContext(ctx).Where("tenant_id = ? AND id IN ?", tenantID, ids),
|
||||
userID,
|
||||
).Delete(&types.Session{})
|
||||
return res.RowsAffected, res.Error
|
||||
}
|
||||
|
||||
// DeleteAllByTenantID deletes all sessions for a tenant
|
||||
func (r *sessionRepository) DeleteAllByTenantID(ctx context.Context, tenantID uint64) error {
|
||||
return r.db.WithContext(ctx).Where("tenant_id = ?", tenantID).Delete(&types.Session{}).Error
|
||||
func (r *sessionRepository) DeleteAllByTenantID(ctx context.Context, tenantID uint64, userID string) (int64, error) {
|
||||
res := applySessionUserScope(
|
||||
r.db.WithContext(ctx).Where("tenant_id = ?", tenantID),
|
||||
userID,
|
||||
).Delete(&types.Session{})
|
||||
return res.RowsAffected, res.Error
|
||||
}
|
||||
|
||||
163
internal/application/repository/session_test.go
Normal file
163
internal/application/repository/session_test.go
Normal file
@@ -0,0 +1,163 @@
|
||||
package repository
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
apperrors "github.com/Tencent/WeKnora/internal/errors"
|
||||
"github.com/Tencent/WeKnora/internal/types"
|
||||
"github.com/Tencent/WeKnora/internal/types/interfaces"
|
||||
"github.com/stretchr/testify/require"
|
||||
"gorm.io/driver/sqlite"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
func newSessionRepositoryForTest(t *testing.T) (interfaces.SessionRepository, *gorm.DB) {
|
||||
t.Helper()
|
||||
|
||||
db, err := gorm.Open(sqlite.Open(":memory:"), &gorm.Config{})
|
||||
require.NoError(t, err)
|
||||
require.NoError(t, db.AutoMigrate(&types.Session{}))
|
||||
|
||||
return NewSessionRepository(db), db
|
||||
}
|
||||
|
||||
func createSessionForTest(t *testing.T, db *gorm.DB, tenantID uint64, userID string) *types.Session {
|
||||
t.Helper()
|
||||
|
||||
session := &types.Session{
|
||||
TenantID: tenantID,
|
||||
UserID: userID,
|
||||
Title: userID + " session",
|
||||
}
|
||||
if userID == "" {
|
||||
session.Title = "legacy tenant session"
|
||||
}
|
||||
require.NoError(t, db.Create(session).Error)
|
||||
|
||||
return session
|
||||
}
|
||||
|
||||
func countActiveSessionsForTest(t *testing.T, db *gorm.DB, id string) int64 {
|
||||
t.Helper()
|
||||
|
||||
var count int64
|
||||
require.NoError(t, db.Model(&types.Session{}).Where("id = ?", id).Count(&count).Error)
|
||||
return count
|
||||
}
|
||||
|
||||
func sessionIDsForTest(sessions []*types.Session) []string {
|
||||
ids := make([]string, 0, len(sessions))
|
||||
for _, session := range sessions {
|
||||
ids = append(ids, session.ID)
|
||||
}
|
||||
return ids
|
||||
}
|
||||
|
||||
func TestSessionRepositoryGetAndListHonorUserScope(t *testing.T) {
|
||||
repo, db := newSessionRepositoryForTest(t)
|
||||
ctx := context.Background()
|
||||
aliceSession := createSessionForTest(t, db, 1, "alice")
|
||||
bobSession := createSessionForTest(t, db, 1, "bob")
|
||||
legacySession := createSessionForTest(t, db, 1, "")
|
||||
_ = createSessionForTest(t, db, 2, "bob")
|
||||
|
||||
_, err := repo.Get(ctx, 1, "bob", aliceSession.ID)
|
||||
require.ErrorIs(t, err, apperrors.ErrSessionNotFound)
|
||||
|
||||
got, err := repo.Get(ctx, 1, "bob", bobSession.ID)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, bobSession.ID, got.ID)
|
||||
|
||||
got, err = repo.Get(ctx, 1, "bob", legacySession.ID)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, legacySession.ID, got.ID)
|
||||
|
||||
sessions, err := repo.GetByTenantID(ctx, 1, "bob")
|
||||
require.NoError(t, err)
|
||||
require.ElementsMatch(t, []string{bobSession.ID, legacySession.ID}, sessionIDsForTest(sessions))
|
||||
|
||||
paged, total, err := repo.GetPagedByTenantID(ctx, 1, "bob", &types.Pagination{Page: 1, PageSize: 10})
|
||||
require.NoError(t, err)
|
||||
require.EqualValues(t, 2, total)
|
||||
require.ElementsMatch(t, []string{bobSession.ID, legacySession.ID}, sessionIDsForTest(paged))
|
||||
}
|
||||
|
||||
func TestSessionRepositoryUpdateHonorsUserScope(t *testing.T) {
|
||||
repo, db := newSessionRepositoryForTest(t)
|
||||
ctx := context.Background()
|
||||
aliceSession := createSessionForTest(t, db, 1, "alice")
|
||||
|
||||
rows, err := repo.Update(ctx, &types.Session{
|
||||
ID: aliceSession.ID,
|
||||
TenantID: aliceSession.TenantID,
|
||||
Title: "bob update attempt",
|
||||
}, "bob")
|
||||
require.NoError(t, err)
|
||||
require.Zero(t, rows)
|
||||
|
||||
var unchanged types.Session
|
||||
require.NoError(t, db.First(&unchanged, "id = ?", aliceSession.ID).Error)
|
||||
require.Equal(t, aliceSession.Title, unchanged.Title)
|
||||
|
||||
rows, err = repo.Update(ctx, &types.Session{
|
||||
ID: aliceSession.ID,
|
||||
TenantID: aliceSession.TenantID,
|
||||
Title: "alice updated session",
|
||||
}, "alice")
|
||||
require.NoError(t, err)
|
||||
require.EqualValues(t, 1, rows)
|
||||
|
||||
var changed types.Session
|
||||
require.NoError(t, db.First(&changed, "id = ?", aliceSession.ID).Error)
|
||||
require.Equal(t, "alice updated session", changed.Title)
|
||||
}
|
||||
|
||||
func TestSessionRepositoryDeleteHonorsUserScope(t *testing.T) {
|
||||
repo, db := newSessionRepositoryForTest(t)
|
||||
ctx := context.Background()
|
||||
aliceSession := createSessionForTest(t, db, 1, "alice")
|
||||
bobSession := createSessionForTest(t, db, 1, "bob")
|
||||
|
||||
rows, err := repo.Delete(ctx, 1, "bob", aliceSession.ID)
|
||||
require.NoError(t, err)
|
||||
require.Zero(t, rows)
|
||||
require.EqualValues(t, 1, countActiveSessionsForTest(t, db, aliceSession.ID))
|
||||
|
||||
rows, err = repo.Delete(ctx, 1, "bob", bobSession.ID)
|
||||
require.NoError(t, err)
|
||||
require.EqualValues(t, 1, rows)
|
||||
require.Zero(t, countActiveSessionsForTest(t, db, bobSession.ID))
|
||||
}
|
||||
|
||||
func TestSessionRepositoryBatchDeleteHonorsUserScope(t *testing.T) {
|
||||
repo, db := newSessionRepositoryForTest(t)
|
||||
ctx := context.Background()
|
||||
aliceSession := createSessionForTest(t, db, 1, "alice")
|
||||
bobSession := createSessionForTest(t, db, 1, "bob")
|
||||
legacySession := createSessionForTest(t, db, 1, "")
|
||||
|
||||
rows, err := repo.BatchDelete(ctx, 1, "bob", []string{aliceSession.ID, bobSession.ID, legacySession.ID})
|
||||
require.NoError(t, err)
|
||||
require.EqualValues(t, 2, rows)
|
||||
require.EqualValues(t, 1, countActiveSessionsForTest(t, db, aliceSession.ID))
|
||||
require.Zero(t, countActiveSessionsForTest(t, db, bobSession.ID))
|
||||
require.Zero(t, countActiveSessionsForTest(t, db, legacySession.ID))
|
||||
}
|
||||
|
||||
func TestSessionRepositoryDeleteAllHonorsUserScope(t *testing.T) {
|
||||
repo, db := newSessionRepositoryForTest(t)
|
||||
ctx := context.Background()
|
||||
aliceSession := createSessionForTest(t, db, 1, "alice")
|
||||
bobSession := createSessionForTest(t, db, 1, "bob")
|
||||
legacySession := createSessionForTest(t, db, 1, "")
|
||||
otherTenantSession := createSessionForTest(t, db, 2, "bob")
|
||||
|
||||
rows, err := repo.DeleteAllByTenantID(ctx, 1, "bob")
|
||||
require.NoError(t, err)
|
||||
require.EqualValues(t, 2, rows)
|
||||
require.EqualValues(t, 1, countActiveSessionsForTest(t, db, aliceSession.ID))
|
||||
require.Zero(t, countActiveSessionsForTest(t, db, bobSession.ID))
|
||||
require.Zero(t, countActiveSessionsForTest(t, db, legacySession.ID))
|
||||
require.EqualValues(t, 1, countActiveSessionsForTest(t, db, otherTenantSession.ID))
|
||||
}
|
||||
@@ -22,12 +22,12 @@ var regThinkIndex = regexp.MustCompile(`(?s)<think>.*?</think>`)
|
||||
// It reads the chat history knowledge base configuration from the tenant's ChatHistoryConfig,
|
||||
// which is managed via the settings UI.
|
||||
type messageService struct {
|
||||
messageRepo interfaces.MessageRepository // Repository for message storage operations
|
||||
sessionRepo interfaces.SessionRepository // Repository for session validation
|
||||
tenantService interfaces.TenantService // Service for tenant operations (read ChatHistoryConfig)
|
||||
kbService interfaces.KnowledgeBaseService // Service for knowledge base operations (search chat history KB)
|
||||
knowService interfaces.KnowledgeService // Service for knowledge operations (index/delete passages)
|
||||
modelService interfaces.ModelService // Service for model operations (rerank model)
|
||||
messageRepo interfaces.MessageRepository // Repository for message storage operations
|
||||
sessionRepo interfaces.SessionRepository // Repository for session validation
|
||||
tenantService interfaces.TenantService // Service for tenant operations (read ChatHistoryConfig)
|
||||
kbService interfaces.KnowledgeBaseService // Service for knowledge base operations (search chat history KB)
|
||||
knowService interfaces.KnowledgeService // Service for knowledge operations (index/delete passages)
|
||||
modelService interfaces.ModelService // Service for model operations (rerank model)
|
||||
}
|
||||
|
||||
// NewMessageService creates a new message service instance with the required repositories
|
||||
@@ -64,6 +64,15 @@ func sessionTenantIDForLookup(ctx context.Context) (uint64, bool) {
|
||||
return 0, false
|
||||
}
|
||||
|
||||
func sessionUserIDForLookup(ctx context.Context) string {
|
||||
if ctx.Value(types.SessionTenantIDContextKey) != nil {
|
||||
// Shared-agent pipelines resolve the session owner tenant first; keep that internal lookup tenant-scoped.
|
||||
return ""
|
||||
}
|
||||
userID, _ := types.UserIDFromContext(ctx)
|
||||
return userID
|
||||
}
|
||||
|
||||
// CreateMessage creates a new message within an existing session
|
||||
func (s *messageService) CreateMessage(ctx context.Context, message *types.Message) (*types.Message, error) {
|
||||
logger.Info(ctx, "Start creating message")
|
||||
@@ -71,7 +80,7 @@ func (s *messageService) CreateMessage(ctx context.Context, message *types.Messa
|
||||
|
||||
tenantID := types.MustTenantIDFromContext(ctx)
|
||||
logger.Infof(ctx, "Checking if session exists, tenant ID: %d, session ID: %s", tenantID, message.SessionID)
|
||||
_, err := s.sessionRepo.Get(ctx, tenantID, message.SessionID)
|
||||
_, err := s.sessionRepo.Get(ctx, tenantID, sessionUserIDForLookup(ctx), message.SessionID)
|
||||
if err != nil {
|
||||
logger.Errorf(ctx, "Failed to get session: %v", err)
|
||||
return nil, err
|
||||
@@ -97,7 +106,7 @@ func (s *messageService) GetMessage(ctx context.Context, sessionID string, messa
|
||||
|
||||
tenantID := types.MustTenantIDFromContext(ctx)
|
||||
logger.Infof(ctx, "Checking if session exists, tenant ID: %d", tenantID)
|
||||
_, err := s.sessionRepo.Get(ctx, tenantID, sessionID)
|
||||
_, err := s.sessionRepo.Get(ctx, tenantID, sessionUserIDForLookup(ctx), sessionID)
|
||||
if err != nil {
|
||||
logger.Errorf(ctx, "Failed to get session: %v", err)
|
||||
return nil, err
|
||||
@@ -126,7 +135,7 @@ func (s *messageService) GetMessagesBySession(ctx context.Context,
|
||||
|
||||
tenantID := types.MustTenantIDFromContext(ctx)
|
||||
logger.Infof(ctx, "Checking if session exists, tenant ID: %d", tenantID)
|
||||
_, err := s.sessionRepo.Get(ctx, tenantID, sessionID)
|
||||
_, err := s.sessionRepo.Get(ctx, tenantID, sessionUserIDForLookup(ctx), sessionID)
|
||||
if err != nil {
|
||||
logger.Errorf(ctx, "Failed to get session: %v", err)
|
||||
return nil, err
|
||||
@@ -160,7 +169,7 @@ func (s *messageService) GetRecentMessagesBySession(ctx context.Context,
|
||||
return nil, errors.New("tenant ID not found in context")
|
||||
}
|
||||
logger.Infof(ctx, "Checking if session exists, tenant ID: %d", tenantID)
|
||||
_, err := s.sessionRepo.Get(ctx, tenantID, sessionID)
|
||||
_, err := s.sessionRepo.Get(ctx, tenantID, sessionUserIDForLookup(ctx), sessionID)
|
||||
if err != nil {
|
||||
logger.Errorf(ctx, "Failed to get session: %v", err)
|
||||
return nil, err
|
||||
@@ -193,7 +202,7 @@ func (s *messageService) GetMessagesBySessionBeforeTime(ctx context.Context,
|
||||
return nil, errors.New("tenant ID not found in context")
|
||||
}
|
||||
logger.Infof(ctx, "Checking if session exists, tenant ID: %d", tenantID)
|
||||
_, err := s.sessionRepo.Get(ctx, tenantID, sessionID)
|
||||
_, err := s.sessionRepo.Get(ctx, tenantID, sessionUserIDForLookup(ctx), sessionID)
|
||||
if err != nil {
|
||||
logger.Errorf(ctx, "Failed to get session: %v", err)
|
||||
return nil, err
|
||||
@@ -221,7 +230,7 @@ func (s *messageService) UpdateMessage(ctx context.Context, message *types.Messa
|
||||
|
||||
tenantID := types.MustTenantIDFromContext(ctx)
|
||||
logger.Infof(ctx, "Checking if session exists, tenant ID: %d", tenantID)
|
||||
_, err := s.sessionRepo.Get(ctx, tenantID, message.SessionID)
|
||||
_, err := s.sessionRepo.Get(ctx, tenantID, sessionUserIDForLookup(ctx), message.SessionID)
|
||||
if err != nil {
|
||||
logger.Errorf(ctx, "Failed to get session: %v", err)
|
||||
return err
|
||||
@@ -258,7 +267,7 @@ func (s *messageService) DeleteMessage(ctx context.Context, sessionID string, me
|
||||
|
||||
tenantID := types.MustTenantIDFromContext(ctx)
|
||||
logger.Infof(ctx, "Checking if session exists, tenant ID: %d", tenantID)
|
||||
_, err := s.sessionRepo.Get(ctx, tenantID, sessionID)
|
||||
_, err := s.sessionRepo.Get(ctx, tenantID, sessionUserIDForLookup(ctx), sessionID)
|
||||
if err != nil {
|
||||
logger.Errorf(ctx, "Failed to get session: %v", err)
|
||||
return err
|
||||
@@ -298,7 +307,7 @@ func (s *messageService) ClearSessionMessages(ctx context.Context, sessionID str
|
||||
logger.Infof(ctx, "Start clearing all messages for session: %s", sessionID)
|
||||
|
||||
tenantID := types.MustTenantIDFromContext(ctx)
|
||||
if _, err := s.sessionRepo.Get(ctx, tenantID, sessionID); err != nil {
|
||||
if _, err := s.sessionRepo.Get(ctx, tenantID, sessionUserIDForLookup(ctx), sessionID); err != nil {
|
||||
logger.Errorf(ctx, "Failed to get session: %v", err)
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -2,11 +2,12 @@ package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
stderrors "errors"
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/Tencent/WeKnora/internal/config"
|
||||
apperrors "github.com/Tencent/WeKnora/internal/errors"
|
||||
"github.com/Tencent/WeKnora/internal/event"
|
||||
"github.com/Tencent/WeKnora/internal/logger"
|
||||
"github.com/Tencent/WeKnora/internal/models/chat"
|
||||
@@ -17,6 +18,11 @@ import (
|
||||
chatpipeline "github.com/Tencent/WeKnora/internal/application/service/chat_pipeline"
|
||||
)
|
||||
|
||||
func sessionUserIDFromContext(ctx context.Context) string {
|
||||
userID, _ := types.UserIDFromContext(ctx)
|
||||
return userID
|
||||
}
|
||||
|
||||
// generateEventID generates a unique event ID with type suffix for better traceability
|
||||
func generateEventID(suffix string) string {
|
||||
return fmt.Sprintf("%s-%s", uuid.New().String()[:8], suffix)
|
||||
@@ -84,7 +90,7 @@ func (s *sessionService) CreateSession(ctx context.Context, session *types.Sessi
|
||||
// Validate tenant ID
|
||||
if session.TenantID == 0 {
|
||||
logger.Error(ctx, "Failed to create session: tenant ID cannot be empty")
|
||||
return nil, errors.New("tenant ID is required")
|
||||
return nil, stderrors.New("tenant ID is required")
|
||||
}
|
||||
|
||||
logger.Infof(ctx, "Creating session, tenant ID: %d", session.TenantID)
|
||||
@@ -106,15 +112,16 @@ func (s *sessionService) GetSession(ctx context.Context, id string) (*types.Sess
|
||||
// Validate session ID
|
||||
if id == "" {
|
||||
logger.Error(ctx, "Failed to get session: session ID cannot be empty")
|
||||
return nil, errors.New("session id is required")
|
||||
return nil, stderrors.New("session id is required")
|
||||
}
|
||||
|
||||
// Get tenant ID from context
|
||||
tenantID := types.MustTenantIDFromContext(ctx)
|
||||
userID := sessionUserIDFromContext(ctx)
|
||||
logger.Infof(ctx, "Retrieving session, ID: %s, tenant ID: %d", id, tenantID)
|
||||
|
||||
// Get session from repository
|
||||
session, err := s.sessionRepo.Get(ctx, tenantID, id)
|
||||
session, err := s.sessionRepo.Get(ctx, tenantID, userID, id)
|
||||
if err != nil {
|
||||
logger.ErrorWithFields(ctx, err, map[string]interface{}{
|
||||
"session_id": id,
|
||||
@@ -131,10 +138,11 @@ func (s *sessionService) GetSession(ctx context.Context, id string) (*types.Sess
|
||||
func (s *sessionService) GetSessionsByTenant(ctx context.Context) ([]*types.Session, error) {
|
||||
// Get tenant ID from context
|
||||
tenantID := types.MustTenantIDFromContext(ctx)
|
||||
userID := sessionUserIDFromContext(ctx)
|
||||
logger.Infof(ctx, "Retrieving all sessions for tenant, tenant ID: %d", tenantID)
|
||||
|
||||
// Get sessions from repository
|
||||
sessions, err := s.sessionRepo.GetByTenantID(ctx, tenantID)
|
||||
sessions, err := s.sessionRepo.GetByTenantID(ctx, tenantID, userID)
|
||||
if err != nil {
|
||||
logger.ErrorWithFields(ctx, err, map[string]interface{}{
|
||||
"tenant_id": tenantID,
|
||||
@@ -154,8 +162,9 @@ func (s *sessionService) GetPagedSessionsByTenant(ctx context.Context,
|
||||
) (*types.PageResult, error) {
|
||||
// Get tenant ID from context
|
||||
tenantID := types.MustTenantIDFromContext(ctx)
|
||||
userID := sessionUserIDFromContext(ctx)
|
||||
// Get paged sessions from repository
|
||||
sessions, total, err := s.sessionRepo.GetPagedByTenantID(ctx, tenantID, pagination)
|
||||
sessions, total, err := s.sessionRepo.GetPagedByTenantID(ctx, tenantID, userID, pagination)
|
||||
if err != nil {
|
||||
logger.ErrorWithFields(ctx, err, map[string]interface{}{
|
||||
"tenant_id": tenantID,
|
||||
@@ -204,10 +213,10 @@ func (s *sessionService) SetSessionPinned(
|
||||
ctx context.Context, sessionID string, pinned bool,
|
||||
) (int64, error) {
|
||||
if sessionID == "" {
|
||||
return 0, errors.New("session id is required")
|
||||
return 0, stderrors.New("session id is required")
|
||||
}
|
||||
tenantID := types.MustTenantIDFromContext(ctx)
|
||||
userID, _ := types.UserIDFromContext(ctx)
|
||||
userID := sessionUserIDFromContext(ctx)
|
||||
return s.sessionRepo.SetPinned(ctx, tenantID, userID, sessionID, pinned)
|
||||
}
|
||||
|
||||
@@ -216,11 +225,16 @@ func (s *sessionService) UpdateSession(ctx context.Context, session *types.Sessi
|
||||
// Validate session ID
|
||||
if session.ID == "" {
|
||||
logger.Error(ctx, "Failed to update session: session ID cannot be empty")
|
||||
return errors.New("session id is required")
|
||||
return stderrors.New("session id is required")
|
||||
}
|
||||
|
||||
// Update session in repository
|
||||
err := s.sessionRepo.Update(ctx, session)
|
||||
userID := sessionUserIDFromContext(ctx)
|
||||
if _, err := s.sessionRepo.Get(ctx, session.TenantID, userID, session.ID); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
_, err := s.sessionRepo.Update(ctx, session, userID)
|
||||
if err != nil {
|
||||
logger.ErrorWithFields(ctx, err, map[string]interface{}{
|
||||
"session_id": session.ID,
|
||||
@@ -238,11 +252,16 @@ func (s *sessionService) DeleteSession(ctx context.Context, id string) error {
|
||||
// Validate session ID
|
||||
if id == "" {
|
||||
logger.Error(ctx, "Failed to delete session: session ID cannot be empty")
|
||||
return errors.New("session id is required")
|
||||
return stderrors.New("session id is required")
|
||||
}
|
||||
|
||||
// Get tenant ID from context
|
||||
tenantID := types.MustTenantIDFromContext(ctx)
|
||||
userID := sessionUserIDFromContext(ctx)
|
||||
|
||||
if _, err := s.sessionRepo.Get(ctx, tenantID, userID, id); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Cleanup chat history knowledge entries for this session (async, best-effort).
|
||||
// Use WithoutCancel so the goroutine survives after the HTTP request context is done.
|
||||
@@ -266,7 +285,7 @@ func (s *sessionService) DeleteSession(ctx context.Context, id string) error {
|
||||
}
|
||||
|
||||
// Delete session from repository
|
||||
err := s.sessionRepo.Delete(ctx, tenantID, id)
|
||||
rows, err := s.sessionRepo.Delete(ctx, tenantID, userID, id)
|
||||
if err != nil {
|
||||
logger.ErrorWithFields(ctx, err, map[string]interface{}{
|
||||
"session_id": id,
|
||||
@@ -274,6 +293,9 @@ func (s *sessionService) DeleteSession(ctx context.Context, id string) error {
|
||||
})
|
||||
return err
|
||||
}
|
||||
if rows == 0 {
|
||||
return apperrors.ErrSessionNotFound
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
@@ -282,15 +304,28 @@ func (s *sessionService) DeleteSession(ctx context.Context, id string) error {
|
||||
func (s *sessionService) BatchDeleteSessions(ctx context.Context, ids []string) error {
|
||||
if len(ids) == 0 {
|
||||
logger.Error(ctx, "Failed to batch delete sessions: IDs list is empty")
|
||||
return errors.New("session ids are required")
|
||||
return stderrors.New("session ids are required")
|
||||
}
|
||||
|
||||
// Get tenant ID from context
|
||||
tenantID := types.MustTenantIDFromContext(ctx)
|
||||
userID := sessionUserIDFromContext(ctx)
|
||||
|
||||
visibleIDs := make([]string, 0, len(ids))
|
||||
for _, id := range ids {
|
||||
if _, err := s.sessionRepo.Get(ctx, tenantID, userID, id); err == nil {
|
||||
visibleIDs = append(visibleIDs, id)
|
||||
} else if !stderrors.Is(err, apperrors.ErrSessionNotFound) {
|
||||
return err
|
||||
}
|
||||
}
|
||||
if len(visibleIDs) == 0 {
|
||||
return apperrors.ErrSessionNotFound
|
||||
}
|
||||
|
||||
// Cleanup associated resources for each session
|
||||
bgCtx := context.WithoutCancel(ctx)
|
||||
for _, id := range ids {
|
||||
for _, id := range visibleIDs {
|
||||
// Cleanup chat history knowledge entries (async, best-effort)
|
||||
go func(sessionID string) {
|
||||
knowledgeIDs, err := s.messageRepo.GetKnowledgeIDsBySessionID(bgCtx, sessionID)
|
||||
@@ -311,9 +346,9 @@ func (s *sessionService) BatchDeleteSessions(ctx context.Context, ids []string)
|
||||
}
|
||||
|
||||
// Batch delete sessions from repository
|
||||
if err := s.sessionRepo.BatchDelete(ctx, tenantID, ids); err != nil {
|
||||
if _, err := s.sessionRepo.BatchDelete(ctx, tenantID, userID, visibleIDs); err != nil {
|
||||
logger.ErrorWithFields(ctx, err, map[string]interface{}{
|
||||
"session_ids": ids,
|
||||
"session_ids": visibleIDs,
|
||||
"tenant_id": tenantID,
|
||||
})
|
||||
return err
|
||||
@@ -325,9 +360,10 @@ func (s *sessionService) BatchDeleteSessions(ctx context.Context, ids []string)
|
||||
// DeleteAllSessions deletes all sessions for the current tenant
|
||||
func (s *sessionService) DeleteAllSessions(ctx context.Context) error {
|
||||
tenantID := types.MustTenantIDFromContext(ctx)
|
||||
userID := sessionUserIDFromContext(ctx)
|
||||
logger.Infof(ctx, "Deleting all sessions for tenant %d", tenantID)
|
||||
|
||||
sessions, err := s.sessionRepo.GetByTenantID(ctx, tenantID)
|
||||
sessions, err := s.sessionRepo.GetByTenantID(ctx, tenantID, userID)
|
||||
if err != nil {
|
||||
logger.Warnf(ctx, "Failed to list sessions for cleanup: %v", err)
|
||||
} else {
|
||||
@@ -353,7 +389,7 @@ func (s *sessionService) DeleteAllSessions(ctx context.Context) error {
|
||||
}
|
||||
}
|
||||
|
||||
if err := s.sessionRepo.DeleteAllByTenantID(ctx, tenantID); err != nil {
|
||||
if _, err := s.sessionRepo.DeleteAllByTenantID(ctx, tenantID, userID); err != nil {
|
||||
logger.ErrorWithFields(ctx, err, map[string]interface{}{
|
||||
"tenant_id": tenantID,
|
||||
})
|
||||
@@ -371,7 +407,7 @@ func (s *sessionService) GenerateTitle(ctx context.Context,
|
||||
) (string, error) {
|
||||
if session == nil {
|
||||
logger.Error(ctx, "Failed to generate title: session cannot be empty")
|
||||
return "", errors.New("session cannot be empty")
|
||||
return "", stderrors.New("session cannot be empty")
|
||||
}
|
||||
|
||||
// Skip if title already exists
|
||||
@@ -401,7 +437,7 @@ func (s *sessionService) GenerateTitle(ctx context.Context,
|
||||
// Ensure a user message was found
|
||||
if message == nil {
|
||||
logger.Error(ctx, "No user message found, cannot generate title")
|
||||
return "", errors.New("no user message found")
|
||||
return "", stderrors.New("no user message found")
|
||||
}
|
||||
|
||||
// Use provided modelID, or fallback to first available KnowledgeQA model
|
||||
@@ -423,7 +459,7 @@ func (s *sessionService) GenerateTitle(ctx context.Context,
|
||||
}
|
||||
if modelID == "" {
|
||||
logger.Error(ctx, "No KnowledgeQA model found")
|
||||
return "", errors.New("no KnowledgeQA model available for title generation")
|
||||
return "", stderrors.New("no KnowledgeQA model available for title generation")
|
||||
}
|
||||
} else {
|
||||
logger.Infof(ctx, "Using specified model for title generation: %s", modelID)
|
||||
@@ -464,7 +500,7 @@ func (s *sessionService) GenerateTitle(ctx context.Context,
|
||||
session.Title = strings.TrimPrefix(response.Content, "<think>\n\n</think>")
|
||||
|
||||
// Update session with new title
|
||||
err = s.sessionRepo.Update(ctx, session)
|
||||
_, err = s.sessionRepo.Update(ctx, session, session.UserID)
|
||||
if err != nil {
|
||||
logger.ErrorWithFields(ctx, err, nil)
|
||||
return "", err
|
||||
@@ -485,7 +521,7 @@ func (s *sessionService) GenerateTitleAsync(
|
||||
eventBus *event.EventBus,
|
||||
) {
|
||||
// Use context tenant (effective tenant when using shared agent) so ListModels/GetChatModel find the agent's model.
|
||||
// sessionRepo.Update uses session.TenantID in WHERE, so the session row is updated correctly regardless of ctx.
|
||||
// The session row itself is still updated by its persisted tenant/user owner scope.
|
||||
tenantID := ctx.Value(types.TenantIDContextKey)
|
||||
requestID := ctx.Value(types.RequestIDContextKey)
|
||||
language := ctx.Value(types.LanguageContextKey)
|
||||
|
||||
97
internal/application/service/session_user_scope_test.go
Normal file
97
internal/application/service/session_user_scope_test.go
Normal file
@@ -0,0 +1,97 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
"github.com/Tencent/WeKnora/internal/application/repository"
|
||||
apperrors "github.com/Tencent/WeKnora/internal/errors"
|
||||
"github.com/Tencent/WeKnora/internal/types"
|
||||
"github.com/stretchr/testify/require"
|
||||
"gorm.io/driver/sqlite"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
func testSessionScopeContext(tenantID uint64, userID string) context.Context {
|
||||
ctx := context.WithValue(context.Background(), types.TenantIDContextKey, tenantID)
|
||||
if userID != "" {
|
||||
ctx = context.WithValue(ctx, types.UserIDContextKey, userID)
|
||||
}
|
||||
return ctx
|
||||
}
|
||||
|
||||
func newTestSessionService(t *testing.T) (*sessionService, *gorm.DB) {
|
||||
t.Helper()
|
||||
|
||||
db, err := gorm.Open(sqlite.Open(":memory:"), &gorm.Config{})
|
||||
require.NoError(t, err)
|
||||
require.NoError(t, db.AutoMigrate(&types.Session{}))
|
||||
|
||||
return &sessionService{
|
||||
sessionRepo: repository.NewSessionRepository(db),
|
||||
}, db
|
||||
}
|
||||
|
||||
func TestGetSessionIsScopedToCurrentUser(t *testing.T) {
|
||||
svc, db := newTestSessionService(t)
|
||||
aliceSession := &types.Session{
|
||||
TenantID: 1,
|
||||
UserID: "alice",
|
||||
Title: "alice private session",
|
||||
}
|
||||
require.NoError(t, db.Create(aliceSession).Error)
|
||||
bobSession := &types.Session{
|
||||
TenantID: 1,
|
||||
UserID: "bob",
|
||||
Title: "bob private session",
|
||||
}
|
||||
require.NoError(t, db.Create(bobSession).Error)
|
||||
legacySession := &types.Session{
|
||||
TenantID: 1,
|
||||
Title: "legacy tenant session",
|
||||
}
|
||||
require.NoError(t, db.Create(legacySession).Error)
|
||||
|
||||
_, err := svc.GetSession(testSessionScopeContext(1, "bob"), aliceSession.ID)
|
||||
require.ErrorIs(t, err, apperrors.ErrSessionNotFound)
|
||||
|
||||
got, err := svc.GetSession(testSessionScopeContext(1, "bob"), bobSession.ID)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, bobSession.ID, got.ID)
|
||||
|
||||
got, err = svc.GetSession(testSessionScopeContext(1, "bob"), legacySession.ID)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, legacySession.ID, got.ID)
|
||||
}
|
||||
|
||||
func TestUpdateSessionIsScopedToCurrentUserAndAllowsNoOp(t *testing.T) {
|
||||
svc, db := newTestSessionService(t)
|
||||
aliceSession := &types.Session{
|
||||
TenantID: 1,
|
||||
UserID: "alice",
|
||||
Title: "alice private session",
|
||||
Description: "original description",
|
||||
}
|
||||
require.NoError(t, db.Create(aliceSession).Error)
|
||||
|
||||
err := svc.UpdateSession(testSessionScopeContext(1, "bob"), &types.Session{
|
||||
ID: aliceSession.ID,
|
||||
TenantID: 1,
|
||||
Title: "bob update attempt",
|
||||
Description: "should not be saved",
|
||||
})
|
||||
require.ErrorIs(t, err, apperrors.ErrSessionNotFound)
|
||||
|
||||
var unchanged types.Session
|
||||
require.NoError(t, db.First(&unchanged, "id = ?", aliceSession.ID).Error)
|
||||
require.Equal(t, aliceSession.Title, unchanged.Title)
|
||||
require.Equal(t, aliceSession.Description, unchanged.Description)
|
||||
|
||||
err = svc.UpdateSession(testSessionScopeContext(1, "alice"), &types.Session{
|
||||
ID: aliceSession.ID,
|
||||
TenantID: 1,
|
||||
Title: aliceSession.Title,
|
||||
Description: aliceSession.Description,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
}
|
||||
@@ -442,6 +442,11 @@ func (h *Handler) BatchDeleteSessions(c *gin.Context) {
|
||||
}
|
||||
|
||||
if err := h.sessionService.BatchDeleteSessions(ctx, sanitizedIDs); err != nil {
|
||||
if err == errors.ErrSessionNotFound {
|
||||
logger.Warnf(ctx, "No visible sessions found for batch delete")
|
||||
c.Error(errors.NewNotFoundError(err.Error()))
|
||||
return
|
||||
}
|
||||
logger.ErrorWithFields(ctx, err, nil)
|
||||
c.Error(errors.NewInternalServerError(err.Error()))
|
||||
return
|
||||
|
||||
@@ -55,25 +55,25 @@ type SessionService interface {
|
||||
type SessionRepository interface {
|
||||
// Create creates a session
|
||||
Create(ctx context.Context, session *types.Session) (*types.Session, error)
|
||||
// Get gets a session
|
||||
Get(ctx context.Context, tenantID uint64, id string) (*types.Session, error)
|
||||
// GetByTenantID gets all sessions of a tenant
|
||||
GetByTenantID(ctx context.Context, tenantID uint64) ([]*types.Session, error)
|
||||
// GetPagedByTenantID gets paged sessions of a tenant
|
||||
GetPagedByTenantID(ctx context.Context, tenantID uint64, page *types.Pagination) ([]*types.Session, int64, error)
|
||||
// Get gets a session visible to the tenant/user scope.
|
||||
Get(ctx context.Context, tenantID uint64, userID string, id string) (*types.Session, error)
|
||||
// GetByTenantID gets all sessions visible to the tenant/user scope.
|
||||
GetByTenantID(ctx context.Context, tenantID uint64, userID string) ([]*types.Session, error)
|
||||
// GetPagedByTenantID gets paged sessions visible to the tenant/user scope.
|
||||
GetPagedByTenantID(ctx context.Context, tenantID uint64, userID string, page *types.Pagination) ([]*types.Session, int64, error)
|
||||
// QueryPaged lists sessions with filters, user-scoped ownership and pin-aware ordering.
|
||||
QueryPaged(ctx context.Context, q *types.SessionListQuery) ([]*types.SessionListItem, int64, error)
|
||||
// Update updates a session
|
||||
Update(ctx context.Context, session *types.Session) error
|
||||
// Update updates a session visible to the tenant/user scope.
|
||||
Update(ctx context.Context, session *types.Session, userID string) (int64, error)
|
||||
// SetPinned pins or unpins a session row scoped by tenant.
|
||||
// userID, when non-empty, is enforced so users cannot pin sessions they don't own.
|
||||
// Returns the number of rows affected; 0 means the session doesn't exist or is
|
||||
// not visible to this caller.
|
||||
SetPinned(ctx context.Context, tenantID uint64, userID string, id string, pinned bool) (int64, error)
|
||||
// Delete deletes a session
|
||||
Delete(ctx context.Context, tenantID uint64, id string) error
|
||||
// BatchDelete deletes multiple sessions by IDs
|
||||
BatchDelete(ctx context.Context, tenantID uint64, ids []string) error
|
||||
// DeleteAllByTenantID deletes all sessions for a tenant
|
||||
DeleteAllByTenantID(ctx context.Context, tenantID uint64) error
|
||||
// Delete deletes a session visible to the tenant/user scope.
|
||||
Delete(ctx context.Context, tenantID uint64, userID string, id string) (int64, error)
|
||||
// BatchDelete deletes multiple sessions visible to the tenant/user scope.
|
||||
BatchDelete(ctx context.Context, tenantID uint64, userID string, ids []string) (int64, error)
|
||||
// DeleteAllByTenantID deletes all sessions visible to the tenant/user scope.
|
||||
DeleteAllByTenantID(ctx context.Context, tenantID uint64, userID string) (int64, error)
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user