Files
WeKnora/internal/container/container.go
ochan.kwon 40b74e2efa feat(retriever): activate OpenSearch k-NN driver (PR 3 of 3)
Phase 3 (#1440) gate flip. PR 1 (#1445) + PR 2a (#1481) + PR 2b (#1482)
laid the type prep + driver skeleton + read/write paths as gated dead
code; this PR wires every activation surface so opensearch becomes a
registerable VectorStore engine.

Activation wiring
- internal/types: validEngineTypes / GetVectorStoreTypes (with HNSW
  bounds + knn_engine enum + Immutable hints) / retrieverEngineMapping /
  buildEnvStoreForDriver — every gated surface now recognises
  "opensearch". IndexConfig grows four omitempty HNSW fields (HNSWM /
  HNSWEFConstruction / HNSWEFSearch / KNNEngine), keeping other engines'
  serialised config byte-identical.
- internal/container: createOpenSearchEngine + the switch case in
  createEngineServiceFromStore; the RETRIEVE_DRIVER=opensearch env path
  in initRetrieveEngineRegistry; NewEngineFactory now closes over the
  AuditLogService (the EngineFactory type itself is unchanged).
- internal/application/service/vectorstore_healthcheck.go: a
  testOpenSearchConnection case so CreateStore's connectivity probe
  accepts opensearch instead of returning 400.
- internal/application/repository/retriever/opensearch/transport.go:
  NewOpenSearchClient is exported so the factory and env path can build
  the TLS-hardened client; healthcheck.go reuses the unexported
  probeVersion / probeKNNPlugin for the service-layer probe.

Service-layer validation
- validateOpenSearchIndexConfig validates the HNSW caps (m 2-100,
  ef_construction 2-4096, ef_search 1-10000, knn_engine ∈ lucene|faiss).
  Shards/replicas continue to be enforced by the flat ValidateIndexConfig.
  Create-only: UpdateStore mutates the name only.
- validateConnectionConfig requires addr for opensearch.

Sync implementations (stubs.go shrinks)
- CopyIndices (copy.go) mirrors the Elasticsearch / Qdrant pattern —
  search → BatchSave with the source_id remap for generated questions —
  so dim/keyword routing and the source_id contract come from BatchSave
  for free. embeddingMap is keyed by the *target* SourceID because
  OpenSearch's BatchSave looks up embeddings by SourceID
  (lookupEmbedding), not by chunk_id (the ES driver's convention).
  Pagination is from/size; copies larger than max_result_window
  (default 10000) need the scroll-based async path that lands later.
- BatchUpdateChunkEnabledStatus / BatchUpdateChunkTagID (bulk_update.go)
  group the input by target value and issue one _update_by_query per
  group over the cross-dim <base>_* pattern. Caller values flow through
  bound script params only — never string-interpolated into the Painless
  source — closing the script-injection surface.
- inspectByQueryResponse (byquery.go) mirrors inspectBulkResponse: the
  full failure reason goes to the debug log only; the returned error
  carries the bounded id + type.
- UpdateByQueryParams.Refresh is *bool in opensearch-go v4.6.0 (the same
  shape as DeleteByQuery's quirk), so refresh=wait_for is not
  expressible; we use refresh=true.

Driver-owned audit (DIP)
- A new opensearch.AuditSink interface (with nopSink + WithAuditSink
  functional option) lets the driver emit opensearch.index_created and
  opensearch.reindex_executed events without importing any service
  package — the service layer implements the interface. NewRepository
  takes opts, so existing 4-arg test call sites keep compiling unchanged.
- internal/container/audit_sink.go bridges AuditSink to AuditLogService.
  When the context carries no tenant (the env-path registration ctx
  during boot, for example) the adapter skips the emit with a warning
  rather than silently writing tenant_id=0, which would collide with the
  system-scope sentinel.

Frontend + polish
- FieldSchema (frontend/src/api/vector-store.ts) gains min/max/enum/
  immutable. VectorStoreSettings.vue is now schema-driven: a closed
  `enum` renders a t-select; number inputs use the schema's `:min`/`:max`
  and fall back to the legacy replica-vs-shard heuristic only when the
  schema does not pin them; a danger-coloured warning fires when
  insecure_skip_verify is toggled on (the switch and warning are wrapped
  in a vertical stack so the warning sits on its own row below the switch).
- i18n: labels for hnsw_m / hnsw_ef_construction / hnsw_ef_search /
  knn_engine / insecure_skip_verify plus the warning copy in en-US,
  ko-KR, zh-CN, ru-RU.
- docker-compose.dev.yml: an opensearch profile (single-node 3.3.2 with
  security plugin disabled for dev only). OpenSearch Dashboards lives in a
  separate, opt-in opensearch-ui profile so the heavy UI container is not
  forced up alongside the cluster (the driver e2e is fully curl-verifiable
  against :9200). The new docs/dev/opensearch-integration-test.md covers the
  end-to-end exercise and the single-node guidance (set replicas=0 to keep
  the cluster Green).

Gating-guard tests flipped
- The "OpenSearch is NOT in validEngineTypes / mapping / types list /
  env builder / stubs" guard tests from PR 1 / PR 2 are replaced by
  their positive counterparts in this PR. The test suite was the
  activation checklist; the activation flip is its diff.

Backward compatibility
- Additive everywhere. IndexConfig's new HNSW fields are omitempty so
  other engines' serialised config is byte-identical. Existing
  Elasticsearch / Qdrant / Milvus / Weaviate / Doris / TencentVectorDB
  stores are untouched. No migrations.

Test plan
- go build ./... clean
- go vet ./... clean
- gofmt -l clean on touched files
- go test ./... — only TestOssEnsureBucket_CreateFails (Aliyun OSS
  endpoint), the docreader gRPC tests, and the doris SQL-shape tests
  fail; all three are pre-existing on upstream/main and untouched by
  this PR.
- New tests across internal/types, opensearch, service and container —
  including a full end-to-end env-path test that exercises
  initRetrieveEngineRegistry with RETRIEVE_DRIVER=opensearch against an
  httptest cluster.
2026-05-29 16:32:27 +08:00

1396 lines
54 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
// Package container implements dependency injection container setup
// Provides centralized configuration for services, repositories, and handlers
// This package is responsible for wiring up all dependencies and ensuring proper lifecycle management
package container
import (
"context"
"database/sql"
"errors"
"fmt"
"net/url"
"os"
"path/filepath"
"slices"
"strconv"
"strings"
"time"
sqlite_vec "github.com/asg017/sqlite-vec-go-bindings/cgo"
_ "github.com/duckdb/duckdb-go/v2"
esv7 "github.com/elastic/go-elasticsearch/v7"
"github.com/elastic/go-elasticsearch/v8"
_ "github.com/go-sql-driver/mysql" // 给 Doris (database/sql) 注册 MySQL 协议驱动
"github.com/milvus-io/milvus/client/v2/milvusclient"
"github.com/neo4j/neo4j-go-driver/v6/neo4j"
"github.com/panjf2000/ants/v2"
"github.com/qdrant/go-client/qdrant"
"github.com/redis/go-redis/v9"
"go.uber.org/dig"
"google.golang.org/grpc"
"gorm.io/driver/postgres"
"gorm.io/driver/sqlite"
"gorm.io/gorm"
"github.com/Tencent/WeKnora/internal/agent/approval"
"github.com/Tencent/WeKnora/internal/application/repository"
memoryRepo "github.com/Tencent/WeKnora/internal/application/repository/memory/neo4j"
dorisRepo "github.com/Tencent/WeKnora/internal/application/repository/retriever/doris"
elasticsearchRepoV7 "github.com/Tencent/WeKnora/internal/application/repository/retriever/elasticsearch/v7"
elasticsearchRepoV8 "github.com/Tencent/WeKnora/internal/application/repository/retriever/elasticsearch/v8"
milvusRepo "github.com/Tencent/WeKnora/internal/application/repository/retriever/milvus"
neo4jRepo "github.com/Tencent/WeKnora/internal/application/repository/retriever/neo4j"
openSearchRepo "github.com/Tencent/WeKnora/internal/application/repository/retriever/opensearch"
postgresRepo "github.com/Tencent/WeKnora/internal/application/repository/retriever/postgres"
qdrantRepo "github.com/Tencent/WeKnora/internal/application/repository/retriever/qdrant"
sqliteRetrieverRepo "github.com/Tencent/WeKnora/internal/application/repository/retriever/sqlite"
tencentVectorDBRepo "github.com/Tencent/WeKnora/internal/application/repository/retriever/tencentvectordb"
weaviateRepo "github.com/Tencent/WeKnora/internal/application/repository/retriever/weaviate"
"github.com/Tencent/WeKnora/internal/application/service"
chatpipeline "github.com/Tencent/WeKnora/internal/application/service/chat_pipeline"
"github.com/Tencent/WeKnora/internal/application/service/file"
memoryService "github.com/Tencent/WeKnora/internal/application/service/memory"
"github.com/Tencent/WeKnora/internal/application/service/retriever"
"github.com/Tencent/WeKnora/internal/config"
"github.com/Tencent/WeKnora/internal/database"
"github.com/Tencent/WeKnora/internal/datasource"
feishuConnector "github.com/Tencent/WeKnora/internal/datasource/connector/feishu"
notionConnector "github.com/Tencent/WeKnora/internal/datasource/connector/notion"
yuqueConnector "github.com/Tencent/WeKnora/internal/datasource/connector/yuque"
"github.com/Tencent/WeKnora/internal/event"
"github.com/Tencent/WeKnora/internal/handler"
"github.com/Tencent/WeKnora/internal/handler/session"
imPkg "github.com/Tencent/WeKnora/internal/im"
"github.com/Tencent/WeKnora/internal/im/dingtalk"
"github.com/Tencent/WeKnora/internal/im/feishu"
"github.com/Tencent/WeKnora/internal/im/mattermost"
"github.com/Tencent/WeKnora/internal/im/slack"
"github.com/Tencent/WeKnora/internal/im/telegram"
"github.com/Tencent/WeKnora/internal/im/wechat"
"github.com/Tencent/WeKnora/internal/im/wecom"
"github.com/Tencent/WeKnora/internal/infrastructure/docparser"
infra_web_search "github.com/Tencent/WeKnora/internal/infrastructure/web_search"
"github.com/Tencent/WeKnora/internal/logger"
"github.com/Tencent/WeKnora/internal/mcp"
"github.com/Tencent/WeKnora/internal/models/embedding"
"github.com/Tencent/WeKnora/internal/models/utils/ollama"
"github.com/Tencent/WeKnora/internal/router"
"github.com/Tencent/WeKnora/internal/stream"
"github.com/Tencent/WeKnora/internal/tracing"
"github.com/Tencent/WeKnora/internal/tracing/langfuse"
"github.com/Tencent/WeKnora/internal/types"
"github.com/Tencent/WeKnora/internal/types/interfaces"
"github.com/tencent/vectordatabase-sdk-go/tcvectordb"
"github.com/weaviate/weaviate-go-client/v5/weaviate"
"github.com/weaviate/weaviate-go-client/v5/weaviate/auth"
wgrpc "github.com/weaviate/weaviate-go-client/v5/weaviate/grpc"
)
// BuildContainer constructs the dependency injection container
// Registers all components, services, repositories and handlers needed by the application
// Creates a fully configured application container with proper dependency resolution
// Parameters:
// - container: Base dig container to add dependencies to
//
// Returns:
// - Configured container with all application dependencies registered
func BuildContainer(container *dig.Container) *dig.Container {
ctx := context.Background()
logger.Debugf(ctx, "[Container] Starting container initialization...")
// Register resource cleaner for proper cleanup of resources
must(container.Provide(NewResourceCleaner, dig.As(new(interfaces.ResourceCleaner))))
// Core infrastructure configuration
logger.Debugf(ctx, "[Container] Registering core infrastructure...")
must(container.Provide(config.LoadConfig))
must(container.Provide(initTracer))
must(container.Provide(initLangfuse))
must(container.Provide(initDatabase))
must(container.Provide(initFileService))
must(container.Provide(initRedisClient))
must(container.Provide(initAntsPool))
// Register tracer cleanup handler (tracer needs to be available for cleanup registration)
must(container.Invoke(registerTracerCleanup))
must(container.Invoke(registerLangfuseCleanup))
// Register goroutine pool cleanup handler
must(container.Invoke(registerPoolCleanup))
// Initialize retrieval engine registry for search capabilities
logger.Debugf(ctx, "[Container] Registering retrieval engine registry...")
must(container.Provide(initRetrieveEngineRegistry))
// External service clients
logger.Debugf(ctx, "[Container] Registering external service clients...")
must(container.Provide(initDocReaderClient))
must(container.Provide(docparser.NewImageResolver))
must(container.Provide(initOllamaService))
must(container.Provide(initNeo4jClient))
must(container.Provide(stream.NewStreamManager))
logger.Debugf(ctx, "[Container] Initializing DuckDB...")
must(container.Provide(NewDuckDB))
logger.Debugf(ctx, "[Container] DuckDB registered")
// Data repositories layer
logger.Debugf(ctx, "[Container] Registering repositories...")
must(container.Provide(repository.NewTenantRepository))
must(container.Provide(repository.NewTenantMemberRepository))
must(container.Provide(repository.NewTenantInvitationRepository))
must(container.Provide(repository.NewAuditLogRepository))
must(container.Provide(repository.NewKnowledgeBaseRepository))
must(container.Provide(repository.NewKnowledgeRepository))
must(container.Provide(repository.NewKnowledgeSpanRepository))
must(container.Provide(repository.NewChunkRepository))
must(container.Provide(repository.NewKnowledgeTagRepository))
must(container.Provide(repository.NewSessionRepository))
must(container.Provide(repository.NewMessageRepository))
must(container.Provide(repository.NewModelRepository))
must(container.Provide(repository.NewUserRepository))
must(container.Provide(repository.NewAuthTokenRepository))
must(container.Provide(repository.NewSystemSettingRepository))
must(container.Provide(neo4jRepo.NewNeo4jRepository))
must(container.Provide(memoryRepo.NewMemoryRepository))
must(container.Provide(repository.NewMCPServiceRepository))
must(container.Provide(repository.NewMCPToolApprovalRepository))
must(container.Provide(repository.NewCustomAgentRepository))
must(container.Provide(repository.NewOrganizationRepository))
must(container.Provide(repository.NewKBShareRepository))
must(container.Provide(repository.NewAgentShareRepository))
must(container.Provide(repository.NewTenantDisabledSharedAgentRepository))
must(container.Provide(repository.NewUserResourceFavoriteRepository))
must(container.Provide(service.NewWebSearchStateService))
must(container.Provide(repository.NewDataSourceRepository))
must(container.Provide(repository.NewSyncLogRepository))
must(container.Provide(repository.NewWikiPageRepository))
must(container.Provide(repository.NewWikiLogEntryRepository))
must(container.Provide(repository.NewTaskPendingOpsRepository))
must(container.Provide(repository.NewTaskDeadLetterRepository))
// MCP manager for managing MCP client connections
logger.Debugf(ctx, "[Container] Registering MCP manager...")
must(container.Provide(mcp.NewMCPManager))
// Business service layer
logger.Debugf(ctx, "[Container] Registering business services...")
must(container.Provide(service.NewTenantService))
must(container.Provide(service.NewTenantMemberService))
must(container.Provide(service.NewTenantInvitationService))
must(container.Provide(service.NewAuditLogService))
must(container.Provide(service.NewAuditLogRetentionRunner))
must(container.Provide(service.NewKnowledgeBaseService))
must(container.Provide(service.NewOrganizationService))
must(container.Provide(service.NewKBShareService)) // KBShareService must be registered before KnowledgeService and KnowledgeTagService
must(container.Provide(service.NewAgentShareService))
must(container.Provide(service.NewKnowledgeService))
must(container.Provide(service.NewSpanTracker))
must(container.Provide(service.NewChunkService))
must(container.Provide(service.NewKnowledgeTagService))
must(container.Provide(embedding.NewBatchEmbedder))
must(container.Provide(service.NewModelService))
must(container.Provide(service.NewDatasetService))
must(container.Provide(service.NewEvaluationService))
must(container.Provide(service.NewUserService))
must(container.Provide(service.NewSystemSettingService))
must(container.Provide(service.NewWeKnoraCloudService))
// Extract services - register individual extracters with names
must(container.Provide(service.NewChunkExtractService, dig.Name("chunkExtractor")))
must(container.Provide(service.NewDataTableSummaryService, dig.Name("dataTableSummary")))
must(container.Provide(service.NewImageMultimodalService, dig.Name("imageMultimodal")))
must(container.Provide(service.NewKnowledgePostProcessService, dig.Name("knowledgePostProcess")))
must(container.Provide(service.NewMessageService))
must(container.Provide(service.NewMCPServiceService))
must(container.Provide(service.NewMCPToolApprovalService))
must(container.Provide(service.NewCustomAgentService))
must(container.Provide(service.NewUserResourceFavoriteService))
must(container.Provide(memoryService.NewMemoryService))
must(container.Provide(service.NewWikiPageService))
must(container.Provide(service.NewWikiLogEntryService))
must(container.Provide(service.NewWikiIngestService, dig.Name("wikiIngest")))
must(container.Provide(service.NewWikiLintService))
// Web search service (needed by AgentService)
logger.Debugf(ctx, "[Container] Registering web search registry and providers...")
must(container.Provide(infra_web_search.NewRegistry))
must(container.Invoke(registerWebSearchProviders))
must(container.Provide(repository.NewWebSearchProviderRepository))
must(container.Provide(repository.NewVectorStoreRepository))
// TenantStoreOwnership adapter used by the retriever factory functions
// to verify that a resolved VectorStore belongs to the caller's tenant.
must(container.Provide(retriever.NewVectorStoreRepoOwnership))
must(container.Provide(service.NewWebSearchService))
must(container.Provide(service.NewWebSearchProviderService))
must(container.Provide(NewEngineFactory))
// StoreRegistry: same instance as RetrieveEngineRegistry, exposed as StoreRegistry interface.
// NewRetrieveEngineRegistry always returns *retriever.RetrieveEngineRegistry which implements both.
must(container.Provide(func(r interfaces.RetrieveEngineRegistry) (interfaces.StoreRegistry, error) {
sr, ok := r.(*retriever.RetrieveEngineRegistry)
if !ok {
return nil, fmt.Errorf("registry does not implement StoreRegistry")
}
return sr, nil
}))
must(container.Provide(service.NewVectorStoreService))
// Agent service layer (requires event bus, web search service)
// SessionService is passed as parameter to CreateAgentEngine method when creating AgentService
logger.Debugf(ctx, "[Container] Registering event bus and agent service...")
must(container.Provide(event.NewEventBus))
must(container.Provide(func(cfg *config.Config, s interfaces.MCPToolApprovalService, rdb *redis.Client) *approval.Gate {
return approval.NewGate(cfg, &approval.Adapter{Svc: s}, rdb)
}))
// Expose Gate as MCPApproval interface so AgentService and others can depend on the abstraction.
must(container.Provide(func(g *approval.Gate) approval.MCPApproval { return g }))
must(container.Provide(service.NewAgentService))
// Session service (depends on agent service)
// SessionService is created after AgentService and passes itself to AgentService.CreateAgentEngine when needed
logger.Debugf(ctx, "[Container] Registering session service...")
must(container.Provide(service.NewSessionService))
logger.Debugf(ctx, "[Container] Registering task enqueuer...")
redisAvailable := os.Getenv("REDIS_ADDR") != ""
if redisAvailable {
must(container.Provide(router.NewAsyncqClient, dig.As(new(interfaces.TaskEnqueuer))))
must(container.Provide(router.NewAsynqServer))
// Asynq inspector for cancel-by-knowledge-id (best-effort
// dequeue of pending/scheduled/retry tasks + active-task cancel).
must(container.Provide(router.NewAsynqInspector))
must(container.Provide(router.NewAsynqTaskInspector))
} else {
syncExec := router.NewSyncTaskExecutor()
must(container.Provide(func() interfaces.TaskEnqueuer { return syncExec }))
must(container.Provide(func() *router.SyncTaskExecutor { return syncExec }))
// Lite mode: no Redis means no asynq inspector. SyncTaskExecutor
// dispatches inline goroutines that the checkpoint-based abort
// already handles.
must(container.Provide(router.NewNoopTaskInspector))
}
// Chat pipeline components for processing chat requests
logger.Debugf(ctx, "[Container] Registering chat pipeline plugins...")
// Data source sync framework
logger.Debugf(ctx, "[Container] Registering data source sync framework...")
must(container.Provide(initConnectorRegistry))
must(container.Provide(datasource.NewScheduler))
must(container.Provide(service.NewDataSourceService))
must(container.Invoke(startDataSourceScheduler))
logger.Debugf(ctx, "[Container] Data source sync framework registered")
must(container.Invoke(startAuditLogRetention))
logger.Debugf(ctx, "[Container] Audit log retention runner registered")
must(container.Provide(service.NewHousekeepingService))
must(container.Invoke(startHousekeepingService))
logger.Debugf(ctx, "[Container] Knowledge housekeeping runner registered")
must(container.Provide(chatpipeline.NewEventManager))
must(container.Invoke(chatpipeline.NewPluginSearch))
must(container.Invoke(chatpipeline.NewPluginRerank))
must(container.Invoke(chatpipeline.NewPluginWebFetch))
must(container.Invoke(chatpipeline.NewPluginMerge))
must(container.Invoke(chatpipeline.NewPluginDataAnalysis))
must(container.Invoke(chatpipeline.NewPluginIntoChatMessage))
must(container.Invoke(chatpipeline.NewPluginChatCompletion))
must(container.Invoke(chatpipeline.NewPluginChatCompletionStream))
must(container.Invoke(chatpipeline.NewPluginFilterTopK))
must(container.Invoke(chatpipeline.NewPluginQueryUnderstand))
must(container.Invoke(chatpipeline.NewPluginLoadHistory))
must(container.Invoke(chatpipeline.NewPluginExtractEntity))
must(container.Invoke(chatpipeline.NewPluginSearchEntity))
must(container.Invoke(chatpipeline.NewPluginSearchParallel))
must(container.Invoke(chatpipeline.NewPluginWikiBoost))
must(container.Invoke(chatpipeline.NewMemoryPlugin))
logger.Debugf(ctx, "[Container] Chat pipeline plugins registered")
// HTTP handlers layer
logger.Debugf(ctx, "[Container] Registering HTTP handlers...")
must(container.Provide(handler.NewTenantHandler))
must(container.Provide(handler.NewTenantMemberHandler))
must(container.Provide(handler.NewTenantInvitationHandler))
must(container.Provide(handler.NewAuditLogHandler))
must(container.Provide(handler.NewKnowledgeBaseHandler))
must(container.Provide(handler.NewKnowledgeHandler))
must(container.Provide(handler.NewChunkHandler))
must(container.Provide(handler.NewFAQHandler))
must(container.Provide(handler.NewTagHandler))
must(container.Provide(session.NewHandler))
must(container.Provide(handler.NewMessageHandler))
must(container.Provide(handler.NewModelHandler))
must(container.Provide(handler.NewEvaluationHandler))
must(container.Provide(handler.NewInitializationHandler))
must(container.Provide(handler.NewAuthHandler))
must(container.Provide(handler.NewSystemHandler))
must(container.Provide(handler.NewMCPServiceHandler))
must(container.Provide(handler.NewMCPCredentialsHandler))
must(container.Provide(handler.NewModelCredentialsHandler))
must(container.Provide(handler.NewWebSearchProviderCredentialsHandler))
must(container.Provide(handler.NewDataSourceCredentialsHandler))
must(container.Provide(handler.NewWebSearchHandler))
must(container.Provide(handler.NewWebSearchProviderHandler))
must(container.Provide(handler.NewVectorStoreHandler))
must(container.Provide(handler.NewCustomAgentHandler))
must(container.Provide(handler.NewUserResourceFavoriteHandler))
must(container.Provide(service.NewSkillService))
must(container.Provide(handler.NewSkillHandler))
must(container.Provide(handler.NewOrganizationHandler))
// Data source handler
must(container.Provide(handler.NewDataSourceHandler))
// Wiki page handler
must(container.Provide(handler.NewWikiPageHandler))
// IM integration
logger.Debugf(ctx, "[Container] Registering IM integration...")
must(container.Provide(imPkg.NewService))
must(container.Invoke(registerIMAdapterFactories))
must(container.Provide(handler.NewIMHandler))
must(container.Provide(handler.NewWeKnoraCloudHandler))
logger.Debugf(ctx, "[Container] HTTP handlers registered")
// Router configuration
logger.Debugf(ctx, "[Container] Registering router and starting task server...")
must(container.Provide(router.NewRouter))
if redisAvailable {
must(container.Invoke(router.RunAsynqServer))
} else {
must(container.Invoke(router.RegisterSyncHandlers))
}
logger.Infof(ctx, "[Container] Container initialization completed successfully")
return container
}
// must is a helper function for error handling
// Panics if the error is not nil, useful for configuration steps that must succeed
// Parameters:
// - err: Error to check
func must(err error) {
if err != nil {
panic(err)
}
}
// initTracer initializes OpenTelemetry tracer
// Sets up distributed tracing for observability across the application
// Parameters:
// - None
//
// Returns:
// - Configured tracer instance
// - Error if initialization fails
func initTracer() (*tracing.Tracer, error) {
return tracing.InitTracer()
}
// initLangfuse initializes the Langfuse ingestion client.
// Configuration is read from LANGFUSE_* environment variables (see
// docs/langfuse.md). Returns a disabled manager if credentials are absent —
// never an error — so deployments that don't use Langfuse are unaffected.
func initLangfuse() (*langfuse.Manager, error) {
cfg := langfuse.LoadConfigFromEnv()
return langfuse.Init(cfg)
}
func initRedisClient() (*redis.Client, error) {
redisAddr := os.Getenv("REDIS_ADDR")
if redisAddr == "" {
logger.Infof(context.Background(), "[Redis] No REDIS_ADDR configured, Redis disabled (Lite mode)")
return nil, nil
}
db, err := strconv.Atoi(os.Getenv("REDIS_DB"))
if err != nil {
db = 0
}
client := redis.NewClient(&redis.Options{
Addr: redisAddr,
Username: os.Getenv("REDIS_USERNAME"),
Password: os.Getenv("REDIS_PASSWORD"),
DB: db,
})
_, err = client.Ping(context.Background()).Result()
if err != nil {
return nil, fmt.Errorf("连接Redis失败: %w", err)
}
return client, nil
}
// initDatabase initializes database connection
// Creates and configures database connection based on environment configuration
// Supports multiple database backends (PostgreSQL)
// Parameters:
// - cfg: Application configuration
//
// Returns:
// - Configured database connection
// - Error if connection fails
func initDatabase(cfg *config.Config) (*gorm.DB, error) {
var dialector gorm.Dialector
var migrateDSN string
var sqliteDBPath string
switch os.Getenv("DB_DRIVER") {
case "postgres":
// DSN for GORM (key-value format)
gormDSN := fmt.Sprintf(
"host=%s port=%s user=%s password=%s dbname=%s sslmode=%s TimeZone=UTC",
os.Getenv("DB_HOST"),
os.Getenv("DB_PORT"),
os.Getenv("DB_USER"),
os.Getenv("DB_PASSWORD"),
os.Getenv("DB_NAME"),
"disable",
)
dialector = postgres.Open(gormDSN)
// DSN for golang-migrate (URL format)
// URL-encode password to handle special characters like !@#
dbPassword := os.Getenv("DB_PASSWORD")
encodedPassword := url.QueryEscape(dbPassword)
// Check if postgres is in RETRIEVE_DRIVER to determine skip_embedding
retrieveDriver := strings.Split(os.Getenv("RETRIEVE_DRIVER"), ",")
skipEmbedding := "true"
if slices.Contains(retrieveDriver, "postgres") {
skipEmbedding = "false"
}
logger.Infof(context.Background(), "Skip embedding: %s", skipEmbedding)
migrateDSN = fmt.Sprintf(
"postgres://%s:%s@%s:%s/%s?sslmode=disable&options=-c%%20app.skip_embedding=%s",
os.Getenv("DB_USER"),
encodedPassword, // Use encoded password
os.Getenv("DB_HOST"),
os.Getenv("DB_PORT"),
os.Getenv("DB_NAME"),
skipEmbedding,
)
// Debug log (don't log password)
logger.Infof(context.Background(), "DB Config: user=%s host=%s port=%s dbname=%s",
os.Getenv("DB_USER"),
os.Getenv("DB_HOST"),
os.Getenv("DB_PORT"),
os.Getenv("DB_NAME"),
)
case "sqlite":
dbPath := os.Getenv("DB_PATH")
if dbPath == "" {
dbPath = "./data/weknora.db"
}
if dir := filepath.Dir(dbPath); dir != "." && dir != "" {
if err := os.MkdirAll(dir, 0o755); err != nil {
return nil, fmt.Errorf("failed to create SQLite data directory %s: %w", dir, err)
}
}
sqlite_vec.Auto()
dsn := dbPath + "?_journal_mode=WAL&_busy_timeout=5000&_foreign_keys=on"
dialector = sqlite.Open(dsn)
sqliteDBPath = dbPath
migrateDSN = "sqlite3://" + dbPath
logger.Infof(context.Background(), "DB Config: driver=sqlite path=%s", dbPath)
default:
return nil, fmt.Errorf("unsupported database driver: %s", os.Getenv("DB_DRIVER"))
}
db, err := gorm.Open(dialector, &gorm.Config{
NowFunc: func() time.Time {
return time.Now().UTC()
},
})
if err != nil {
return nil, err
}
// Sanity check: dialect-specific code in services (notably the
// vector_stores delete guard) compares Dialector.Name() to "postgres" /
// "sqlite" string literals. A future driver swap that produces a
// different name (e.g., a wrapper dialect for managed PG) would silently
// fall back to the SQLite path, dropping the row-level X-lock. Catching
// the mismatch at startup is loud and inexpensive.
if name := db.Dialector.Name(); name != "postgres" && name != "sqlite" {
return nil, fmt.Errorf(
"unsupported gorm dialector %q; expected postgres or sqlite "+
"(see vectorStoreService.isPostgres for impact)", name)
}
if os.Getenv("DB_DRIVER") == "sqlite" {
sqlDB, err := db.DB()
if err != nil {
return nil, fmt.Errorf("failed to get underlying sql.DB: %w", err)
}
if err := sqlDB.Ping(); err != nil {
return nil, fmt.Errorf("failed to ping SQLite database: %w", err)
}
}
// Run database migrations automatically (optional, can be disabled via env var)
// To disable auto-migration, set AUTO_MIGRATE=false
// To enable auto-recovery from dirty state, set AUTO_RECOVER_DIRTY=true
if os.Getenv("AUTO_MIGRATE") != "false" {
logger.Infof(context.Background(), "Running database migrations...")
autoRecover := os.Getenv("AUTO_RECOVER_DIRTY") != "false"
migrationOpts := database.MigrationOptions{
AutoRecoverDirty: autoRecover,
SQLiteDBPath: sqliteDBPath,
}
// Run base migrations (all versioned migrations including embeddings)
// The embeddings migration will be conditionally executed based on skip_embedding parameter in DSN
if err := database.RunMigrationsWithOptions(migrateDSN, migrationOpts); err != nil {
// Log warning but don't fail startup - migrations might be handled externally
logger.Warnf(context.Background(), "Database migration failed: %v", err)
logger.Warnf(
context.Background(),
"Continuing with application startup. Please run migrations manually if needed.",
)
}
// Post-migration: resolve __pending_env__ storage provider markers for historical KBs.
// The SQL migration marks KBs that have documents but no provider with "__pending_env__";
// we replace that with the actual STORAGE_TYPE from the environment.
resolveStorageProviderPending(db)
// Post-migration: declarative built-in models from config/builtin_models.yaml (optional).
if err := types.LoadBuiltinModelsConfig(context.Background(), db, config.ConfigDir()); err != nil {
logger.Warnf(context.Background(), "Load builtin models config failed: %v", err)
}
} else {
logger.Infof(context.Background(), "Auto-migration is disabled (AUTO_MIGRATE=false)")
}
// Get underlying SQL DB object
sqlDB, err := db.DB()
if err != nil {
return nil, err
}
// Configure connection pool parameters
if os.Getenv("DB_DRIVER") == "sqlite" {
// SQLite only supports one concurrent writer even in WAL mode.
// Limiting to a single open connection serialises all DB access and
// prevents "database is locked" errors from concurrent goroutines.
sqlDB.SetMaxOpenConns(1)
} else {
sqlDB.SetMaxIdleConns(10)
}
sqlDB.SetConnMaxLifetime(time.Duration(10) * time.Minute)
return db, nil
}
// resolveStorageProviderPending replaces the "__pending_env__" sentinel in
// knowledge_bases.storage_provider_config with the actual STORAGE_TYPE from the environment.
// This runs once after SQL migrations to bind historical KBs to their real storage provider.
func resolveStorageProviderPending(db *gorm.DB) {
storageType := strings.TrimSpace(os.Getenv("STORAGE_TYPE"))
if storageType == "" {
storageType = "local"
}
storageType = strings.ToLower(storageType)
result := db.Exec(
`UPDATE knowledge_bases SET storage_provider_config = ? WHERE storage_provider_config IS NOT NULL AND storage_provider_config->>'provider' = '__pending_env__'`,
fmt.Sprintf(`{"provider":"%s"}`, storageType),
)
if result.Error != nil {
logger.Warnf(context.Background(), "Failed to resolve __pending_env__ storage providers: %v", result.Error)
} else if result.RowsAffected > 0 {
logger.Infof(context.Background(), "Resolved %d knowledge bases with __pending_env__ storage provider → %s", result.RowsAffected, storageType)
}
// Sync PostgreSQL sequences with actual MAX values to prevent duplicate key
// errors. The old code assigned seq_id via SELECT MAX()+1 in application
// code, which could push values past the DB sequence counter.
syncSequences(db)
// Reset any pending tasks left over from previous aborted runs (Lite App mode)
resetPendingTasks(db)
}
// syncSequences ensures PostgreSQL sequences for auto-increment columns (seq_id)
// are at least as high as the current MAX value in each table. This is needed
// because older code assigned seq_id via application-level MAX()+1, which could
// advance values past the DB sequence counter and cause duplicate key errors.
func syncSequences(db *gorm.DB) {
if db.Dialector.Name() != "postgres" {
return
}
pairs := [][2]string{
{"chunks", "chunks_seq_id_seq"},
{"knowledge_tags", "knowledge_tags_seq_id_seq"},
}
for _, p := range pairs {
table, seq := p[0], p[1]
sql := fmt.Sprintf(
`SELECT setval('%s', GREATEST(nextval('%s'), (SELECT COALESCE(MAX(seq_id), 0) FROM %s)))`,
seq, seq, table,
)
if err := db.Exec(sql).Error; err != nil {
logger.Warnf(context.Background(), "Failed to sync sequence %s: %v", seq, err)
} else {
logger.Infof(context.Background(), "Synced sequence %s with table %s", seq, table)
}
}
}
// initFileService initializes file storage service
// Creates the appropriate file storage service based on configuration
// Supports multiple storage backends (MinIO, COS, local filesystem)
// Parameters:
// - cfg: Application configuration
//
// Returns:
// - Configured file service implementation
// - Error if initialization fails
func initFileService(cfg *config.Config) (interfaces.FileService, error) {
storageType := strings.TrimSpace(os.Getenv("STORAGE_TYPE"))
if storageType == "" {
storageType = "local"
}
switch storageType {
case "minio":
if os.Getenv("MINIO_ENDPOINT") == "" ||
os.Getenv("MINIO_ACCESS_KEY_ID") == "" ||
os.Getenv("MINIO_SECRET_ACCESS_KEY") == "" ||
os.Getenv("MINIO_BUCKET_NAME") == "" {
return nil, fmt.Errorf("missing MinIO configuration")
}
return file.NewMinioFileService(
os.Getenv("MINIO_ENDPOINT"),
os.Getenv("MINIO_ACCESS_KEY_ID"),
os.Getenv("MINIO_SECRET_ACCESS_KEY"),
os.Getenv("MINIO_BUCKET_NAME"),
strings.EqualFold(os.Getenv("MINIO_USE_SSL"), "true"),
)
case "cos":
if os.Getenv("COS_BUCKET_NAME") == "" ||
os.Getenv("COS_REGION") == "" ||
os.Getenv("COS_SECRET_ID") == "" ||
os.Getenv("COS_SECRET_KEY") == "" ||
os.Getenv("COS_PATH_PREFIX") == "" {
return nil, fmt.Errorf("missing COS configuration")
}
return file.NewCosFileServiceWithTempBucket(
os.Getenv("COS_BUCKET_NAME"),
os.Getenv("COS_REGION"),
os.Getenv("COS_SECRET_ID"),
os.Getenv("COS_SECRET_KEY"),
os.Getenv("COS_PATH_PREFIX"),
os.Getenv("COS_TEMP_BUCKET_NAME"),
os.Getenv("COS_TEMP_REGION"),
)
case "tos":
if os.Getenv("TOS_ENDPOINT") == "" ||
os.Getenv("TOS_REGION") == "" ||
os.Getenv("TOS_ACCESS_KEY") == "" ||
os.Getenv("TOS_SECRET_KEY") == "" ||
os.Getenv("TOS_BUCKET_NAME") == "" {
return nil, fmt.Errorf("missing TOS configuration")
}
return file.NewTosFileServiceWithTempBucket(
os.Getenv("TOS_ENDPOINT"),
os.Getenv("TOS_REGION"),
os.Getenv("TOS_ACCESS_KEY"),
os.Getenv("TOS_SECRET_KEY"),
os.Getenv("TOS_BUCKET_NAME"),
os.Getenv("TOS_PATH_PREFIX"),
os.Getenv("TOS_TEMP_BUCKET_NAME"), // 可选:临时桶名称(桶需配置生命周期规则自动过期)
os.Getenv("TOS_TEMP_REGION"), // 可选:临时桶 region默认与主桶相同
)
case "s3":
if os.Getenv("S3_ENDPOINT") == "" ||
os.Getenv("S3_REGION") == "" ||
os.Getenv("S3_ACCESS_KEY") == "" ||
os.Getenv("S3_SECRET_KEY") == "" ||
os.Getenv("S3_BUCKET_NAME") == "" {
return nil, fmt.Errorf("missing S3 configuration")
}
pathPrefix := os.Getenv("S3_PATH_PREFIX")
if pathPrefix == "" {
pathPrefix = "weknora/"
}
return file.NewS3FileService(
os.Getenv("S3_ENDPOINT"),
os.Getenv("S3_ACCESS_KEY"),
os.Getenv("S3_SECRET_KEY"),
os.Getenv("S3_BUCKET_NAME"),
os.Getenv("S3_REGION"),
pathPrefix,
)
case "obs":
if os.Getenv("OBS_ENDPOINT") == "" ||
os.Getenv("OBS_ACCESS_KEY") == "" ||
os.Getenv("OBS_SECRET_KEY") == "" ||
os.Getenv("OBS_BUCKET_NAME") == "" {
return nil, fmt.Errorf("missing OBS configuration")
}
obsRegion := os.Getenv("OBS_REGION")
obsPathPrefix := os.Getenv("OBS_PATH_PREFIX")
if obsPathPrefix == "" {
obsPathPrefix = "weknora/"
}
return file.NewObsFileService(
os.Getenv("OBS_ENDPOINT"),
obsRegion,
os.Getenv("OBS_ACCESS_KEY"),
os.Getenv("OBS_SECRET_KEY"),
os.Getenv("OBS_BUCKET_NAME"),
obsPathPrefix,
)
case "oss":
if os.Getenv("OSS_ENDPOINT") == "" ||
os.Getenv("OSS_REGION") == "" ||
os.Getenv("OSS_ACCESS_KEY") == "" ||
os.Getenv("OSS_SECRET_KEY") == "" ||
os.Getenv("OSS_BUCKET_NAME") == "" {
return nil, fmt.Errorf("missing OSS configuration")
}
pathPrefix := os.Getenv("OSS_PATH_PREFIX")
if pathPrefix == "" {
pathPrefix = "weknora/"
}
return file.NewOssFileServiceWithTempBucket(
os.Getenv("OSS_ENDPOINT"),
os.Getenv("OSS_REGION"),
os.Getenv("OSS_ACCESS_KEY"),
os.Getenv("OSS_SECRET_KEY"),
os.Getenv("OSS_BUCKET_NAME"),
pathPrefix,
os.Getenv("OSS_TEMP_BUCKET_NAME"),
os.Getenv("OSS_TEMP_REGION"),
)
case "local":
baseDir := os.Getenv("LOCAL_STORAGE_BASE_DIR")
if baseDir == "" {
baseDir = "/data/files"
}
externalURL := strings.TrimSpace(os.Getenv("APP_EXTERNAL_URL"))
return file.NewLocalFileService(baseDir, externalURL), nil
case "dummy":
return file.NewDummyFileService(), nil
default:
return nil, fmt.Errorf("unsupported storage type: %s", storageType)
}
}
// initRetrieveEngineRegistry initializes the retrieval engine registry
// Sets up and configures various search engine backends based on configuration
// Supports multiple retrieval engines (PostgreSQL, ElasticsearchV7, ElasticsearchV8)
// Parameters:
// - db: Database connection
// - cfg: Application configuration
//
// Returns:
// - Configured retrieval engine registry
// - Error if initialization fails
func initRetrieveEngineRegistry(
db *gorm.DB, cfg *config.Config, auditSvc interfaces.AuditLogService,
) (interfaces.RetrieveEngineRegistry, error) {
registry := retriever.NewRetrieveEngineRegistry()
retrieveDriver := strings.Split(os.Getenv("RETRIEVE_DRIVER"), ",")
log := logger.GetLogger(context.Background())
// Audit sink for OpenSearch driver events (index created / reindex). Driver
// events fire under a tenant-scoped ctx at indexing time; the env-path
// registration ctx below has no tenant, so those emits self-skip.
auditSink := newAuditSinkAdapter(auditSvc)
if slices.Contains(retrieveDriver, "postgres") {
postgresRepo := postgresRepo.NewPostgresRetrieveEngineRepository(db)
if err := registry.Register(
retriever.NewKVHybridRetrieveEngine(postgresRepo, types.PostgresRetrieverEngineType),
); err != nil {
log.Errorf("Register postgres retrieve engine failed: %v", err)
} else {
log.Infof("Register postgres retrieve engine success")
}
}
if slices.Contains(retrieveDriver, "sqlite") {
sqliteRepo := sqliteRetrieverRepo.NewSQLiteRetrieveEngineRepository(db)
if err := registry.Register(
retriever.NewKVHybridRetrieveEngine(sqliteRepo, types.SQLiteRetrieverEngineType),
); err != nil {
log.Errorf("Register sqlite retrieve engine failed: %v", err)
} else {
log.Infof("Register sqlite retrieve engine success")
}
}
if slices.Contains(retrieveDriver, "elasticsearch_v8") {
client, err := elasticsearch.NewTypedClient(elasticsearch.Config{
Addresses: []string{os.Getenv("ELASTICSEARCH_ADDR")},
Username: os.Getenv("ELASTICSEARCH_USERNAME"),
Password: os.Getenv("ELASTICSEARCH_PASSWORD"),
})
if err != nil {
log.Errorf("Create elasticsearch_v8 client failed: %v", err)
} else {
elasticsearchRepo := elasticsearchRepoV8.NewElasticsearchEngineRepository(client, cfg, nil)
if err := registry.Register(
retriever.NewKVHybridRetrieveEngine(
elasticsearchRepo, types.ElasticsearchRetrieverEngineType,
),
); err != nil {
log.Errorf("Register elasticsearch_v8 retrieve engine failed: %v", err)
} else {
log.Infof("Register elasticsearch_v8 retrieve engine success")
}
}
}
if slices.Contains(retrieveDriver, "elasticsearch_v7") {
client, err := esv7.NewClient(esv7.Config{
Addresses: []string{os.Getenv("ELASTICSEARCH_ADDR")},
Username: os.Getenv("ELASTICSEARCH_USERNAME"),
Password: os.Getenv("ELASTICSEARCH_PASSWORD"),
})
if err != nil {
log.Errorf("Create elasticsearch_v7 client failed: %v", err)
} else {
elasticsearchRepo := elasticsearchRepoV7.NewElasticsearchEngineRepository(client, cfg, nil)
if err := registry.Register(
retriever.NewKVHybridRetrieveEngine(
elasticsearchRepo, types.ElasticsearchRetrieverEngineType,
),
); err != nil {
log.Errorf("Register elasticsearch_v7 retrieve engine failed: %v", err)
} else {
log.Infof("Register elasticsearch_v7 retrieve engine success")
}
}
}
if slices.Contains(retrieveDriver, "opensearch") {
cc := &types.ConnectionConfig{
Addr: os.Getenv("OPENSEARCH_ADDR"),
Username: os.Getenv("OPENSEARCH_USERNAME"),
Password: os.Getenv("OPENSEARCH_PASSWORD"),
InsecureSkipVerify: strings.EqualFold(os.Getenv("OPENSEARCH_INSECURE_SKIP_VERIFY"), "true"),
}
client, err := openSearchRepo.NewOpenSearchClient(cc)
if err != nil {
log.Errorf("Create opensearch client failed: %v", err)
} else if repo, err := openSearchRepo.NewRepository(
context.Background(), client, "", nil, openSearchRepo.WithAuditSink(auditSink),
); err != nil {
log.Errorf("Create opensearch repository failed: %v", err)
} else if err := registry.Register(
retriever.NewKVHybridRetrieveEngine(repo, types.OpenSearchRetrieverEngineType),
); err != nil {
log.Errorf("Register opensearch retrieve engine failed: %v", err)
} else {
log.Infof("Register opensearch retrieve engine success")
}
}
if slices.Contains(retrieveDriver, "qdrant") {
qdrantHost := os.Getenv("QDRANT_HOST")
if qdrantHost == "" {
qdrantHost = "localhost"
}
qdrantPort := 6334 // Default port
if portStr := os.Getenv("QDRANT_PORT"); portStr != "" {
if port, err := strconv.Atoi(portStr); err == nil {
qdrantPort = port
}
}
// API key for authentication (optional)
qdrantAPIKey := os.Getenv("QDRANT_API_KEY")
// TLS configuration (optional, defaults to false)
// Enable TLS unless explicitly set to "false" or "0" (case insensitive)
qdrantUseTLS := false
if useTLSStr := os.Getenv("QDRANT_USE_TLS"); useTLSStr != "" {
useTLSLower := strings.ToLower(strings.TrimSpace(useTLSStr))
qdrantUseTLS = useTLSLower != "false" && useTLSLower != "0"
}
log.Infof("Connecting to Qdrant at %s:%d (TLS: %v)", qdrantHost, qdrantPort, qdrantUseTLS)
client, err := qdrant.NewClient(&qdrant.Config{
Host: qdrantHost,
Port: qdrantPort,
APIKey: qdrantAPIKey,
UseTLS: qdrantUseTLS,
})
if err != nil {
log.Errorf("Create qdrant client failed: %v", err)
} else {
qdrantRepository := qdrantRepo.NewQdrantRetrieveEngineRepository(client, nil)
if err := registry.Register(
retriever.NewKVHybridRetrieveEngine(
qdrantRepository, types.QdrantRetrieverEngineType,
),
); err != nil {
log.Errorf("Register qdrant retrieve engine failed: %v", err)
} else {
log.Infof("Register qdrant retrieve engine success")
}
}
}
if slices.Contains(retrieveDriver, "weaviate") {
weaviateHost := os.Getenv("WEAVIATE_HOST")
if weaviateHost == "" {
// Docker compose default (service name inside network)
weaviateHost = "weaviate:8080"
}
weaviateGrpcAddress := os.Getenv("WEAVIATE_GRPC_ADDRESS")
if weaviateGrpcAddress == "" {
weaviateGrpcAddress = "weaviate:50051"
}
weaviateScheme := os.Getenv("WEAVIATE_SCHEME")
if weaviateScheme == "" {
weaviateScheme = "http"
}
var authConfig auth.Config
if strings.EqualFold(strings.TrimSpace(os.Getenv("WEAVIATE_AUTH_ENABLED")), "true") {
if apiKey := strings.TrimSpace(os.Getenv("WEAVIATE_API_KEY")); apiKey != "" {
authConfig = auth.ApiKey{Value: apiKey}
}
}
weaviateClient, err := weaviate.NewClient(weaviate.Config{
Host: weaviateHost,
GrpcConfig: &wgrpc.Config{
Host: weaviateGrpcAddress,
},
Scheme: weaviateScheme,
AuthConfig: authConfig,
})
if err != nil {
log.Errorf("Create weaviate client failed: %v", err)
} else {
weaviateRepository := weaviateRepo.NewWeaviateRetrieveEngineRepository(weaviateClient, nil)
if err := registry.Register(
retriever.NewKVHybridRetrieveEngine(
weaviateRepository, types.WeaviateRetrieverEngineType,
),
); err != nil {
log.Errorf("Register weaviate retrieve engine failed: %v", err)
} else {
log.Infof("Register weaviate retrieve engine success")
}
}
}
if slices.Contains(retrieveDriver, "milvus") {
milvusCfg := milvusclient.ClientConfig{
DialOptions: []grpc.DialOption{grpc.WithTimeout(5 * time.Second)},
}
milvusAddress := os.Getenv("MILVUS_ADDRESS")
if milvusAddress == "" {
milvusAddress = "localhost:19530"
}
milvusCfg.Address = milvusAddress
milvusUsername := os.Getenv("MILVUS_USERNAME")
if milvusUsername != "" {
milvusCfg.Username = milvusUsername
}
milvusPassword := os.Getenv("MILVUS_PASSWORD")
if milvusPassword != "" {
milvusCfg.Password = milvusPassword
}
milvusDBName := os.Getenv("MILVUS_DB_NAME")
if milvusDBName != "" {
milvusCfg.DBName = milvusDBName
}
milvusCli, err := milvusclient.New(context.Background(), &milvusCfg)
if err != nil {
log.Errorf("Create milvus client failed: %v", err)
} else {
milvusRepository := milvusRepo.NewMilvusRetrieveEngineRepository(milvusCli, nil)
if err := registry.Register(
retriever.NewKVHybridRetrieveEngine(
milvusRepository, types.MilvusRetrieverEngineType,
),
); err != nil {
log.Errorf("Register milvus retrieve engine failed: %v", err)
} else {
log.Infof("Register milvus retrieve engine success")
}
}
}
if slices.Contains(retrieveDriver, "doris") {
dorisAddr := os.Getenv("DORIS_ADDR")
if dorisAddr == "" {
// docker-compose 默认服务名 + Doris FE MySQL 端口
dorisAddr = "doris-fe:9030"
}
dorisDatabase := os.Getenv("DORIS_DATABASE")
if dorisDatabase == "" {
dorisDatabase = "weknora"
}
dorisUsername := os.Getenv("DORIS_USERNAME")
if dorisUsername == "" {
dorisUsername = "root"
}
dorisPassword := os.Getenv("DORIS_PASSWORD")
dorisHTTPPort := 8030
if portStr := os.Getenv("DORIS_HTTP_PORT"); portStr != "" {
if port, err := strconv.Atoi(portStr); err == nil {
dorisHTTPPort = port
}
}
dsn := fmt.Sprintf("%s:%s@tcp(%s)/%s?charset=utf8mb4&parseTime=true&loc=Local&interpolateParams=true",
dorisUsername, dorisPassword, dorisAddr, dorisDatabase)
dorisDB, err := sql.Open("mysql", dsn)
if err != nil {
log.Errorf("Create doris client failed: %v", err)
} else {
dorisDB.SetMaxOpenConns(20)
dorisDB.SetMaxIdleConns(5)
dorisDB.SetConnMaxLifetime(time.Hour)
httpBase := "http://" + hostFromAddr(dorisAddr) + ":" + strconv.Itoa(dorisHTTPPort)
dorisRepository := dorisRepo.NewDorisRetrieveEngineRepository(
dorisDB, httpBase, dorisUsername, dorisPassword, dorisDatabase, nil,
)
if err := registry.Register(
retriever.NewKVHybridRetrieveEngine(
dorisRepository, types.DorisRetrieverEngineType,
),
); err != nil {
log.Errorf("Register doris retrieve engine failed: %v", err)
} else {
log.Infof("Register doris retrieve engine success: %s db=%s", dorisAddr, dorisDatabase)
}
}
}
if slices.Contains(retrieveDriver, "tencent_vectordb") {
addr := os.Getenv("TENCENT_VECTORDB_ADDR")
username := os.Getenv("TENCENT_VECTORDB_USERNAME")
apiKey := os.Getenv("TENCENT_VECTORDB_API_KEY")
if addr == "" || username == "" || apiKey == "" {
log.Errorf("Missing Tencent VectorDB configuration")
} else {
client, err := tcvectordb.NewRpcClient(addr, username, apiKey, &tcvectordb.ClientOption{
ReadConsistency: tcvectordb.EventualConsistency,
Timeout: 10 * time.Second,
})
if err != nil {
log.Errorf("Create tencent_vectordb client failed: %v", err)
} else {
tencentRepository := tencentVectorDBRepo.NewTencentVectorDBRetrieveEngineRepository(
client,
os.Getenv("TENCENT_VECTORDB_DATABASE"),
nil,
)
if err := registry.Register(
retriever.NewKVHybridRetrieveEngine(
tencentRepository, types.TencentVectorDBRetrieverEngineType,
),
); err != nil {
log.Errorf("Register tencent_vectordb retrieve engine failed: %v", err)
} else {
log.Infof("Register tencent_vectordb retrieve engine success")
}
}
}
}
// ─── DB store registration (byStoreID) ───
if storeReg, ok := registry.(*retriever.RetrieveEngineRegistry); ok {
loadDBStoresIntoRegistry(storeReg, db, cfg, auditSink)
}
return registry, nil
}
// loadDBStoresIntoRegistry loads VectorStore records from DB and registers them
// in the registry's byStoreID map. Failures are logged and skipped (non-fatal).
func loadDBStoresIntoRegistry(
storeRegistry interfaces.StoreRegistry, db *gorm.DB, cfg *config.Config, auditSink openSearchRepo.AuditSink,
) {
ctx := context.Background()
log := logger.GetLogger(ctx)
var stores []types.VectorStore
// GORM soft delete automatically adds "deleted_at IS NULL" condition
if err := db.Find(&stores).Error; err != nil {
log.Warnf("Failed to load vector stores from DB: %v", err)
return
}
if len(stores) == 0 {
return
}
log.Infof("Loading %d vector store(s) from database", len(stores))
for _, store := range stores {
svc, err := createEngineServiceFromStore(ctx, store, db, cfg, auditSink)
if err != nil {
log.Errorf("Failed to create engine for store %s (%s): %v", store.ID, store.Name, err)
continue
}
storeRegistry.RegisterWithStoreID(store.ID, svc)
log.Infof("Registered DB vector store: id=%s, name=%s, engine=%s", store.ID, store.Name, store.EngineType)
}
}
// initAntsPool initializes the goroutine pool
// Creates a managed goroutine pool for concurrent task execution
// Parameters:
// - cfg: Application configuration
//
// Returns:
// - Configured goroutine pool
// - Error if initialization fails
func initAntsPool(cfg *config.Config) (*ants.Pool, error) {
// Default to 5 if not specified in config
poolSize := os.Getenv("CONCURRENCY_POOL_SIZE")
if poolSize == "" {
poolSize = "5"
}
poolSizeInt, err := strconv.Atoi(poolSize)
if err != nil {
return nil, err
}
// Set up the pool with pre-allocation for better performance
return ants.NewPool(poolSizeInt, ants.WithPreAlloc(true))
}
// registerPoolCleanup registers the goroutine pool for cleanup
// Ensures proper cleanup of the goroutine pool when application shuts down
// Parameters:
// - pool: Goroutine pool
// - cleaner: Resource cleaner
func registerPoolCleanup(pool *ants.Pool, cleaner interfaces.ResourceCleaner) {
cleaner.RegisterWithName("AntsPool", func() error {
pool.Release()
return nil
})
}
// registerTracerCleanup registers the tracer for cleanup
// Ensures proper cleanup of the tracer when application shuts down
// Parameters:
// - tracer: Tracer instance
// - cleaner: Resource cleaner
func registerTracerCleanup(tracer *tracing.Tracer, cleaner interfaces.ResourceCleaner) {
// Register the cleanup function - actual context will be provided during cleanup
cleaner.RegisterWithName("Tracer", func() error {
// Create context for cleanup with longer timeout for tracer shutdown
return tracer.Cleanup(context.Background())
})
}
// registerLangfuseCleanup ensures buffered Langfuse events are flushed on
// shutdown. A 5-second timeout matches other external-service cleanups and
// balances data durability against a slow remote endpoint holding up exit.
func registerLangfuseCleanup(mgr *langfuse.Manager, cleaner interfaces.ResourceCleaner) {
if mgr == nil {
return
}
cleaner.RegisterWithName("Langfuse", func() error {
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
return mgr.Shutdown(ctx)
})
}
// initDocReaderClient initializes the DocumentReader client (lightweight API).
func initDocReaderClient(cfg *config.Config) (interfaces.DocumentReader, error) {
addr := strings.TrimSpace(os.Getenv("DOCREADER_ADDR"))
transport := strings.TrimSpace(os.Getenv("DOCREADER_TRANSPORT"))
if transport == "" {
transport = "grpc"
}
if addr == "" {
logger.Infof(context.Background(), "[DocConverter] No DOCREADER_ADDR configured, starting disconnected")
}
transport = strings.ToLower(transport)
switch transport {
case "http", "https":
if addr != "" && !strings.HasPrefix(addr, "http://") && !strings.HasPrefix(addr, "https://") {
addr = "http://" + addr
}
return docparser.NewHTTPDocumentReader(addr)
default:
return docparser.NewGRPCDocumentReader(addr)
}
}
// initOllamaService initializes the Ollama service client
// Creates a client for interacting with Ollama API for model inference
// Parameters:
// - None
//
// Returns:
// - Configured Ollama service client
// - Error if initialization fails
func initOllamaService() (*ollama.OllamaService, error) {
// Get Ollama service from existing factory function
return ollama.GetOllamaService()
}
func initNeo4jClient() (neo4j.Driver, error) {
ctx := context.Background()
if strings.ToLower(os.Getenv("NEO4J_ENABLE")) != "true" {
logger.Debugf(ctx, "NOT SUPPORT RETRIEVE GRAPH")
return nil, nil
}
uri := os.Getenv("NEO4J_URI")
username := os.Getenv("NEO4J_USERNAME")
password := os.Getenv("NEO4J_PASSWORD")
// Retry configuration
maxRetries := 30 // Max retry attempts
retryInterval := 2 * time.Second // Wait between retries
var driver neo4j.Driver
var err error
for attempt := 1; attempt <= maxRetries; attempt++ {
driver, err = neo4j.NewDriver(uri, neo4j.BasicAuth(username, password, ""))
if err != nil {
logger.Warnf(ctx, "Failed to create Neo4j driver (attempt %d/%d): %v", attempt, maxRetries, err)
time.Sleep(retryInterval)
continue
}
err = driver.VerifyAuthentication(ctx, nil)
if err == nil {
if attempt > 1 {
logger.Infof(ctx, "Successfully connected to Neo4j after %d attempts", attempt)
}
return driver, nil
}
logger.Warnf(ctx, "Failed to verify Neo4j authentication (attempt %d/%d): %v", attempt, maxRetries, err)
driver.Close(ctx)
time.Sleep(retryInterval)
}
return nil, fmt.Errorf("failed to connect to Neo4j after %d attempts: %w", maxRetries, err)
}
func NewDuckDB() (*sql.DB, error) {
sqlDB, err := sql.Open("duckdb", ":memory:")
if err != nil {
return nil, fmt.Errorf("failed to open duckdb: %w", err)
}
// Try to install and load required extensions.
// - spatial: used for st_read_meta() to enumerate layer (sheet) names from .xlsx/.xls
// - excel: used for read_xlsx() which gives proper type inference per sheet
bgCtx := context.Background()
for _, ext := range []string{"spatial", "excel"} {
if _, err := sqlDB.ExecContext(bgCtx, fmt.Sprintf("INSTALL %s;", ext)); err != nil {
logger.Warnf(bgCtx, "[DuckDB] Failed to install %s extension: %v", ext, err)
}
if _, err := sqlDB.ExecContext(bgCtx, fmt.Sprintf("LOAD %s;", ext)); err != nil {
logger.Warnf(bgCtx, "[DuckDB] Failed to load %s extension: %v", ext, err)
}
}
return sqlDB, nil
}
// registerWebSearchProviders registers all web search provider types to the registry.
// Each provider type is registered with its factory function that accepts parameters.
// Provider instances are created on-demand when tenants configure them.
func registerWebSearchProviders(registry *infra_web_search.Registry) {
registry.Register("duckduckgo", infra_web_search.NewDuckDuckGoProvider)
registry.Register("google", infra_web_search.NewGoogleProvider)
registry.Register("bing", infra_web_search.NewBingProvider)
registry.Register("tavily", infra_web_search.NewTavilyProvider)
registry.Register("ollama", infra_web_search.NewOllamaProvider)
registry.Register("baidu", infra_web_search.NewBaiduProvider)
registry.Register("searxng", infra_web_search.NewSearxngProvider)
}
// registerIMAdapterFactories registers adapter factories for each IM platform
// and loads enabled channels from the database. Each platform's factory lives
// in its own subpackage to keep this file focused on wiring.
func registerIMAdapterFactories(imService *imPkg.Service) {
imService.RegisterAdapterFactory("wecom", wecom.NewFactory())
imService.RegisterAdapterFactory("feishu", feishu.NewFactory())
imService.RegisterAdapterFactory("slack", slack.NewFactory())
imService.RegisterAdapterFactory("telegram", telegram.NewFactory())
imService.RegisterAdapterFactory("dingtalk", dingtalk.NewFactory())
imService.RegisterAdapterFactory("mattermost", mattermost.NewFactory())
imService.RegisterAdapterFactory("wechat", wechat.NewFactory())
// Load and start all enabled channels from database
if err := imService.LoadAndStartChannels(); err != nil {
logger.Warnf(context.Background(), "[IM] Failed to load channels from database: %v", err)
}
}
// initConnectorRegistry creates and populates the connector registry with all available connectors.
// Aggregates registration errors via errors.Join so a misconfigured or duplicated connector fails
// container initialization loudly instead of silently disabling the feature at runtime.
func initConnectorRegistry() (*datasource.ConnectorRegistry, error) {
registry := datasource.NewConnectorRegistry()
var errs error
if err := registry.Register(feishuConnector.NewConnector()); err != nil {
errs = errors.Join(errs, fmt.Errorf("register feishu connector: %w", err))
}
if err := registry.Register(notionConnector.NewConnector()); err != nil {
errs = errors.Join(errs, fmt.Errorf("register notion connector: %w", err))
}
if err := registry.Register(yuqueConnector.NewConnector()); err != nil {
errs = errors.Join(errs, fmt.Errorf("register yuque connector: %w", err))
}
// Future connectors will be registered here:
// if err := registry.Register(confluenceConnector.NewConnector()); err != nil { ... }
// if err := registry.Register(githubConnector.NewConnector()); err != nil { ... }
if errs != nil {
return nil, errs
}
return registry, nil
}
// startDataSourceScheduler starts the data source cron scheduler and registers cleanup.
func startDataSourceScheduler(scheduler *datasource.Scheduler, cleaner interfaces.ResourceCleaner) {
if err := scheduler.Start(context.Background()); err != nil {
logger.Warnf(context.Background(), "[Container] data source scheduler start failed: %v", err)
}
cleaner.RegisterWithName("DataSourceScheduler", func() error {
scheduler.Stop()
return nil
})
}
// startHousekeepingService starts the knowledge housekeeping cron and registers
// cleanup. This is the safety net that recovers any knowledge stuck in
// "processing" past a configurable threshold (see HousekeepingService for
// rationale). Best-effort: a startup error is logged but does NOT abort the
// container — the rest of the system stays usable.
func startHousekeepingService(svc *service.HousekeepingService, cleaner interfaces.ResourceCleaner) {
if svc == nil {
return
}
if err := svc.Start(context.Background()); err != nil {
logger.Warnf(context.Background(), "[Container] housekeeping start failed: %v", err)
}
cleaner.RegisterWithName("KnowledgeHousekeeping", func() error {
svc.Stop()
return nil
})
}
// startAuditLogRetention spins up the daily audit_logs purge sweep
// and registers shutdown cleanup. Mirrors the data-source-scheduler
// pattern: container init kicks the goroutine, ResourceCleaner stops
// it during graceful shutdown so a SIGTERM during a sweep doesn't
// orphan the goroutine.
//
// retention_days <= 0 is the configured way to disable retention;
// the runner short-circuits Start() on that path so we don't need
// to gate the wiring here.
func startAuditLogRetention(
runner *service.AuditLogRetentionRunner, cleaner interfaces.ResourceCleaner,
) {
runner.Start(context.Background())
cleaner.RegisterWithName("AuditLogRetentionRunner", func() error {
runner.Stop()
return nil
})
}