fix(embedding): support native Gemini embeddings

This commit is contained in:
Yeongpil Yoon
2026-06-02 17:32:29 +09:00
committed by lyingbug
parent 482686d17e
commit 8f8a276120
8 changed files with 364 additions and 6 deletions

View File

@@ -435,6 +435,16 @@ const fallbackProviderOptions = computed(() => [
description: t('model.editor.providers.openrouter.description'),
modelTypes: ['chat', 'embedding']
},
{
value: 'gemini',
label: t('model.editor.providers.gemini.label'),
defaultUrls: {
chat: 'https://generativelanguage.googleapis.com/v1beta/openai',
embedding: 'https://generativelanguage.googleapis.com/v1beta'
},
description: t('model.editor.providers.gemini.description'),
modelTypes: ['chat', 'embedding']
},
{
value: 'siliconflow',
label: t('model.editor.providers.siliconflow.label'),

View File

@@ -218,6 +218,19 @@ func newEmbedder(config Config, pooler EmbedderPooler, ollamaService *ollama.Oll
}
embedder, err = nvEmb, nErr
return embedder, err
case provider.ProviderGemini:
geminiEmb, gErr := NewGeminiEmbedder(config.APIKey,
config.BaseURL,
config.ModelName,
config.TruncatePromptTokens,
config.Dimensions,
config.ModelID,
pooler)
if geminiEmb != nil {
geminiEmb.SetCustomHeaders(config.CustomHeaders)
}
embedder, err = geminiEmb, gErr
return embedder, err
case provider.ProviderZhipu:
zhipuEmb, zErr := NewZhipuEmbedder(config.APIKey,
config.BaseURL,

View File

@@ -0,0 +1,227 @@
package embedding
import (
"bytes"
"context"
"encoding/json"
"fmt"
"io"
"net/http"
"strings"
"time"
"github.com/Tencent/WeKnora/internal/logger"
secutils "github.com/Tencent/WeKnora/internal/utils"
)
const geminiEmbeddingBaseURL = "https://generativelanguage.googleapis.com/v1beta"
// GeminiEmbedder implements text vectorization using the native Gemini
// embedContent / batchEmbedContents REST API.
type GeminiEmbedder struct {
apiKey string
baseURL string
modelName string
truncatePromptTokens int
dimensions int
modelID string
httpClient *http.Client
timeout time.Duration
maxRetries int
customHeaders map[string]string
EmbedderPooler
}
type geminiBatchEmbedRequest struct {
Requests []geminiEmbedRequest `json:"requests"`
}
type geminiEmbedRequest struct {
Model string `json:"model"`
Content geminiContent `json:"content"`
TaskType string `json:"taskType,omitempty"`
OutputDimensionality int `json:"output_dimensionality,omitempty"`
}
type geminiContent struct {
Parts []geminiPart `json:"parts"`
}
type geminiPart struct {
Text string `json:"text"`
}
type geminiBatchEmbedResponse struct {
Embeddings []geminiEmbedding `json:"embeddings"`
}
type geminiEmbedding struct {
Values []float32 `json:"values"`
}
func NewGeminiEmbedder(apiKey, baseURL, modelName string,
truncatePromptTokens int, dimensions int, modelID string, pooler EmbedderPooler,
) (*GeminiEmbedder, error) {
if modelName == "" {
return nil, fmt.Errorf("model name is required")
}
if truncatePromptTokens == 0 {
truncatePromptTokens = 511
}
if baseURL == "" {
baseURL = geminiEmbeddingBaseURL
}
baseURL = strings.TrimRight(baseURL, "/")
if strings.HasSuffix(baseURL, "/openai") {
baseURL = strings.TrimSuffix(baseURL, "/openai")
}
timeout := 60 * time.Second
return &GeminiEmbedder{
apiKey: apiKey,
baseURL: baseURL,
modelName: strings.TrimPrefix(modelName, "models/"),
truncatePromptTokens: truncatePromptTokens,
dimensions: dimensions,
modelID: modelID,
httpClient: &http.Client{Timeout: timeout},
timeout: timeout,
maxRetries: 3,
EmbedderPooler: pooler,
}, nil
}
func (e *GeminiEmbedder) SetCustomHeaders(headers map[string]string) {
e.customHeaders = headers
}
func (e *GeminiEmbedder) Embed(ctx context.Context, text string) ([]float32, error) {
embeddings, err := e.BatchEmbed(ctx, []string{text})
if err != nil {
return nil, err
}
if len(embeddings) == 0 {
return nil, fmt.Errorf("no embedding returned")
}
return embeddings[0], nil
}
func (e *GeminiEmbedder) BatchEmbed(ctx context.Context, texts []string) ([][]float32, error) {
if len(texts) == 0 {
return nil, nil
}
requests := make([]geminiEmbedRequest, 0, len(texts))
for _, text := range texts {
requests = append(requests, geminiEmbedRequest{
Model: "models/" + e.modelName,
Content: geminiContent{Parts: []geminiPart{
{Text: text},
}},
OutputDimensionality: e.dimensions,
})
}
jsonData, err := json.Marshal(geminiBatchEmbedRequest{Requests: requests})
if err != nil {
logger.GetLogger(ctx).Errorf("GeminiEmbedder BatchEmbed marshal request error: %v", err)
return nil, fmt.Errorf("marshal request: %w", err)
}
logger.GetLogger(ctx).Debugf("GeminiEmbedder BatchEmbed: model=%s, input_count=%d",
e.modelName, len(texts))
resp, err := e.doRequestWithRetry(ctx, jsonData)
if err != nil {
logger.GetLogger(ctx).Errorf("GeminiEmbedder BatchEmbed send request error: %v", err)
return nil, fmt.Errorf("send request: %w", err)
}
if resp.Body != nil {
defer resp.Body.Close()
}
body, err := io.ReadAll(resp.Body)
if err != nil {
logger.GetLogger(ctx).Errorf("GeminiEmbedder BatchEmbed read response error: %v", err)
return nil, fmt.Errorf("read response: %w", err)
}
if resp.StatusCode != http.StatusOK {
bodyStr := string(body)
if len(bodyStr) > 1000 {
bodyStr = bodyStr[:1000] + "... (truncated)"
}
logger.GetLogger(ctx).Errorf("GeminiEmbedder BatchEmbed API error: Http Status %s, Response Body: %s", resp.Status, bodyStr)
return nil, fmt.Errorf("Gemini BatchEmbed API error: Http Status %s, Response: %s", resp.Status, bodyStr)
}
var response geminiBatchEmbedResponse
if err := json.Unmarshal(body, &response); err != nil {
logger.GetLogger(ctx).Errorf("GeminiEmbedder BatchEmbed unmarshal response error: %v", err)
return nil, fmt.Errorf("unmarshal response: %w", err)
}
if len(response.Embeddings) != len(texts) {
return nil, fmt.Errorf("Gemini BatchEmbed returned %d embeddings for %d inputs", len(response.Embeddings), len(texts))
}
embeddings := make([][]float32, 0, len(response.Embeddings))
for _, embedding := range response.Embeddings {
embeddings = append(embeddings, embedding.Values)
}
return embeddings, nil
}
func (e *GeminiEmbedder) doRequestWithRetry(ctx context.Context, jsonData []byte) (*http.Response, error) {
var resp *http.Response
var err error
url := fmt.Sprintf("%s/models/%s:batchEmbedContents", e.baseURL, e.modelName)
for i := 0; i <= e.maxRetries; i++ {
if i > 0 {
backoffTime := time.Duration(1<<uint(i-1)) * time.Second
if backoffTime > 10*time.Second {
backoffTime = 10 * time.Second
}
logger.GetLogger(ctx).
Infof("GeminiEmbedder retrying request (%d/%d), waiting %v", i, e.maxRetries, backoffTime)
select {
case <-time.After(backoffTime):
case <-ctx.Done():
return nil, ctx.Err()
}
}
var req *http.Request
req, err = http.NewRequestWithContext(ctx, "POST", url, bytes.NewReader(jsonData))
if err != nil {
logger.GetLogger(ctx).Errorf("GeminiEmbedder failed to create request: %v", err)
continue
}
req.Header.Set("Content-Type", "application/json")
req.Header.Set("x-goog-api-key", e.apiKey)
secutils.ApplyCustomHeaders(req, e.customHeaders)
resp, err = e.httpClient.Do(req)
if err == nil {
return resp, nil
}
logger.GetLogger(ctx).Errorf("GeminiEmbedder request failed (attempt %d/%d): %v", i+1, e.maxRetries+1, err)
}
return nil, err
}
func (e *GeminiEmbedder) GetModelName() string {
return e.modelName
}
func (e *GeminiEmbedder) GetDimensions() int {
return e.dimensions
}
func (e *GeminiEmbedder) GetModelID() string {
return e.modelID
}

View File

@@ -0,0 +1,84 @@
package embedding
import (
"context"
"encoding/json"
"net/http"
"net/http/httptest"
"strings"
"testing"
)
func TestGeminiEmbedderBatchEmbedUsesNativeAPI(t *testing.T) {
var gotPath string
var gotAPIKey string
var gotReq geminiBatchEmbedRequest
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
gotPath = r.URL.Path
gotAPIKey = r.Header.Get("x-goog-api-key")
if err := json.NewDecoder(r.Body).Decode(&gotReq); err != nil {
t.Fatalf("decode request: %v", err)
}
w.Header().Set("Content-Type", "application/json")
_, _ = w.Write([]byte(`{
"embeddings": [
{"values": [0.1, 0.2]},
{"values": [0.3, 0.4]}
]
}`))
}))
defer server.Close()
embedder, err := NewGeminiEmbedder("test-key", server.URL+"/openai", "gemini-embedding-2",
0, 768, "model-id", nil)
if err != nil {
t.Fatalf("NewGeminiEmbedder: %v", err)
}
embeddings, err := embedder.BatchEmbed(context.Background(), []string{"hello", "world"})
if err != nil {
t.Fatalf("BatchEmbed: %v", err)
}
if gotPath != "/models/gemini-embedding-2:batchEmbedContents" {
t.Fatalf("path = %q, want native batchEmbedContents path", gotPath)
}
if gotAPIKey != "test-key" {
t.Fatalf("x-goog-api-key = %q", gotAPIKey)
}
if len(gotReq.Requests) != 2 {
t.Fatalf("requests len = %d", len(gotReq.Requests))
}
if gotReq.Requests[0].Model != "models/gemini-embedding-2" {
t.Fatalf("request model = %q", gotReq.Requests[0].Model)
}
if gotReq.Requests[0].OutputDimensionality != 768 {
t.Fatalf("output_dimensionality = %d", gotReq.Requests[0].OutputDimensionality)
}
if gotReq.Requests[0].Content.Parts[0].Text != "hello" {
t.Fatalf("first text = %q", gotReq.Requests[0].Content.Parts[0].Text)
}
if len(embeddings) != 2 || len(embeddings[0]) != 2 || embeddings[1][1] != 0.4 {
t.Fatalf("unexpected embeddings: %#v", embeddings)
}
}
func TestGeminiEmbedderReturnsAPIErrorBody(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
http.Error(w, `{"error":"not found"}`, http.StatusNotFound)
}))
defer server.Close()
embedder, err := NewGeminiEmbedder("test-key", server.URL, "gemini-embedding-2",
0, 0, "model-id", nil)
if err != nil {
t.Fatalf("NewGeminiEmbedder: %v", err)
}
_, err = embedder.BatchEmbed(context.Background(), []string{"hello"})
if err == nil || !strings.Contains(err.Error(), "404") {
t.Fatalf("expected 404 error, got %v", err)
}
}

