feat(auth): 连接docreader支持auth

This commit is contained in:
Li Xianggang
2026-05-16 17:49:31 +08:00
committed by lyingbug
parent c87e35b34b
commit 5a02e22f52
9 changed files with 363 additions and 26 deletions

View File

@@ -437,6 +437,15 @@ DOCREADER_ADDR=docreader:50051
# Docreader 连接方式
DOCREADER_TRANSPORT=grpc
# gRPC TLS 配置(可选)
# GRPC_TLS_ENABLED=false
# GRPC_TLS_CERT=/path/to/server.crt
# GRPC_TLS_KEY=/path/to/server.key
# GRPC_TLS_CA=/path/to/ca.crt
# gRPC 认证 Token可选客户端和服务端需配置相同的值
# GRPC_AUTH_TOKEN=your-secret-token
# Docreader 中 DOCX 解析的最大页数,默认 0不限制
# 设为正整数(如 500可限制超大 Word 文档的解析开销;超过页数的内容将不会继续解析
# DOCREADER_DOCX_MAX_PAGES=0

View File

@@ -58,4 +58,13 @@ CONCURRENCY_POOL_SIZE=3
DOCREADER_ADDR=127.0.0.1:50051
# Docreader 传输方式
DOCREADER_TRANSPORT=grpc
DOCREADER_TRANSPORT=grpc
# gRPC TLS 配置(可选)
# GRPC_TLS_ENABLED=false
# GRPC_TLS_CERT=/path/to/server.crt
# GRPC_TLS_KEY=/path/to/server.key
# GRPC_TLS_CA=/path/to/ca.crt
# gRPC 认证 Token可选客户端和服务端需配置相同的值
# GRPC_AUTH_TOKEN=your-secret-token

View File

@@ -199,6 +199,11 @@ services:
- DOCREADER_PDF_RENDER_MAX_WORKERS=${DOCREADER_PDF_RENDER_MAX_WORKERS:-1}
- DOCREADER_PDF_RENDER_DPI=${DOCREADER_PDF_RENDER_DPI:-200}
- DOCREADER_PDF_JPEG_QUALITY=${DOCREADER_PDF_JPEG_QUALITY:-90}
- GRPC_TLS_ENABLED=${GRPC_TLS_ENABLED:-false}
- GRPC_TLS_CERT=${GRPC_TLS_CERT:-}
- GRPC_TLS_KEY=${GRPC_TLS_KEY:-}
- GRPC_TLS_CA=${GRPC_TLS_CA:-}
- GRPC_AUTH_TOKEN=${GRPC_AUTH_TOKEN:-}
healthcheck:
test: ["CMD", "grpc_health_probe", "-addr=localhost:50051"]
interval: 30s

View File

@@ -205,6 +205,11 @@ services:
- DOCREADER_PDF_RENDER_MAX_WORKERS=${DOCREADER_PDF_RENDER_MAX_WORKERS:-1}
- DOCREADER_PDF_RENDER_DPI=${DOCREADER_PDF_RENDER_DPI:-200}
- DOCREADER_PDF_JPEG_QUALITY=${DOCREADER_PDF_JPEG_QUALITY:-90}
- GRPC_TLS_ENABLED=${GRPC_TLS_ENABLED:-false}
- GRPC_TLS_CERT=${GRPC_TLS_CERT:-}
- GRPC_TLS_KEY=${GRPC_TLS_KEY:-}
- GRPC_TLS_CA=${GRPC_TLS_CA:-}
- GRPC_AUTH_TOKEN=${GRPC_AUTH_TOKEN:-}
healthcheck:
test: ["CMD", "grpc_health_probe", "-addr=localhost:50051"]
interval: 30s

111
docreader/auth.py Normal file
View File

