feat: add registry dual-map and engine factory for VectorStore

- Extend RetrieveEngineRegistry with byStoreID map for DB store engines
- Add StoreRegistry interface and EngineFactory function type
- Create engine factory with per-driver SDK client construction
- Load VectorStore records from DB into registry at startup
- Wire Create/Delete to dynamically register/unregister engines
- Add comprehensive unit tests for registry and service
This commit is contained in:
ochan.kwon
2026-04-15 12:44:52 +09:00
committed by lyingbug
parent 4f55994266
commit c5fc05f3d0
7 changed files with 868 additions and 32 deletions

View File

@@ -8,40 +8,51 @@ import (
"github.com/Tencent/WeKnora/internal/types/interfaces"
)
// RetrieveEngineRegistry implements the retrieval engine registry
// RetrieveEngineRegistry implements the retrieval engine registry.
// It maintains two maps:
// - byEngineType: env stores registered via RETRIEVE_DRIVER (backward compatible)
// - byStoreID: DB stores registered via VectorStore table (instance-based)
//
// Implements both interfaces.RetrieveEngineRegistry and interfaces.StoreRegistry.
type RetrieveEngineRegistry struct {
repositories map[types.RetrieverEngineType]interfaces.RetrieveEngineService
byEngineType map[types.RetrieverEngineType]interfaces.RetrieveEngineService
byStoreID map[string]interfaces.RetrieveEngineService
mu sync.RWMutex
}
// NewRetrieveEngineRegistry creates a new retrieval engine registry
func NewRetrieveEngineRegistry() interfaces.RetrieveEngineRegistry {
return &RetrieveEngineRegistry{
repositories: make(map[types.RetrieverEngineType]interfaces.RetrieveEngineService),
byEngineType: make(map[types.RetrieverEngineType]interfaces.RetrieveEngineService),
byStoreID: make(map[string]interfaces.RetrieveEngineService),
}
}
// Register registers a retrieval engine service
// --- interfaces.RetrieveEngineRegistry methods (unchanged behavior) ---
// Register registers a retrieval engine service by engine type.
// Returns an error if the engine type is already registered.
func (r *RetrieveEngineRegistry) Register(repo interfaces.RetrieveEngineService) error {
r.mu.Lock()
defer r.mu.Unlock()
if _, exists := r.repositories[repo.EngineType()]; exists {
if _, exists := r.byEngineType[repo.EngineType()]; exists {
return fmt.Errorf("repository type %s already registered", repo.EngineType())
}
r.repositories[repo.EngineType()] = repo
r.byEngineType[repo.EngineType()] = repo
return nil
}
// GetRetrieveEngineService retrieves a retrieval engine service by type
// GetRetrieveEngineService retrieves a retrieval engine service by type.
// Only searches the byEngineType map (env stores).
func (r *RetrieveEngineRegistry) GetRetrieveEngineService(repoType types.RetrieverEngineType) (
interfaces.RetrieveEngineService, error,
) {
r.mu.RLock()
defer r.mu.RUnlock()
repo, exists := r.repositories[repoType]
repo, exists := r.byEngineType[repoType]
if !exists {
return nil, fmt.Errorf("repository of type %s not found", repoType)
}
@@ -49,16 +60,55 @@ func (r *RetrieveEngineRegistry) GetRetrieveEngineService(repoType types.Retriev
return repo, nil
}
// GetAllRetrieveEngineServices retrieves all registered retrieval engine services
// GetAllRetrieveEngineServices retrieves all registered retrieval engine services.
// Only returns byEngineType entries (env stores) for backward compatibility.
func (r *RetrieveEngineRegistry) GetAllRetrieveEngineServices() []interfaces.RetrieveEngineService {
r.mu.RLock()
defer r.mu.RUnlock()
// Create a copy to avoid modifying the original map
result := make([]interfaces.RetrieveEngineService, 0, len(r.repositories))
for _, v := range r.repositories {
result := make([]interfaces.RetrieveEngineService, 0, len(r.byEngineType))
for _, v := range r.byEngineType {
result = append(result, v)
}
return result
}
// --- interfaces.StoreRegistry methods (new, for VectorStore-based engines) ---
// RegisterWithStoreID registers an engine service by VectorStore ID.
// Unlike Register(), the same EngineType can be registered multiple times
// with different StoreIDs (e.g., two Elasticsearch clusters).
// Upsert semantics: existing entry is overwritten silently.
func (r *RetrieveEngineRegistry) RegisterWithStoreID(storeID string, svc interfaces.RetrieveEngineService) {
r.mu.Lock()
defer r.mu.Unlock()
r.byStoreID[storeID] = svc
}
// GetByStoreID retrieves an engine service by VectorStore ID.
// Callers must verify tenant ownership before using the returned service.
func (r *RetrieveEngineRegistry) GetByStoreID(storeID string) (interfaces.RetrieveEngineService, error) {
r.mu.RLock()
defer r.mu.RUnlock()
svc, exists := r.byStoreID[storeID]
if !exists {
return nil, fmt.Errorf("store %s not found in registry", storeID)
}
return svc, nil
}
// UnregisterByStoreID removes an engine service from the byStoreID map.
// Idempotent: returns silently if the storeID is not found.
//
// NOTE: gRPC-based clients (Qdrant, Milvus) hold connections that are not closed here.
// Known Phase 1 limitation — store deletion is rare, connections cleaned up on process exit.
// Phase 2 should add Close() to RetrieveEngineService interface and call it here.
func (r *RetrieveEngineRegistry) UnregisterByStoreID(storeID string) {
r.mu.Lock()
defer r.mu.Unlock()
delete(r.byStoreID, storeID)
}

View File

@@ -0,0 +1,237 @@
package retriever
import (
"context"
"fmt"
"sync"
"testing"
"github.com/Tencent/WeKnora/internal/models/embedding"
"github.com/Tencent/WeKnora/internal/types"
"github.com/Tencent/WeKnora/internal/types/interfaces"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
// mockEngineService is a minimal mock for testing registry operations.
// Only EngineType() is meaningful; all other methods are no-ops.
type mockEngineService struct {
engineType types.RetrieverEngineType
}
func (m *mockEngineService) EngineType() types.RetrieverEngineType { return m.engineType }
func (m *mockEngineService) Retrieve(_ context.Context, _ types.RetrieveParams) ([]*types.RetrieveResult, error) {
return nil, nil
}
func (m *mockEngineService) Support() []types.RetrieverType { return nil }
func (m *mockEngineService) Index(_ context.Context, _ embedding.Embedder, _ *types.IndexInfo, _ []types.RetrieverType) error {
return nil
}
func (m *mockEngineService) BatchIndex(_ context.Context, _ embedding.Embedder, _ []*types.IndexInfo, _ []types.RetrieverType) error {
return nil
}
func (m *mockEngineService) EstimateStorageSize(_ context.Context, _ embedding.Embedder, _ []*types.IndexInfo, _ []types.RetrieverType) int64 {
return 0
}
func (m *mockEngineService) CopyIndices(_ context.Context, _ string, _ map[string]string, _ map[string]string, _ string, _ int, _ string) error {
return nil
}
func (m *mockEngineService) DeleteByChunkIDList(_ context.Context, _ []string, _ int, _ string) error {
return nil
}
func (m *mockEngineService) DeleteBySourceIDList(_ context.Context, _ []string, _ int, _ string) error {
return nil
}
func (m *mockEngineService) DeleteByKnowledgeIDList(_ context.Context, _ []string, _ int, _ string) error {
return nil
}
func (m *mockEngineService) BatchUpdateChunkEnabledStatus(_ context.Context, _ map[string]bool) error {
return nil
}
func (m *mockEngineService) BatchUpdateChunkTagID(_ context.Context, _ map[string]string) error {
return nil
}
func newMock(engineType types.RetrieverEngineType) interfaces.RetrieveEngineService {
return &mockEngineService{engineType: engineType}
}
// --- Register (byEngineType) tests ---
func TestRegistry_Register(t *testing.T) {
reg := NewRetrieveEngineRegistry().(*RetrieveEngineRegistry)
t.Run("success", func(t *testing.T) {
err := reg.Register(newMock(types.PostgresRetrieverEngineType))
assert.NoError(t, err)
})
t.Run("duplicate engine type returns error", func(t *testing.T) {
err := reg.Register(newMock(types.PostgresRetrieverEngineType))
assert.Error(t, err)
assert.Contains(t, err.Error(), "already registered")
})
}
func TestRegistry_GetRetrieveEngineService(t *testing.T) {
reg := NewRetrieveEngineRegistry().(*RetrieveEngineRegistry)
_ = reg.Register(newMock(types.PostgresRetrieverEngineType))
t.Run("found", func(t *testing.T) {
svc, err := reg.GetRetrieveEngineService(types.PostgresRetrieverEngineType)
assert.NoError(t, err)
assert.NotNil(t, svc)
})
t.Run("not found", func(t *testing.T) {
_, err := reg.GetRetrieveEngineService(types.QdrantRetrieverEngineType)
assert.Error(t, err)
assert.Contains(t, err.Error(), "not found")
})
}
func TestRegistry_GetAllRetrieveEngineServices(t *testing.T) {
reg := NewRetrieveEngineRegistry().(*RetrieveEngineRegistry)
_ = reg.Register(newMock(types.PostgresRetrieverEngineType))
_ = reg.Register(newMock(types.ElasticsearchRetrieverEngineType))
t.Run("returns all byEngineType entries", func(t *testing.T) {
all := reg.GetAllRetrieveEngineServices()
assert.Len(t, all, 2)
})
t.Run("returns copy - modifying result does not affect registry", func(t *testing.T) {
all := reg.GetAllRetrieveEngineServices()
all = append(all, newMock(types.QdrantRetrieverEngineType))
assert.Len(t, reg.GetAllRetrieveEngineServices(), 2)
})
}
// --- RegisterWithStoreID (byStoreID) tests ---
func TestRegistry_RegisterWithStoreID(t *testing.T) {
reg := NewRetrieveEngineRegistry().(*RetrieveEngineRegistry)
t.Run("success", func(t *testing.T) {
reg.RegisterWithStoreID("store-1", newMock(types.PostgresRetrieverEngineType))
svc, err := reg.GetByStoreID("store-1")
assert.NoError(t, err)
assert.NotNil(t, svc)
})
t.Run("upsert overwrites existing", func(t *testing.T) {
newSvc := newMock(types.ElasticsearchRetrieverEngineType)
reg.RegisterWithStoreID("store-1", newSvc)
svc, err := reg.GetByStoreID("store-1")
assert.NoError(t, err)
assert.Equal(t, types.ElasticsearchRetrieverEngineType, svc.EngineType())
})
t.Run("same engine type different store IDs", func(t *testing.T) {
reg.RegisterWithStoreID("es-hot", newMock(types.ElasticsearchRetrieverEngineType))
reg.RegisterWithStoreID("es-warm", newMock(types.ElasticsearchRetrieverEngineType))
svc1, err1 := reg.GetByStoreID("es-hot")
svc2, err2 := reg.GetByStoreID("es-warm")
assert.NoError(t, err1)
assert.NoError(t, err2)
assert.NotSame(t, svc1, svc2)
})
}
func TestRegistry_GetByStoreID(t *testing.T) {
reg := NewRetrieveEngineRegistry().(*RetrieveEngineRegistry)
reg.RegisterWithStoreID("store-1", newMock(types.PostgresRetrieverEngineType))
t.Run("found", func(t *testing.T) {
svc, err := reg.GetByStoreID("store-1")
assert.NoError(t, err)
assert.NotNil(t, svc)
})
t.Run("not found", func(t *testing.T) {
_, err := reg.GetByStoreID("nonexistent")
assert.Error(t, err)
assert.Contains(t, err.Error(), "not found")
})
}
func TestRegistry_UnregisterByStoreID(t *testing.T) {
reg := NewRetrieveEngineRegistry().(*RetrieveEngineRegistry)
reg.RegisterWithStoreID("store-1", newMock(types.PostgresRetrieverEngineType))
t.Run("removes registered store", func(t *testing.T) {
reg.UnregisterByStoreID("store-1")
_, err := reg.GetByStoreID("store-1")
assert.Error(t, err)
})
t.Run("idempotent on nonexistent store", func(t *testing.T) {
reg.UnregisterByStoreID("nonexistent") // should not panic
})
}
// --- Dual map isolation tests ---
func TestRegistry_DualMapIsolation(t *testing.T) {
reg := NewRetrieveEngineRegistry().(*RetrieveEngineRegistry)
_ = reg.Register(newMock(types.PostgresRetrieverEngineType))
reg.RegisterWithStoreID("store-pg", newMock(types.PostgresRetrieverEngineType))
reg.RegisterWithStoreID("store-es", newMock(types.ElasticsearchRetrieverEngineType))
t.Run("GetAllRetrieveEngineServices returns only byEngineType", func(t *testing.T) {
all := reg.GetAllRetrieveEngineServices()
assert.Len(t, all, 1)
})
t.Run("byStoreID does not affect byEngineType lookup", func(t *testing.T) {
_, err := reg.GetRetrieveEngineService(types.ElasticsearchRetrieverEngineType)
assert.Error(t, err) // ES is only in byStoreID, not byEngineType
})
t.Run("unregister byStoreID does not affect byEngineType", func(t *testing.T) {
reg.UnregisterByStoreID("store-pg")
svc, err := reg.GetRetrieveEngineService(types.PostgresRetrieverEngineType)
assert.NoError(t, err)
assert.NotNil(t, svc)
})
}
// --- Concurrency test ---
func TestRegistry_ConcurrentAccess(t *testing.T) {
reg := NewRetrieveEngineRegistry().(*RetrieveEngineRegistry)
const goroutines = 10
var wg sync.WaitGroup
wg.Add(goroutines * 3)
for i := 0; i < goroutines; i++ {
storeID := fmt.Sprintf("store-%d", i)
go func() {
defer wg.Done()
reg.RegisterWithStoreID(storeID, newMock(types.PostgresRetrieverEngineType))
}()
go func() {
defer wg.Done()
_, _ = reg.GetByStoreID(storeID)
}()
go func() {
defer wg.Done()
reg.UnregisterByStoreID(storeID)
}()
}
wg.Wait()
}
// --- Interface compliance ---
func TestRegistry_ImplementsStoreRegistry(t *testing.T) {
reg := NewRetrieveEngineRegistry()
concreteReg, ok := reg.(*RetrieveEngineRegistry)
require.True(t, ok)
var _ interfaces.StoreRegistry = concreteReg
}

View File

@@ -3,6 +3,7 @@ package service
import (
"context"
"os"
"time"
"github.com/Tencent/WeKnora/internal/errors"
"github.com/Tencent/WeKnora/internal/logger"
@@ -13,14 +14,22 @@ import (
// vectorStoreService implements interfaces.VectorStoreService
type vectorStoreService struct {
repo interfaces.VectorStoreRepository
repo interfaces.VectorStoreRepository
storeRegistry interfaces.StoreRegistry // for dynamic registry updates on CRUD
factory interfaces.EngineFactory // creates engine services from VectorStore config
}
// NewVectorStoreService creates a new vector store service
func NewVectorStoreService(
repo interfaces.VectorStoreRepository,
storeRegistry interfaces.StoreRegistry,
factory interfaces.EngineFactory,
) interfaces.VectorStoreService {
return &vectorStoreService{repo: repo}
return &vectorStoreService{
repo: repo,
storeRegistry: storeRegistry,
factory: factory,
}
}
// CreateStore validates and creates a new vector store.
@@ -60,10 +69,20 @@ func (s *vectorStoreService) CreateStore(ctx context.Context, store *types.Vecto
// 5. Persist
logger.Infof(ctx, "Creating vector store: tenant=%d, name=%s, engine=%s",
store.TenantID, secutils.SanitizeForLog(store.Name), store.EngineType)
return s.repo.Create(ctx, store)
if err := s.repo.Create(ctx, store); err != nil {
return err
}
// 6. Register in registry (best-effort; failure doesn't roll back DB).
// The store is already persisted, and will be loaded on next app restart (self-healing).
s.registerInRegistry(ctx, store)
return nil
}
// UpdateStore updates an existing vector store (name only).
// NOTE: If connection_config or index_config become mutable in the future,
// registry re-registration must be added here (unregister old + register new).
func (s *vectorStoreService) UpdateStore(ctx context.Context, store *types.VectorStore) error {
if store.TenantID == 0 {
return errors.NewValidationError("tenant_id is required")
@@ -79,8 +98,17 @@ func (s *vectorStoreService) UpdateStore(ctx context.Context, store *types.Vecto
// DeleteStore deletes a vector store by tenant + id.
// Phase 2: KB binding check will be added here.
func (s *vectorStoreService) DeleteStore(ctx context.Context, tenantID uint64, id string) error {
logger.Infof(ctx, "Deleting vector store: tenant=%d, id=%s", tenantID, id)
return s.repo.Delete(ctx, tenantID, id)
if err := s.repo.Delete(ctx, tenantID, id); err != nil {
return err
}
// Unregister from registry (idempotent)
if s.storeRegistry != nil {
s.storeRegistry.UnregisterByStoreID(id)
}
logger.Infof(ctx, "Deleted vector store: tenant=%d, id=%s", tenantID, id)
return nil
}
// SaveDetectedVersion updates the connection_config.version for a stored vector store.
@@ -91,6 +119,28 @@ func (s *vectorStoreService) SaveDetectedVersion(ctx context.Context, store *typ
return s.repo.UpdateConnectionConfig(ctx, &updated)
}
// registerInRegistry creates an engine service and registers it in the registry.
// Logs and skips on failure — the store is already persisted in DB,
// and will be loaded on next app restart (self-healing).
func (s *vectorStoreService) registerInRegistry(ctx context.Context, store *types.VectorStore) {
if s.storeRegistry == nil || s.factory == nil {
return
}
// Use a short timeout for engine creation to avoid blocking on unreachable hosts
// (e.g., gRPC dial to Qdrant/Milvus). The store is already persisted in DB,
// so it will be loaded on next app restart if this times out.
factoryCtx, cancel := context.WithTimeout(ctx, 10*time.Second)
defer cancel()
svc, err := s.factory(factoryCtx, *store)
if err != nil {
logger.Warnf(ctx, "Failed to create engine for store %s, will be available after restart: %v", store.ID, err)
return
}
s.storeRegistry.RegisterWithStoreID(store.ID, svc)
}
// validateConnectionConfig validates required fields per engine type.
func validateConnectionConfig(engineType types.RetrieverEngineType, config types.ConnectionConfig) error {
switch engineType {

View File

@@ -5,7 +5,9 @@ import (
"testing"
"github.com/Tencent/WeKnora/internal/errors"
"github.com/Tencent/WeKnora/internal/models/embedding"
"github.com/Tencent/WeKnora/internal/types"
"github.com/Tencent/WeKnora/internal/types/interfaces"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
@@ -71,13 +73,88 @@ func (m *mockVectorStoreRepo) ExistsByEndpointAndIndex(
return m.existsByEndpoint, nil
}
// ---------------------------------------------------------------------------
// Mock StoreRegistry
// ---------------------------------------------------------------------------
type mockStoreRegistry struct {
registered map[string]bool
unregistered []string
}
func newMockStoreRegistry() *mockStoreRegistry {
return &mockStoreRegistry{registered: make(map[string]bool)}
}
func (m *mockStoreRegistry) RegisterWithStoreID(storeID string, _ interfaces.RetrieveEngineService) {
m.registered[storeID] = true
}
func (m *mockStoreRegistry) GetByStoreID(storeID string) (interfaces.RetrieveEngineService, error) {
return nil, nil
}
func (m *mockStoreRegistry) UnregisterByStoreID(storeID string) {
m.unregistered = append(m.unregistered, storeID)
delete(m.registered, storeID)
}
// ---------------------------------------------------------------------------
// Mock EngineFactory
// ---------------------------------------------------------------------------
func mockEngineFactory(err error) interfaces.EngineFactory {
return func(_ context.Context, _ types.VectorStore) (interfaces.RetrieveEngineService, error) {
if err != nil {
return nil, err
}
return &mockEngineService{}, nil
}
}
// mockEngineService satisfies interfaces.RetrieveEngineService minimally.
type mockEngineService struct{}
func (m *mockEngineService) EngineType() types.RetrieverEngineType { return "mock" }
func (m *mockEngineService) Retrieve(_ context.Context, _ types.RetrieveParams) ([]*types.RetrieveResult, error) {
return nil, nil
}
func (m *mockEngineService) Support() []types.RetrieverType { return nil }
func (m *mockEngineService) Index(_ context.Context, _ embedding.Embedder, _ *types.IndexInfo, _ []types.RetrieverType) error {
return nil
}
func (m *mockEngineService) BatchIndex(_ context.Context, _ embedding.Embedder, _ []*types.IndexInfo, _ []types.RetrieverType) error {
return nil
}
func (m *mockEngineService) EstimateStorageSize(_ context.Context, _ embedding.Embedder, _ []*types.IndexInfo, _ []types.RetrieverType) int64 {
return 0
}
func (m *mockEngineService) CopyIndices(_ context.Context, _ string, _ map[string]string, _ map[string]string, _ string, _ int, _ string) error {
return nil
}
func (m *mockEngineService) DeleteByChunkIDList(_ context.Context, _ []string, _ int, _ string) error {
return nil
}
func (m *mockEngineService) DeleteBySourceIDList(_ context.Context, _ []string, _ int, _ string) error {
return nil
}
func (m *mockEngineService) DeleteByKnowledgeIDList(_ context.Context, _ []string, _ int, _ string) error {
return nil
}
func (m *mockEngineService) BatchUpdateChunkEnabledStatus(_ context.Context, _ map[string]bool) error {
return nil
}
func (m *mockEngineService) BatchUpdateChunkTagID(_ context.Context, _ map[string]string) error {
return nil
}
// ---------------------------------------------------------------------------
// CreateStore tests
// ---------------------------------------------------------------------------
func TestCreateStore_Success(t *testing.T) {
repo := &mockVectorStoreRepo{}
svc := NewVectorStoreService(repo)
svc := NewVectorStoreService(repo, nil, nil)
store := &types.VectorStore{
TenantID: 1,
@@ -95,7 +172,7 @@ func TestCreateStore_Success(t *testing.T) {
func TestCreateStore_ValidationError(t *testing.T) {
repo := &mockVectorStoreRepo{}
svc := NewVectorStoreService(repo)
svc := NewVectorStoreService(repo, nil, nil)
tests := []struct {
name string
@@ -127,7 +204,7 @@ func TestCreateStore_ValidationError(t *testing.T) {
func TestCreateStore_ConnectionConfigValidation(t *testing.T) {
repo := &mockVectorStoreRepo{}
svc := NewVectorStoreService(repo)
svc := NewVectorStoreService(repo, nil, nil)
tests := []struct {
name string
@@ -213,7 +290,7 @@ func TestCreateStore_ConnectionConfigValidation(t *testing.T) {
func TestCreateStore_DuplicateCheck_DBStore(t *testing.T) {
repo := &mockVectorStoreRepo{existsByEndpoint: true}
svc := NewVectorStoreService(repo)
svc := NewVectorStoreService(repo, nil, nil)
store := &types.VectorStore{
TenantID: 1,
@@ -236,7 +313,7 @@ func TestCreateStore_DuplicateCheck_DBError(t *testing.T) {
repo := &mockVectorStoreRepo{
existsByEndpointErr: assert.AnError,
}
svc := NewVectorStoreService(repo)
svc := NewVectorStoreService(repo, nil, nil)
store := &types.VectorStore{
TenantID: 1,
@@ -260,7 +337,7 @@ func TestCreateStore_DuplicateCheck_EnvStore(t *testing.T) {
t.Setenv("ELASTICSEARCH_INDEX", "xwrag_default")
repo := &mockVectorStoreRepo{existsByEndpoint: false} // no DB duplicate
svc := NewVectorStoreService(repo)
svc := NewVectorStoreService(repo, nil, nil)
store := &types.VectorStore{
TenantID: 1,
@@ -290,7 +367,7 @@ func TestCreateStore_DuplicateCheck_EnvStore_DifferentIndex_Allowed(t *testing.T
t.Setenv("ELASTICSEARCH_INDEX", "xwrag_default")
repo := &mockVectorStoreRepo{existsByEndpoint: false}
svc := NewVectorStoreService(repo)
svc := NewVectorStoreService(repo, nil, nil)
store := &types.VectorStore{
TenantID: 1,
@@ -310,7 +387,7 @@ func TestCreateStore_DuplicateCheck_EnvStore_DifferentIndex_Allowed(t *testing.T
func TestCreateStore_DifferentEndpointSameIndex_Allowed(t *testing.T) {
repo := &mockVectorStoreRepo{existsByEndpoint: false}
svc := NewVectorStoreService(repo)
svc := NewVectorStoreService(repo, nil, nil)
store := &types.VectorStore{
TenantID: 1,
@@ -328,13 +405,125 @@ func TestCreateStore_DifferentEndpointSameIndex_Allowed(t *testing.T) {
assert.NoError(t, err)
}
// ---------------------------------------------------------------------------
// CreateStore + Registry integration tests
// ---------------------------------------------------------------------------
func TestCreateStore_RegistersInRegistry(t *testing.T) {
repo := &mockVectorStoreRepo{}
registry := newMockStoreRegistry()
factory := mockEngineFactory(nil)
svc := NewVectorStoreService(repo, registry, factory)
store := &types.VectorStore{
TenantID: 1,
Name: "test-es",
EngineType: types.ElasticsearchRetrieverEngineType,
ConnectionConfig: types.ConnectionConfig{
Addr: "http://es:9200",
},
}
err := svc.CreateStore(context.Background(), store)
require.NoError(t, err)
// Store should be persisted AND registered in registry
assert.Len(t, repo.stores, 1)
assert.True(t, registry.registered[store.ID])
}
func TestCreateStore_RegistryFailureDoesNotRollBackDB(t *testing.T) {
repo := &mockVectorStoreRepo{}
registry := newMockStoreRegistry()
factory := mockEngineFactory(assert.AnError) // factory fails
svc := NewVectorStoreService(repo, registry, factory)
store := &types.VectorStore{
TenantID: 1,
Name: "test-es",
EngineType: types.ElasticsearchRetrieverEngineType,
ConnectionConfig: types.ConnectionConfig{
Addr: "http://es:9200",
},
}
// CreateStore should succeed even if registry fails (best-effort + self-healing)
err := svc.CreateStore(context.Background(), store)
assert.NoError(t, err)
// DB should have the store
assert.Len(t, repo.stores, 1)
// Registry should NOT have it (factory failed)
assert.False(t, registry.registered[store.ID])
}
func TestCreateStore_NilRegistryAndFactory(t *testing.T) {
repo := &mockVectorStoreRepo{}
svc := NewVectorStoreService(repo, nil, nil) // no registry
store := &types.VectorStore{
TenantID: 1,
Name: "test-es",
EngineType: types.ElasticsearchRetrieverEngineType,
ConnectionConfig: types.ConnectionConfig{
Addr: "http://es:9200",
},
}
// Should work fine without registry (degrades gracefully)
err := svc.CreateStore(context.Background(), store)
assert.NoError(t, err)
assert.Len(t, repo.stores, 1)
}
// ---------------------------------------------------------------------------
// DeleteStore + Registry integration tests
// ---------------------------------------------------------------------------
func TestDeleteStore_UnregistersFromRegistry(t *testing.T) {
repo := &mockVectorStoreRepo{}
registry := newMockStoreRegistry()
registry.registered["store-1"] = true
svc := NewVectorStoreService(repo, registry, nil)
err := svc.DeleteStore(context.Background(), 1, "store-1")
require.NoError(t, err)
// Should be unregistered
assert.Contains(t, registry.unregistered, "store-1")
assert.False(t, registry.registered["store-1"])
}
func TestDeleteStore_NilRegistryGraceful(t *testing.T) {
repo := &mockVectorStoreRepo{}
svc := NewVectorStoreService(repo, nil, nil)
// Should not panic with nil registry
err := svc.DeleteStore(context.Background(), 1, "store-1")
assert.NoError(t, err)
}
func TestDeleteStore_RepoErrorSkipsUnregister(t *testing.T) {
repo := &mockVectorStoreRepo{deleteErr: assert.AnError}
registry := newMockStoreRegistry()
registry.registered["store-1"] = true
svc := NewVectorStoreService(repo, registry, nil)
err := svc.DeleteStore(context.Background(), 1, "store-1")
assert.Error(t, err)
// Registry should NOT be touched if DB delete fails
assert.True(t, registry.registered["store-1"])
assert.Empty(t, registry.unregistered)
}
// ---------------------------------------------------------------------------
// UpdateStore tests
// ---------------------------------------------------------------------------
func TestUpdateStore_Success(t *testing.T) {
repo := &mockVectorStoreRepo{}
svc := NewVectorStoreService(repo)
svc := NewVectorStoreService(repo, nil, nil)
store := &types.VectorStore{
ID: "test-id",
@@ -348,7 +537,7 @@ func TestUpdateStore_Success(t *testing.T) {
func TestUpdateStore_ValidationError(t *testing.T) {
repo := &mockVectorStoreRepo{}
svc := NewVectorStoreService(repo)
svc := NewVectorStoreService(repo, nil, nil)
tests := []struct {
name string
@@ -378,7 +567,7 @@ func TestUpdateStore_ValidationError(t *testing.T) {
func TestDeleteStore_Success(t *testing.T) {
repo := &mockVectorStoreRepo{}
svc := NewVectorStoreService(repo)
svc := NewVectorStoreService(repo, nil, nil)
err := svc.DeleteStore(context.Background(), 1, "test-id")
assert.NoError(t, err)
@@ -386,19 +575,63 @@ func TestDeleteStore_Success(t *testing.T) {
func TestDeleteStore_RepoError(t *testing.T) {
repo := &mockVectorStoreRepo{deleteErr: assert.AnError}
svc := NewVectorStoreService(repo)
svc := NewVectorStoreService(repo, nil, nil)
err := svc.DeleteStore(context.Background(), 1, "test-id")
assert.Error(t, err)
}
// ---------------------------------------------------------------------------
// SaveDetectedVersion tests
// ---------------------------------------------------------------------------
func TestSaveDetectedVersion_Success(t *testing.T) {
repo := &mockVectorStoreRepo{}
svc := NewVectorStoreService(repo, nil, nil)
store := &types.VectorStore{
ID: "store-1",
TenantID: 1,
ConnectionConfig: types.ConnectionConfig{Addr: "http://es:9200"},
}
err := svc.SaveDetectedVersion(context.Background(), store, "7.10.1")
assert.NoError(t, err)
}
func TestSaveDetectedVersion_RepoError(t *testing.T) {
repo := &mockVectorStoreRepo{updateErr: assert.AnError}
svc := NewVectorStoreService(repo, nil, nil)
store := &types.VectorStore{ID: "store-1", TenantID: 1}
err := svc.SaveDetectedVersion(context.Background(), store, "8.11.0")
assert.Error(t, err)
}
func TestSaveDetectedVersion_DoesNotMutateOriginal(t *testing.T) {
repo := &mockVectorStoreRepo{}
svc := NewVectorStoreService(repo, nil, nil)
store := &types.VectorStore{
ID: "store-1",
TenantID: 1,
ConnectionConfig: types.ConnectionConfig{Version: "old"},
}
err := svc.SaveDetectedVersion(context.Background(), store, "new")
require.NoError(t, err)
// Original store must not be mutated
assert.Equal(t, "old", store.ConnectionConfig.Version)
}
// ---------------------------------------------------------------------------
// TestConnection tests
// ---------------------------------------------------------------------------
func TestTestConnection_UnsupportedEngineType(t *testing.T) {
repo := &mockVectorStoreRepo{}
svc := NewVectorStoreService(repo)
svc := NewVectorStoreService(repo, nil, nil)
_, err := svc.TestConnection(context.Background(), "unknown_engine", types.ConnectionConfig{})
require.Error(t, err)
@@ -410,7 +643,7 @@ func TestTestConnection_UnsupportedEngineType(t *testing.T) {
func TestTestConnection_SQLiteAlwaysSucceeds(t *testing.T) {
repo := &mockVectorStoreRepo{}
svc := NewVectorStoreService(repo)
svc := NewVectorStoreService(repo, nil, nil)
version, err := svc.TestConnection(context.Background(), types.SQLiteRetrieverEngineType, types.ConnectionConfig{})
assert.NoError(t, err)
@@ -419,7 +652,7 @@ func TestTestConnection_SQLiteAlwaysSucceeds(t *testing.T) {
func TestTestConnection_PostgresDefaultConnection(t *testing.T) {
repo := &mockVectorStoreRepo{}
svc := NewVectorStoreService(repo)
svc := NewVectorStoreService(repo, nil, nil)
version, err := svc.TestConnection(context.Background(), types.PostgresRetrieverEngineType,
types.ConnectionConfig{UseDefaultConnection: true})

View File

@@ -191,6 +191,16 @@ func BuildContainer(container *dig.Container) *dig.Container {
must(container.Provide(repository.NewVectorStoreRepository))
must(container.Provide(service.NewWebSearchService))
must(container.Provide(service.NewWebSearchProviderService))
must(container.Provide(NewEngineFactory))
// StoreRegistry: same instance as RetrieveEngineRegistry, exposed as StoreRegistry interface.
// NewRetrieveEngineRegistry always returns *retriever.RetrieveEngineRegistry which implements both.
must(container.Provide(func(r interfaces.RetrieveEngineRegistry) (interfaces.StoreRegistry, error) {
sr, ok := r.(*retriever.RetrieveEngineRegistry)
if !ok {
return nil, fmt.Errorf("registry does not implement StoreRegistry")
}
return sr, nil
}))
must(container.Provide(service.NewVectorStoreService))
// Agent service layer (requires event bus, web search service)
@@ -923,9 +933,43 @@ func initRetrieveEngineRegistry(db *gorm.DB, cfg *config.Config) (interfaces.Ret
}
}
}
// ─── DB store registration (byStoreID) ───
if storeReg, ok := registry.(*retriever.RetrieveEngineRegistry); ok {
loadDBStoresIntoRegistry(storeReg, db, cfg)
}
return registry, nil
}
// loadDBStoresIntoRegistry loads VectorStore records from DB and registers them
// in the registry's byStoreID map. Failures are logged and skipped (non-fatal).
func loadDBStoresIntoRegistry(storeRegistry interfaces.StoreRegistry, db *gorm.DB, cfg *config.Config) {
ctx := context.Background()
log := logger.GetLogger(ctx)
var stores []types.VectorStore
// GORM soft delete automatically adds "deleted_at IS NULL" condition
if err := db.Find(&stores).Error; err != nil {
log.Warnf("Failed to load vector stores from DB: %v", err)
return
}
if len(stores) == 0 {
return
}
log.Infof("Loading %d vector store(s) from database", len(stores))
for _, store := range stores {
svc, err := createEngineServiceFromStore(ctx, store, db, cfg)
if err != nil {
log.Errorf("Failed to create engine for store %s (%s): %v", store.ID, store.Name, err)
continue
}
storeRegistry.RegisterWithStoreID(store.ID, svc)
log.Infof("Registered DB vector store: id=%s, name=%s, engine=%s", store.ID, store.Name, store.EngineType)
}
}
// initAntsPool initializes the goroutine pool
// Creates a managed goroutine pool for concurrent task execution
// Parameters:

View File

@@ -0,0 +1,205 @@
package container
import (
"context"
"fmt"
"strings"
"time"
esv7 "github.com/elastic/go-elasticsearch/v7"
"github.com/elastic/go-elasticsearch/v8"
"github.com/milvus-io/milvus/client/v2/milvusclient"
"github.com/qdrant/go-client/qdrant"
"github.com/weaviate/weaviate-go-client/v5/weaviate"
"github.com/weaviate/weaviate-go-client/v5/weaviate/auth"
wgrpc "github.com/weaviate/weaviate-go-client/v5/weaviate/grpc"
"google.golang.org/grpc"
"gorm.io/gorm"
elasticsearchRepoV7 "github.com/Tencent/WeKnora/internal/application/repository/retriever/elasticsearch/v7"
elasticsearchRepoV8 "github.com/Tencent/WeKnora/internal/application/repository/retriever/elasticsearch/v8"
milvusRepo "github.com/Tencent/WeKnora/internal/application/repository/retriever/milvus"
postgresRepo "github.com/Tencent/WeKnora/internal/application/repository/retriever/postgres"
qdrantRepo "github.com/Tencent/WeKnora/internal/application/repository/retriever/qdrant"
sqliteRetrieverRepo "github.com/Tencent/WeKnora/internal/application/repository/retriever/sqlite"
weaviateRepo "github.com/Tencent/WeKnora/internal/application/repository/retriever/weaviate"
"github.com/Tencent/WeKnora/internal/application/service/retriever"
"github.com/Tencent/WeKnora/internal/config"
"github.com/Tencent/WeKnora/internal/types"
"github.com/Tencent/WeKnora/internal/types/interfaces"
)
// NewEngineFactory returns an EngineFactory function closed over db and cfg.
// Registered in dig and injected into VectorStoreService for dynamic registry updates.
func NewEngineFactory(db *gorm.DB, cfg *config.Config) interfaces.EngineFactory {
return func(ctx context.Context, store types.VectorStore) (interfaces.RetrieveEngineService, error) {
return createEngineServiceFromStore(ctx, store, db, cfg)
}
}
// createEngineServiceFromStore creates a RetrieveEngineService from a VectorStore's config.
// This is the DB store counterpart of the env-based initialization in initRetrieveEngineRegistry.
func createEngineServiceFromStore(
ctx context.Context,
store types.VectorStore,
db *gorm.DB,
cfg *config.Config,
) (interfaces.RetrieveEngineService, error) {
switch store.EngineType {
case types.PostgresRetrieverEngineType:
return createPostgresEngine(store, db)
case types.ElasticsearchRetrieverEngineType:
return createElasticsearchEngine(store, cfg)
case types.QdrantRetrieverEngineType:
return createQdrantEngine(store)
case types.MilvusRetrieverEngineType:
return createMilvusEngine(ctx, store)
case types.WeaviateRetrieverEngineType:
return createWeaviateEngine(store)
case types.SQLiteRetrieverEngineType:
return createSQLiteEngine(store, db)
default:
return nil, fmt.Errorf("unsupported engine type: %s", store.EngineType)
}
}
func createPostgresEngine(store types.VectorStore, db *gorm.DB) (interfaces.RetrieveEngineService, error) {
if store.ConnectionConfig.UseDefaultConnection {
repo := postgresRepo.NewPostgresRetrieveEngineRepository(db)
return retriever.NewKVHybridRetrieveEngine(repo, types.PostgresRetrieverEngineType), nil
}
// Phase 1: only UseDefaultConnection is supported.
// Custom connections require connection pool management and migration handling.
return nil, fmt.Errorf("custom postgres connections not yet supported; use use_default_connection=true")
}
func createSQLiteEngine(_ types.VectorStore, db *gorm.DB) (interfaces.RetrieveEngineService, error) {
repo := sqliteRetrieverRepo.NewSQLiteRetrieveEngineRepository(db)
return retriever.NewKVHybridRetrieveEngine(repo, types.SQLiteRetrieverEngineType), nil
}
func createElasticsearchEngine(store types.VectorStore, cfg *config.Config) (interfaces.RetrieveEngineService, error) {
cc := store.ConnectionConfig
// Version-based v7/v8 SDK selection.
// Version is auto-detected by PR2's TestConnection and saved to connection_config.
// Empty version defaults to v8 (latest SDK).
if isESv7(cc.Version) {
return createElasticsearchV7Engine(cc, cfg)
}
return createElasticsearchV8Engine(cc, cfg)
}
// isESv7 checks if the detected ES version is 7.x.
func isESv7(version string) bool {
return strings.HasPrefix(version, "7.")
}
func createElasticsearchV8Engine(cc types.ConnectionConfig, cfg *config.Config) (interfaces.RetrieveEngineService, error) {
client, err := elasticsearch.NewTypedClient(elasticsearch.Config{
Addresses: []string{cc.Addr},
Username: cc.Username,
Password: cc.Password,
})
if err != nil {
return nil, fmt.Errorf("create elasticsearch v8 client: %w", err)
}
repo := elasticsearchRepoV8.NewElasticsearchEngineRepository(client, cfg)
return retriever.NewKVHybridRetrieveEngine(repo, types.ElasticsearchRetrieverEngineType), nil
}
func createElasticsearchV7Engine(cc types.ConnectionConfig, cfg *config.Config) (interfaces.RetrieveEngineService, error) {
client, err := esv7.NewClient(esv7.Config{
Addresses: []string{cc.Addr},
Username: cc.Username,
Password: cc.Password,
})
if err != nil {
return nil, fmt.Errorf("create elasticsearch v7 client: %w", err)
}
repo := elasticsearchRepoV7.NewElasticsearchEngineRepository(client, cfg)
return retriever.NewKVHybridRetrieveEngine(repo, types.ElasticsearchRetrieverEngineType), nil
}
func createQdrantEngine(store types.VectorStore) (interfaces.RetrieveEngineService, error) {
cc := store.ConnectionConfig
port := cc.Port
if port == 0 {
port = 6334
}
client, err := qdrant.NewClient(&qdrant.Config{
Host: cc.Host,
Port: port,
APIKey: cc.APIKey,
UseTLS: cc.UseTLS,
})
if err != nil {
return nil, fmt.Errorf("create qdrant client: %w", err)
}
repo := qdrantRepo.NewQdrantRetrieveEngineRepository(client)
return retriever.NewKVHybridRetrieveEngine(repo, types.QdrantRetrieverEngineType), nil
}
func createMilvusEngine(ctx context.Context, store types.VectorStore) (interfaces.RetrieveEngineService, error) {
cc := store.ConnectionConfig
addr := cc.Addr
if addr == "" {
addr = "localhost:19530"
}
milvusCfg := milvusclient.ClientConfig{
Address: addr,
DialOptions: []grpc.DialOption{grpc.WithTimeout(5 * time.Second)},
}
if cc.Username != "" {
milvusCfg.Username = cc.Username
}
if cc.Password != "" {
milvusCfg.Password = cc.Password
}
// NOTE: Milvus DBName is not yet in ConnectionConfig.
// Phase 1 limitation — only the default database is used.
client, err := milvusclient.New(ctx, &milvusCfg)
if err != nil {
return nil, fmt.Errorf("create milvus client: %w", err)
}
repo := milvusRepo.NewMilvusRetrieveEngineRepository(client)
return retriever.NewKVHybridRetrieveEngine(repo, types.MilvusRetrieverEngineType), nil
}
func createWeaviateEngine(store types.VectorStore) (interfaces.RetrieveEngineService, error) {
cc := store.ConnectionConfig
host := cc.Host
if host == "" {
host = "weaviate:8080"
}
grpcAddress := cc.GrpcAddress
if grpcAddress == "" {
grpcAddress = "weaviate:50051"
}
scheme := cc.Scheme
if scheme == "" {
scheme = "http"
}
weaviateCfg := weaviate.Config{
Host: host,
GrpcConfig: &wgrpc.Config{
Host: grpcAddress,
},
Scheme: scheme,
}
// Unlike the env path (which checks WEAVIATE_AUTH_ENABLED), the factory uses
// APIKey directly — if a user provides it, they intend to use it.
if cc.APIKey != "" {
weaviateCfg.AuthConfig = auth.ApiKey{Value: cc.APIKey}
}
client, err := weaviate.NewClient(weaviateCfg)
if err != nil {
return nil, fmt.Errorf("create weaviate client: %w", err)
}
repo := weaviateRepo.NewWeaviateRetrieveEngineRepository(client)
return retriever.NewKVHybridRetrieveEngine(repo, types.WeaviateRetrieverEngineType), nil
}

View File

@@ -6,6 +6,23 @@ import (
"github.com/Tencent/WeKnora/internal/types"
)
// StoreRegistry provides VectorStore-based engine registration/lookup.
// Separated from RetrieveEngineRegistry to avoid changing the existing interface
// used by 6 services (17 call sites). Phase 2 may merge into RetrieveEngineRegistry.
type StoreRegistry interface {
// RegisterWithStoreID registers an engine service by VectorStore ID.
// Upsert semantics: existing entry is overwritten silently.
RegisterWithStoreID(storeID string, svc RetrieveEngineService)
// GetByStoreID retrieves an engine service by VectorStore ID.
GetByStoreID(storeID string) (RetrieveEngineService, error)
// UnregisterByStoreID removes an engine service by VectorStore ID (idempotent).
UnregisterByStoreID(storeID string)
}
// EngineFactory creates a RetrieveEngineService from a VectorStore's config.
// Defined as a function type to avoid circular imports between container and service packages.
type EngineFactory func(ctx context.Context, store types.VectorStore) (RetrieveEngineService, error)
// VectorStoreService defines the service interface for vector store management.
// Tenant isolation is enforced by the handler layer (getOwnedStore pattern).
type VectorStoreService interface {