View File

@@ -25,12 +25,14 @@ func (p *GeminiProvider) Info() ProviderInfo {
return ProviderInfo{
Name: ProviderGemini,
DisplayName: "Google Gemini",
Description: "gemini-3-flash-preview, gemini-2.5-pro, etc.",
Description: "gemini-3-flash-preview, gemini-2.5-pro, gemini-embedding-2, etc.",
DefaultURLs: map[types.ModelType]string{
types.ModelTypeKnowledgeQA: GeminiOpenAICompatBaseURL,
types.ModelTypeEmbedding: GeminiBaseURL,
},
ModelTypes: []types.ModelType{
types.ModelTypeKnowledgeQA,
types.ModelTypeEmbedding,
},
RequiresAuth: true,
}

View File

@@ -256,4 +256,20 @@ func TestListByModelType(t *testing.T) {
assert.True(t, found, "OpenRouter should support embedding")
})
t.Run("embedding models include gemini", func(t *testing.T) {
providers := ListByModelType(types.ModelTypeEmbedding)
assert.NotEmpty(t, providers)
found := false
for _, p := range providers {
if p.Name == ProviderGemini {
found = true
assert.Equal(t, GeminiBaseURL, p.GetDefaultURL(types.ModelTypeEmbedding))
break
}
}
assert.True(t, found, "Gemini should support embedding via the native Gemini API")
})
}