@@ -0,0 +1,111 @@
"""gRPC TLS 和认证模块
环境变量配置:
TLS 相关:
GRPC_TLS_ENABLED: 是否启用 TLStrue/false默认 false
GRPC_TLS_CERT: TLS 证书文件路径
GRPC_TLS_KEY: TLS 私钥文件路径
GRPC_TLS_CA: CA 证书路径(可选,用于 mTLS 双向认证)
认证相关:
GRPC_AUTH_TOKEN: 认证 Token如果设置则启用认证
"""
import logging
import os
from typing import Optional
import grpc
logger = logging.getLogger(__name__)
def load_tls_credentials() -> Optional[grpc.ServerCredentials]:
tls_enabled = os.getenv("GRPC_TLS_ENABLED", "false").lower() == "true"
if not tls_enabled:
logger.info("TLS disabled (GRPC_TLS_ENABLED is not 'true')")
return None
cert_path = os.getenv("GRPC_TLS_CERT")
key_path = os.getenv("GRPC_TLS_KEY")
if not cert_path or not key_path:
logger.warning(
"TLS enabled but certificate not configured (GRPC_TLS_CERT or GRPC_TLS_KEY not set)"
)
return None
try:
with open(cert_path, "rb") as f:
cert_chain = f.read()
with open(key_path, "rb") as f:
private_key = f.read()
ca_path = os.getenv("GRPC_TLS_CA")
if ca_path:
with open(ca_path, "rb") as f:
ca_cert = f.read()
credentials = grpc.ssl_server_credentials(
[(private_key, cert_chain)],
root_certificates=ca_cert,
require_client_auth=True,
)
logger.info("TLS enabled with mTLS (mutual authentication)")
else:
credentials = grpc.ssl_server_credentials([(private_key, cert_chain)])
logger.info("TLS enabled")
return credentials
except FileNotFoundError as e:
logger.error(f"TLS certificate file not found: {e}")
return None
except Exception as e:
logger.error(f"Failed to load TLS credentials: {e}")
return None
class AuthInterceptor(grpc.ServerInterceptor):
"""Token 认证拦截器
环境变量配置:
GRPC_AUTH_TOKEN: 认证 Token如果设置则启用认证
客户端需要在 metadata 中传递 Token
- key: "authorization"
- value: "Bearer <token>" 或直接 "<token>"
"""
def __init__(self):
self.auth_token = os.getenv("GRPC_AUTH_TOKEN")
if self.auth_token:
logger.info("Token authentication enabled")
else:
logger.warning("Token authentication disabled (GRPC_AUTH_TOKEN not set)")
def intercept_service(self, continuation, handler_call_details):
if not self.auth_token:
return continuation(handler_call_details)
method = handler_call_details.method
if method.endswith("/Check") or method.endswith("/Watch"):
return continuation(handler_call_details)
metadata = dict(handler_call_details.invocation_metadata or [])
token = metadata.get("authorization", "")
if token.startswith("Bearer "):
token = token[7:]
if token != self.auth_token:
logger.warning(f"Authentication failed for method: {method}")
return self._unauthenticated_handler()
return continuation(handler_call_details)
def _unauthenticated_handler(self):
def handler(request, context):
context.set_code(grpc.StatusCode.UNAUTHENTICATED)
context.set_details("Invalid or missing authentication token")
return None
return grpc.unary_unary_rpc_method_handler(handler)

103
docreader/client/auth.go Normal file
View File

