mirror of
https://github.com/Tencent/WeKnora.git
synced 2026-06-04 13:30:32 +08:00
fix(embedding): support native Gemini embeddings
This commit is contained in:
@@ -435,6 +435,16 @@ const fallbackProviderOptions = computed(() => [
|
||||
description: t('model.editor.providers.openrouter.description'),
|
||||
modelTypes: ['chat', 'embedding']
|
||||
},
|
||||
{
|
||||
value: 'gemini',
|
||||
label: t('model.editor.providers.gemini.label'),
|
||||
defaultUrls: {
|
||||
chat: 'https://generativelanguage.googleapis.com/v1beta/openai',
|
||||
embedding: 'https://generativelanguage.googleapis.com/v1beta'
|
||||
},
|
||||
description: t('model.editor.providers.gemini.description'),
|
||||
modelTypes: ['chat', 'embedding']
|
||||
},
|
||||
{
|
||||
value: 'siliconflow',
|
||||
label: t('model.editor.providers.siliconflow.label'),
|
||||
|
||||
@@ -218,6 +218,19 @@ func newEmbedder(config Config, pooler EmbedderPooler, ollamaService *ollama.Oll
|
||||
}
|
||||
embedder, err = nvEmb, nErr
|
||||
return embedder, err
|
||||
case provider.ProviderGemini:
|
||||
geminiEmb, gErr := NewGeminiEmbedder(config.APIKey,
|
||||
config.BaseURL,
|
||||
config.ModelName,
|
||||
config.TruncatePromptTokens,
|
||||
config.Dimensions,
|
||||
config.ModelID,
|
||||
pooler)
|
||||
if geminiEmb != nil {
|
||||
geminiEmb.SetCustomHeaders(config.CustomHeaders)
|
||||
}
|
||||
embedder, err = geminiEmb, gErr
|
||||
return embedder, err
|
||||
case provider.ProviderZhipu:
|
||||
zhipuEmb, zErr := NewZhipuEmbedder(config.APIKey,
|
||||
config.BaseURL,
|
||||
|
||||
227
internal/models/embedding/gemini.go
Normal file
227
internal/models/embedding/gemini.go
Normal file
@@ -0,0 +1,227 @@
|
||||
package embedding
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/Tencent/WeKnora/internal/logger"
|
||||
secutils "github.com/Tencent/WeKnora/internal/utils"
|
||||
)
|
||||
|
||||
const geminiEmbeddingBaseURL = "https://generativelanguage.googleapis.com/v1beta"
|
||||
|
||||
// GeminiEmbedder implements text vectorization using the native Gemini
|
||||
// embedContent / batchEmbedContents REST API.
|
||||
type GeminiEmbedder struct {
|
||||
apiKey string
|
||||
baseURL string
|
||||
modelName string
|
||||
truncatePromptTokens int
|
||||
dimensions int
|
||||
modelID string
|
||||
httpClient *http.Client
|
||||
timeout time.Duration
|
||||
maxRetries int
|
||||
customHeaders map[string]string
|
||||
EmbedderPooler
|
||||
}
|
||||
|
||||
type geminiBatchEmbedRequest struct {
|
||||
Requests []geminiEmbedRequest `json:"requests"`
|
||||
}
|
||||
|
||||
type geminiEmbedRequest struct {
|
||||
Model string `json:"model"`
|
||||
Content geminiContent `json:"content"`
|
||||
TaskType string `json:"taskType,omitempty"`
|
||||
OutputDimensionality int `json:"output_dimensionality,omitempty"`
|
||||
}
|
||||
|
||||
type geminiContent struct {
|
||||
Parts []geminiPart `json:"parts"`
|
||||
}
|
||||
|
||||
type geminiPart struct {
|
||||
Text string `json:"text"`
|
||||
}
|
||||
|
||||
type geminiBatchEmbedResponse struct {
|
||||
Embeddings []geminiEmbedding `json:"embeddings"`
|
||||
}
|
||||
|
||||
type geminiEmbedding struct {
|
||||
Values []float32 `json:"values"`
|
||||
}
|
||||
|
||||
func NewGeminiEmbedder(apiKey, baseURL, modelName string,
|
||||
truncatePromptTokens int, dimensions int, modelID string, pooler EmbedderPooler,
|
||||
) (*GeminiEmbedder, error) {
|
||||
if modelName == "" {
|
||||
return nil, fmt.Errorf("model name is required")
|
||||
}
|
||||
if truncatePromptTokens == 0 {
|
||||
truncatePromptTokens = 511
|
||||
}
|
||||
if baseURL == "" {
|
||||
baseURL = geminiEmbeddingBaseURL
|
||||
}
|
||||
|
||||
baseURL = strings.TrimRight(baseURL, "/")
|
||||
if strings.HasSuffix(baseURL, "/openai") {
|
||||
baseURL = strings.TrimSuffix(baseURL, "/openai")
|
||||
}
|
||||
|
||||
timeout := 60 * time.Second
|
||||
return &GeminiEmbedder{
|
||||
apiKey: apiKey,
|
||||
baseURL: baseURL,
|
||||
modelName: strings.TrimPrefix(modelName, "models/"),
|
||||
truncatePromptTokens: truncatePromptTokens,
|
||||
dimensions: dimensions,
|
||||
modelID: modelID,
|
||||
httpClient: &http.Client{Timeout: timeout},
|
||||
timeout: timeout,
|
||||
maxRetries: 3,
|
||||
EmbedderPooler: pooler,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (e *GeminiEmbedder) SetCustomHeaders(headers map[string]string) {
|
||||
e.customHeaders = headers
|
||||
}
|
||||
|
||||
func (e *GeminiEmbedder) Embed(ctx context.Context, text string) ([]float32, error) {
|
||||
embeddings, err := e.BatchEmbed(ctx, []string{text})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if len(embeddings) == 0 {
|
||||
return nil, fmt.Errorf("no embedding returned")
|
||||
}
|
||||
return embeddings[0], nil
|
||||
}
|
||||
|
||||
func (e *GeminiEmbedder) BatchEmbed(ctx context.Context, texts []string) ([][]float32, error) {
|
||||
if len(texts) == 0 {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
requests := make([]geminiEmbedRequest, 0, len(texts))
|
||||
for _, text := range texts {
|
||||
requests = append(requests, geminiEmbedRequest{
|
||||
Model: "models/" + e.modelName,
|
||||
Content: geminiContent{Parts: []geminiPart{
|
||||
{Text: text},
|
||||
}},
|
||||
OutputDimensionality: e.dimensions,
|
||||
})
|
||||
}
|
||||
|
||||
jsonData, err := json.Marshal(geminiBatchEmbedRequest{Requests: requests})
|
||||
if err != nil {
|
||||
logger.GetLogger(ctx).Errorf("GeminiEmbedder BatchEmbed marshal request error: %v", err)
|
||||
return nil, fmt.Errorf("marshal request: %w", err)
|
||||
}
|
||||
|
||||
logger.GetLogger(ctx).Debugf("GeminiEmbedder BatchEmbed: model=%s, input_count=%d",
|
||||
e.modelName, len(texts))
|
||||
|
||||
resp, err := e.doRequestWithRetry(ctx, jsonData)
|
||||
if err != nil {
|
||||
logger.GetLogger(ctx).Errorf("GeminiEmbedder BatchEmbed send request error: %v", err)
|
||||
return nil, fmt.Errorf("send request: %w", err)
|
||||
}
|
||||
if resp.Body != nil {
|
||||
defer resp.Body.Close()
|
||||
}
|
||||
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
logger.GetLogger(ctx).Errorf("GeminiEmbedder BatchEmbed read response error: %v", err)
|
||||
return nil, fmt.Errorf("read response: %w", err)
|
||||
}
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
bodyStr := string(body)
|
||||
if len(bodyStr) > 1000 {
|
||||
bodyStr = bodyStr[:1000] + "... (truncated)"
|
||||
}
|
||||
logger.GetLogger(ctx).Errorf("GeminiEmbedder BatchEmbed API error: Http Status %s, Response Body: %s", resp.Status, bodyStr)
|
||||
return nil, fmt.Errorf("Gemini BatchEmbed API error: Http Status %s, Response: %s", resp.Status, bodyStr)
|
||||
}
|
||||
|
||||
var response geminiBatchEmbedResponse
|
||||
if err := json.Unmarshal(body, &response); err != nil {
|
||||
logger.GetLogger(ctx).Errorf("GeminiEmbedder BatchEmbed unmarshal response error: %v", err)
|
||||
return nil, fmt.Errorf("unmarshal response: %w", err)
|
||||
}
|
||||
if len(response.Embeddings) != len(texts) {
|
||||
return nil, fmt.Errorf("Gemini BatchEmbed returned %d embeddings for %d inputs", len(response.Embeddings), len(texts))
|
||||
}
|
||||
|
||||
embeddings := make([][]float32, 0, len(response.Embeddings))
|
||||
for _, embedding := range response.Embeddings {
|
||||
embeddings = append(embeddings, embedding.Values)
|
||||
}
|
||||
return embeddings, nil
|
||||
}
|
||||
|
||||
func (e *GeminiEmbedder) doRequestWithRetry(ctx context.Context, jsonData []byte) (*http.Response, error) {
|
||||
var resp *http.Response
|
||||
var err error
|
||||
url := fmt.Sprintf("%s/models/%s:batchEmbedContents", e.baseURL, e.modelName)
|
||||
|
||||
for i := 0; i <= e.maxRetries; i++ {
|
||||
if i > 0 {
|
||||
backoffTime := time.Duration(1<<uint(i-1)) * time.Second
|
||||
if backoffTime > 10*time.Second {
|
||||
backoffTime = 10 * time.Second
|
||||
}
|
||||
logger.GetLogger(ctx).
|
||||
Infof("GeminiEmbedder retrying request (%d/%d), waiting %v", i, e.maxRetries, backoffTime)
|
||||
|
||||
select {
|
||||
case <-time.After(backoffTime):
|
||||
case <-ctx.Done():
|
||||
return nil, ctx.Err()
|
||||
}
|
||||
}
|
||||
|
||||
var req *http.Request
|
||||
req, err = http.NewRequestWithContext(ctx, "POST", url, bytes.NewReader(jsonData))
|
||||
if err != nil {
|
||||
logger.GetLogger(ctx).Errorf("GeminiEmbedder failed to create request: %v", err)
|
||||
continue
|
||||
}
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
req.Header.Set("x-goog-api-key", e.apiKey)
|
||||
secutils.ApplyCustomHeaders(req, e.customHeaders)
|
||||
|
||||
resp, err = e.httpClient.Do(req)
|
||||
if err == nil {
|
||||
return resp, nil
|
||||
}
|
||||
|
||||
logger.GetLogger(ctx).Errorf("GeminiEmbedder request failed (attempt %d/%d): %v", i+1, e.maxRetries+1, err)
|
||||
}
|
||||
|
||||
return nil, err
|
||||
}
|
||||
|
||||
func (e *GeminiEmbedder) GetModelName() string {
|
||||
return e.modelName
|
||||
}
|
||||
|
||||
func (e *GeminiEmbedder) GetDimensions() int {
|
||||
return e.dimensions
|
||||
}
|
||||
|
||||
func (e *GeminiEmbedder) GetModelID() string {
|
||||
return e.modelID
|
||||
}
|
||||
84
internal/models/embedding/gemini_test.go
Normal file
84
internal/models/embedding/gemini_test.go
Normal file
@@ -0,0 +1,84 @@
|
||||
package embedding
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestGeminiEmbedderBatchEmbedUsesNativeAPI(t *testing.T) {
|
||||
var gotPath string
|
||||
var gotAPIKey string
|
||||
var gotReq geminiBatchEmbedRequest
|
||||
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
gotPath = r.URL.Path
|
||||
gotAPIKey = r.Header.Get("x-goog-api-key")
|
||||
if err := json.NewDecoder(r.Body).Decode(&gotReq); err != nil {
|
||||
t.Fatalf("decode request: %v", err)
|
||||
}
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
_, _ = w.Write([]byte(`{
|
||||
"embeddings": [
|
||||
{"values": [0.1, 0.2]},
|
||||
{"values": [0.3, 0.4]}
|
||||
]
|
||||
}`))
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
embedder, err := NewGeminiEmbedder("test-key", server.URL+"/openai", "gemini-embedding-2",
|
||||
0, 768, "model-id", nil)
|
||||
if err != nil {
|
||||
t.Fatalf("NewGeminiEmbedder: %v", err)
|
||||
}
|
||||
|
||||
embeddings, err := embedder.BatchEmbed(context.Background(), []string{"hello", "world"})
|
||||
if err != nil {
|
||||
t.Fatalf("BatchEmbed: %v", err)
|
||||
}
|
||||
|
||||
if gotPath != "/models/gemini-embedding-2:batchEmbedContents" {
|
||||
t.Fatalf("path = %q, want native batchEmbedContents path", gotPath)
|
||||
}
|
||||
if gotAPIKey != "test-key" {
|
||||
t.Fatalf("x-goog-api-key = %q", gotAPIKey)
|
||||
}
|
||||
if len(gotReq.Requests) != 2 {
|
||||
t.Fatalf("requests len = %d", len(gotReq.Requests))
|
||||
}
|
||||
if gotReq.Requests[0].Model != "models/gemini-embedding-2" {
|
||||
t.Fatalf("request model = %q", gotReq.Requests[0].Model)
|
||||
}
|
||||
if gotReq.Requests[0].OutputDimensionality != 768 {
|
||||
t.Fatalf("output_dimensionality = %d", gotReq.Requests[0].OutputDimensionality)
|
||||
}
|
||||
if gotReq.Requests[0].Content.Parts[0].Text != "hello" {
|
||||
t.Fatalf("first text = %q", gotReq.Requests[0].Content.Parts[0].Text)
|
||||
}
|
||||
if len(embeddings) != 2 || len(embeddings[0]) != 2 || embeddings[1][1] != 0.4 {
|
||||
t.Fatalf("unexpected embeddings: %#v", embeddings)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGeminiEmbedderReturnsAPIErrorBody(t *testing.T) {
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
http.Error(w, `{"error":"not found"}`, http.StatusNotFound)
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
embedder, err := NewGeminiEmbedder("test-key", server.URL, "gemini-embedding-2",
|
||||
0, 0, "model-id", nil)
|
||||
if err != nil {
|
||||
t.Fatalf("NewGeminiEmbedder: %v", err)
|
||||
}
|
||||
|
||||
_, err = embedder.BatchEmbed(context.Background(), []string{"hello"})
|
||||
if err == nil || !strings.Contains(err.Error(), "404") {
|
||||
t.Fatalf("expected 404 error, got %v", err)
|
||||
}
|
||||
}
|
||||
@@ -25,12 +25,14 @@ func (p *GeminiProvider) Info() ProviderInfo {
|
||||
return ProviderInfo{
|
||||
Name: ProviderGemini,
|
||||
DisplayName: "Google Gemini",
|
||||
Description: "gemini-3-flash-preview, gemini-2.5-pro, etc.",
|
||||
Description: "gemini-3-flash-preview, gemini-2.5-pro, gemini-embedding-2, etc.",
|
||||
DefaultURLs: map[types.ModelType]string{
|
||||
types.ModelTypeKnowledgeQA: GeminiOpenAICompatBaseURL,
|
||||
types.ModelTypeEmbedding: GeminiBaseURL,
|
||||
},
|
||||
ModelTypes: []types.ModelType{
|
||||
types.ModelTypeKnowledgeQA,
|
||||
types.ModelTypeEmbedding,
|
||||
},
|
||||
RequiresAuth: true,
|
||||
}
|
||||
|
||||
@@ -256,4 +256,20 @@ func TestListByModelType(t *testing.T) {
|
||||
|
||||
assert.True(t, found, "OpenRouter should support embedding")
|
||||
})
|
||||
|
||||
t.Run("embedding models include gemini", func(t *testing.T) {
|
||||
providers := ListByModelType(types.ModelTypeEmbedding)
|
||||
assert.NotEmpty(t, providers)
|
||||
|
||||
found := false
|
||||
for _, p := range providers {
|
||||
if p.Name == ProviderGemini {
|
||||
found = true
|
||||
assert.Equal(t, GeminiBaseURL, p.GetDefaultURL(types.ModelTypeEmbedding))
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
assert.True(t, found, "Gemini should support embedding via the native Gemini API")
|
||||
})
|
||||
}
|
||||
|
||||
@@ -11,6 +11,7 @@ var reservedHeaderKeys = map[string]struct{}{
|
||||
"authorization": {},
|
||||
"api-key": {},
|
||||
"x-api-key": {},
|
||||
"x-goog-api-key": {},
|
||||
"content-type": {},
|
||||
"content-length": {},
|
||||
"accept-encoding": {},
|
||||
|
||||
@@ -12,13 +12,15 @@ func TestApplyCustomHeaders_SkipReserved(t *testing.T) {
|
||||
req, _ := http.NewRequest("GET", "https://example.com", nil)
|
||||
req.Header.Set("Authorization", "Bearer original")
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
req.Header.Set("x-goog-api-key", "google-original")
|
||||
|
||||
ApplyCustomHeaders(req, map[string]string{
|
||||
"Authorization": "Bearer injected",
|
||||
"Content-Type": "text/plain",
|
||||
"X-Trace-Id": "trace-123",
|
||||
"X-Route": "edge",
|
||||
"": "empty-key-should-be-skipped",
|
||||
"Authorization": "Bearer injected",
|
||||
"Content-Type": "text/plain",
|
||||
"x-goog-api-key": "google-injected",
|
||||
"X-Trace-Id": "trace-123",
|
||||
"X-Route": "edge",
|
||||
"": "empty-key-should-be-skipped",
|
||||
})
|
||||
|
||||
if got := req.Header.Get("Authorization"); got != "Bearer original" {
|
||||
@@ -27,6 +29,9 @@ func TestApplyCustomHeaders_SkipReserved(t *testing.T) {
|
||||
if got := req.Header.Get("Content-Type"); got != "application/json" {
|
||||
t.Fatalf("content-type overwritten: %q", got)
|
||||
}
|
||||
if got := req.Header.Get("x-goog-api-key"); got != "google-original" {
|
||||
t.Fatalf("x-goog-api-key overwritten: %q", got)
|
||||
}
|
||||
if got := req.Header.Get("X-Trace-Id"); got != "trace-123" {
|
||||
t.Fatalf("X-Trace-Id not injected: %q", got)
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user