mirror of
https://github.com/Tencent/WeKnora.git
synced 2026-06-04 13:30:32 +08:00
feat: add registry dual-map and engine factory for VectorStore
- Extend RetrieveEngineRegistry with byStoreID map for DB store engines - Add StoreRegistry interface and EngineFactory function type - Create engine factory with per-driver SDK client construction - Load VectorStore records from DB into registry at startup - Wire Create/Delete to dynamically register/unregister engines - Add comprehensive unit tests for registry and service
This commit is contained in:
@@ -8,40 +8,51 @@ import (
|
||||
"github.com/Tencent/WeKnora/internal/types/interfaces"
|
||||
)
|
||||
|
||||
// RetrieveEngineRegistry implements the retrieval engine registry
|
||||
// RetrieveEngineRegistry implements the retrieval engine registry.
|
||||
// It maintains two maps:
|
||||
// - byEngineType: env stores registered via RETRIEVE_DRIVER (backward compatible)
|
||||
// - byStoreID: DB stores registered via VectorStore table (instance-based)
|
||||
//
|
||||
// Implements both interfaces.RetrieveEngineRegistry and interfaces.StoreRegistry.
|
||||
type RetrieveEngineRegistry struct {
|
||||
repositories map[types.RetrieverEngineType]interfaces.RetrieveEngineService
|
||||
byEngineType map[types.RetrieverEngineType]interfaces.RetrieveEngineService
|
||||
byStoreID map[string]interfaces.RetrieveEngineService
|
||||
mu sync.RWMutex
|
||||
}
|
||||
|
||||
// NewRetrieveEngineRegistry creates a new retrieval engine registry
|
||||
func NewRetrieveEngineRegistry() interfaces.RetrieveEngineRegistry {
|
||||
return &RetrieveEngineRegistry{
|
||||
repositories: make(map[types.RetrieverEngineType]interfaces.RetrieveEngineService),
|
||||
byEngineType: make(map[types.RetrieverEngineType]interfaces.RetrieveEngineService),
|
||||
byStoreID: make(map[string]interfaces.RetrieveEngineService),
|
||||
}
|
||||
}
|
||||
|
||||
// Register registers a retrieval engine service
|
||||
// --- interfaces.RetrieveEngineRegistry methods (unchanged behavior) ---
|
||||
|
||||
// Register registers a retrieval engine service by engine type.
|
||||
// Returns an error if the engine type is already registered.
|
||||
func (r *RetrieveEngineRegistry) Register(repo interfaces.RetrieveEngineService) error {
|
||||
r.mu.Lock()
|
||||
defer r.mu.Unlock()
|
||||
|
||||
if _, exists := r.repositories[repo.EngineType()]; exists {
|
||||
if _, exists := r.byEngineType[repo.EngineType()]; exists {
|
||||
return fmt.Errorf("repository type %s already registered", repo.EngineType())
|
||||
}
|
||||
|
||||
r.repositories[repo.EngineType()] = repo
|
||||
r.byEngineType[repo.EngineType()] = repo
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetRetrieveEngineService retrieves a retrieval engine service by type
|
||||
// GetRetrieveEngineService retrieves a retrieval engine service by type.
|
||||
// Only searches the byEngineType map (env stores).
|
||||
func (r *RetrieveEngineRegistry) GetRetrieveEngineService(repoType types.RetrieverEngineType) (
|
||||
interfaces.RetrieveEngineService, error,
|
||||
) {
|
||||
r.mu.RLock()
|
||||
defer r.mu.RUnlock()
|
||||
|
||||
repo, exists := r.repositories[repoType]
|
||||
repo, exists := r.byEngineType[repoType]
|
||||
if !exists {
|
||||
return nil, fmt.Errorf("repository of type %s not found", repoType)
|
||||
}
|
||||
@@ -49,16 +60,55 @@ func (r *RetrieveEngineRegistry) GetRetrieveEngineService(repoType types.Retriev
|
||||
return repo, nil
|
||||
}
|
||||
|
||||
// GetAllRetrieveEngineServices retrieves all registered retrieval engine services
|
||||
// GetAllRetrieveEngineServices retrieves all registered retrieval engine services.
|
||||
// Only returns byEngineType entries (env stores) for backward compatibility.
|
||||
func (r *RetrieveEngineRegistry) GetAllRetrieveEngineServices() []interfaces.RetrieveEngineService {
|
||||
r.mu.RLock()
|
||||
defer r.mu.RUnlock()
|
||||
|
||||
// Create a copy to avoid modifying the original map
|
||||
result := make([]interfaces.RetrieveEngineService, 0, len(r.repositories))
|
||||
for _, v := range r.repositories {
|
||||
result := make([]interfaces.RetrieveEngineService, 0, len(r.byEngineType))
|
||||
for _, v := range r.byEngineType {
|
||||
result = append(result, v)
|
||||
}
|
||||
|
||||
return result
|
||||
}
|
||||
|
||||
// --- interfaces.StoreRegistry methods (new, for VectorStore-based engines) ---
|
||||
|
||||
// RegisterWithStoreID registers an engine service by VectorStore ID.
|
||||
// Unlike Register(), the same EngineType can be registered multiple times
|
||||
// with different StoreIDs (e.g., two Elasticsearch clusters).
|
||||
// Upsert semantics: existing entry is overwritten silently.
|
||||
func (r *RetrieveEngineRegistry) RegisterWithStoreID(storeID string, svc interfaces.RetrieveEngineService) {
|
||||
r.mu.Lock()
|
||||
defer r.mu.Unlock()
|
||||
|
||||
r.byStoreID[storeID] = svc
|
||||
}
|
||||
|
||||
// GetByStoreID retrieves an engine service by VectorStore ID.
|
||||
// Callers must verify tenant ownership before using the returned service.
|
||||
func (r *RetrieveEngineRegistry) GetByStoreID(storeID string) (interfaces.RetrieveEngineService, error) {
|
||||
r.mu.RLock()
|
||||
defer r.mu.RUnlock()
|
||||
|
||||
svc, exists := r.byStoreID[storeID]
|
||||
if !exists {
|
||||
return nil, fmt.Errorf("store %s not found in registry", storeID)
|
||||
}
|
||||
return svc, nil
|
||||
}
|
||||
|
||||
// UnregisterByStoreID removes an engine service from the byStoreID map.
|
||||
// Idempotent: returns silently if the storeID is not found.
|
||||
//
|
||||
// NOTE: gRPC-based clients (Qdrant, Milvus) hold connections that are not closed here.
|
||||
// Known Phase 1 limitation — store deletion is rare, connections cleaned up on process exit.
|
||||
// Phase 2 should add Close() to RetrieveEngineService interface and call it here.
|
||||
func (r *RetrieveEngineRegistry) UnregisterByStoreID(storeID string) {
|
||||
r.mu.Lock()
|
||||
defer r.mu.Unlock()
|
||||
|
||||
delete(r.byStoreID, storeID)
|
||||
}
|
||||
|
||||
237
internal/application/service/retriever/registry_test.go
Normal file
237
internal/application/service/retriever/registry_test.go
Normal file
@@ -0,0 +1,237 @@
|
||||
package retriever
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"sync"
|
||||
"testing"
|
||||
|
||||
"github.com/Tencent/WeKnora/internal/models/embedding"
|
||||
"github.com/Tencent/WeKnora/internal/types"
|
||||
"github.com/Tencent/WeKnora/internal/types/interfaces"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
// mockEngineService is a minimal mock for testing registry operations.
|
||||
// Only EngineType() is meaningful; all other methods are no-ops.
|
||||
type mockEngineService struct {
|
||||
engineType types.RetrieverEngineType
|
||||
}
|
||||
|
||||
func (m *mockEngineService) EngineType() types.RetrieverEngineType { return m.engineType }
|
||||
func (m *mockEngineService) Retrieve(_ context.Context, _ types.RetrieveParams) ([]*types.RetrieveResult, error) {
|
||||
return nil, nil
|
||||
}
|
||||
func (m *mockEngineService) Support() []types.RetrieverType { return nil }
|
||||
func (m *mockEngineService) Index(_ context.Context, _ embedding.Embedder, _ *types.IndexInfo, _ []types.RetrieverType) error {
|
||||
return nil
|
||||
}
|
||||
func (m *mockEngineService) BatchIndex(_ context.Context, _ embedding.Embedder, _ []*types.IndexInfo, _ []types.RetrieverType) error {
|
||||
return nil
|
||||
}
|
||||
func (m *mockEngineService) EstimateStorageSize(_ context.Context, _ embedding.Embedder, _ []*types.IndexInfo, _ []types.RetrieverType) int64 {
|
||||
return 0
|
||||
}
|
||||
func (m *mockEngineService) CopyIndices(_ context.Context, _ string, _ map[string]string, _ map[string]string, _ string, _ int, _ string) error {
|
||||
return nil
|
||||
}
|
||||
func (m *mockEngineService) DeleteByChunkIDList(_ context.Context, _ []string, _ int, _ string) error {
|
||||
return nil
|
||||
}
|
||||
func (m *mockEngineService) DeleteBySourceIDList(_ context.Context, _ []string, _ int, _ string) error {
|
||||
return nil
|
||||
}
|
||||
func (m *mockEngineService) DeleteByKnowledgeIDList(_ context.Context, _ []string, _ int, _ string) error {
|
||||
return nil
|
||||
}
|
||||
func (m *mockEngineService) BatchUpdateChunkEnabledStatus(_ context.Context, _ map[string]bool) error {
|
||||
return nil
|
||||
}
|
||||
func (m *mockEngineService) BatchUpdateChunkTagID(_ context.Context, _ map[string]string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func newMock(engineType types.RetrieverEngineType) interfaces.RetrieveEngineService {
|
||||
return &mockEngineService{engineType: engineType}
|
||||
}
|
||||
|
||||
// --- Register (byEngineType) tests ---
|
||||
|
||||
func TestRegistry_Register(t *testing.T) {
|
||||
reg := NewRetrieveEngineRegistry().(*RetrieveEngineRegistry)
|
||||
|
||||
t.Run("success", func(t *testing.T) {
|
||||
err := reg.Register(newMock(types.PostgresRetrieverEngineType))
|
||||
assert.NoError(t, err)
|
||||
})
|
||||
|
||||
t.Run("duplicate engine type returns error", func(t *testing.T) {
|
||||
err := reg.Register(newMock(types.PostgresRetrieverEngineType))
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "already registered")
|
||||
})
|
||||
}
|
||||
|
||||
func TestRegistry_GetRetrieveEngineService(t *testing.T) {
|
||||
reg := NewRetrieveEngineRegistry().(*RetrieveEngineRegistry)
|
||||
_ = reg.Register(newMock(types.PostgresRetrieverEngineType))
|
||||
|
||||
t.Run("found", func(t *testing.T) {
|
||||
svc, err := reg.GetRetrieveEngineService(types.PostgresRetrieverEngineType)
|
||||
assert.NoError(t, err)
|
||||
assert.NotNil(t, svc)
|
||||
})
|
||||
|
||||
t.Run("not found", func(t *testing.T) {
|
||||
_, err := reg.GetRetrieveEngineService(types.QdrantRetrieverEngineType)
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "not found")
|
||||
})
|
||||
}
|
||||
|
||||
func TestRegistry_GetAllRetrieveEngineServices(t *testing.T) {
|
||||
reg := NewRetrieveEngineRegistry().(*RetrieveEngineRegistry)
|
||||
_ = reg.Register(newMock(types.PostgresRetrieverEngineType))
|
||||
_ = reg.Register(newMock(types.ElasticsearchRetrieverEngineType))
|
||||
|
||||
t.Run("returns all byEngineType entries", func(t *testing.T) {
|
||||
all := reg.GetAllRetrieveEngineServices()
|
||||
assert.Len(t, all, 2)
|
||||
})
|
||||
|
||||
t.Run("returns copy - modifying result does not affect registry", func(t *testing.T) {
|
||||
all := reg.GetAllRetrieveEngineServices()
|
||||
all = append(all, newMock(types.QdrantRetrieverEngineType))
|
||||
assert.Len(t, reg.GetAllRetrieveEngineServices(), 2)
|
||||
})
|
||||
}
|
||||
|
||||
// --- RegisterWithStoreID (byStoreID) tests ---
|
||||
|
||||
func TestRegistry_RegisterWithStoreID(t *testing.T) {
|
||||
reg := NewRetrieveEngineRegistry().(*RetrieveEngineRegistry)
|
||||
|
||||
t.Run("success", func(t *testing.T) {
|
||||
reg.RegisterWithStoreID("store-1", newMock(types.PostgresRetrieverEngineType))
|
||||
svc, err := reg.GetByStoreID("store-1")
|
||||
assert.NoError(t, err)
|
||||
assert.NotNil(t, svc)
|
||||
})
|
||||
|
||||
t.Run("upsert overwrites existing", func(t *testing.T) {
|
||||
newSvc := newMock(types.ElasticsearchRetrieverEngineType)
|
||||
reg.RegisterWithStoreID("store-1", newSvc)
|
||||
svc, err := reg.GetByStoreID("store-1")
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, types.ElasticsearchRetrieverEngineType, svc.EngineType())
|
||||
})
|
||||
|
||||
t.Run("same engine type different store IDs", func(t *testing.T) {
|
||||
reg.RegisterWithStoreID("es-hot", newMock(types.ElasticsearchRetrieverEngineType))
|
||||
reg.RegisterWithStoreID("es-warm", newMock(types.ElasticsearchRetrieverEngineType))
|
||||
|
||||
svc1, err1 := reg.GetByStoreID("es-hot")
|
||||
svc2, err2 := reg.GetByStoreID("es-warm")
|
||||
assert.NoError(t, err1)
|
||||
assert.NoError(t, err2)
|
||||
assert.NotSame(t, svc1, svc2)
|
||||
})
|
||||
}
|
||||
|
||||
func TestRegistry_GetByStoreID(t *testing.T) {
|
||||
reg := NewRetrieveEngineRegistry().(*RetrieveEngineRegistry)
|
||||
reg.RegisterWithStoreID("store-1", newMock(types.PostgresRetrieverEngineType))
|
||||
|
||||
t.Run("found", func(t *testing.T) {
|
||||
svc, err := reg.GetByStoreID("store-1")
|
||||
assert.NoError(t, err)
|
||||
assert.NotNil(t, svc)
|
||||
})
|
||||
|
||||
t.Run("not found", func(t *testing.T) {
|
||||
_, err := reg.GetByStoreID("nonexistent")
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "not found")
|
||||
})
|
||||
}
|
||||
|
||||
func TestRegistry_UnregisterByStoreID(t *testing.T) {
|
||||
reg := NewRetrieveEngineRegistry().(*RetrieveEngineRegistry)
|
||||
reg.RegisterWithStoreID("store-1", newMock(types.PostgresRetrieverEngineType))
|
||||
|
||||
t.Run("removes registered store", func(t *testing.T) {
|
||||
reg.UnregisterByStoreID("store-1")
|
||||
_, err := reg.GetByStoreID("store-1")
|
||||
assert.Error(t, err)
|
||||
})
|
||||
|
||||
t.Run("idempotent on nonexistent store", func(t *testing.T) {
|
||||
reg.UnregisterByStoreID("nonexistent") // should not panic
|
||||
})
|
||||
}
|
||||
|
||||
// --- Dual map isolation tests ---
|
||||
|
||||
func TestRegistry_DualMapIsolation(t *testing.T) {
|
||||
reg := NewRetrieveEngineRegistry().(*RetrieveEngineRegistry)
|
||||
|
||||
_ = reg.Register(newMock(types.PostgresRetrieverEngineType))
|
||||
reg.RegisterWithStoreID("store-pg", newMock(types.PostgresRetrieverEngineType))
|
||||
reg.RegisterWithStoreID("store-es", newMock(types.ElasticsearchRetrieverEngineType))
|
||||
|
||||
t.Run("GetAllRetrieveEngineServices returns only byEngineType", func(t *testing.T) {
|
||||
all := reg.GetAllRetrieveEngineServices()
|
||||
assert.Len(t, all, 1)
|
||||
})
|
||||
|
||||
t.Run("byStoreID does not affect byEngineType lookup", func(t *testing.T) {
|
||||
_, err := reg.GetRetrieveEngineService(types.ElasticsearchRetrieverEngineType)
|
||||
assert.Error(t, err) // ES is only in byStoreID, not byEngineType
|
||||
})
|
||||
|
||||
t.Run("unregister byStoreID does not affect byEngineType", func(t *testing.T) {
|
||||
reg.UnregisterByStoreID("store-pg")
|
||||
svc, err := reg.GetRetrieveEngineService(types.PostgresRetrieverEngineType)
|
||||
assert.NoError(t, err)
|
||||
assert.NotNil(t, svc)
|
||||
})
|
||||
}
|
||||
|
||||
// --- Concurrency test ---
|
||||
|
||||
func TestRegistry_ConcurrentAccess(t *testing.T) {
|
||||
reg := NewRetrieveEngineRegistry().(*RetrieveEngineRegistry)
|
||||
const goroutines = 10
|
||||
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(goroutines * 3)
|
||||
|
||||
for i := 0; i < goroutines; i++ {
|
||||
storeID := fmt.Sprintf("store-%d", i)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
reg.RegisterWithStoreID(storeID, newMock(types.PostgresRetrieverEngineType))
|
||||
}()
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
_, _ = reg.GetByStoreID(storeID)
|
||||
}()
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
reg.UnregisterByStoreID(storeID)
|
||||
}()
|
||||
}
|
||||
|
||||
wg.Wait()
|
||||
}
|
||||
|
||||
// --- Interface compliance ---
|
||||
|
||||
func TestRegistry_ImplementsStoreRegistry(t *testing.T) {
|
||||
reg := NewRetrieveEngineRegistry()
|
||||
concreteReg, ok := reg.(*RetrieveEngineRegistry)
|
||||
require.True(t, ok)
|
||||
|
||||
var _ interfaces.StoreRegistry = concreteReg
|
||||
}
|
||||
@@ -3,6 +3,7 @@ package service
|
||||
import (
|
||||
"context"
|
||||
"os"
|
||||
"time"
|
||||
|
||||
"github.com/Tencent/WeKnora/internal/errors"
|
||||
"github.com/Tencent/WeKnora/internal/logger"
|
||||
@@ -13,14 +14,22 @@ import (
|
||||
|
||||
// vectorStoreService implements interfaces.VectorStoreService
|
||||
type vectorStoreService struct {
|
||||
repo interfaces.VectorStoreRepository
|
||||
repo interfaces.VectorStoreRepository
|
||||
storeRegistry interfaces.StoreRegistry // for dynamic registry updates on CRUD
|
||||
factory interfaces.EngineFactory // creates engine services from VectorStore config
|
||||
}
|
||||
|
||||
// NewVectorStoreService creates a new vector store service
|
||||
func NewVectorStoreService(
|
||||
repo interfaces.VectorStoreRepository,
|
||||
storeRegistry interfaces.StoreRegistry,
|
||||
factory interfaces.EngineFactory,
|
||||
) interfaces.VectorStoreService {
|
||||
return &vectorStoreService{repo: repo}
|
||||
return &vectorStoreService{
|
||||
repo: repo,
|
||||
storeRegistry: storeRegistry,
|
||||
factory: factory,
|
||||
}
|
||||
}
|
||||
|
||||
// CreateStore validates and creates a new vector store.
|
||||
@@ -60,10 +69,20 @@ func (s *vectorStoreService) CreateStore(ctx context.Context, store *types.Vecto
|
||||
// 5. Persist
|
||||
logger.Infof(ctx, "Creating vector store: tenant=%d, name=%s, engine=%s",
|
||||
store.TenantID, secutils.SanitizeForLog(store.Name), store.EngineType)
|
||||
return s.repo.Create(ctx, store)
|
||||
if err := s.repo.Create(ctx, store); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// 6. Register in registry (best-effort; failure doesn't roll back DB).
|
||||
// The store is already persisted, and will be loaded on next app restart (self-healing).
|
||||
s.registerInRegistry(ctx, store)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// UpdateStore updates an existing vector store (name only).
|
||||
// NOTE: If connection_config or index_config become mutable in the future,
|
||||
// registry re-registration must be added here (unregister old + register new).
|
||||
func (s *vectorStoreService) UpdateStore(ctx context.Context, store *types.VectorStore) error {
|
||||
if store.TenantID == 0 {
|
||||
return errors.NewValidationError("tenant_id is required")
|
||||
@@ -79,8 +98,17 @@ func (s *vectorStoreService) UpdateStore(ctx context.Context, store *types.Vecto
|
||||
// DeleteStore deletes a vector store by tenant + id.
|
||||
// Phase 2: KB binding check will be added here.
|
||||
func (s *vectorStoreService) DeleteStore(ctx context.Context, tenantID uint64, id string) error {
|
||||
logger.Infof(ctx, "Deleting vector store: tenant=%d, id=%s", tenantID, id)
|
||||
return s.repo.Delete(ctx, tenantID, id)
|
||||
if err := s.repo.Delete(ctx, tenantID, id); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Unregister from registry (idempotent)
|
||||
if s.storeRegistry != nil {
|
||||
s.storeRegistry.UnregisterByStoreID(id)
|
||||
}
|
||||
|
||||
logger.Infof(ctx, "Deleted vector store: tenant=%d, id=%s", tenantID, id)
|
||||
return nil
|
||||
}
|
||||
|
||||
// SaveDetectedVersion updates the connection_config.version for a stored vector store.
|
||||
@@ -91,6 +119,28 @@ func (s *vectorStoreService) SaveDetectedVersion(ctx context.Context, store *typ
|
||||
return s.repo.UpdateConnectionConfig(ctx, &updated)
|
||||
}
|
||||
|
||||
// registerInRegistry creates an engine service and registers it in the registry.
|
||||
// Logs and skips on failure — the store is already persisted in DB,
|
||||
// and will be loaded on next app restart (self-healing).
|
||||
func (s *vectorStoreService) registerInRegistry(ctx context.Context, store *types.VectorStore) {
|
||||
if s.storeRegistry == nil || s.factory == nil {
|
||||
return
|
||||
}
|
||||
|
||||
// Use a short timeout for engine creation to avoid blocking on unreachable hosts
|
||||
// (e.g., gRPC dial to Qdrant/Milvus). The store is already persisted in DB,
|
||||
// so it will be loaded on next app restart if this times out.
|
||||
factoryCtx, cancel := context.WithTimeout(ctx, 10*time.Second)
|
||||
defer cancel()
|
||||
|
||||
svc, err := s.factory(factoryCtx, *store)
|
||||
if err != nil {
|
||||
logger.Warnf(ctx, "Failed to create engine for store %s, will be available after restart: %v", store.ID, err)
|
||||
return
|
||||
}
|
||||
s.storeRegistry.RegisterWithStoreID(store.ID, svc)
|
||||
}
|
||||
|
||||
// validateConnectionConfig validates required fields per engine type.
|
||||
func validateConnectionConfig(engineType types.RetrieverEngineType, config types.ConnectionConfig) error {
|
||||
switch engineType {
|
||||
|
||||
@@ -5,7 +5,9 @@ import (
|
||||
"testing"
|
||||
|
||||
"github.com/Tencent/WeKnora/internal/errors"
|
||||
"github.com/Tencent/WeKnora/internal/models/embedding"
|
||||
"github.com/Tencent/WeKnora/internal/types"
|
||||
"github.com/Tencent/WeKnora/internal/types/interfaces"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
@@ -71,13 +73,88 @@ func (m *mockVectorStoreRepo) ExistsByEndpointAndIndex(
|
||||
return m.existsByEndpoint, nil
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Mock StoreRegistry
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
type mockStoreRegistry struct {
|
||||
registered map[string]bool
|
||||
unregistered []string
|
||||
}
|
||||
|
||||
func newMockStoreRegistry() *mockStoreRegistry {
|
||||
return &mockStoreRegistry{registered: make(map[string]bool)}
|
||||
}
|
||||
|
||||
func (m *mockStoreRegistry) RegisterWithStoreID(storeID string, _ interfaces.RetrieveEngineService) {
|
||||
m.registered[storeID] = true
|
||||
}
|
||||
|
||||
func (m *mockStoreRegistry) GetByStoreID(storeID string) (interfaces.RetrieveEngineService, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (m *mockStoreRegistry) UnregisterByStoreID(storeID string) {
|
||||
m.unregistered = append(m.unregistered, storeID)
|
||||
delete(m.registered, storeID)
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Mock EngineFactory
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
func mockEngineFactory(err error) interfaces.EngineFactory {
|
||||
return func(_ context.Context, _ types.VectorStore) (interfaces.RetrieveEngineService, error) {
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &mockEngineService{}, nil
|
||||
}
|
||||
}
|
||||
|
||||
// mockEngineService satisfies interfaces.RetrieveEngineService minimally.
|
||||
type mockEngineService struct{}
|
||||
|
||||
func (m *mockEngineService) EngineType() types.RetrieverEngineType { return "mock" }
|
||||
func (m *mockEngineService) Retrieve(_ context.Context, _ types.RetrieveParams) ([]*types.RetrieveResult, error) {
|
||||
return nil, nil
|
||||
}
|
||||
func (m *mockEngineService) Support() []types.RetrieverType { return nil }
|
||||
func (m *mockEngineService) Index(_ context.Context, _ embedding.Embedder, _ *types.IndexInfo, _ []types.RetrieverType) error {
|
||||
return nil
|
||||
}
|
||||
func (m *mockEngineService) BatchIndex(_ context.Context, _ embedding.Embedder, _ []*types.IndexInfo, _ []types.RetrieverType) error {
|
||||
return nil
|
||||
}
|
||||
func (m *mockEngineService) EstimateStorageSize(_ context.Context, _ embedding.Embedder, _ []*types.IndexInfo, _ []types.RetrieverType) int64 {
|
||||
return 0
|
||||
}
|
||||
func (m *mockEngineService) CopyIndices(_ context.Context, _ string, _ map[string]string, _ map[string]string, _ string, _ int, _ string) error {
|
||||
return nil
|
||||
}
|
||||
func (m *mockEngineService) DeleteByChunkIDList(_ context.Context, _ []string, _ int, _ string) error {
|
||||
return nil
|
||||
}
|
||||
func (m *mockEngineService) DeleteBySourceIDList(_ context.Context, _ []string, _ int, _ string) error {
|
||||
return nil
|
||||
}
|
||||
func (m *mockEngineService) DeleteByKnowledgeIDList(_ context.Context, _ []string, _ int, _ string) error {
|
||||
return nil
|
||||
}
|
||||
func (m *mockEngineService) BatchUpdateChunkEnabledStatus(_ context.Context, _ map[string]bool) error {
|
||||
return nil
|
||||
}
|
||||
func (m *mockEngineService) BatchUpdateChunkTagID(_ context.Context, _ map[string]string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// CreateStore tests
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
func TestCreateStore_Success(t *testing.T) {
|
||||
repo := &mockVectorStoreRepo{}
|
||||
svc := NewVectorStoreService(repo)
|
||||
svc := NewVectorStoreService(repo, nil, nil)
|
||||
|
||||
store := &types.VectorStore{
|
||||
TenantID: 1,
|
||||
@@ -95,7 +172,7 @@ func TestCreateStore_Success(t *testing.T) {
|
||||
|
||||
func TestCreateStore_ValidationError(t *testing.T) {
|
||||
repo := &mockVectorStoreRepo{}
|
||||
svc := NewVectorStoreService(repo)
|
||||
svc := NewVectorStoreService(repo, nil, nil)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
@@ -127,7 +204,7 @@ func TestCreateStore_ValidationError(t *testing.T) {
|
||||
|
||||
func TestCreateStore_ConnectionConfigValidation(t *testing.T) {
|
||||
repo := &mockVectorStoreRepo{}
|
||||
svc := NewVectorStoreService(repo)
|
||||
svc := NewVectorStoreService(repo, nil, nil)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
@@ -213,7 +290,7 @@ func TestCreateStore_ConnectionConfigValidation(t *testing.T) {
|
||||
|
||||
func TestCreateStore_DuplicateCheck_DBStore(t *testing.T) {
|
||||
repo := &mockVectorStoreRepo{existsByEndpoint: true}
|
||||
svc := NewVectorStoreService(repo)
|
||||
svc := NewVectorStoreService(repo, nil, nil)
|
||||
|
||||
store := &types.VectorStore{
|
||||
TenantID: 1,
|
||||
@@ -236,7 +313,7 @@ func TestCreateStore_DuplicateCheck_DBError(t *testing.T) {
|
||||
repo := &mockVectorStoreRepo{
|
||||
existsByEndpointErr: assert.AnError,
|
||||
}
|
||||
svc := NewVectorStoreService(repo)
|
||||
svc := NewVectorStoreService(repo, nil, nil)
|
||||
|
||||
store := &types.VectorStore{
|
||||
TenantID: 1,
|
||||
@@ -260,7 +337,7 @@ func TestCreateStore_DuplicateCheck_EnvStore(t *testing.T) {
|
||||
t.Setenv("ELASTICSEARCH_INDEX", "xwrag_default")
|
||||
|
||||
repo := &mockVectorStoreRepo{existsByEndpoint: false} // no DB duplicate
|
||||
svc := NewVectorStoreService(repo)
|
||||
svc := NewVectorStoreService(repo, nil, nil)
|
||||
|
||||
store := &types.VectorStore{
|
||||
TenantID: 1,
|
||||
@@ -290,7 +367,7 @@ func TestCreateStore_DuplicateCheck_EnvStore_DifferentIndex_Allowed(t *testing.T
|
||||
t.Setenv("ELASTICSEARCH_INDEX", "xwrag_default")
|
||||
|
||||
repo := &mockVectorStoreRepo{existsByEndpoint: false}
|
||||
svc := NewVectorStoreService(repo)
|
||||
svc := NewVectorStoreService(repo, nil, nil)
|
||||
|
||||
store := &types.VectorStore{
|
||||
TenantID: 1,
|
||||
@@ -310,7 +387,7 @@ func TestCreateStore_DuplicateCheck_EnvStore_DifferentIndex_Allowed(t *testing.T
|
||||
|
||||
func TestCreateStore_DifferentEndpointSameIndex_Allowed(t *testing.T) {
|
||||
repo := &mockVectorStoreRepo{existsByEndpoint: false}
|
||||
svc := NewVectorStoreService(repo)
|
||||
svc := NewVectorStoreService(repo, nil, nil)
|
||||
|
||||
store := &types.VectorStore{
|
||||
TenantID: 1,
|
||||
@@ -328,13 +405,125 @@ func TestCreateStore_DifferentEndpointSameIndex_Allowed(t *testing.T) {
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// CreateStore + Registry integration tests
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
func TestCreateStore_RegistersInRegistry(t *testing.T) {
|
||||
repo := &mockVectorStoreRepo{}
|
||||
registry := newMockStoreRegistry()
|
||||
factory := mockEngineFactory(nil)
|
||||
svc := NewVectorStoreService(repo, registry, factory)
|
||||
|
||||
store := &types.VectorStore{
|
||||
TenantID: 1,
|
||||
Name: "test-es",
|
||||
EngineType: types.ElasticsearchRetrieverEngineType,
|
||||
ConnectionConfig: types.ConnectionConfig{
|
||||
Addr: "http://es:9200",
|
||||
},
|
||||
}
|
||||
|
||||
err := svc.CreateStore(context.Background(), store)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Store should be persisted AND registered in registry
|
||||
assert.Len(t, repo.stores, 1)
|
||||
assert.True(t, registry.registered[store.ID])
|
||||
}
|
||||
|
||||
func TestCreateStore_RegistryFailureDoesNotRollBackDB(t *testing.T) {
|
||||
repo := &mockVectorStoreRepo{}
|
||||
registry := newMockStoreRegistry()
|
||||
factory := mockEngineFactory(assert.AnError) // factory fails
|
||||
svc := NewVectorStoreService(repo, registry, factory)
|
||||
|
||||
store := &types.VectorStore{
|
||||
TenantID: 1,
|
||||
Name: "test-es",
|
||||
EngineType: types.ElasticsearchRetrieverEngineType,
|
||||
ConnectionConfig: types.ConnectionConfig{
|
||||
Addr: "http://es:9200",
|
||||
},
|
||||
}
|
||||
|
||||
// CreateStore should succeed even if registry fails (best-effort + self-healing)
|
||||
err := svc.CreateStore(context.Background(), store)
|
||||
assert.NoError(t, err)
|
||||
|
||||
// DB should have the store
|
||||
assert.Len(t, repo.stores, 1)
|
||||
// Registry should NOT have it (factory failed)
|
||||
assert.False(t, registry.registered[store.ID])
|
||||
}
|
||||
|
||||
func TestCreateStore_NilRegistryAndFactory(t *testing.T) {
|
||||
repo := &mockVectorStoreRepo{}
|
||||
svc := NewVectorStoreService(repo, nil, nil) // no registry
|
||||
|
||||
store := &types.VectorStore{
|
||||
TenantID: 1,
|
||||
Name: "test-es",
|
||||
EngineType: types.ElasticsearchRetrieverEngineType,
|
||||
ConnectionConfig: types.ConnectionConfig{
|
||||
Addr: "http://es:9200",
|
||||
},
|
||||
}
|
||||
|
||||
// Should work fine without registry (degrades gracefully)
|
||||
err := svc.CreateStore(context.Background(), store)
|
||||
assert.NoError(t, err)
|
||||
assert.Len(t, repo.stores, 1)
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// DeleteStore + Registry integration tests
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
func TestDeleteStore_UnregistersFromRegistry(t *testing.T) {
|
||||
repo := &mockVectorStoreRepo{}
|
||||
registry := newMockStoreRegistry()
|
||||
registry.registered["store-1"] = true
|
||||
svc := NewVectorStoreService(repo, registry, nil)
|
||||
|
||||
err := svc.DeleteStore(context.Background(), 1, "store-1")
|
||||
require.NoError(t, err)
|
||||
|
||||
// Should be unregistered
|
||||
assert.Contains(t, registry.unregistered, "store-1")
|
||||
assert.False(t, registry.registered["store-1"])
|
||||
}
|
||||
|
||||
func TestDeleteStore_NilRegistryGraceful(t *testing.T) {
|
||||
repo := &mockVectorStoreRepo{}
|
||||
svc := NewVectorStoreService(repo, nil, nil)
|
||||
|
||||
// Should not panic with nil registry
|
||||
err := svc.DeleteStore(context.Background(), 1, "store-1")
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
|
||||
func TestDeleteStore_RepoErrorSkipsUnregister(t *testing.T) {
|
||||
repo := &mockVectorStoreRepo{deleteErr: assert.AnError}
|
||||
registry := newMockStoreRegistry()
|
||||
registry.registered["store-1"] = true
|
||||
svc := NewVectorStoreService(repo, registry, nil)
|
||||
|
||||
err := svc.DeleteStore(context.Background(), 1, "store-1")
|
||||
assert.Error(t, err)
|
||||
|
||||
// Registry should NOT be touched if DB delete fails
|
||||
assert.True(t, registry.registered["store-1"])
|
||||
assert.Empty(t, registry.unregistered)
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// UpdateStore tests
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
func TestUpdateStore_Success(t *testing.T) {
|
||||
repo := &mockVectorStoreRepo{}
|
||||
svc := NewVectorStoreService(repo)
|
||||
svc := NewVectorStoreService(repo, nil, nil)
|
||||
|
||||
store := &types.VectorStore{
|
||||
ID: "test-id",
|
||||
@@ -348,7 +537,7 @@ func TestUpdateStore_Success(t *testing.T) {
|
||||
|
||||
func TestUpdateStore_ValidationError(t *testing.T) {
|
||||
repo := &mockVectorStoreRepo{}
|
||||
svc := NewVectorStoreService(repo)
|
||||
svc := NewVectorStoreService(repo, nil, nil)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
@@ -378,7 +567,7 @@ func TestUpdateStore_ValidationError(t *testing.T) {
|
||||
|
||||
func TestDeleteStore_Success(t *testing.T) {
|
||||
repo := &mockVectorStoreRepo{}
|
||||
svc := NewVectorStoreService(repo)
|
||||
svc := NewVectorStoreService(repo, nil, nil)
|
||||
|
||||
err := svc.DeleteStore(context.Background(), 1, "test-id")
|
||||
assert.NoError(t, err)
|
||||
@@ -386,19 +575,63 @@ func TestDeleteStore_Success(t *testing.T) {
|
||||
|
||||
func TestDeleteStore_RepoError(t *testing.T) {
|
||||
repo := &mockVectorStoreRepo{deleteErr: assert.AnError}
|
||||
svc := NewVectorStoreService(repo)
|
||||
svc := NewVectorStoreService(repo, nil, nil)
|
||||
|
||||
err := svc.DeleteStore(context.Background(), 1, "test-id")
|
||||
assert.Error(t, err)
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// SaveDetectedVersion tests
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
func TestSaveDetectedVersion_Success(t *testing.T) {
|
||||
repo := &mockVectorStoreRepo{}
|
||||
svc := NewVectorStoreService(repo, nil, nil)
|
||||
|
||||
store := &types.VectorStore{
|
||||
ID: "store-1",
|
||||
TenantID: 1,
|
||||
ConnectionConfig: types.ConnectionConfig{Addr: "http://es:9200"},
|
||||
}
|
||||
|
||||
err := svc.SaveDetectedVersion(context.Background(), store, "7.10.1")
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
|
||||
func TestSaveDetectedVersion_RepoError(t *testing.T) {
|
||||
repo := &mockVectorStoreRepo{updateErr: assert.AnError}
|
||||
svc := NewVectorStoreService(repo, nil, nil)
|
||||
|
||||
store := &types.VectorStore{ID: "store-1", TenantID: 1}
|
||||
err := svc.SaveDetectedVersion(context.Background(), store, "8.11.0")
|
||||
assert.Error(t, err)
|
||||
}
|
||||
|
||||
func TestSaveDetectedVersion_DoesNotMutateOriginal(t *testing.T) {
|
||||
repo := &mockVectorStoreRepo{}
|
||||
svc := NewVectorStoreService(repo, nil, nil)
|
||||
|
||||
store := &types.VectorStore{
|
||||
ID: "store-1",
|
||||
TenantID: 1,
|
||||
ConnectionConfig: types.ConnectionConfig{Version: "old"},
|
||||
}
|
||||
|
||||
err := svc.SaveDetectedVersion(context.Background(), store, "new")
|
||||
require.NoError(t, err)
|
||||
|
||||
// Original store must not be mutated
|
||||
assert.Equal(t, "old", store.ConnectionConfig.Version)
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// TestConnection tests
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
func TestTestConnection_UnsupportedEngineType(t *testing.T) {
|
||||
repo := &mockVectorStoreRepo{}
|
||||
svc := NewVectorStoreService(repo)
|
||||
svc := NewVectorStoreService(repo, nil, nil)
|
||||
|
||||
_, err := svc.TestConnection(context.Background(), "unknown_engine", types.ConnectionConfig{})
|
||||
require.Error(t, err)
|
||||
@@ -410,7 +643,7 @@ func TestTestConnection_UnsupportedEngineType(t *testing.T) {
|
||||
|
||||
func TestTestConnection_SQLiteAlwaysSucceeds(t *testing.T) {
|
||||
repo := &mockVectorStoreRepo{}
|
||||
svc := NewVectorStoreService(repo)
|
||||
svc := NewVectorStoreService(repo, nil, nil)
|
||||
|
||||
version, err := svc.TestConnection(context.Background(), types.SQLiteRetrieverEngineType, types.ConnectionConfig{})
|
||||
assert.NoError(t, err)
|
||||
@@ -419,7 +652,7 @@ func TestTestConnection_SQLiteAlwaysSucceeds(t *testing.T) {
|
||||
|
||||
func TestTestConnection_PostgresDefaultConnection(t *testing.T) {
|
||||
repo := &mockVectorStoreRepo{}
|
||||
svc := NewVectorStoreService(repo)
|
||||
svc := NewVectorStoreService(repo, nil, nil)
|
||||
|
||||
version, err := svc.TestConnection(context.Background(), types.PostgresRetrieverEngineType,
|
||||
types.ConnectionConfig{UseDefaultConnection: true})
|
||||
|
||||
@@ -191,6 +191,16 @@ func BuildContainer(container *dig.Container) *dig.Container {
|
||||
must(container.Provide(repository.NewVectorStoreRepository))
|
||||
must(container.Provide(service.NewWebSearchService))
|
||||
must(container.Provide(service.NewWebSearchProviderService))
|
||||
must(container.Provide(NewEngineFactory))
|
||||
// StoreRegistry: same instance as RetrieveEngineRegistry, exposed as StoreRegistry interface.
|
||||
// NewRetrieveEngineRegistry always returns *retriever.RetrieveEngineRegistry which implements both.
|
||||
must(container.Provide(func(r interfaces.RetrieveEngineRegistry) (interfaces.StoreRegistry, error) {
|
||||
sr, ok := r.(*retriever.RetrieveEngineRegistry)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("registry does not implement StoreRegistry")
|
||||
}
|
||||
return sr, nil
|
||||
}))
|
||||
must(container.Provide(service.NewVectorStoreService))
|
||||
|
||||
// Agent service layer (requires event bus, web search service)
|
||||
@@ -923,9 +933,43 @@ func initRetrieveEngineRegistry(db *gorm.DB, cfg *config.Config) (interfaces.Ret
|
||||
}
|
||||
}
|
||||
}
|
||||
// ─── DB store registration (byStoreID) ───
|
||||
if storeReg, ok := registry.(*retriever.RetrieveEngineRegistry); ok {
|
||||
loadDBStoresIntoRegistry(storeReg, db, cfg)
|
||||
}
|
||||
|
||||
return registry, nil
|
||||
}
|
||||
|
||||
// loadDBStoresIntoRegistry loads VectorStore records from DB and registers them
|
||||
// in the registry's byStoreID map. Failures are logged and skipped (non-fatal).
|
||||
func loadDBStoresIntoRegistry(storeRegistry interfaces.StoreRegistry, db *gorm.DB, cfg *config.Config) {
|
||||
ctx := context.Background()
|
||||
log := logger.GetLogger(ctx)
|
||||
|
||||
var stores []types.VectorStore
|
||||
// GORM soft delete automatically adds "deleted_at IS NULL" condition
|
||||
if err := db.Find(&stores).Error; err != nil {
|
||||
log.Warnf("Failed to load vector stores from DB: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
if len(stores) == 0 {
|
||||
return
|
||||
}
|
||||
|
||||
log.Infof("Loading %d vector store(s) from database", len(stores))
|
||||
for _, store := range stores {
|
||||
svc, err := createEngineServiceFromStore(ctx, store, db, cfg)
|
||||
if err != nil {
|
||||
log.Errorf("Failed to create engine for store %s (%s): %v", store.ID, store.Name, err)
|
||||
continue
|
||||
}
|
||||
storeRegistry.RegisterWithStoreID(store.ID, svc)
|
||||
log.Infof("Registered DB vector store: id=%s, name=%s, engine=%s", store.ID, store.Name, store.EngineType)
|
||||
}
|
||||
}
|
||||
|
||||
// initAntsPool initializes the goroutine pool
|
||||
// Creates a managed goroutine pool for concurrent task execution
|
||||
// Parameters:
|
||||
|
||||
205
internal/container/engine_factory.go
Normal file
205
internal/container/engine_factory.go
Normal file
@@ -0,0 +1,205 @@
|
||||
package container
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
esv7 "github.com/elastic/go-elasticsearch/v7"
|
||||
"github.com/elastic/go-elasticsearch/v8"
|
||||
"github.com/milvus-io/milvus/client/v2/milvusclient"
|
||||
"github.com/qdrant/go-client/qdrant"
|
||||
"github.com/weaviate/weaviate-go-client/v5/weaviate"
|
||||
"github.com/weaviate/weaviate-go-client/v5/weaviate/auth"
|
||||
wgrpc "github.com/weaviate/weaviate-go-client/v5/weaviate/grpc"
|
||||
"google.golang.org/grpc"
|
||||
"gorm.io/gorm"
|
||||
|
||||
elasticsearchRepoV7 "github.com/Tencent/WeKnora/internal/application/repository/retriever/elasticsearch/v7"
|
||||
elasticsearchRepoV8 "github.com/Tencent/WeKnora/internal/application/repository/retriever/elasticsearch/v8"
|
||||
milvusRepo "github.com/Tencent/WeKnora/internal/application/repository/retriever/milvus"
|
||||
postgresRepo "github.com/Tencent/WeKnora/internal/application/repository/retriever/postgres"
|
||||
qdrantRepo "github.com/Tencent/WeKnora/internal/application/repository/retriever/qdrant"
|
||||
sqliteRetrieverRepo "github.com/Tencent/WeKnora/internal/application/repository/retriever/sqlite"
|
||||
weaviateRepo "github.com/Tencent/WeKnora/internal/application/repository/retriever/weaviate"
|
||||
"github.com/Tencent/WeKnora/internal/application/service/retriever"
|
||||
"github.com/Tencent/WeKnora/internal/config"
|
||||
"github.com/Tencent/WeKnora/internal/types"
|
||||
"github.com/Tencent/WeKnora/internal/types/interfaces"
|
||||
)
|
||||
|
||||
// NewEngineFactory returns an EngineFactory function closed over db and cfg.
|
||||
// Registered in dig and injected into VectorStoreService for dynamic registry updates.
|
||||
func NewEngineFactory(db *gorm.DB, cfg *config.Config) interfaces.EngineFactory {
|
||||
return func(ctx context.Context, store types.VectorStore) (interfaces.RetrieveEngineService, error) {
|
||||
return createEngineServiceFromStore(ctx, store, db, cfg)
|
||||
}
|
||||
}
|
||||
|
||||
// createEngineServiceFromStore creates a RetrieveEngineService from a VectorStore's config.
|
||||
// This is the DB store counterpart of the env-based initialization in initRetrieveEngineRegistry.
|
||||
func createEngineServiceFromStore(
|
||||
ctx context.Context,
|
||||
store types.VectorStore,
|
||||
db *gorm.DB,
|
||||
cfg *config.Config,
|
||||
) (interfaces.RetrieveEngineService, error) {
|
||||
switch store.EngineType {
|
||||
case types.PostgresRetrieverEngineType:
|
||||
return createPostgresEngine(store, db)
|
||||
case types.ElasticsearchRetrieverEngineType:
|
||||
return createElasticsearchEngine(store, cfg)
|
||||
case types.QdrantRetrieverEngineType:
|
||||
return createQdrantEngine(store)
|
||||
case types.MilvusRetrieverEngineType:
|
||||
return createMilvusEngine(ctx, store)
|
||||
case types.WeaviateRetrieverEngineType:
|
||||
return createWeaviateEngine(store)
|
||||
case types.SQLiteRetrieverEngineType:
|
||||
return createSQLiteEngine(store, db)
|
||||
default:
|
||||
return nil, fmt.Errorf("unsupported engine type: %s", store.EngineType)
|
||||
}
|
||||
}
|
||||
|
||||
func createPostgresEngine(store types.VectorStore, db *gorm.DB) (interfaces.RetrieveEngineService, error) {
|
||||
if store.ConnectionConfig.UseDefaultConnection {
|
||||
repo := postgresRepo.NewPostgresRetrieveEngineRepository(db)
|
||||
return retriever.NewKVHybridRetrieveEngine(repo, types.PostgresRetrieverEngineType), nil
|
||||
}
|
||||
// Phase 1: only UseDefaultConnection is supported.
|
||||
// Custom connections require connection pool management and migration handling.
|
||||
return nil, fmt.Errorf("custom postgres connections not yet supported; use use_default_connection=true")
|
||||
}
|
||||
|
||||
func createSQLiteEngine(_ types.VectorStore, db *gorm.DB) (interfaces.RetrieveEngineService, error) {
|
||||
repo := sqliteRetrieverRepo.NewSQLiteRetrieveEngineRepository(db)
|
||||
return retriever.NewKVHybridRetrieveEngine(repo, types.SQLiteRetrieverEngineType), nil
|
||||
}
|
||||
|
||||
func createElasticsearchEngine(store types.VectorStore, cfg *config.Config) (interfaces.RetrieveEngineService, error) {
|
||||
cc := store.ConnectionConfig
|
||||
// Version-based v7/v8 SDK selection.
|
||||
// Version is auto-detected by PR2's TestConnection and saved to connection_config.
|
||||
// Empty version defaults to v8 (latest SDK).
|
||||
if isESv7(cc.Version) {
|
||||
return createElasticsearchV7Engine(cc, cfg)
|
||||
}
|
||||
return createElasticsearchV8Engine(cc, cfg)
|
||||
}
|
||||
|
||||
// isESv7 checks if the detected ES version is 7.x.
|
||||
func isESv7(version string) bool {
|
||||
return strings.HasPrefix(version, "7.")
|
||||
}
|
||||
|
||||
func createElasticsearchV8Engine(cc types.ConnectionConfig, cfg *config.Config) (interfaces.RetrieveEngineService, error) {
|
||||
client, err := elasticsearch.NewTypedClient(elasticsearch.Config{
|
||||
Addresses: []string{cc.Addr},
|
||||
Username: cc.Username,
|
||||
Password: cc.Password,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("create elasticsearch v8 client: %w", err)
|
||||
}
|
||||
repo := elasticsearchRepoV8.NewElasticsearchEngineRepository(client, cfg)
|
||||
return retriever.NewKVHybridRetrieveEngine(repo, types.ElasticsearchRetrieverEngineType), nil
|
||||
}
|
||||
|
||||
func createElasticsearchV7Engine(cc types.ConnectionConfig, cfg *config.Config) (interfaces.RetrieveEngineService, error) {
|
||||
client, err := esv7.NewClient(esv7.Config{
|
||||
Addresses: []string{cc.Addr},
|
||||
Username: cc.Username,
|
||||
Password: cc.Password,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("create elasticsearch v7 client: %w", err)
|
||||
}
|
||||
repo := elasticsearchRepoV7.NewElasticsearchEngineRepository(client, cfg)
|
||||
return retriever.NewKVHybridRetrieveEngine(repo, types.ElasticsearchRetrieverEngineType), nil
|
||||
}
|
||||
|
||||
func createQdrantEngine(store types.VectorStore) (interfaces.RetrieveEngineService, error) {
|
||||
cc := store.ConnectionConfig
|
||||
port := cc.Port
|
||||
if port == 0 {
|
||||
port = 6334
|
||||
}
|
||||
|
||||
client, err := qdrant.NewClient(&qdrant.Config{
|
||||
Host: cc.Host,
|
||||
Port: port,
|
||||
APIKey: cc.APIKey,
|
||||
UseTLS: cc.UseTLS,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("create qdrant client: %w", err)
|
||||
}
|
||||
repo := qdrantRepo.NewQdrantRetrieveEngineRepository(client)
|
||||
return retriever.NewKVHybridRetrieveEngine(repo, types.QdrantRetrieverEngineType), nil
|
||||
}
|
||||
|
||||
func createMilvusEngine(ctx context.Context, store types.VectorStore) (interfaces.RetrieveEngineService, error) {
|
||||
cc := store.ConnectionConfig
|
||||
addr := cc.Addr
|
||||
if addr == "" {
|
||||
addr = "localhost:19530"
|
||||
}
|
||||
|
||||
milvusCfg := milvusclient.ClientConfig{
|
||||
Address: addr,
|
||||
DialOptions: []grpc.DialOption{grpc.WithTimeout(5 * time.Second)},
|
||||
}
|
||||
if cc.Username != "" {
|
||||
milvusCfg.Username = cc.Username
|
||||
}
|
||||
if cc.Password != "" {
|
||||
milvusCfg.Password = cc.Password
|
||||
}
|
||||
// NOTE: Milvus DBName is not yet in ConnectionConfig.
|
||||
// Phase 1 limitation — only the default database is used.
|
||||
|
||||
client, err := milvusclient.New(ctx, &milvusCfg)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("create milvus client: %w", err)
|
||||
}
|
||||
repo := milvusRepo.NewMilvusRetrieveEngineRepository(client)
|
||||
return retriever.NewKVHybridRetrieveEngine(repo, types.MilvusRetrieverEngineType), nil
|
||||
}
|
||||
|
||||
func createWeaviateEngine(store types.VectorStore) (interfaces.RetrieveEngineService, error) {
|
||||
cc := store.ConnectionConfig
|
||||
host := cc.Host
|
||||
if host == "" {
|
||||
host = "weaviate:8080"
|
||||
}
|
||||
grpcAddress := cc.GrpcAddress
|
||||
if grpcAddress == "" {
|
||||
grpcAddress = "weaviate:50051"
|
||||
}
|
||||
scheme := cc.Scheme
|
||||
if scheme == "" {
|
||||
scheme = "http"
|
||||
}
|
||||
|
||||
weaviateCfg := weaviate.Config{
|
||||
Host: host,
|
||||
GrpcConfig: &wgrpc.Config{
|
||||
Host: grpcAddress,
|
||||
},
|
||||
Scheme: scheme,
|
||||
}
|
||||
// Unlike the env path (which checks WEAVIATE_AUTH_ENABLED), the factory uses
|
||||
// APIKey directly — if a user provides it, they intend to use it.
|
||||
if cc.APIKey != "" {
|
||||
weaviateCfg.AuthConfig = auth.ApiKey{Value: cc.APIKey}
|
||||
}
|
||||
|
||||
client, err := weaviate.NewClient(weaviateCfg)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("create weaviate client: %w", err)
|
||||
}
|
||||
repo := weaviateRepo.NewWeaviateRetrieveEngineRepository(client)
|
||||
return retriever.NewKVHybridRetrieveEngine(repo, types.WeaviateRetrieverEngineType), nil
|
||||
}
|
||||
@@ -6,6 +6,23 @@ import (
|
||||
"github.com/Tencent/WeKnora/internal/types"
|
||||
)
|
||||
|
||||
// StoreRegistry provides VectorStore-based engine registration/lookup.
|
||||
// Separated from RetrieveEngineRegistry to avoid changing the existing interface
|
||||
// used by 6 services (17 call sites). Phase 2 may merge into RetrieveEngineRegistry.
|
||||
type StoreRegistry interface {
|
||||
// RegisterWithStoreID registers an engine service by VectorStore ID.
|
||||
// Upsert semantics: existing entry is overwritten silently.
|
||||
RegisterWithStoreID(storeID string, svc RetrieveEngineService)
|
||||
// GetByStoreID retrieves an engine service by VectorStore ID.
|
||||
GetByStoreID(storeID string) (RetrieveEngineService, error)
|
||||
// UnregisterByStoreID removes an engine service by VectorStore ID (idempotent).
|
||||
UnregisterByStoreID(storeID string)
|
||||
}
|
||||
|
||||
// EngineFactory creates a RetrieveEngineService from a VectorStore's config.
|
||||
// Defined as a function type to avoid circular imports between container and service packages.
|
||||
type EngineFactory func(ctx context.Context, store types.VectorStore) (RetrieveEngineService, error)
|
||||
|
||||
// VectorStoreService defines the service interface for vector store management.
|
||||
// Tenant isolation is enforced by the handler layer (getOwnedStore pattern).
|
||||
type VectorStoreService interface {
|
||||
|
||||
Reference in New Issue
Block a user