@@ -0,0 +1,103 @@
package client
import (
"context"
"crypto/tls"
"crypto/x509"
"fmt"
"os"
"google.golang.org/grpc"
"google.golang.org/grpc/credentials"
"google.golang.org/grpc/credentials/insecure"
)
type AuthConfig struct {
TLSEnabled bool
CertFile string
KeyFile string
CAFile string
AuthToken string
}
func LoadAuthConfigFromEnv() *AuthConfig {
return &AuthConfig{
TLSEnabled: os.Getenv("GRPC_TLS_ENABLED") == "true",
CertFile: os.Getenv("GRPC_TLS_CERT"),
KeyFile: os.Getenv("GRPC_TLS_KEY"),
CAFile: os.Getenv("GRPC_TLS_CA"),
AuthToken: os.Getenv("GRPC_AUTH_TOKEN"),
}
}
func (c *AuthConfig) BuildDialOptions(maxMsgSize int) ([]grpc.DialOption, error) {
opts := []grpc.DialOption{
grpc.WithDefaultServiceConfig(`{"loadBalancingPolicy":"round_robin"}`),
grpc.WithDefaultCallOptions(
grpc.MaxCallRecvMsgSize(maxMsgSize),
grpc.MaxCallSendMsgSize(maxMsgSize),
),
}
if c.TLSEnabled {
creds, err := c.buildTLSCredentials()
if err != nil {
return nil, fmt.Errorf("failed to build TLS credentials: %w", err)
}
opts = append(opts, grpc.WithTransportCredentials(creds))
Logger.Printf("INFO: TLS enabled for gRPC client")
} else {
opts = append(opts, grpc.WithTransportCredentials(insecure.NewCredentials()))
}
if c.AuthToken != "" {
opts = append(opts, grpc.WithPerRPCCredentials(&tokenAuth{token: c.AuthToken}))
Logger.Printf("INFO: Token authentication enabled for gRPC client")
}
return opts, nil
}
func (c *AuthConfig) buildTLSCredentials() (credentials.TransportCredentials, error) {
tlsConfig := &tls.Config{
MinVersion: tls.VersionTLS12,
}
if c.CAFile != "" {
caCert, err := os.ReadFile(c.CAFile)
if err != nil {
return nil, fmt.Errorf("failed to read CA certificate: %w", err)
}
certPool := x509.NewCertPool()
if !certPool.AppendCertsFromPEM(caCert) {
return nil, fmt.Errorf("failed to parse CA certificate")
}
tlsConfig.RootCAs = certPool
}
if c.CertFile != "" && c.KeyFile != "" {
cert, err := tls.LoadX509KeyPair(c.CertFile, c.KeyFile)
if err != nil {
return nil, fmt.Errorf("failed to load client certificate: %w", err)
}
tlsConfig.Certificates = []tls.Certificate{cert}
Logger.Printf("INFO: mTLS enabled (client certificate loaded)")
}
return credentials.NewTLS(tlsConfig), nil
}
type tokenAuth struct {
token string
}
func (t *tokenAuth) GetRequestMetadata(ctx context.Context, uri ...string) (map[string]string, error) {
return map[string]string{
"authorization": "Bearer " + t.token,
}, nil
}
func (t *tokenAuth) RequireTransportSecurity() bool {
return false
}

View File