View File

@@ -11,6 +11,7 @@ var reservedHeaderKeys = map[string]struct{}{
"authorization": {},
"api-key": {},
"x-api-key": {},
"x-goog-api-key": {},
"content-type": {},
"content-length": {},
"accept-encoding": {},

View File

@@ -12,13 +12,15 @@ func TestApplyCustomHeaders_SkipReserved(t *testing.T) {
req, _ := http.NewRequest("GET", "https://example.com", nil)
req.Header.Set("Authorization", "Bearer original")
req.Header.Set("Content-Type", "application/json")
req.Header.Set("x-goog-api-key", "google-original")
ApplyCustomHeaders(req, map[string]string{
"Authorization": "Bearer injected",
"Content-Type": "text/plain",
"X-Trace-Id": "trace-123",
"X-Route": "edge",
"": "empty-key-should-be-skipped",
"Authorization": "Bearer injected",
"Content-Type": "text/plain",
"x-goog-api-key": "google-injected",
"X-Trace-Id": "trace-123",
"X-Route": "edge",
"": "empty-key-should-be-skipped",
})
if got := req.Header.Get("Authorization"); got != "Bearer original" {
@@ -27,6 +29,9 @@ func TestApplyCustomHeaders_SkipReserved(t *testing.T) {
if got := req.Header.Get("Content-Type"); got != "application/json" {
t.Fatalf("content-type overwritten: %q", got)
}
if got := req.Header.Get("x-goog-api-key"); got != "google-original" {
t.Fatalf("x-goog-api-key overwritten: %q", got)
}
if got := req.Header.Get("X-Trace-Id"); got != "trace-123" {
t.Fatalf("X-Trace-Id not injected: %q", got)
}