diff --git a/frontend/src/components/ModelEditorDialog.vue b/frontend/src/components/ModelEditorDialog.vue index e2a34fc6..6dc3d228 100644 --- a/frontend/src/components/ModelEditorDialog.vue +++ b/frontend/src/components/ModelEditorDialog.vue @@ -435,6 +435,16 @@ const fallbackProviderOptions = computed(() => [ description: t('model.editor.providers.openrouter.description'), modelTypes: ['chat', 'embedding'] }, + { + value: 'gemini', + label: t('model.editor.providers.gemini.label'), + defaultUrls: { + chat: 'https://generativelanguage.googleapis.com/v1beta/openai', + embedding: 'https://generativelanguage.googleapis.com/v1beta' + }, + description: t('model.editor.providers.gemini.description'), + modelTypes: ['chat', 'embedding'] + }, { value: 'siliconflow', label: t('model.editor.providers.siliconflow.label'), diff --git a/internal/models/embedding/embedder.go b/internal/models/embedding/embedder.go index c6d5b3a2..8279c6fe 100644 --- a/internal/models/embedding/embedder.go +++ b/internal/models/embedding/embedder.go @@ -218,6 +218,19 @@ func newEmbedder(config Config, pooler EmbedderPooler, ollamaService *ollama.Oll } embedder, err = nvEmb, nErr return embedder, err + case provider.ProviderGemini: + geminiEmb, gErr := NewGeminiEmbedder(config.APIKey, + config.BaseURL, + config.ModelName, + config.TruncatePromptTokens, + config.Dimensions, + config.ModelID, + pooler) + if geminiEmb != nil { + geminiEmb.SetCustomHeaders(config.CustomHeaders) + } + embedder, err = geminiEmb, gErr + return embedder, err case provider.ProviderZhipu: zhipuEmb, zErr := NewZhipuEmbedder(config.APIKey, config.BaseURL, diff --git a/internal/models/embedding/gemini.go b/internal/models/embedding/gemini.go new file mode 100644 index 00000000..718f6d2b --- /dev/null +++ b/internal/models/embedding/gemini.go @@ -0,0 +1,227 @@ +package embedding + +import ( + "bytes" + "context" + "encoding/json" + "fmt" + "io" + "net/http" + "strings" + "time" + + "github.com/Tencent/WeKnora/internal/logger" + secutils "github.com/Tencent/WeKnora/internal/utils" +) + +const geminiEmbeddingBaseURL = "https://generativelanguage.googleapis.com/v1beta" + +// GeminiEmbedder implements text vectorization using the native Gemini +// embedContent / batchEmbedContents REST API. +type GeminiEmbedder struct { + apiKey string + baseURL string + modelName string + truncatePromptTokens int + dimensions int + modelID string + httpClient *http.Client + timeout time.Duration + maxRetries int + customHeaders map[string]string + EmbedderPooler +} + +type geminiBatchEmbedRequest struct { + Requests []geminiEmbedRequest `json:"requests"` +} + +type geminiEmbedRequest struct { + Model string `json:"model"` + Content geminiContent `json:"content"` + TaskType string `json:"taskType,omitempty"` + OutputDimensionality int `json:"output_dimensionality,omitempty"` +} + +type geminiContent struct { + Parts []geminiPart `json:"parts"` +} + +type geminiPart struct { + Text string `json:"text"` +} + +type geminiBatchEmbedResponse struct { + Embeddings []geminiEmbedding `json:"embeddings"` +} + +type geminiEmbedding struct { + Values []float32 `json:"values"` +} + +func NewGeminiEmbedder(apiKey, baseURL, modelName string, + truncatePromptTokens int, dimensions int, modelID string, pooler EmbedderPooler, +) (*GeminiEmbedder, error) { + if modelName == "" { + return nil, fmt.Errorf("model name is required") + } + if truncatePromptTokens == 0 { + truncatePromptTokens = 511 + } + if baseURL == "" { + baseURL = geminiEmbeddingBaseURL + } + + baseURL = strings.TrimRight(baseURL, "/") + if strings.HasSuffix(baseURL, "/openai") { + baseURL = strings.TrimSuffix(baseURL, "/openai") + } + + timeout := 60 * time.Second + return &GeminiEmbedder{ + apiKey: apiKey, + baseURL: baseURL, + modelName: strings.TrimPrefix(modelName, "models/"), + truncatePromptTokens: truncatePromptTokens, + dimensions: dimensions, + modelID: modelID, + httpClient: &http.Client{Timeout: timeout}, + timeout: timeout, + maxRetries: 3, + EmbedderPooler: pooler, + }, nil +} + +func (e *GeminiEmbedder) SetCustomHeaders(headers map[string]string) { + e.customHeaders = headers +} + +func (e *GeminiEmbedder) Embed(ctx context.Context, text string) ([]float32, error) { + embeddings, err := e.BatchEmbed(ctx, []string{text}) + if err != nil { + return nil, err + } + if len(embeddings) == 0 { + return nil, fmt.Errorf("no embedding returned") + } + return embeddings[0], nil +} + +func (e *GeminiEmbedder) BatchEmbed(ctx context.Context, texts []string) ([][]float32, error) { + if len(texts) == 0 { + return nil, nil + } + + requests := make([]geminiEmbedRequest, 0, len(texts)) + for _, text := range texts { + requests = append(requests, geminiEmbedRequest{ + Model: "models/" + e.modelName, + Content: geminiContent{Parts: []geminiPart{ + {Text: text}, + }}, + OutputDimensionality: e.dimensions, + }) + } + + jsonData, err := json.Marshal(geminiBatchEmbedRequest{Requests: requests}) + if err != nil { + logger.GetLogger(ctx).Errorf("GeminiEmbedder BatchEmbed marshal request error: %v", err) + return nil, fmt.Errorf("marshal request: %w", err) + } + + logger.GetLogger(ctx).Debugf("GeminiEmbedder BatchEmbed: model=%s, input_count=%d", + e.modelName, len(texts)) + + resp, err := e.doRequestWithRetry(ctx, jsonData) + if err != nil { + logger.GetLogger(ctx).Errorf("GeminiEmbedder BatchEmbed send request error: %v", err) + return nil, fmt.Errorf("send request: %w", err) + } + if resp.Body != nil { + defer resp.Body.Close() + } + + body, err := io.ReadAll(resp.Body) + if err != nil { + logger.GetLogger(ctx).Errorf("GeminiEmbedder BatchEmbed read response error: %v", err) + return nil, fmt.Errorf("read response: %w", err) + } + + if resp.StatusCode != http.StatusOK { + bodyStr := string(body) + if len(bodyStr) > 1000 { + bodyStr = bodyStr[:1000] + "... (truncated)" + } + logger.GetLogger(ctx).Errorf("GeminiEmbedder BatchEmbed API error: Http Status %s, Response Body: %s", resp.Status, bodyStr) + return nil, fmt.Errorf("Gemini BatchEmbed API error: Http Status %s, Response: %s", resp.Status, bodyStr) + } + + var response geminiBatchEmbedResponse + if err := json.Unmarshal(body, &response); err != nil { + logger.GetLogger(ctx).Errorf("GeminiEmbedder BatchEmbed unmarshal response error: %v", err) + return nil, fmt.Errorf("unmarshal response: %w", err) + } + if len(response.Embeddings) != len(texts) { + return nil, fmt.Errorf("Gemini BatchEmbed returned %d embeddings for %d inputs", len(response.Embeddings), len(texts)) + } + + embeddings := make([][]float32, 0, len(response.Embeddings)) + for _, embedding := range response.Embeddings { + embeddings = append(embeddings, embedding.Values) + } + return embeddings, nil +} + +func (e *GeminiEmbedder) doRequestWithRetry(ctx context.Context, jsonData []byte) (*http.Response, error) { + var resp *http.Response + var err error + url := fmt.Sprintf("%s/models/%s:batchEmbedContents", e.baseURL, e.modelName) + + for i := 0; i <= e.maxRetries; i++ { + if i > 0 { + backoffTime := time.Duration(1< 10*time.Second { + backoffTime = 10 * time.Second + } + logger.GetLogger(ctx). + Infof("GeminiEmbedder retrying request (%d/%d), waiting %v", i, e.maxRetries, backoffTime) + + select { + case <-time.After(backoffTime): + case <-ctx.Done(): + return nil, ctx.Err() + } + } + + var req *http.Request + req, err = http.NewRequestWithContext(ctx, "POST", url, bytes.NewReader(jsonData)) + if err != nil { + logger.GetLogger(ctx).Errorf("GeminiEmbedder failed to create request: %v", err) + continue + } + req.Header.Set("Content-Type", "application/json") + req.Header.Set("x-goog-api-key", e.apiKey) + secutils.ApplyCustomHeaders(req, e.customHeaders) + + resp, err = e.httpClient.Do(req) + if err == nil { + return resp, nil + } + + logger.GetLogger(ctx).Errorf("GeminiEmbedder request failed (attempt %d/%d): %v", i+1, e.maxRetries+1, err) + } + + return nil, err +} + +func (e *GeminiEmbedder) GetModelName() string { + return e.modelName +} + +func (e *GeminiEmbedder) GetDimensions() int { + return e.dimensions +} + +func (e *GeminiEmbedder) GetModelID() string { + return e.modelID +} diff --git a/internal/models/embedding/gemini_test.go b/internal/models/embedding/gemini_test.go new file mode 100644 index 00000000..6c96cbcf --- /dev/null +++ b/internal/models/embedding/gemini_test.go @@ -0,0 +1,84 @@ +package embedding + +import ( + "context" + "encoding/json" + "net/http" + "net/http/httptest" + "strings" + "testing" +) + +func TestGeminiEmbedderBatchEmbedUsesNativeAPI(t *testing.T) { + var gotPath string + var gotAPIKey string + var gotReq geminiBatchEmbedRequest + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + gotPath = r.URL.Path + gotAPIKey = r.Header.Get("x-goog-api-key") + if err := json.NewDecoder(r.Body).Decode(&gotReq); err != nil { + t.Fatalf("decode request: %v", err) + } + + w.Header().Set("Content-Type", "application/json") + _, _ = w.Write([]byte(`{ + "embeddings": [ + {"values": [0.1, 0.2]}, + {"values": [0.3, 0.4]} + ] + }`)) + })) + defer server.Close() + + embedder, err := NewGeminiEmbedder("test-key", server.URL+"/openai", "gemini-embedding-2", + 0, 768, "model-id", nil) + if err != nil { + t.Fatalf("NewGeminiEmbedder: %v", err) + } + + embeddings, err := embedder.BatchEmbed(context.Background(), []string{"hello", "world"}) + if err != nil { + t.Fatalf("BatchEmbed: %v", err) + } + + if gotPath != "/models/gemini-embedding-2:batchEmbedContents" { + t.Fatalf("path = %q, want native batchEmbedContents path", gotPath) + } + if gotAPIKey != "test-key" { + t.Fatalf("x-goog-api-key = %q", gotAPIKey) + } + if len(gotReq.Requests) != 2 { + t.Fatalf("requests len = %d", len(gotReq.Requests)) + } + if gotReq.Requests[0].Model != "models/gemini-embedding-2" { + t.Fatalf("request model = %q", gotReq.Requests[0].Model) + } + if gotReq.Requests[0].OutputDimensionality != 768 { + t.Fatalf("output_dimensionality = %d", gotReq.Requests[0].OutputDimensionality) + } + if gotReq.Requests[0].Content.Parts[0].Text != "hello" { + t.Fatalf("first text = %q", gotReq.Requests[0].Content.Parts[0].Text) + } + if len(embeddings) != 2 || len(embeddings[0]) != 2 || embeddings[1][1] != 0.4 { + t.Fatalf("unexpected embeddings: %#v", embeddings) + } +} + +func TestGeminiEmbedderReturnsAPIErrorBody(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + http.Error(w, `{"error":"not found"}`, http.StatusNotFound) + })) + defer server.Close() + + embedder, err := NewGeminiEmbedder("test-key", server.URL, "gemini-embedding-2", + 0, 0, "model-id", nil) + if err != nil { + t.Fatalf("NewGeminiEmbedder: %v", err) + } + + _, err = embedder.BatchEmbed(context.Background(), []string{"hello"}) + if err == nil || !strings.Contains(err.Error(), "404") { + t.Fatalf("expected 404 error, got %v", err) + } +} diff --git a/internal/models/provider/gemini.go b/internal/models/provider/gemini.go index b4d4a936..e1dde9ab 100644 --- a/internal/models/provider/gemini.go +++ b/internal/models/provider/gemini.go @@ -25,12 +25,14 @@ func (p *GeminiProvider) Info() ProviderInfo { return ProviderInfo{ Name: ProviderGemini, DisplayName: "Google Gemini", - Description: "gemini-3-flash-preview, gemini-2.5-pro, etc.", + Description: "gemini-3-flash-preview, gemini-2.5-pro, gemini-embedding-2, etc.", DefaultURLs: map[types.ModelType]string{ types.ModelTypeKnowledgeQA: GeminiOpenAICompatBaseURL, + types.ModelTypeEmbedding: GeminiBaseURL, }, ModelTypes: []types.ModelType{ types.ModelTypeKnowledgeQA, + types.ModelTypeEmbedding, }, RequiresAuth: true, } diff --git a/internal/models/provider/provider_test.go b/internal/models/provider/provider_test.go index 5b35b1cc..918a94b5 100644 --- a/internal/models/provider/provider_test.go +++ b/internal/models/provider/provider_test.go @@ -256,4 +256,20 @@ func TestListByModelType(t *testing.T) { assert.True(t, found, "OpenRouter should support embedding") }) + + t.Run("embedding models include gemini", func(t *testing.T) { + providers := ListByModelType(types.ModelTypeEmbedding) + assert.NotEmpty(t, providers) + + found := false + for _, p := range providers { + if p.Name == ProviderGemini { + found = true + assert.Equal(t, GeminiBaseURL, p.GetDefaultURL(types.ModelTypeEmbedding)) + break + } + } + + assert.True(t, found, "Gemini should support embedding via the native Gemini API") + }) } diff --git a/internal/utils/extraheaders.go b/internal/utils/extraheaders.go index 0a7aab28..fd504466 100644 --- a/internal/utils/extraheaders.go +++ b/internal/utils/extraheaders.go @@ -11,6 +11,7 @@ var reservedHeaderKeys = map[string]struct{}{ "authorization": {}, "api-key": {}, "x-api-key": {}, + "x-goog-api-key": {}, "content-type": {}, "content-length": {}, "accept-encoding": {}, diff --git a/internal/utils/extraheaders_test.go b/internal/utils/extraheaders_test.go index 4b2d6fb7..e24f972a 100644 --- a/internal/utils/extraheaders_test.go +++ b/internal/utils/extraheaders_test.go @@ -12,13 +12,15 @@ func TestApplyCustomHeaders_SkipReserved(t *testing.T) { req, _ := http.NewRequest("GET", "https://example.com", nil) req.Header.Set("Authorization", "Bearer original") req.Header.Set("Content-Type", "application/json") + req.Header.Set("x-goog-api-key", "google-original") ApplyCustomHeaders(req, map[string]string{ - "Authorization": "Bearer injected", - "Content-Type": "text/plain", - "X-Trace-Id": "trace-123", - "X-Route": "edge", - "": "empty-key-should-be-skipped", + "Authorization": "Bearer injected", + "Content-Type": "text/plain", + "x-goog-api-key": "google-injected", + "X-Trace-Id": "trace-123", + "X-Route": "edge", + "": "empty-key-should-be-skipped", }) if got := req.Header.Get("Authorization"); got != "Bearer original" { @@ -27,6 +29,9 @@ func TestApplyCustomHeaders_SkipReserved(t *testing.T) { if got := req.Header.Get("Content-Type"); got != "application/json" { t.Fatalf("content-type overwritten: %q", got) } + if got := req.Header.Get("x-goog-api-key"); got != "google-original" { + t.Fatalf("x-goog-api-key overwritten: %q", got) + } if got := req.Header.Get("X-Trace-Id"); got != "trace-123" { t.Fatalf("X-Trace-Id not injected: %q", got) }