@@ -9,7 +9,6 @@ import (
"github.com/Tencent/WeKnora/docreader/proto"
"google.golang.org/grpc"
"google.golang.org/grpc/credentials/insecure"
"google.golang.org/grpc/resolver"
)
@@ -40,17 +39,25 @@ type Client struct {
}
func NewClient(addr string) (*Client, error) {
authConfig := LoadAuthConfigFromEnv()
return NewClientWithAuth(addr, authConfig)
}
func NewClientWithAuth(addr string, authConfig *AuthConfig) (*Client, error) {
Logger.Printf("INFO: Creating new DocReader client connecting to %s", addr)
maxMsgSize := getMaxMessageSize()
opts := []grpc.DialOption{
grpc.WithTransportCredentials(insecure.NewCredentials()),
grpc.WithDefaultServiceConfig(`{"loadBalancingPolicy":"round_robin"}`),
grpc.WithDefaultCallOptions(
grpc.MaxCallRecvMsgSize(maxMsgSize),
grpc.MaxCallSendMsgSize(maxMsgSize),
),
if authConfig == nil {
authConfig = &AuthConfig{}
}
opts, err := authConfig.BuildDialOptions(maxMsgSize)
if err != nil {
Logger.Printf("ERROR: Failed to build dial options: %v", err)
return nil, err
}
resolver.SetDefaultScheme("dns")
startTime := time.Now()

View File

@@ -11,6 +11,7 @@ import grpc
from grpc_health.v1 import health_pb2_grpc
from grpc_health.v1.health import HealthServicer
from docreader.auth import AuthInterceptor, load_tls_credentials
from docreader import config
from docreader.config import CONFIG
from docreader.parser import Parser
@@ -51,7 +52,9 @@ logger.info("Initializing server logging, level=%s", _level_name)
init_logging_request_id()
def _resolve_images(images: dict, request_id: str, storage_map: dict | None = None) -> tuple[str, list]:
def _resolve_images(
images: dict, request_id: str, storage_map: dict | None = None
) -> tuple[str, list]:
"""Resolve document images into inline bytes for the Go App to persist.
``images`` is a dict of {relative_path: raw_data} where raw_data is
@@ -69,8 +72,12 @@ def _resolve_images(images: dict, request_id: str, storage_map: dict | None = No
return "", []
mime_map = {
".png": "image/png", ".jpg": "image/jpeg", ".jpeg": "image/jpeg",
".gif": "image/gif", ".webp": "image/webp", ".bmp": "image/bmp",
".png": "image/png",
".jpg": "image/jpeg",
".jpeg": "image/jpeg",
".gif": "image/gif",
".webp": "image/webp",
".bmp": "image/bmp",
}
refs = []
@@ -84,12 +91,14 @@ def _resolve_images(images: dict, request_id: str, storage_map: dict | None = No
ext = os.path.splitext(fname)[1].lower()
mime = mime_map.get(ext, "application/octet-stream")
refs.append(ImageRef(
filename=fname,
original_ref=ref_path,
mime_type=mime,
image_data=img_bytes,
))
refs.append(
ImageRef(
filename=fname,
original_ref=ref_path,
mime_type=mime,
image_data=img_bytes,
)
)
logger.info("Resolved %d images (mode=inline)", len(refs))
return "", refs
@@ -126,7 +135,9 @@ class DocReaderServicer(docreader_pb2_grpc.DocReaderServicer):
)
logger.info(
"Read(File): file=%s, type=%s, size=%d bytes",
request.file_name, file_type, len(request.file_content),
request.file_name,
file_type,
len(request.file_content),
)
result = self.parser.parse_file(
request.file_name,
@@ -143,19 +154,20 @@ class DocReaderServicer(docreader_pb2_grpc.DocReaderServicer):
return ReadResponse(error=error_msg)
_c = to_valid_utf8_text
image_dir, image_refs = _resolve_images(
result.images, request_id
)
image_dir, image_refs = _resolve_images(result.images, request_id)
response = ReadResponse(
markdown_content=_c(result.content),
image_refs=image_refs,
image_dir_path=image_dir,
metadata={k: _c(str(v)) for k, v in result.metadata.items()} if result.metadata else {},
metadata={k: _c(str(v)) for k, v in result.metadata.items()}
if result.metadata
else {},
)
logger.info(
"Read response: content_len=%d, images=%d",
len(result.content), len(image_refs),
len(result.content),
len(image_refs),
)
return response
@@ -184,12 +196,15 @@ class DocReaderServicer(docreader_pb2_grpc.DocReaderServicer):
def main():
config.print_config()
interceptors = [AuthInterceptor()]
server = grpc.server(
futures.ThreadPoolExecutor(max_workers=CONFIG.grpc_max_workers),
options=[
("grpc.max_send_message_length", CONFIG.grpc_max_file_size_mb),
("grpc.max_receive_message_length", CONFIG.grpc_max_file_size_mb),
],
interceptors=interceptors,
)
docreader_pb2_grpc.add_DocReaderServicer_to_server(DocReaderServicer(), server)
@@ -197,7 +212,16 @@ def main():
health_servicer = HealthServicer()
health_pb2_grpc.add_HealthServicer_to_server(health_servicer, server)
server.add_insecure_port(f"[::]:{CONFIG.grpc_port}")
tls_credentials = load_tls_credentials()
if tls_credentials:
server.add_secure_port(f"[::]:{CONFIG.grpc_port}", tls_credentials)
logger.info("Server starting on port %d with TLS", CONFIG.grpc_port)
else:
server.add_insecure_port(f"[::]:{CONFIG.grpc_port}")
logger.warning(
"Server starting on port %d WITHOUT TLS (insecure mode)", CONFIG.grpc_port
)
server.start()
logger.info("Server started on port %d", CONFIG.grpc_port)

View File

@@ -2,6 +2,8 @@ package docparser
import (
"context"
"crypto/tls"
"crypto/x509"
"fmt"
"os"
"strconv"
@@ -12,6 +14,7 @@ import (
"github.com/Tencent/WeKnora/internal/logger"
"github.com/Tencent/WeKnora/internal/types"
"google.golang.org/grpc"
"google.golang.org/grpc/credentials"
"google.golang.org/grpc/credentials/insecure"
"google.golang.org/grpc/resolver"
)
@@ -47,13 +50,29 @@ func (p *GRPCDocumentReader) connect(addr string) error {
maxMsgSize := getMaxMessageSize()
opts := []grpc.DialOption{
grpc.WithTransportCredentials(insecure.NewCredentials()),
grpc.WithDefaultServiceConfig(`{"loadBalancingPolicy":"round_robin"}`),
grpc.WithDefaultCallOptions(
grpc.MaxCallRecvMsgSize(maxMsgSize),
grpc.MaxCallSendMsgSize(maxMsgSize),
),
}
if os.Getenv("GRPC_TLS_ENABLED") == "true" {
creds, err := buildTLSCredentials()
if err != nil {
return fmt.Errorf("failed to build TLS credentials: %w", err)
}
opts = append(opts, grpc.WithTransportCredentials(creds))
logger.Infof(context.Background(), "TLS enabled for docreader gRPC client")
} else {
opts = append(opts, grpc.WithTransportCredentials(insecure.NewCredentials()))
}
if authToken := os.Getenv("GRPC_AUTH_TOKEN"); authToken != "" {
opts = append(opts, grpc.WithPerRPCCredentials(&tokenAuth{token: authToken}))
logger.Infof(context.Background(), "Token authentication enabled for docreader gRPC client")
}
resolver.SetDefaultScheme("dns")
start := time.Now()
@@ -173,3 +192,48 @@ func fromProtoReadResponse(resp *proto.ReadResponse) *types.ReadResult {
return result
}
func buildTLSCredentials() (credentials.TransportCredentials, error) {
tlsConfig := &tls.Config{
MinVersion: tls.VersionTLS12,
}
if caFile := os.Getenv("GRPC_TLS_CA"); caFile != "" {
caCert, err := os.ReadFile(caFile)
if err != nil {
return nil, fmt.Errorf("failed to read CA certificate: %w", err)
}
certPool := x509.NewCertPool()
if !certPool.AppendCertsFromPEM(caCert) {
return nil, fmt.Errorf("failed to parse CA certificate")
}
tlsConfig.RootCAs = certPool
}
if certFile := os.Getenv("GRPC_TLS_CERT"); certFile != "" {
if keyFile := os.Getenv("GRPC_TLS_KEY"); keyFile != "" {
cert, err := tls.LoadX509KeyPair(certFile, keyFile)
if err != nil {
return nil, fmt.Errorf("failed to load client certificate: %w", err)
}
tlsConfig.Certificates = []tls.Certificate{cert}
logger.Infof(context.Background(), "mTLS enabled (client certificate loaded)")
}
}
return credentials.NewTLS(tlsConfig), nil
}
type tokenAuth struct {
token string
}
func (t *tokenAuth) GetRequestMetadata(ctx context.Context, uri ...string) (map[string]string, error) {
return map[string]string{
"authorization": "Bearer " + t.token,
}, nil
}
func (t *tokenAuth) RequireTransportSecurity() bool {
return false
}