mirror of
https://github.com/fatedier/frp.git
synced 2026-06-04 11:34:23 +08:00
2
.github/workflows/golangci-lint.yml
vendored
2
.github/workflows/golangci-lint.yml
vendored
@@ -32,4 +32,4 @@ jobs:
|
||||
uses: golangci/golangci-lint-action@v9
|
||||
with:
|
||||
# Optional: version of golangci-lint to use in form of v1.2 or v1.2.3 or `latest` to use the latest version
|
||||
version: v2.10
|
||||
version: v2.11
|
||||
|
||||
1
.gitignore
vendored
1
.gitignore
vendored
@@ -18,6 +18,7 @@ release/
|
||||
test/bin/
|
||||
vendor/
|
||||
lastversion/
|
||||
.cache/
|
||||
dist/
|
||||
.idea/
|
||||
.vscode/
|
||||
|
||||
@@ -34,7 +34,7 @@ linters:
|
||||
disabled-checks:
|
||||
- exitAfterDefer
|
||||
gosec:
|
||||
excludes: ["G115", "G117", "G204", "G401", "G402", "G404", "G501", "G703", "G704", "G705"]
|
||||
excludes: ["G115", "G117", "G118", "G204", "G401", "G402", "G404", "G501", "G703", "G704", "G705"]
|
||||
severity: low
|
||||
confidence: low
|
||||
govet:
|
||||
|
||||
15
Makefile
15
Makefile
@@ -2,8 +2,10 @@ export PATH := $(PATH):`go env GOPATH`/bin
|
||||
export GO111MODULE=on
|
||||
LDFLAGS := -s -w
|
||||
NOWEB_TAG = $(shell [ ! -d web/frps/dist ] || [ ! -d web/frpc/dist ] && echo ',noweb')
|
||||
FRP_COMPAT_BASELINE_COUNT ?= 8
|
||||
FRP_COMPAT_FLOOR_VERSION ?= 0.61.0
|
||||
|
||||
.PHONY: web frps-web frpc-web frps frpc
|
||||
.PHONY: web frps-web frpc-web frps frpc e2e-compatibility-smoke e2e-compatibility e2e-compatibility-floor
|
||||
|
||||
all: env fmt web build
|
||||
|
||||
@@ -53,6 +55,15 @@ e2e:
|
||||
e2e-trace:
|
||||
DEBUG=true LOG_LEVEL=trace ./hack/run-e2e.sh
|
||||
|
||||
e2e-compatibility-smoke: build
|
||||
FRP_COMPAT_BASELINE_COUNT=1 ./hack/run-e2e-compatibility.sh
|
||||
|
||||
e2e-compatibility: build
|
||||
FRP_COMPAT_BASELINE_COUNT="$(FRP_COMPAT_BASELINE_COUNT)" ./hack/run-e2e-compatibility.sh
|
||||
|
||||
e2e-compatibility-floor: build
|
||||
FRP_COMPAT_BASELINE_VERSIONS="$(FRP_COMPAT_FLOOR_VERSION)" ./hack/run-e2e-compatibility.sh
|
||||
|
||||
e2e-compatibility-last-frpc:
|
||||
if [ ! -d "./lastversion" ]; then \
|
||||
TARGET_DIRNAME=lastversion ./hack/download.sh; \
|
||||
@@ -73,3 +84,5 @@ clean:
|
||||
rm -f ./bin/frpc
|
||||
rm -f ./bin/frps
|
||||
rm -rf ./lastversion
|
||||
rm -rf ./.cache
|
||||
rm -rf ./.compat
|
||||
|
||||
26
README.md
26
README.md
@@ -13,6 +13,14 @@ frp is an open source project with its ongoing development made possible entirel
|
||||
|
||||
<h3 align="center">Gold Sponsors</h3>
|
||||
<!--gold sponsors start-->
|
||||
<p align="center">
|
||||
<a href="https://jb.gg/frp" target="_blank">
|
||||
<img width="420px" src="https://raw.githubusercontent.com/fatedier/frp/dev/doc/pic/sponsor_jetbrains.jpg">
|
||||
<br>
|
||||
<b>The complete IDE crafted for professional Go developers</b>
|
||||
</a>
|
||||
</p>
|
||||
|
||||
<p align="center">
|
||||
<a href="https://github.com/beclab/Olares" target="_blank">
|
||||
<img width="420px" src="https://raw.githubusercontent.com/fatedier/frp/dev/doc/pic/sponsor_olares.jpeg">
|
||||
@@ -32,24 +40,6 @@ If you're looking for a meeting recording API, consider checking out [Recall.ai]
|
||||
an API that records Zoom, Google Meet, Microsoft Teams, in-person meetings, and more.
|
||||
|
||||
</div>
|
||||
|
||||
<p align="center">
|
||||
<a href="https://requestly.com/?utm_source=github&utm_medium=partnered&utm_campaign=frp" target="_blank">
|
||||
<img width="480px" src="https://github.com/user-attachments/assets/24670320-997d-4d62-9bca-955c59fe883d">
|
||||
<br>
|
||||
<b>Requestly - Free & Open-Source alternative to Postman</b>
|
||||
<br>
|
||||
<sub>All-in-one platform to Test, Mock and Intercept APIs.</sub>
|
||||
</a>
|
||||
</p>
|
||||
|
||||
<p align="center">
|
||||
<a href="https://jb.gg/frp" target="_blank">
|
||||
<img width="420px" src="https://raw.githubusercontent.com/fatedier/frp/dev/doc/pic/sponsor_jetbrains.jpg">
|
||||
<br>
|
||||
<b>The complete IDE crafted for professional Go developers</b>
|
||||
</a>
|
||||
</p>
|
||||
<!--gold sponsors end-->
|
||||
|
||||
## What is frp?
|
||||
|
||||
26
README_zh.md
26
README_zh.md
@@ -15,6 +15,14 @@ frp 是一个完全开源的项目,我们的开发工作完全依靠赞助者
|
||||
|
||||
<h3 align="center">Gold Sponsors</h3>
|
||||
<!--gold sponsors start-->
|
||||
<p align="center">
|
||||
<a href="https://jb.gg/frp" target="_blank">
|
||||
<img width="420px" src="https://raw.githubusercontent.com/fatedier/frp/dev/doc/pic/sponsor_jetbrains.jpg">
|
||||
<br>
|
||||
<b>The complete IDE crafted for professional Go developers</b>
|
||||
</a>
|
||||
</p>
|
||||
|
||||
<p align="center">
|
||||
<a href="https://github.com/beclab/Olares" target="_blank">
|
||||
<img width="420px" src="https://raw.githubusercontent.com/fatedier/frp/dev/doc/pic/sponsor_olares.jpeg">
|
||||
@@ -34,24 +42,6 @@ If you're looking for a meeting recording API, consider checking out [Recall.ai]
|
||||
an API that records Zoom, Google Meet, Microsoft Teams, in-person meetings, and more.
|
||||
|
||||
</div>
|
||||
|
||||
<p align="center">
|
||||
<a href="https://requestly.com/?utm_source=github&utm_medium=partnered&utm_campaign=frp" target="_blank">
|
||||
<img width="480px" src="https://github.com/user-attachments/assets/24670320-997d-4d62-9bca-955c59fe883d">
|
||||
<br>
|
||||
<b>Requestly - Free & Open-Source alternative to Postman</b>
|
||||
<br>
|
||||
<sub>All-in-one platform to Test, Mock and Intercept APIs.</sub>
|
||||
</a>
|
||||
</p>
|
||||
|
||||
<p align="center">
|
||||
<a href="https://jb.gg/frp" target="_blank">
|
||||
<img width="420px" src="https://raw.githubusercontent.com/fatedier/frp/dev/doc/pic/sponsor_jetbrains.jpg">
|
||||
<br>
|
||||
<b>The complete IDE crafted for professional Go developers</b>
|
||||
</a>
|
||||
</p>
|
||||
<!--gold sponsors end-->
|
||||
|
||||
## 为什么使用 frp ?
|
||||
|
||||
22
Release.md
22
Release.md
@@ -1,3 +1,21 @@
|
||||
## Fixes
|
||||
## Compatibility Policy
|
||||
|
||||
* Fixed a configuration-dependent authentication bypass in `type = "http"` proxies when `routeByHTTPUser` is used together with `httpUser` / `httpPassword`. This affected proxy-style requests. Proxy-style authentication failures now return `407 Proxy Authentication Required`.
|
||||
Starting with v0.69.0, each minor release is supported until there are nine newer minor releases. For example, v0.69.0 will be supported until v0.78.0 is released. Within this window, frpc v0.69.0 is guaranteed to work with any frps from v0.61.0 to v0.77.0, and vice versa. Patch releases within the same minor are always compatible. Versions outside the support window may continue to work on a best-effort basis, but compatibility is no longer guaranteed.
|
||||
|
||||
For mixed-version deployments, upgrade frps first, then upgrade frpc. This keeps the server side ready for newer client-side protocol behavior before clients start using it.
|
||||
|
||||
## Notes
|
||||
|
||||
This release introduces wire protocol v2 as a transition path for future frpc/frps protocol changes. The existing wire protocol is difficult to extend without compatibility risk, and upcoming changes, including replacing deprecated stream encryption methods, require a versioned protocol.
|
||||
|
||||
**The default value of `transport.wireProtocol` remains `v1` in this release.** Users can keep the default for now. To test v2 early, upgrade both frpc and frps to versions that support it, then set `transport.wireProtocol = "v2"` in frpc. A v2-enabled frpc cannot connect to an older frps.
|
||||
|
||||
When `transport.wireProtocol = "v2"` is enabled, the control channel uses negotiated AEAD encryption after the login handshake. Both frpc and frps must be upgraded to this release to use v2.
|
||||
|
||||
v1 will be deprecated when v2 becomes the default in a future release. It will continue to be supported until v0.78.0 is released, and may be removed in v0.78.0 or later.
|
||||
|
||||
## Features
|
||||
|
||||
* Added `transport.wireProtocol` for frpc to select the internal message protocol used between frpc and frps. Supported values are `v1` and `v2`.
|
||||
* Added client protocol visibility in the frps dashboard and `/api/clients` API. Online clients now report their negotiated protocol as `v1` or `v2`.
|
||||
* Wire protocol v2 now negotiates AEAD control-channel encryption. Supported algorithms are `xchacha20-poly1305` and `aes-256-gcm`; frpc advertises its preferred order based on local AES-GCM hardware support, and frps selects the first supported algorithm from that list.
|
||||
|
||||
@@ -29,6 +29,8 @@ import (
|
||||
"github.com/samber/lo"
|
||||
|
||||
v1 "github.com/fatedier/frp/pkg/config/v1"
|
||||
"github.com/fatedier/frp/pkg/msg"
|
||||
"github.com/fatedier/frp/pkg/proto/wire"
|
||||
"github.com/fatedier/frp/pkg/transport"
|
||||
netpkg "github.com/fatedier/frp/pkg/util/net"
|
||||
"github.com/fatedier/frp/pkg/util/xlog"
|
||||
@@ -41,6 +43,39 @@ type Connector interface {
|
||||
Close() error
|
||||
}
|
||||
|
||||
type MessageConnector interface {
|
||||
Connect() (*msg.Conn, error)
|
||||
Close() error
|
||||
}
|
||||
|
||||
type messageConnector struct {
|
||||
connector Connector
|
||||
wireProtocol string
|
||||
}
|
||||
|
||||
func newMessageConnector(connector Connector, wireProtocol string) *messageConnector {
|
||||
return &messageConnector{
|
||||
connector: connector,
|
||||
wireProtocol: wireProtocol,
|
||||
}
|
||||
}
|
||||
|
||||
func (c *messageConnector) Connect() (*msg.Conn, error) {
|
||||
conn, err := c.connector.Connect()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err = wire.WriteMagicIfV2(conn, c.wireProtocol); err != nil {
|
||||
conn.Close()
|
||||
return nil, err
|
||||
}
|
||||
return msg.NewConn(conn, msg.NewReadWriter(conn, c.wireProtocol)), nil
|
||||
}
|
||||
|
||||
func (c *messageConnector) Close() error {
|
||||
return c.connector.Close()
|
||||
}
|
||||
|
||||
// defaultConnectorImpl is the default implementation of Connector for normal frpc.
|
||||
type defaultConnectorImpl struct {
|
||||
ctx context.Context
|
||||
|
||||
@@ -27,7 +27,6 @@ import (
|
||||
"github.com/fatedier/frp/pkg/msg"
|
||||
"github.com/fatedier/frp/pkg/naming"
|
||||
"github.com/fatedier/frp/pkg/transport"
|
||||
netpkg "github.com/fatedier/frp/pkg/util/net"
|
||||
"github.com/fatedier/frp/pkg/util/wait"
|
||||
"github.com/fatedier/frp/pkg/util/xlog"
|
||||
"github.com/fatedier/frp/pkg/vnet"
|
||||
@@ -41,13 +40,11 @@ type SessionContext struct {
|
||||
// It should be attached to the login message when reconnecting.
|
||||
RunID string
|
||||
// Underlying control connection. Once conn is closed, the msgDispatcher and the entire Control will exit.
|
||||
Conn net.Conn
|
||||
// Indicates whether the connection is encrypted.
|
||||
ConnEncrypted bool
|
||||
Conn *msg.Conn
|
||||
// Auth runtime used for login, heartbeats, and encryption.
|
||||
Auth *auth.ClientAuth
|
||||
// Connector is used to create new connections, which could be real TCP connections or virtual streams.
|
||||
Connector Connector
|
||||
// Connector is used to create message connections to frps.
|
||||
Connector MessageConnector
|
||||
// Virtual net controller
|
||||
VnetController *vnet.Controller
|
||||
}
|
||||
@@ -91,15 +88,7 @@ func NewControl(ctx context.Context, sessionCtx *SessionContext) (*Control, erro
|
||||
}
|
||||
ctl.lastPong.Store(time.Now())
|
||||
|
||||
if sessionCtx.ConnEncrypted {
|
||||
cryptoRW, err := netpkg.NewCryptoReadWriter(sessionCtx.Conn, sessionCtx.Auth.EncryptionKey())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
ctl.msgDispatcher = msg.NewDispatcher(cryptoRW)
|
||||
} else {
|
||||
ctl.msgDispatcher = msg.NewDispatcher(sessionCtx.Conn)
|
||||
}
|
||||
ctl.msgDispatcher = msg.NewDispatcher(sessionCtx.Conn)
|
||||
ctl.registerMsgHandlers()
|
||||
ctl.msgTransporter = transport.NewMessageTransporter(ctl.msgDispatcher)
|
||||
|
||||
@@ -139,14 +128,14 @@ func (ctl *Control) handleReqWorkConn(_ msg.Message) {
|
||||
workConn.Close()
|
||||
return
|
||||
}
|
||||
if err = msg.WriteMsg(workConn, m); err != nil {
|
||||
if err = workConn.WriteMsg(m); err != nil {
|
||||
xl.Warnf("work connection write to server error: %v", err)
|
||||
workConn.Close()
|
||||
return
|
||||
}
|
||||
|
||||
var startMsg msg.StartWorkConn
|
||||
if err = msg.ReadMsgInto(workConn, &startMsg); err != nil {
|
||||
if err = workConn.ReadMsgInto(&startMsg); err != nil {
|
||||
xl.Tracef("work connection closed before response StartWorkConn message: %v", err)
|
||||
workConn.Close()
|
||||
return
|
||||
@@ -227,7 +216,7 @@ func (ctl *Control) Done() <-chan struct{} {
|
||||
}
|
||||
|
||||
// connectServer return a new connection to frps
|
||||
func (ctl *Control) connectServer() (net.Conn, error) {
|
||||
func (ctl *Control) connectServer() (*msg.Conn, error) {
|
||||
return ctl.sessionCtx.Connector.Connect()
|
||||
}
|
||||
|
||||
|
||||
220
client/control_session.go
Normal file
220
client/control_session.go
Normal file
@@ -0,0 +1,220 @@
|
||||
// Copyright 2026 The frp Authors
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package client
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"os"
|
||||
"runtime"
|
||||
"time"
|
||||
|
||||
"github.com/samber/lo"
|
||||
|
||||
"github.com/fatedier/frp/pkg/auth"
|
||||
v1 "github.com/fatedier/frp/pkg/config/v1"
|
||||
"github.com/fatedier/frp/pkg/msg"
|
||||
"github.com/fatedier/frp/pkg/proto/wire"
|
||||
netpkg "github.com/fatedier/frp/pkg/util/net"
|
||||
"github.com/fatedier/frp/pkg/util/version"
|
||||
"github.com/fatedier/frp/pkg/vnet"
|
||||
)
|
||||
|
||||
type controlSessionDialer struct {
|
||||
ctx context.Context
|
||||
|
||||
common *v1.ClientCommonConfig
|
||||
auth *auth.ClientAuth
|
||||
clientSpec *msg.ClientSpec
|
||||
vnetController *vnet.Controller
|
||||
|
||||
connectorCreator func(context.Context, *v1.ClientCommonConfig) Connector
|
||||
}
|
||||
|
||||
func (d *controlSessionDialer) Dial(previousRunID string) (*SessionContext, error) {
|
||||
connector := d.connectorCreator(d.ctx, d.common)
|
||||
if err := connector.Open(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
success := false
|
||||
defer func() {
|
||||
if !success {
|
||||
_ = connector.Close()
|
||||
}
|
||||
}()
|
||||
|
||||
conn, err := connector.Connect()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer func() {
|
||||
if !success {
|
||||
_ = conn.Close()
|
||||
}
|
||||
}()
|
||||
|
||||
loginMsg, err := d.buildLoginMsg(previousRunID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
loginResult, err := d.exchangeLogin(conn, loginMsg)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
loginRespMsg := loginResult.resp
|
||||
if loginRespMsg.Error != "" {
|
||||
return nil, errors.New(loginRespMsg.Error)
|
||||
}
|
||||
|
||||
var controlRW io.ReadWriter = conn
|
||||
if d.clientSpec == nil || d.clientSpec.Type != "ssh-tunnel" {
|
||||
controlRW, err = d.newControlReadWriter(conn, loginResult.crypto)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("create control crypto read writer: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
success = true
|
||||
return &SessionContext{
|
||||
Common: d.common,
|
||||
RunID: loginRespMsg.RunID,
|
||||
Conn: msg.NewConn(conn, msg.NewReadWriter(controlRW, d.common.Transport.WireProtocol)),
|
||||
Auth: d.auth,
|
||||
Connector: newMessageConnector(connector, d.common.Transport.WireProtocol),
|
||||
VnetController: d.vnetController,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (d *controlSessionDialer) buildLoginMsg(previousRunID string) (*msg.Login, error) {
|
||||
hostname, _ := os.Hostname()
|
||||
loginMsg := &msg.Login{
|
||||
Arch: runtime.GOARCH,
|
||||
Os: runtime.GOOS,
|
||||
Hostname: hostname,
|
||||
PoolCount: d.common.Transport.PoolCount,
|
||||
User: d.common.User,
|
||||
ClientID: d.common.ClientID,
|
||||
Version: version.Full(),
|
||||
Timestamp: time.Now().Unix(),
|
||||
RunID: previousRunID,
|
||||
Metas: d.common.Metadatas,
|
||||
}
|
||||
if d.clientSpec != nil {
|
||||
loginMsg.ClientSpec = *d.clientSpec
|
||||
}
|
||||
|
||||
if err := d.auth.Setter.SetLogin(loginMsg); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return loginMsg, nil
|
||||
}
|
||||
|
||||
type loginExchangeResult struct {
|
||||
resp *msg.LoginResp
|
||||
crypto *wire.CryptoContext
|
||||
}
|
||||
|
||||
func (d *controlSessionDialer) exchangeLogin(conn net.Conn, loginMsg *msg.Login) (*loginExchangeResult, error) {
|
||||
rw := msg.NewV1ReadWriter(conn)
|
||||
var wireConn *wire.Conn
|
||||
var clientHello wire.ClientHello
|
||||
var clientHelloPayload []byte
|
||||
|
||||
if d.common.Transport.WireProtocol == wire.ProtocolV2 {
|
||||
if err := wire.WriteMagic(conn); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
wireConn = wire.NewConn(conn)
|
||||
rw = msg.NewV2ReadWriterWithConn(wireConn)
|
||||
var err error
|
||||
clientHello, err = wire.NewClientHello(wire.BootstrapInfo{
|
||||
Transport: d.common.Transport.Protocol,
|
||||
TLS: lo.FromPtr(d.common.Transport.TLS.Enable) || d.common.Transport.Protocol == "wss" || d.common.Transport.Protocol == "quic",
|
||||
TCPMux: lo.FromPtr(d.common.Transport.TCPMux),
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
clientHelloFrame, err := wire.NewJSONFrame(wire.FrameTypeClientHello, clientHello)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err := wireConn.WriteFrame(clientHelloFrame); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
clientHelloPayload = clientHelloFrame.Payload
|
||||
}
|
||||
if err := rw.WriteMsg(loginMsg); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
_ = conn.SetReadDeadline(time.Now().Add(10 * time.Second))
|
||||
defer func() {
|
||||
_ = conn.SetReadDeadline(time.Time{})
|
||||
}()
|
||||
|
||||
var cryptoContext *wire.CryptoContext
|
||||
if wireConn != nil {
|
||||
serverHelloFrame, err := wireConn.ReadFrame()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if serverHelloFrame.Type != wire.FrameTypeServerHello {
|
||||
return nil, fmt.Errorf("unexpected frame type %d, want %d", serverHelloFrame.Type, wire.FrameTypeServerHello)
|
||||
}
|
||||
var serverHello wire.ServerHello
|
||||
if err := wireConn.UnmarshalFrame(serverHelloFrame, &serverHello); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if serverHello.Error != "" {
|
||||
return nil, errors.New(serverHello.Error)
|
||||
}
|
||||
cryptoContext, err = wire.NewClientCryptoContext(clientHelloPayload, serverHelloFrame.Payload)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
var loginRespMsg msg.LoginResp
|
||||
if err := rw.ReadMsgInto(&loginRespMsg); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &loginExchangeResult{
|
||||
resp: &loginRespMsg,
|
||||
crypto: cryptoContext,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (d *controlSessionDialer) newControlReadWriter(conn net.Conn, cryptoContext *wire.CryptoContext) (io.ReadWriter, error) {
|
||||
if d.common.Transport.WireProtocol == wire.ProtocolV2 {
|
||||
if cryptoContext == nil {
|
||||
return nil, errors.New("missing v2 crypto negotiation")
|
||||
}
|
||||
return netpkg.NewAEADCryptoReadWriter(
|
||||
conn,
|
||||
d.auth.EncryptionKey(),
|
||||
netpkg.AEADCryptoRoleClient,
|
||||
cryptoContext.Algorithm,
|
||||
cryptoContext.TranscriptHash,
|
||||
)
|
||||
}
|
||||
return netpkg.NewCryptoReadWriter(conn, d.auth.EncryptionKey())
|
||||
}
|
||||
297
client/control_session_test.go
Normal file
297
client/control_session_test.go
Normal file
@@ -0,0 +1,297 @@
|
||||
// Copyright 2026 The frp Authors
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package client
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/fatedier/frp/pkg/auth"
|
||||
v1 "github.com/fatedier/frp/pkg/config/v1"
|
||||
"github.com/fatedier/frp/pkg/msg"
|
||||
"github.com/fatedier/frp/pkg/proto/wire"
|
||||
netpkg "github.com/fatedier/frp/pkg/util/net"
|
||||
)
|
||||
|
||||
type testConnector struct {
|
||||
conn net.Conn
|
||||
closed atomic.Bool
|
||||
}
|
||||
|
||||
func (c *testConnector) Open() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *testConnector) Connect() (net.Conn, error) {
|
||||
return c.conn, nil
|
||||
}
|
||||
|
||||
func (c *testConnector) Close() error {
|
||||
c.closed.Store(true)
|
||||
return nil
|
||||
}
|
||||
|
||||
type trackingConn struct {
|
||||
net.Conn
|
||||
closed atomic.Bool
|
||||
}
|
||||
|
||||
func (c *trackingConn) Close() error {
|
||||
c.closed.Store(true)
|
||||
return c.Conn.Close()
|
||||
}
|
||||
|
||||
func newTestControlSessionDialer(t *testing.T, protocol string, connector Connector, clientSpec *msg.ClientSpec) *controlSessionDialer {
|
||||
t.Helper()
|
||||
|
||||
authRuntime, err := auth.BuildClientAuth(&v1.AuthClientConfig{
|
||||
Method: v1.AuthMethodToken,
|
||||
Token: "token",
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
return &controlSessionDialer{
|
||||
ctx: context.Background(),
|
||||
common: &v1.ClientCommonConfig{
|
||||
User: "test-user",
|
||||
Transport: v1.ClientTransportConfig{
|
||||
Protocol: "tcp",
|
||||
WireProtocol: protocol,
|
||||
},
|
||||
},
|
||||
auth: authRuntime,
|
||||
clientSpec: clientSpec,
|
||||
connectorCreator: func(context.Context, *v1.ClientCommonConfig) Connector {
|
||||
return connector
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func TestControlSessionDialerDialV1(t *testing.T) {
|
||||
clientRaw, serverRaw := net.Pipe()
|
||||
defer serverRaw.Close()
|
||||
|
||||
connector := &testConnector{conn: &trackingConn{Conn: clientRaw}}
|
||||
serverErrCh := make(chan error, 1)
|
||||
go func() {
|
||||
rw := msg.NewV1ReadWriter(serverRaw)
|
||||
var loginMsg msg.Login
|
||||
if err := rw.ReadMsgInto(&loginMsg); err != nil {
|
||||
serverErrCh <- err
|
||||
return
|
||||
}
|
||||
if loginMsg.RunID != "previous-run-id" {
|
||||
serverErrCh <- fmt.Errorf("unexpected previous run id: %s", loginMsg.RunID)
|
||||
return
|
||||
}
|
||||
if loginMsg.User != "test-user" {
|
||||
serverErrCh <- fmt.Errorf("unexpected user: %s", loginMsg.User)
|
||||
return
|
||||
}
|
||||
serverErrCh <- rw.WriteMsg(&msg.LoginResp{RunID: "run-v1"})
|
||||
}()
|
||||
|
||||
dialer := newTestControlSessionDialer(t, wire.ProtocolV1, connector, nil)
|
||||
sessionCtx, err := dialer.Dial("previous-run-id")
|
||||
require.NoError(t, err)
|
||||
defer sessionCtx.Conn.Close()
|
||||
defer sessionCtx.Connector.Close()
|
||||
|
||||
require.Equal(t, "run-v1", sessionCtx.RunID)
|
||||
require.NotNil(t, sessionCtx.Conn)
|
||||
require.NotNil(t, sessionCtx.Connector)
|
||||
require.False(t, connector.closed.Load())
|
||||
require.NoError(t, <-serverErrCh)
|
||||
}
|
||||
|
||||
func TestControlSessionDialerDialV2(t *testing.T) {
|
||||
clientRaw, serverRaw := net.Pipe()
|
||||
defer serverRaw.Close()
|
||||
|
||||
connector := &testConnector{conn: &trackingConn{Conn: clientRaw}}
|
||||
serverErrCh := make(chan error, 1)
|
||||
go func() {
|
||||
magic := make([]byte, len(wire.MagicV2))
|
||||
if _, err := io.ReadFull(serverRaw, magic); err != nil {
|
||||
serverErrCh <- err
|
||||
return
|
||||
}
|
||||
if string(magic) != wire.MagicV2 {
|
||||
serverErrCh <- fmt.Errorf("unexpected magic: %q", string(magic))
|
||||
return
|
||||
}
|
||||
|
||||
wireConn := wire.NewConn(serverRaw)
|
||||
clientHelloFrame, err := wireConn.ReadFrame()
|
||||
if err != nil {
|
||||
serverErrCh <- err
|
||||
return
|
||||
}
|
||||
if clientHelloFrame.Type != wire.FrameTypeClientHello {
|
||||
serverErrCh <- fmt.Errorf("unexpected frame type %d, want %d", clientHelloFrame.Type, wire.FrameTypeClientHello)
|
||||
return
|
||||
}
|
||||
var hello wire.ClientHello
|
||||
if err := wireConn.UnmarshalFrame(clientHelloFrame, &hello); err != nil {
|
||||
serverErrCh <- err
|
||||
return
|
||||
}
|
||||
if err := wire.ValidateClientHello(hello); err != nil {
|
||||
serverErrCh <- err
|
||||
return
|
||||
}
|
||||
|
||||
rw := msg.NewV2ReadWriterWithConn(wireConn)
|
||||
var loginMsg msg.Login
|
||||
if err := rw.ReadMsgInto(&loginMsg); err != nil {
|
||||
serverErrCh <- err
|
||||
return
|
||||
}
|
||||
if loginMsg.User != "test-user" {
|
||||
serverErrCh <- fmt.Errorf("unexpected user: %s", loginMsg.User)
|
||||
return
|
||||
}
|
||||
serverHello, err := wire.NewServerHello(hello)
|
||||
if err != nil {
|
||||
serverErrCh <- err
|
||||
return
|
||||
}
|
||||
serverHelloFrame, err := wire.NewJSONFrame(wire.FrameTypeServerHello, serverHello)
|
||||
if err != nil {
|
||||
serverErrCh <- err
|
||||
return
|
||||
}
|
||||
cryptoContext := wire.NewCryptoContext(
|
||||
serverHello.Selected.Crypto.Algorithm,
|
||||
clientHelloFrame.Payload,
|
||||
serverHelloFrame.Payload,
|
||||
)
|
||||
if err := wireConn.WriteFrame(serverHelloFrame); err != nil {
|
||||
serverErrCh <- err
|
||||
return
|
||||
}
|
||||
if err := rw.WriteMsg(&msg.LoginResp{RunID: "run-v2"}); err != nil {
|
||||
serverErrCh <- err
|
||||
return
|
||||
}
|
||||
|
||||
controlRW, err := netpkg.NewAEADCryptoReadWriter(
|
||||
serverRaw,
|
||||
[]byte("token"),
|
||||
netpkg.AEADCryptoRoleServer,
|
||||
cryptoContext.Algorithm,
|
||||
cryptoContext.TranscriptHash,
|
||||
)
|
||||
if err != nil {
|
||||
serverErrCh <- err
|
||||
return
|
||||
}
|
||||
controlMsgRW := msg.NewReadWriter(controlRW, wire.ProtocolV2)
|
||||
var ping msg.Ping
|
||||
if err := controlMsgRW.ReadMsgInto(&ping); err != nil {
|
||||
serverErrCh <- err
|
||||
return
|
||||
}
|
||||
if ping.PrivilegeKey != "v2-ping" || ping.Timestamp != 12345 {
|
||||
serverErrCh <- fmt.Errorf("unexpected ping: %+v", ping)
|
||||
return
|
||||
}
|
||||
serverErrCh <- nil
|
||||
}()
|
||||
|
||||
dialer := newTestControlSessionDialer(t, wire.ProtocolV2, connector, nil)
|
||||
sessionCtx, err := dialer.Dial("")
|
||||
require.NoError(t, err)
|
||||
defer sessionCtx.Conn.Close()
|
||||
defer sessionCtx.Connector.Close()
|
||||
|
||||
require.Equal(t, "run-v2", sessionCtx.RunID)
|
||||
require.NotNil(t, sessionCtx.Conn)
|
||||
require.NotNil(t, sessionCtx.Connector)
|
||||
require.False(t, connector.closed.Load())
|
||||
require.NoError(t, sessionCtx.Conn.WriteMsg(&msg.Ping{PrivilegeKey: "v2-ping", Timestamp: 12345}))
|
||||
require.NoError(t, <-serverErrCh)
|
||||
}
|
||||
|
||||
func TestControlSessionDialerDialLoginErrorClosesResources(t *testing.T) {
|
||||
clientRaw, serverRaw := net.Pipe()
|
||||
defer serverRaw.Close()
|
||||
|
||||
clientConn := &trackingConn{Conn: clientRaw}
|
||||
connector := &testConnector{conn: clientConn}
|
||||
serverErrCh := make(chan error, 1)
|
||||
go func() {
|
||||
rw := msg.NewV1ReadWriter(serverRaw)
|
||||
var loginMsg msg.Login
|
||||
if err := rw.ReadMsgInto(&loginMsg); err != nil {
|
||||
serverErrCh <- err
|
||||
return
|
||||
}
|
||||
serverErrCh <- rw.WriteMsg(&msg.LoginResp{Error: "login denied"})
|
||||
}()
|
||||
|
||||
dialer := newTestControlSessionDialer(t, wire.ProtocolV1, connector, nil)
|
||||
sessionCtx, err := dialer.Dial("")
|
||||
require.Nil(t, sessionCtx)
|
||||
require.ErrorContains(t, err, "login denied")
|
||||
require.True(t, clientConn.closed.Load())
|
||||
require.True(t, connector.closed.Load())
|
||||
require.NoError(t, <-serverErrCh)
|
||||
}
|
||||
|
||||
func TestControlSessionDialerDialSSHTunnelSkipsControlEncryption(t *testing.T) {
|
||||
clientRaw, serverRaw := net.Pipe()
|
||||
defer serverRaw.Close()
|
||||
|
||||
connector := &testConnector{conn: &trackingConn{Conn: clientRaw}}
|
||||
serverErrCh := make(chan error, 1)
|
||||
go func() {
|
||||
rw := msg.NewV1ReadWriter(serverRaw)
|
||||
var loginMsg msg.Login
|
||||
if err := rw.ReadMsgInto(&loginMsg); err != nil {
|
||||
serverErrCh <- err
|
||||
return
|
||||
}
|
||||
if err := rw.WriteMsg(&msg.LoginResp{RunID: "run-ssh-tunnel"}); err != nil {
|
||||
serverErrCh <- err
|
||||
return
|
||||
}
|
||||
|
||||
_ = serverRaw.SetReadDeadline(time.Now().Add(time.Second))
|
||||
var ping msg.Ping
|
||||
if err := rw.ReadMsgInto(&ping); err != nil {
|
||||
serverErrCh <- err
|
||||
return
|
||||
}
|
||||
serverErrCh <- nil
|
||||
}()
|
||||
|
||||
dialer := newTestControlSessionDialer(t, wire.ProtocolV1, connector, &msg.ClientSpec{Type: "ssh-tunnel"})
|
||||
sessionCtx, err := dialer.Dial("")
|
||||
require.NoError(t, err)
|
||||
defer sessionCtx.Conn.Close()
|
||||
defer sessionCtx.Connector.Close()
|
||||
|
||||
require.Equal(t, "run-ssh-tunnel", sessionCtx.RunID)
|
||||
require.NoError(t, sessionCtx.Conn.WriteMsg(&msg.Ping{}))
|
||||
require.NoError(t, <-serverErrCh)
|
||||
}
|
||||
@@ -21,7 +21,6 @@ import (
|
||||
"net"
|
||||
"net/http"
|
||||
"os"
|
||||
"runtime"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
@@ -38,7 +37,6 @@ import (
|
||||
httppkg "github.com/fatedier/frp/pkg/util/http"
|
||||
"github.com/fatedier/frp/pkg/util/log"
|
||||
netpkg "github.com/fatedier/frp/pkg/util/net"
|
||||
"github.com/fatedier/frp/pkg/util/version"
|
||||
"github.com/fatedier/frp/pkg/util/wait"
|
||||
"github.com/fatedier/frp/pkg/util/xlog"
|
||||
"github.com/fatedier/frp/pkg/vnet"
|
||||
@@ -303,80 +301,20 @@ func (svr *Service) keepControllerWorking() {
|
||||
), true, svr.ctx.Done())
|
||||
}
|
||||
|
||||
// login creates a connection to frps and registers it self as a client
|
||||
// conn: control connection
|
||||
// session: if it's not nil, using tcp mux
|
||||
func (svr *Service) login() (conn net.Conn, connector Connector, err error) {
|
||||
xl := xlog.FromContextSafe(svr.ctx)
|
||||
connector = svr.connectorCreator(svr.ctx, svr.common)
|
||||
if err = connector.Open(); err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
defer func() {
|
||||
if err != nil {
|
||||
connector.Close()
|
||||
}
|
||||
}()
|
||||
|
||||
conn, err = connector.Connect()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
hostname, _ := os.Hostname()
|
||||
|
||||
loginMsg := &msg.Login{
|
||||
Arch: runtime.GOARCH,
|
||||
Os: runtime.GOOS,
|
||||
Hostname: hostname,
|
||||
PoolCount: svr.common.Transport.PoolCount,
|
||||
User: svr.common.User,
|
||||
ClientID: svr.common.ClientID,
|
||||
Version: version.Full(),
|
||||
Timestamp: time.Now().Unix(),
|
||||
RunID: svr.runID,
|
||||
Metas: svr.common.Metadatas,
|
||||
}
|
||||
if svr.clientSpec != nil {
|
||||
loginMsg.ClientSpec = *svr.clientSpec
|
||||
}
|
||||
|
||||
// Add auth
|
||||
if err = svr.auth.Setter.SetLogin(loginMsg); err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
if err = msg.WriteMsg(conn, loginMsg); err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
var loginRespMsg msg.LoginResp
|
||||
_ = conn.SetReadDeadline(time.Now().Add(10 * time.Second))
|
||||
if err = msg.ReadMsgInto(conn, &loginRespMsg); err != nil {
|
||||
return
|
||||
}
|
||||
_ = conn.SetReadDeadline(time.Time{})
|
||||
|
||||
if loginRespMsg.Error != "" {
|
||||
err = fmt.Errorf("%s", loginRespMsg.Error)
|
||||
xl.Errorf("%s", loginRespMsg.Error)
|
||||
return
|
||||
}
|
||||
|
||||
svr.runID = loginRespMsg.RunID
|
||||
xl.AddPrefix(xlog.LogPrefix{Name: "runID", Value: svr.runID})
|
||||
|
||||
xl.Infof("login to server success, get run id [%s]", loginRespMsg.RunID)
|
||||
return
|
||||
}
|
||||
|
||||
func (svr *Service) loopLoginUntilSuccess(maxInterval time.Duration, firstLoginExit bool) {
|
||||
xl := xlog.FromContextSafe(svr.ctx)
|
||||
|
||||
loginFunc := func() (bool, error) {
|
||||
xl.Infof("try to connect to server...")
|
||||
conn, connector, err := svr.login()
|
||||
dialer := &controlSessionDialer{
|
||||
ctx: svr.ctx,
|
||||
common: svr.common,
|
||||
auth: svr.auth,
|
||||
clientSpec: svr.clientSpec,
|
||||
vnetController: svr.vnetController,
|
||||
connectorCreator: svr.connectorCreator,
|
||||
}
|
||||
sessionCtx, err := dialer.Dial(svr.runID)
|
||||
if err != nil {
|
||||
xl.Warnf("connect to server error: %v", err)
|
||||
if firstLoginExit {
|
||||
@@ -385,25 +323,19 @@ func (svr *Service) loopLoginUntilSuccess(maxInterval time.Duration, firstLoginE
|
||||
return false, err
|
||||
}
|
||||
|
||||
svr.runID = sessionCtx.RunID
|
||||
xl.AddPrefix(xlog.LogPrefix{Name: "runID", Value: svr.runID})
|
||||
xl.Infof("login to server success, get run id [%s]", svr.runID)
|
||||
|
||||
svr.cfgMu.RLock()
|
||||
proxyCfgs := svr.proxyCfgs
|
||||
visitorCfgs := svr.visitorCfgs
|
||||
svr.cfgMu.RUnlock()
|
||||
|
||||
connEncrypted := svr.clientSpec == nil || svr.clientSpec.Type != "ssh-tunnel"
|
||||
|
||||
sessionCtx := &SessionContext{
|
||||
Common: svr.common,
|
||||
RunID: svr.runID,
|
||||
Conn: conn,
|
||||
ConnEncrypted: connEncrypted,
|
||||
Auth: svr.auth,
|
||||
Connector: connector,
|
||||
VnetController: svr.vnetController,
|
||||
}
|
||||
ctl, err := NewControl(svr.ctx, sessionCtx)
|
||||
if err != nil {
|
||||
conn.Close()
|
||||
sessionCtx.Conn.Close()
|
||||
sessionCtx.Connector.Close()
|
||||
xl.Errorf("new control error: %v", err)
|
||||
return false, err
|
||||
}
|
||||
|
||||
@@ -38,7 +38,7 @@ import (
|
||||
// Helper wraps some functions for visitor to use.
|
||||
type Helper interface {
|
||||
// ConnectServer directly connects to the frp server.
|
||||
ConnectServer() (net.Conn, error)
|
||||
ConnectServer() (*msg.Conn, error)
|
||||
// TransferConn transfers the connection to another visitor.
|
||||
TransferConn(string, net.Conn) error
|
||||
// MsgTransporter returns the message transporter that is used to send and receive messages
|
||||
@@ -167,15 +167,15 @@ func (v *BaseVisitor) dialRawVisitorConn(cfg *v1.VisitorBaseConfig) (net.Conn, e
|
||||
UseEncryption: cfg.Transport.UseEncryption,
|
||||
UseCompression: cfg.Transport.UseCompression,
|
||||
}
|
||||
err = msg.WriteMsg(visitorConn, newVisitorConnMsg)
|
||||
err = visitorConn.WriteMsg(newVisitorConnMsg)
|
||||
if err != nil {
|
||||
visitorConn.Close()
|
||||
return nil, fmt.Errorf("send newVisitorConnMsg to server error: %v", err)
|
||||
}
|
||||
|
||||
var newVisitorConnRespMsg msg.NewVisitorConnResp
|
||||
_ = visitorConn.SetReadDeadline(time.Now().Add(10 * time.Second))
|
||||
err = msg.ReadMsgInto(visitorConn, &newVisitorConnRespMsg)
|
||||
var newVisitorConnRespMsg msg.NewVisitorConnResp
|
||||
err = visitorConn.ReadMsgInto(&newVisitorConnRespMsg)
|
||||
if err != nil {
|
||||
visitorConn.Close()
|
||||
return nil, fmt.Errorf("read newVisitorConnRespMsg error: %v", err)
|
||||
|
||||
@@ -25,6 +25,7 @@ import (
|
||||
"github.com/samber/lo"
|
||||
|
||||
v1 "github.com/fatedier/frp/pkg/config/v1"
|
||||
"github.com/fatedier/frp/pkg/msg"
|
||||
"github.com/fatedier/frp/pkg/transport"
|
||||
"github.com/fatedier/frp/pkg/util/xlog"
|
||||
"github.com/fatedier/frp/pkg/vnet"
|
||||
@@ -49,7 +50,7 @@ func NewManager(
|
||||
ctx context.Context,
|
||||
runID string,
|
||||
clientCfg *v1.ClientCommonConfig,
|
||||
connectServer func() (net.Conn, error),
|
||||
connectServer func() (*msg.Conn, error),
|
||||
msgTransporter transport.MessageTransporter,
|
||||
vnetController *vnet.Controller,
|
||||
) *Manager {
|
||||
@@ -199,14 +200,14 @@ func (vm *Manager) GetVisitorCfg(name string) (v1.VisitorConfigurer, bool) {
|
||||
}
|
||||
|
||||
type visitorHelperImpl struct {
|
||||
connectServerFn func() (net.Conn, error)
|
||||
connectServerFn func() (*msg.Conn, error)
|
||||
msgTransporter transport.MessageTransporter
|
||||
vnetController *vnet.Controller
|
||||
transferConnFn func(name string, conn net.Conn) error
|
||||
runID string
|
||||
}
|
||||
|
||||
func (v *visitorHelperImpl) ConnectServer() (net.Conn, error) {
|
||||
func (v *visitorHelperImpl) ConnectServer() (*msg.Conn, error) {
|
||||
return v.connectServerFn()
|
||||
}
|
||||
|
||||
|
||||
@@ -103,6 +103,10 @@ transport.poolCount = 5
|
||||
# supports tcp, kcp, quic, websocket and wss now, default is tcp
|
||||
transport.protocol = "tcp"
|
||||
|
||||
# FRP wire protocol used inside the selected transport.
|
||||
# supports v1 and v2, default is v1. v2 requires frps support and must be enabled explicitly.
|
||||
# transport.wireProtocol = "v1"
|
||||
|
||||
# set client binding ip when connect server, default is empty.
|
||||
# only when protocol = tcp or websocket, the value will be used.
|
||||
transport.connectServerLocalIP = "0.0.0.0"
|
||||
|
||||
@@ -33,7 +33,51 @@ git commit -m "bump version to vX.Y.Z"
|
||||
git push origin dev
|
||||
```
|
||||
|
||||
## 3. Merge dev → master
|
||||
## 3. Pre-release Validation
|
||||
|
||||
Run the standard e2e suite locally:
|
||||
|
||||
```bash
|
||||
make e2e
|
||||
```
|
||||
|
||||
For releases that touch compatibility-sensitive areas such as login, control
|
||||
connections, work connections, visitors, transport, or wire protocol handling,
|
||||
also run the manual compatibility e2e suite:
|
||||
|
||||
```bash
|
||||
make e2e-compatibility
|
||||
make e2e-compatibility-floor
|
||||
```
|
||||
|
||||
`make e2e-compatibility` builds the current `frps` and `frpc`, resolves the
|
||||
recent stable release baselines from GitHub, downloads or reuses their binaries,
|
||||
and tests current binaries against those historical releases. The default number
|
||||
of recent baselines is controlled by `FRP_COMPAT_BASELINE_COUNT` in the
|
||||
`Makefile`.
|
||||
|
||||
Downloaded release binaries are cached under:
|
||||
|
||||
```text
|
||||
.cache/e2e-compat/<version>/<os>_<arch>/
|
||||
```
|
||||
|
||||
For a release validation run that must be exactly reproducible, pass an explicit
|
||||
baseline matrix instead of using the floating recent-release list:
|
||||
|
||||
```bash
|
||||
FRP_COMPAT_BASELINE_VERSIONS="0.X.0 0.Y.0" make e2e-compatibility
|
||||
```
|
||||
|
||||
Use `make e2e-compatibility-smoke` for a quick single-baseline check while
|
||||
iterating locally. If GitHub release metadata requests are rate-limited, set
|
||||
`GITHUB_TOKEN` or use `FRP_COMPAT_BASELINE_VERSIONS`.
|
||||
|
||||
The compatibility floor is a support-policy decision, not a value that should
|
||||
change every release. Update `FRP_COMPAT_FLOOR_VERSION` only when the declared
|
||||
compatibility window changes.
|
||||
|
||||
## 4. Merge dev → master
|
||||
|
||||
Create a PR from `dev` to `master`:
|
||||
|
||||
@@ -43,7 +87,7 @@ gh pr create --base master --head dev --title "bump version"
|
||||
|
||||
Wait for CI to pass, then merge using **merge commit** (not squash).
|
||||
|
||||
## 4. Tag the Release
|
||||
## 5. Tag the Release
|
||||
|
||||
```bash
|
||||
git checkout master
|
||||
@@ -52,7 +96,7 @@ git tag -a vX.Y.Z -m "bump version"
|
||||
git push origin vX.Y.Z
|
||||
```
|
||||
|
||||
## 5. Trigger GoReleaser
|
||||
## 6. Trigger GoReleaser
|
||||
|
||||
Manually trigger the `goreleaser` workflow in GitHub Actions:
|
||||
|
||||
|
||||
38
doc/deprecations.md
Normal file
38
doc/deprecations.md
Normal file
@@ -0,0 +1,38 @@
|
||||
# Deprecations
|
||||
|
||||
This document tracks deprecated features and APIs that are still shipped but scheduled for removal. Maintainers should review this list before each release to decide whether any items are due for removal.
|
||||
|
||||
For the version compatibility policy that bounds these support windows, see the latest `Release.md`.
|
||||
|
||||
## Active
|
||||
|
||||
### Wire protocol v1
|
||||
|
||||
- **Deprecated since:** v0.70.0 (planned, when v2 becomes the default).
|
||||
- **Removal target:** v0.78.0 or later. v0.69.0 (the last release where v1 is the default) is supported until v0.78.0 is released, so v0.77.0 is the last release that must keep v1 support.
|
||||
- **Replacement:** wire protocol v2 (`transport.wireProtocol = "v2"` in frpc).
|
||||
- **Code references:** v1 message types and codec under `pkg/msg/` and the protocol negotiation path in `client/` and `server/`.
|
||||
- **Notes:** Removing v1 will also drop compatibility with any frpc/frps that does not negotiate v2.
|
||||
|
||||
### INI configuration format
|
||||
|
||||
- **Deprecated since:** predates this document; startup warning has been in place for several releases.
|
||||
- **Removal target:** TBD.
|
||||
- **Replacement:** YAML / JSON / TOML.
|
||||
- **Code references:**
|
||||
- `cmd/frpc/sub/root.go` — frpc startup warning.
|
||||
- `cmd/frps/root.go` — frps startup warning.
|
||||
- `pkg/config/legacy/` — legacy INI parser; remove together with the warnings.
|
||||
|
||||
### Visitor connections without `runID`
|
||||
|
||||
- **Deprecated since:** v0.50.0 (when `runID` was introduced).
|
||||
- **Removal target:** TBD.
|
||||
- **Replacement:** require `runID` on every visitor connection.
|
||||
- **Code references:**
|
||||
- `server/service.go` — `RegisterVisitorConn` still accepts empty `runID` for backward compatibility.
|
||||
- **Notes:** Removal will break frpc clients released before v0.50.0. Schedule for a release where dropping pre-v0.50.0 frpc is acceptable.
|
||||
|
||||
## Removed
|
||||
|
||||
_None yet._
|
||||
6
go.mod
6
go.mod
@@ -5,7 +5,7 @@ go 1.25.0
|
||||
require (
|
||||
github.com/armon/go-socks5 v0.0.0-20160902184237-e75332964ef5
|
||||
github.com/coreos/go-oidc/v3 v3.14.1
|
||||
github.com/fatedier/golib v0.6.0
|
||||
github.com/fatedier/golib v0.7.0
|
||||
github.com/google/uuid v1.6.0
|
||||
github.com/gorilla/mux v1.8.1
|
||||
github.com/gorilla/websocket v1.5.0
|
||||
@@ -30,11 +30,13 @@ require (
|
||||
golang.org/x/net v0.52.0
|
||||
golang.org/x/oauth2 v0.28.0
|
||||
golang.org/x/sync v0.20.0
|
||||
golang.org/x/sys v0.42.0
|
||||
golang.org/x/time v0.10.0
|
||||
golang.zx2c4.com/wireguard v0.0.0-20231211153847-12269c276173
|
||||
gopkg.in/ini.v1 v1.67.0
|
||||
k8s.io/apimachinery v0.28.8
|
||||
k8s.io/client-go v0.28.8
|
||||
k8s.io/utils v0.0.0-20230406110748-d93618cff8a2
|
||||
)
|
||||
|
||||
require (
|
||||
@@ -68,14 +70,12 @@ require (
|
||||
github.com/wlynxg/anet v0.0.5 // indirect
|
||||
go.uber.org/automaxprocs v1.6.0 // indirect
|
||||
golang.org/x/mod v0.33.0 // indirect
|
||||
golang.org/x/sys v0.42.0 // indirect
|
||||
golang.org/x/text v0.35.0 // indirect
|
||||
golang.org/x/tools v0.42.0 // indirect
|
||||
golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2 // indirect
|
||||
google.golang.org/protobuf v1.36.5 // indirect
|
||||
gopkg.in/yaml.v2 v2.4.0 // indirect
|
||||
gopkg.in/yaml.v3 v3.0.1 // indirect
|
||||
k8s.io/utils v0.0.0-20230406110748-d93618cff8a2 // indirect
|
||||
sigs.k8s.io/json v0.0.0-20221116044647-bc3834ca7abd // indirect
|
||||
sigs.k8s.io/yaml v1.3.0 // indirect
|
||||
)
|
||||
|
||||
4
go.sum
4
go.sum
@@ -20,8 +20,8 @@ github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSs
|
||||
github.com/envoyproxy/go-control-plane v0.9.0/go.mod h1:YTl/9mNaCwkRvm6d1a2C3ymFceY/DCBVvsKhRF0iEA4=
|
||||
github.com/envoyproxy/go-control-plane v0.9.4/go.mod h1:6rpuAdCZL397s3pYoYcLgu1mIlRU8Am5FuJP05cCM98=
|
||||
github.com/envoyproxy/protoc-gen-validate v0.1.0/go.mod h1:iSmxcyjqTsJpI2R4NaDN7+kN2VEUnK/pcBlmesArF7c=
|
||||
github.com/fatedier/golib v0.6.0 h1:/mgBZZbkbMhIEZoXf7nV8knpUDzas/b+2ruYKxx1lww=
|
||||
github.com/fatedier/golib v0.6.0/go.mod h1:ArUGvPg2cOw/py2RAuBt46nNZH2VQ5Z70p109MAZpJw=
|
||||
github.com/fatedier/golib v0.7.0 h1:tMDF9ObcwVt59VUHroJOzHQjVFPLymZVMpGm9WAVwhY=
|
||||
github.com/fatedier/golib v0.7.0/go.mod h1:ArUGvPg2cOw/py2RAuBt46nNZH2VQ5Z70p109MAZpJw=
|
||||
github.com/fatedier/yamux v0.0.0-20250825093530-d0154be01cd6 h1:u92UUy6FURPmNsMBUuongRWC0rBqN6gd01Dzu+D21NE=
|
||||
github.com/fatedier/yamux v0.0.0-20250825093530-d0154be01cd6/go.mod h1:c5/tk6G0dSpXGzJN7Wk1OEie8grdSJAmeawId9Zvd34=
|
||||
github.com/go-jose/go-jose/v4 v4.0.5 h1:M6T8+mKZl/+fNNuFHvGIzDz7BTLQPIounk/b9dw3AaE=
|
||||
|
||||
162
hack/run-e2e-compatibility.sh
Executable file
162
hack/run-e2e-compatibility.sh
Executable file
@@ -0,0 +1,162 @@
|
||||
#!/bin/sh
|
||||
|
||||
set -eu
|
||||
|
||||
SCRIPT=$(readlink -f "$0")
|
||||
ROOT=$(unset CDPATH && cd "$(dirname "$SCRIPT")/.." && pwd)
|
||||
|
||||
if ! command -v ginkgo >/dev/null 2>&1; then
|
||||
echo "ginkgo not found, try to install..."
|
||||
go install github.com/onsi/ginkgo/v2/ginkgo@v2.23.4
|
||||
fi
|
||||
|
||||
debug=false
|
||||
if [ "x${DEBUG:-}" = "xtrue" ]; then
|
||||
debug=true
|
||||
fi
|
||||
logLevel=debug
|
||||
if [ "${LOG_LEVEL:-}" ]; then
|
||||
logLevel="${LOG_LEVEL}"
|
||||
fi
|
||||
|
||||
currentFrpsPath=${CURRENT_FRPS_PATH:-${ROOT}/bin/frps}
|
||||
currentFrpcPath=${CURRENT_FRPC_PATH:-${ROOT}/bin/frpc}
|
||||
baselineCount=${FRP_COMPAT_BASELINE_COUNT:-8}
|
||||
targetOS=${TARGET_OS:-$(go env GOOS)}
|
||||
targetArch=${TARGET_ARCH:-$(go env GOARCH)}
|
||||
targetPlatform="${targetOS}_${targetArch}"
|
||||
cacheRoot=${FRP_COMPAT_CACHE_DIR:-${ROOT}/.cache/e2e-compat}
|
||||
|
||||
check_file() {
|
||||
if [ ! -f "$2" ]; then
|
||||
echo "$1 not found: $2"
|
||||
exit 1
|
||||
fi
|
||||
}
|
||||
|
||||
check_file "current frps" "${currentFrpsPath}"
|
||||
check_file "current frpc" "${currentFrpcPath}"
|
||||
|
||||
run_current_current=true
|
||||
|
||||
run_compatibility() {
|
||||
baselineVersion=$1
|
||||
baselineFrpsPath=$2
|
||||
baselineFrpcPath=$3
|
||||
|
||||
check_file "baseline frps" "${baselineFrpsPath}"
|
||||
check_file "baseline frpc" "${baselineFrpcPath}"
|
||||
|
||||
echo "Running compatibility e2e with baseline ${baselineVersion}"
|
||||
ginkgo -nodes=1 --poll-progress-after=60s "${ROOT}/test/e2e/compatibility" -- \
|
||||
-current-frps-path="${currentFrpsPath}" \
|
||||
-current-frpc-path="${currentFrpcPath}" \
|
||||
-baseline-frps-path="${baselineFrpsPath}" \
|
||||
-baseline-frpc-path="${baselineFrpcPath}" \
|
||||
-baseline-version="${baselineVersion}" \
|
||||
-run-current-current="${run_current_current}" \
|
||||
-log-level="${logLevel}" \
|
||||
-debug="${debug}"
|
||||
run_current_current=false
|
||||
}
|
||||
|
||||
github_api_curl() {
|
||||
if [ "${GITHUB_TOKEN:-}" ]; then
|
||||
curl -fsSL \
|
||||
-H "Accept: application/vnd.github+json" \
|
||||
-H "Authorization: Bearer ${GITHUB_TOKEN}" \
|
||||
"$1"
|
||||
else
|
||||
curl -fsSL "$1"
|
||||
fi
|
||||
}
|
||||
|
||||
resolve_versions() {
|
||||
if [ "${FRP_COMPAT_BASELINE_VERSIONS:-}" ]; then
|
||||
printf "%s\n" "${FRP_COMPAT_BASELINE_VERSIONS}"
|
||||
return
|
||||
fi
|
||||
|
||||
case "${baselineCount}" in
|
||||
'' | *[!0-9]*)
|
||||
echo "FRP_COMPAT_BASELINE_COUNT must be a positive integer: ${baselineCount}" >&2
|
||||
exit 1
|
||||
;;
|
||||
esac
|
||||
if [ "${baselineCount}" -eq 0 ]; then
|
||||
echo "FRP_COMPAT_BASELINE_COUNT must be greater than 0" >&2
|
||||
exit 1
|
||||
fi
|
||||
if [ "${baselineCount}" -gt 100 ]; then
|
||||
echo "FRP_COMPAT_BASELINE_COUNT must be less than or equal to 100" >&2
|
||||
exit 1
|
||||
fi
|
||||
|
||||
releaseURL="https://api.github.com/repos/fatedier/frp/releases?per_page=100"
|
||||
resolvedVersions=""
|
||||
if releases=$(github_api_curl "${releaseURL}" 2>/dev/null); then
|
||||
resolvedVersions=$(printf "%s\n" "${releases}" |
|
||||
sed -n 's/.*"tag_name":[[:space:]]*"v\([0-9][0-9]*\.[0-9][0-9]*\.[0-9][0-9]*\)".*/\1/p' |
|
||||
awk '!seen[$0]++' |
|
||||
head -n "${baselineCount}" |
|
||||
tr '\n' ' ' |
|
||||
sed 's/[[:space:]]*$//')
|
||||
else
|
||||
echo "Failed to fetch release metadata from GitHub API, falling back to GitHub releases page." >&2
|
||||
fi
|
||||
|
||||
if [ -z "${resolvedVersions}" ]; then
|
||||
releasesPageURL="https://github.com/fatedier/frp/releases"
|
||||
if ! releases=$(curl -fsSL "${releasesPageURL}"); then
|
||||
echo "Failed to fetch release metadata from GitHub: ${releasesPageURL}" >&2
|
||||
echo "Set FRP_COMPAT_BASELINE_VERSIONS to run with explicit baseline versions." >&2
|
||||
exit 1
|
||||
fi
|
||||
resolvedVersions=$(printf "%s\n" "${releases}" |
|
||||
grep -o 'releases/tag/v[0-9][0-9]*\.[0-9][0-9]*\.[0-9][0-9]*"' |
|
||||
sed 's#.*/v##; s/"$//' |
|
||||
awk '!seen[$0]++' |
|
||||
head -n "${baselineCount}" |
|
||||
tr '\n' ' ' |
|
||||
sed 's/[[:space:]]*$//')
|
||||
fi
|
||||
|
||||
set -- ${resolvedVersions}
|
||||
if [ "$#" -lt "${baselineCount}" ]; then
|
||||
echo "Only resolved $# stable release versions from GitHub, expected ${baselineCount}." >&2
|
||||
echo "Set FRP_COMPAT_BASELINE_VERSIONS to run with explicit baseline versions." >&2
|
||||
exit 1
|
||||
fi
|
||||
|
||||
printf "%s\n" "${resolvedVersions}"
|
||||
}
|
||||
|
||||
if [ "${BASELINE_FRPS_PATH:-}" ] || [ "${BASELINE_FRPC_PATH:-}" ]; then
|
||||
if [ -z "${BASELINE_FRPS_PATH:-}" ] || [ -z "${BASELINE_FRPC_PATH:-}" ]; then
|
||||
echo "BASELINE_FRPS_PATH and BASELINE_FRPC_PATH must be set together"
|
||||
exit 1
|
||||
fi
|
||||
run_compatibility "${FRP_COMPAT_BASELINE_VERSION:-custom}" "${BASELINE_FRPS_PATH}" "${BASELINE_FRPC_PATH}"
|
||||
exit 0
|
||||
fi
|
||||
|
||||
versions=$(resolve_versions)
|
||||
echo "Compatibility baseline versions: ${versions}"
|
||||
|
||||
mkdir -p "${cacheRoot}"
|
||||
for version in ${versions}; do
|
||||
baselineDir="${cacheRoot}/${version}/${targetPlatform}"
|
||||
if [ ! -f "${baselineDir}/frps" ] || [ ! -f "${baselineDir}/frpc" ]; then
|
||||
tmpDir="${cacheRoot}/.download-${version}-${targetPlatform}"
|
||||
rm -rf "${tmpDir}"
|
||||
(
|
||||
cd "${cacheRoot}"
|
||||
FRP_VERSION="${version}" TARGET_DIRNAME="$(basename "${tmpDir}")" "${ROOT}/hack/download.sh"
|
||||
)
|
||||
mkdir -p "$(dirname "${baselineDir}")"
|
||||
rm -rf "${baselineDir}"
|
||||
mv "${tmpDir}" "${baselineDir}"
|
||||
fi
|
||||
|
||||
run_compatibility "${version}" "${baselineDir}/frps" "${baselineDir}/frpc"
|
||||
done
|
||||
@@ -14,11 +14,7 @@
|
||||
|
||||
package source
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
v1 "github.com/fatedier/frp/pkg/config/v1"
|
||||
)
|
||||
import v1 "github.com/fatedier/frp/pkg/config/v1"
|
||||
|
||||
// ConfigSource implements Source for in-memory configuration.
|
||||
// All operations are thread-safe.
|
||||
@@ -39,23 +35,17 @@ func (s *ConfigSource) ReplaceAll(proxies []v1.ProxyConfigurer, visitors []v1.Vi
|
||||
|
||||
nextProxies := make(map[string]v1.ProxyConfigurer, len(proxies))
|
||||
for _, p := range proxies {
|
||||
if p == nil {
|
||||
return fmt.Errorf("proxy cannot be nil")
|
||||
}
|
||||
name := p.GetBaseConfig().Name
|
||||
if name == "" {
|
||||
return fmt.Errorf("proxy name cannot be empty")
|
||||
name, err := validateProxyName(p)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
nextProxies[name] = p
|
||||
}
|
||||
nextVisitors := make(map[string]v1.VisitorConfigurer, len(visitors))
|
||||
for _, v := range visitors {
|
||||
if v == nil {
|
||||
return fmt.Errorf("visitor cannot be nil")
|
||||
}
|
||||
name := v.GetBaseConfig().Name
|
||||
if name == "" {
|
||||
return fmt.Errorf("visitor name cannot be empty")
|
||||
name, err := validateVisitorName(v)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
nextVisitors[name] = v
|
||||
}
|
||||
|
||||
@@ -43,6 +43,11 @@ var (
|
||||
ErrNotFound = errors.New("not found")
|
||||
)
|
||||
|
||||
const (
|
||||
storeKindProxy = "proxy"
|
||||
storeKindVisitor = "visitor"
|
||||
)
|
||||
|
||||
func NewStoreSource(cfg StoreSourceConfig) (*StoreSource, error) {
|
||||
if cfg.Path == "" {
|
||||
return nil, fmt.Errorf("path is required")
|
||||
@@ -172,79 +177,111 @@ func (s *StoreSource) saveToFileUnlocked() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *StoreSource) AddProxy(proxy v1.ProxyConfigurer) error {
|
||||
if proxy == nil {
|
||||
return fmt.Errorf("proxy cannot be nil")
|
||||
func (s *StoreSource) persistOrRollbackUnlocked(rollback func()) error {
|
||||
if err := s.saveToFileUnlocked(); err != nil {
|
||||
rollback()
|
||||
return fmt.Errorf("failed to persist: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Store map selectors return the target map for generic helpers.
|
||||
func proxyStoreEntries(s *StoreSource) map[string]v1.ProxyConfigurer {
|
||||
return s.proxies
|
||||
}
|
||||
|
||||
func visitorStoreEntries(s *StoreSource) map[string]v1.VisitorConfigurer {
|
||||
return s.visitors
|
||||
}
|
||||
|
||||
// Store entry helpers share mutation, persistence, and rollback for proxy and visitor maps.
|
||||
// T is intentionally limited by callers to v1.ProxyConfigurer or v1.VisitorConfigurer.
|
||||
func addStoreEntry[T any](
|
||||
s *StoreSource,
|
||||
entriesFn func(*StoreSource) map[string]T,
|
||||
kind string,
|
||||
name string,
|
||||
value T,
|
||||
) error {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
entries := entriesFn(s)
|
||||
if _, exists := entries[name]; exists {
|
||||
return fmt.Errorf("%w: %s %q", ErrAlreadyExists, kind, name)
|
||||
}
|
||||
|
||||
name := proxy.GetBaseConfig().Name
|
||||
entries[name] = value
|
||||
return s.persistOrRollbackUnlocked(func() {
|
||||
delete(entries, name)
|
||||
})
|
||||
}
|
||||
|
||||
func updateStoreEntry[T any](
|
||||
s *StoreSource,
|
||||
entriesFn func(*StoreSource) map[string]T,
|
||||
kind string,
|
||||
name string,
|
||||
value T,
|
||||
) error {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
entries := entriesFn(s)
|
||||
old, exists := entries[name]
|
||||
if !exists {
|
||||
return fmt.Errorf("%w: %s %q", ErrNotFound, kind, name)
|
||||
}
|
||||
|
||||
entries[name] = value
|
||||
return s.persistOrRollbackUnlocked(func() {
|
||||
entries[name] = old
|
||||
})
|
||||
}
|
||||
|
||||
func removeStoreEntry[T any](
|
||||
s *StoreSource,
|
||||
entriesFn func(*StoreSource) map[string]T,
|
||||
kind string,
|
||||
name string,
|
||||
) error {
|
||||
if name == "" {
|
||||
return fmt.Errorf("proxy name cannot be empty")
|
||||
return fmt.Errorf("%s name cannot be empty", kind)
|
||||
}
|
||||
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
if _, exists := s.proxies[name]; exists {
|
||||
return fmt.Errorf("%w: proxy %q", ErrAlreadyExists, name)
|
||||
entries := entriesFn(s)
|
||||
old, exists := entries[name]
|
||||
if !exists {
|
||||
return fmt.Errorf("%w: %s %q", ErrNotFound, kind, name)
|
||||
}
|
||||
|
||||
s.proxies[name] = proxy
|
||||
delete(entries, name)
|
||||
return s.persistOrRollbackUnlocked(func() {
|
||||
entries[name] = old
|
||||
})
|
||||
}
|
||||
|
||||
if err := s.saveToFileUnlocked(); err != nil {
|
||||
delete(s.proxies, name)
|
||||
return fmt.Errorf("failed to persist: %w", err)
|
||||
func (s *StoreSource) AddProxy(proxy v1.ProxyConfigurer) error {
|
||||
name, err := validateProxyName(proxy)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
return addStoreEntry(s, proxyStoreEntries, storeKindProxy, name, proxy)
|
||||
}
|
||||
|
||||
func (s *StoreSource) UpdateProxy(proxy v1.ProxyConfigurer) error {
|
||||
if proxy == nil {
|
||||
return fmt.Errorf("proxy cannot be nil")
|
||||
name, err := validateProxyName(proxy)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
name := proxy.GetBaseConfig().Name
|
||||
if name == "" {
|
||||
return fmt.Errorf("proxy name cannot be empty")
|
||||
}
|
||||
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
oldProxy, exists := s.proxies[name]
|
||||
if !exists {
|
||||
return fmt.Errorf("%w: proxy %q", ErrNotFound, name)
|
||||
}
|
||||
|
||||
s.proxies[name] = proxy
|
||||
|
||||
if err := s.saveToFileUnlocked(); err != nil {
|
||||
s.proxies[name] = oldProxy
|
||||
return fmt.Errorf("failed to persist: %w", err)
|
||||
}
|
||||
return nil
|
||||
return updateStoreEntry(s, proxyStoreEntries, storeKindProxy, name, proxy)
|
||||
}
|
||||
|
||||
func (s *StoreSource) RemoveProxy(name string) error {
|
||||
if name == "" {
|
||||
return fmt.Errorf("proxy name cannot be empty")
|
||||
}
|
||||
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
oldProxy, exists := s.proxies[name]
|
||||
if !exists {
|
||||
return fmt.Errorf("%w: proxy %q", ErrNotFound, name)
|
||||
}
|
||||
|
||||
delete(s.proxies, name)
|
||||
|
||||
if err := s.saveToFileUnlocked(); err != nil {
|
||||
s.proxies[name] = oldProxy
|
||||
return fmt.Errorf("failed to persist: %w", err)
|
||||
}
|
||||
return nil
|
||||
return removeStoreEntry(s, proxyStoreEntries, storeKindProxy, name)
|
||||
}
|
||||
|
||||
func (s *StoreSource) GetProxy(name string) v1.ProxyConfigurer {
|
||||
@@ -259,78 +296,23 @@ func (s *StoreSource) GetProxy(name string) v1.ProxyConfigurer {
|
||||
}
|
||||
|
||||
func (s *StoreSource) AddVisitor(visitor v1.VisitorConfigurer) error {
|
||||
if visitor == nil {
|
||||
return fmt.Errorf("visitor cannot be nil")
|
||||
name, err := validateVisitorName(visitor)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
name := visitor.GetBaseConfig().Name
|
||||
if name == "" {
|
||||
return fmt.Errorf("visitor name cannot be empty")
|
||||
}
|
||||
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
if _, exists := s.visitors[name]; exists {
|
||||
return fmt.Errorf("%w: visitor %q", ErrAlreadyExists, name)
|
||||
}
|
||||
|
||||
s.visitors[name] = visitor
|
||||
|
||||
if err := s.saveToFileUnlocked(); err != nil {
|
||||
delete(s.visitors, name)
|
||||
return fmt.Errorf("failed to persist: %w", err)
|
||||
}
|
||||
return nil
|
||||
return addStoreEntry(s, visitorStoreEntries, storeKindVisitor, name, visitor)
|
||||
}
|
||||
|
||||
func (s *StoreSource) UpdateVisitor(visitor v1.VisitorConfigurer) error {
|
||||
if visitor == nil {
|
||||
return fmt.Errorf("visitor cannot be nil")
|
||||
name, err := validateVisitorName(visitor)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
name := visitor.GetBaseConfig().Name
|
||||
if name == "" {
|
||||
return fmt.Errorf("visitor name cannot be empty")
|
||||
}
|
||||
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
oldVisitor, exists := s.visitors[name]
|
||||
if !exists {
|
||||
return fmt.Errorf("%w: visitor %q", ErrNotFound, name)
|
||||
}
|
||||
|
||||
s.visitors[name] = visitor
|
||||
|
||||
if err := s.saveToFileUnlocked(); err != nil {
|
||||
s.visitors[name] = oldVisitor
|
||||
return fmt.Errorf("failed to persist: %w", err)
|
||||
}
|
||||
return nil
|
||||
return updateStoreEntry(s, visitorStoreEntries, storeKindVisitor, name, visitor)
|
||||
}
|
||||
|
||||
func (s *StoreSource) RemoveVisitor(name string) error {
|
||||
if name == "" {
|
||||
return fmt.Errorf("visitor name cannot be empty")
|
||||
}
|
||||
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
oldVisitor, exists := s.visitors[name]
|
||||
if !exists {
|
||||
return fmt.Errorf("%w: visitor %q", ErrNotFound, name)
|
||||
}
|
||||
|
||||
delete(s.visitors, name)
|
||||
|
||||
if err := s.saveToFileUnlocked(); err != nil {
|
||||
s.visitors[name] = oldVisitor
|
||||
return fmt.Errorf("failed to persist: %w", err)
|
||||
}
|
||||
return nil
|
||||
return removeStoreEntry(s, visitorStoreEntries, storeKindVisitor, name)
|
||||
}
|
||||
|
||||
func (s *StoreSource) GetVisitor(name string) v1.VisitorConfigurer {
|
||||
|
||||
@@ -17,6 +17,7 @@ package source
|
||||
import (
|
||||
"os"
|
||||
"path/filepath"
|
||||
"runtime"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
@@ -59,6 +60,101 @@ func TestStoreSource_AddProxyAndVisitor_DoesNotApplyRuntimeDefaults(t *testing.T
|
||||
require.Empty(gotVisitor.(*v1.XTCPVisitorConfig).Protocol)
|
||||
}
|
||||
|
||||
func TestStoreSource_UpdateAndRemoveProxyAndVisitor(t *testing.T) {
|
||||
require := require.New(t)
|
||||
|
||||
storeSource := newTestStoreSource(t)
|
||||
|
||||
proxyCfg := mockProxy("proxy1")
|
||||
visitorCfg := mockVisitor("visitor1")
|
||||
|
||||
require.NoError(storeSource.AddProxy(proxyCfg))
|
||||
require.NoError(storeSource.AddVisitor(visitorCfg))
|
||||
require.ErrorIs(storeSource.AddProxy(proxyCfg), ErrAlreadyExists)
|
||||
require.ErrorIs(storeSource.AddVisitor(visitorCfg), ErrAlreadyExists)
|
||||
require.ErrorContains(storeSource.RemoveProxy(""), "proxy name cannot be empty")
|
||||
require.ErrorContains(storeSource.RemoveVisitor(""), "visitor name cannot be empty")
|
||||
|
||||
updatedProxy := mockProxy("proxy1").(*v1.TCPProxyConfig)
|
||||
updatedProxy.RemotePort = 19090
|
||||
require.NoError(storeSource.UpdateProxy(updatedProxy))
|
||||
require.Equal(19090, storeSource.GetProxy("proxy1").(*v1.TCPProxyConfig).RemotePort)
|
||||
|
||||
updatedVisitor := mockVisitor("visitor1").(*v1.STCPVisitorConfig)
|
||||
updatedVisitor.ServerName = "updated-server"
|
||||
require.NoError(storeSource.UpdateVisitor(updatedVisitor))
|
||||
require.Equal("updated-server", storeSource.GetVisitor("visitor1").(*v1.STCPVisitorConfig).ServerName)
|
||||
|
||||
require.NoError(storeSource.RemoveProxy("proxy1"))
|
||||
require.Nil(storeSource.GetProxy("proxy1"))
|
||||
require.ErrorIs(storeSource.RemoveProxy("proxy1"), ErrNotFound)
|
||||
|
||||
require.NoError(storeSource.RemoveVisitor("visitor1"))
|
||||
require.Nil(storeSource.GetVisitor("visitor1"))
|
||||
require.ErrorIs(storeSource.RemoveVisitor("visitor1"), ErrNotFound)
|
||||
|
||||
require.ErrorIs(storeSource.UpdateProxy(updatedProxy), ErrNotFound)
|
||||
require.ErrorIs(storeSource.UpdateVisitor(updatedVisitor), ErrNotFound)
|
||||
}
|
||||
|
||||
func TestStoreSource_MutationRollsBackOnPersistFailure(t *testing.T) {
|
||||
if runtime.GOOS == "windows" {
|
||||
t.Skip("chmod does not make directories unwritable on Windows")
|
||||
}
|
||||
if os.Getuid() == 0 {
|
||||
t.Skip("chmod does not block writes for uid 0")
|
||||
}
|
||||
|
||||
require := require.New(t)
|
||||
|
||||
dir := t.TempDir()
|
||||
path := filepath.Join(dir, "store.json")
|
||||
storeSource, err := NewStoreSource(StoreSourceConfig{Path: path})
|
||||
require.NoError(err)
|
||||
|
||||
proxyCfg := mockProxy("proxy1")
|
||||
visitorCfg := mockVisitor("visitor1")
|
||||
originalRemotePort := proxyCfg.(*v1.TCPProxyConfig).RemotePort
|
||||
originalServerName := visitorCfg.(*v1.STCPVisitorConfig).ServerName
|
||||
require.NoError(storeSource.AddProxy(proxyCfg))
|
||||
require.NoError(storeSource.AddVisitor(visitorCfg))
|
||||
|
||||
require.NoError(os.Chmod(dir, 0o500))
|
||||
t.Cleanup(func() {
|
||||
_ = os.Chmod(dir, 0o700)
|
||||
})
|
||||
|
||||
requirePersistError := func(err error) {
|
||||
t.Helper()
|
||||
require.Error(err)
|
||||
require.ErrorContains(err, "failed to persist")
|
||||
require.NotErrorIs(err, ErrAlreadyExists)
|
||||
require.NotErrorIs(err, ErrNotFound)
|
||||
}
|
||||
|
||||
requirePersistError(storeSource.AddProxy(mockProxy("proxy2")))
|
||||
require.Nil(storeSource.GetProxy("proxy2"))
|
||||
|
||||
updatedProxy := mockProxy("proxy1").(*v1.TCPProxyConfig)
|
||||
updatedProxy.RemotePort = 19090
|
||||
requirePersistError(storeSource.UpdateProxy(updatedProxy))
|
||||
require.Equal(originalRemotePort, storeSource.GetProxy("proxy1").(*v1.TCPProxyConfig).RemotePort)
|
||||
|
||||
requirePersistError(storeSource.RemoveProxy("proxy1"))
|
||||
require.NotNil(storeSource.GetProxy("proxy1"))
|
||||
|
||||
requirePersistError(storeSource.AddVisitor(mockVisitor("visitor2")))
|
||||
require.Nil(storeSource.GetVisitor("visitor2"))
|
||||
|
||||
updatedVisitor := mockVisitor("visitor1").(*v1.STCPVisitorConfig)
|
||||
updatedVisitor.ServerName = "updated-server"
|
||||
requirePersistError(storeSource.UpdateVisitor(updatedVisitor))
|
||||
require.Equal(originalServerName, storeSource.GetVisitor("visitor1").(*v1.STCPVisitorConfig).ServerName)
|
||||
|
||||
requirePersistError(storeSource.RemoveVisitor("visitor1"))
|
||||
require.NotNil(storeSource.GetVisitor("visitor1"))
|
||||
}
|
||||
|
||||
func TestStoreSource_LoadFromFile_DoesNotApplyRuntimeDefaults(t *testing.T) {
|
||||
require := require.New(t)
|
||||
|
||||
|
||||
43
pkg/config/source/validation.go
Normal file
43
pkg/config/source/validation.go
Normal file
@@ -0,0 +1,43 @@
|
||||
// Copyright 2026 The frp Authors
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package source
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
v1 "github.com/fatedier/frp/pkg/config/v1"
|
||||
)
|
||||
|
||||
func validateProxyName(proxy v1.ProxyConfigurer) (string, error) {
|
||||
if proxy == nil {
|
||||
return "", fmt.Errorf("proxy cannot be nil")
|
||||
}
|
||||
name := proxy.GetBaseConfig().Name
|
||||
if name == "" {
|
||||
return "", fmt.Errorf("proxy name cannot be empty")
|
||||
}
|
||||
return name, nil
|
||||
}
|
||||
|
||||
func validateVisitorName(visitor v1.VisitorConfigurer) (string, error) {
|
||||
if visitor == nil {
|
||||
return "", fmt.Errorf("visitor cannot be nil")
|
||||
}
|
||||
name := visitor.GetBaseConfig().Name
|
||||
if name == "" {
|
||||
return "", fmt.Errorf("visitor name cannot be empty")
|
||||
}
|
||||
return name, nil
|
||||
}
|
||||
@@ -104,6 +104,9 @@ type ClientTransportConfig struct {
|
||||
// Valid values are "tcp", "kcp", "quic", "websocket" and "wss". By default, this value
|
||||
// is "tcp".
|
||||
Protocol string `json:"protocol,omitempty"`
|
||||
// WireProtocol specifies the frpc/frps internal wire protocol version.
|
||||
// Valid values are "v1" and "v2". By default, this value is "v1".
|
||||
WireProtocol string `json:"wireProtocol,omitempty"`
|
||||
// The maximum amount of time a dial to server will wait for a connect to complete.
|
||||
DialServerTimeout int64 `json:"dialServerTimeout,omitempty"`
|
||||
// DialServerKeepAlive specifies the interval between keep-alive probes for an active network connection between frpc and frps.
|
||||
@@ -143,6 +146,7 @@ type ClientTransportConfig struct {
|
||||
|
||||
func (c *ClientTransportConfig) Complete() {
|
||||
c.Protocol = util.EmptyOr(c.Protocol, "tcp")
|
||||
c.WireProtocol = util.EmptyOr(c.WireProtocol, "v1")
|
||||
c.DialServerTimeout = util.EmptyOr(c.DialServerTimeout, 10)
|
||||
c.DialServerKeepAlive = util.EmptyOr(c.DialServerKeepAlive, 7200)
|
||||
c.ProxyURL = util.EmptyOr(c.ProxyURL, os.Getenv("http_proxy"))
|
||||
|
||||
@@ -29,6 +29,7 @@ func TestClientConfigComplete(t *testing.T) {
|
||||
|
||||
require.EqualValues("token", c.Auth.Method)
|
||||
require.Equal(true, lo.FromPtr(c.Transport.TCPMux))
|
||||
require.Equal("v1", c.Transport.WireProtocol)
|
||||
require.Equal(true, lo.FromPtr(c.LoginFailExit))
|
||||
require.Equal(true, lo.FromPtr(c.Transport.TLS.Enable))
|
||||
require.Equal(true, lo.FromPtr(c.Transport.TLS.DisableCustomTLSFirstByte))
|
||||
|
||||
43
pkg/config/v1/validation/auth.go
Normal file
43
pkg/config/v1/validation/auth.go
Normal file
@@ -0,0 +1,43 @@
|
||||
// Copyright 2026 The frp Authors
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package validation
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
v1 "github.com/fatedier/frp/pkg/config/v1"
|
||||
"github.com/fatedier/frp/pkg/policy/security"
|
||||
)
|
||||
|
||||
func (v *ConfigValidator) validateAuthTokenSource(token string, tokenSource *v1.ValueSource) error {
|
||||
var errs error
|
||||
// Preserve the previous client/server validation order for joined errors.
|
||||
if token != "" && tokenSource != nil {
|
||||
errs = AppendError(errs, fmt.Errorf("cannot specify both auth.token and auth.tokenSource"))
|
||||
}
|
||||
if tokenSource == nil {
|
||||
return errs
|
||||
}
|
||||
|
||||
if tokenSource.Type == "exec" {
|
||||
if err := v.ValidateUnsafeFeature(security.TokenSourceExec); err != nil {
|
||||
errs = AppendError(errs, err)
|
||||
}
|
||||
}
|
||||
if err := tokenSource.Validate(); err != nil {
|
||||
errs = AppendError(errs, fmt.Errorf("invalid auth.tokenSource: %v", err))
|
||||
}
|
||||
return errs
|
||||
}
|
||||
228
pkg/config/v1/validation/auth_test.go
Normal file
228
pkg/config/v1/validation/auth_test.go
Normal file
@@ -0,0 +1,228 @@
|
||||
// Copyright 2026 The frp Authors
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package validation
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
v1 "github.com/fatedier/frp/pkg/config/v1"
|
||||
"github.com/fatedier/frp/pkg/policy/security"
|
||||
)
|
||||
|
||||
const (
|
||||
tokenSourceConflictErr = "cannot specify both auth.token and auth.tokenSource"
|
||||
tokenSourceExecErr = "unsafe feature \"TokenSourceExec\" is not enabled. To enable it, ensure it is allowed in the configuration or command line flags"
|
||||
invalidFileSourceErr = "invalid auth.tokenSource: file configuration is required when type is 'file'"
|
||||
unsupportedSourceErr = "invalid auth.tokenSource: unsupported value source type: env (only 'file' and 'exec' are supported)"
|
||||
)
|
||||
|
||||
func TestValidateAuthTokenSource(t *testing.T) {
|
||||
for _, tc := range authTokenSourceTestCases() {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
validator := newAuthTokenSourceValidator(tc.unsafeAllowed)
|
||||
err := validator.validateAuthTokenSource(tc.token, tc.tokenSource())
|
||||
requireValidationErrors(t, err, tc.wantErrs)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestValidateClientAuthTokenSource(t *testing.T) {
|
||||
for _, tc := range authTokenSourceTestCases() {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
auth := v1.AuthClientConfig{
|
||||
Method: v1.AuthMethodToken,
|
||||
Token: tc.token,
|
||||
TokenSource: tc.tokenSource(),
|
||||
}
|
||||
validator := newAuthTokenSourceValidator(tc.unsafeAllowed)
|
||||
_, err := validator.ValidateClientCommonConfig(validClientConfigWithAuth(auth))
|
||||
requireValidationErrors(t, err, tc.wantErrs)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestValidateServerAuthTokenSource(t *testing.T) {
|
||||
for _, tc := range authTokenSourceTestCases() {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
auth := v1.AuthServerConfig{
|
||||
Method: v1.AuthMethodToken,
|
||||
Token: tc.token,
|
||||
TokenSource: tc.tokenSource(),
|
||||
}
|
||||
validator := newAuthTokenSourceValidator(tc.unsafeAllowed)
|
||||
_, err := validator.ValidateServerConfig(validServerConfigWithAuth(auth))
|
||||
requireValidationErrors(t, err, tc.wantErrs)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
type authTokenSourceTestCase struct {
|
||||
name string
|
||||
token string
|
||||
tokenSource func() *v1.ValueSource
|
||||
unsafeAllowed bool
|
||||
wantErrs []string
|
||||
}
|
||||
|
||||
func authTokenSourceTestCases() []authTokenSourceTestCase {
|
||||
return []authTokenSourceTestCase{
|
||||
{
|
||||
name: "empty token config",
|
||||
tokenSource: nilTokenSource,
|
||||
},
|
||||
{
|
||||
name: "valid file tokenSource",
|
||||
tokenSource: validFileTokenSource,
|
||||
},
|
||||
{
|
||||
name: "literal token without tokenSource",
|
||||
token: "token",
|
||||
tokenSource: nilTokenSource,
|
||||
},
|
||||
{
|
||||
name: "literal token conflicts with file tokenSource",
|
||||
token: "token",
|
||||
tokenSource: validFileTokenSource,
|
||||
wantErrs: []string{tokenSourceConflictErr},
|
||||
},
|
||||
{
|
||||
name: "exec tokenSource requires unsafe feature",
|
||||
tokenSource: validExecTokenSource,
|
||||
wantErrs: []string{tokenSourceExecErr},
|
||||
},
|
||||
{
|
||||
name: "exec tokenSource with unsafe feature allowed",
|
||||
tokenSource: validExecTokenSource,
|
||||
unsafeAllowed: true,
|
||||
},
|
||||
{
|
||||
name: "literal token conflicts with exec tokenSource and unsafe feature disabled",
|
||||
token: "token",
|
||||
tokenSource: validExecTokenSource,
|
||||
wantErrs: []string{
|
||||
tokenSourceConflictErr,
|
||||
tokenSourceExecErr,
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "literal token conflicts with exec tokenSource and unsafe feature allowed",
|
||||
token: "token",
|
||||
tokenSource: validExecTokenSource,
|
||||
unsafeAllowed: true,
|
||||
wantErrs: []string{tokenSourceConflictErr},
|
||||
},
|
||||
{
|
||||
name: "invalid file tokenSource is wrapped",
|
||||
tokenSource: invalidFileTokenSource,
|
||||
wantErrs: []string{invalidFileSourceErr},
|
||||
},
|
||||
{
|
||||
name: "unsupported tokenSource type is wrapped",
|
||||
tokenSource: unsupportedTokenSource,
|
||||
wantErrs: []string{unsupportedSourceErr},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func newAuthTokenSourceValidator(unsafeAllowed bool) *ConfigValidator {
|
||||
if !unsafeAllowed {
|
||||
return NewConfigValidator(nil)
|
||||
}
|
||||
return NewConfigValidator(security.NewUnsafeFeatures([]string{security.TokenSourceExec}))
|
||||
}
|
||||
|
||||
func requireValidationErrors(t *testing.T, err error, wantErrs []string) {
|
||||
t.Helper()
|
||||
if len(wantErrs) == 0 {
|
||||
require.NoError(t, err)
|
||||
return
|
||||
}
|
||||
require.Error(t, err)
|
||||
// Client/server validators may wrap joined errors in another join layer; compare leaf errors.
|
||||
gotErrs := unwrapValidationErrors(err)
|
||||
require.Len(t, gotErrs, len(wantErrs))
|
||||
for i, wantErr := range wantErrs {
|
||||
require.EqualError(t, gotErrs[i], wantErr)
|
||||
}
|
||||
}
|
||||
|
||||
func unwrapValidationErrors(err error) []error {
|
||||
type joinedError interface {
|
||||
Unwrap() []error
|
||||
}
|
||||
joined, ok := err.(joinedError)
|
||||
if !ok {
|
||||
return []error{err}
|
||||
}
|
||||
|
||||
var errs []error
|
||||
for _, err := range joined.Unwrap() {
|
||||
errs = append(errs, unwrapValidationErrors(err)...)
|
||||
}
|
||||
return errs
|
||||
}
|
||||
|
||||
// nilTokenSource keeps the shared table shape uniform for cases without a tokenSource.
|
||||
func nilTokenSource() *v1.ValueSource {
|
||||
return nil
|
||||
}
|
||||
|
||||
func validFileTokenSource() *v1.ValueSource {
|
||||
return &v1.ValueSource{
|
||||
Type: "file",
|
||||
File: &v1.FileSource{Path: "token.txt"},
|
||||
}
|
||||
}
|
||||
|
||||
func validExecTokenSource() *v1.ValueSource {
|
||||
return &v1.ValueSource{
|
||||
Type: "exec",
|
||||
Exec: &v1.ExecSource{Command: "print-token"},
|
||||
}
|
||||
}
|
||||
|
||||
func invalidFileTokenSource() *v1.ValueSource {
|
||||
return &v1.ValueSource{
|
||||
Type: "file",
|
||||
}
|
||||
}
|
||||
|
||||
func unsupportedTokenSource() *v1.ValueSource {
|
||||
return &v1.ValueSource{Type: "env"}
|
||||
}
|
||||
|
||||
func validClientConfigWithAuth(auth v1.AuthClientConfig) *v1.ClientCommonConfig {
|
||||
return &v1.ClientCommonConfig{
|
||||
Auth: auth,
|
||||
Log: v1.LogConfig{
|
||||
Level: "info",
|
||||
},
|
||||
Transport: v1.ClientTransportConfig{
|
||||
Protocol: "tcp",
|
||||
WireProtocol: "v1",
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func validServerConfigWithAuth(auth v1.AuthServerConfig) *v1.ServerConfig {
|
||||
return &v1.ServerConfig{
|
||||
Auth: auth,
|
||||
Log: v1.LogConfig{
|
||||
Level: "info",
|
||||
},
|
||||
}
|
||||
}
|
||||
@@ -68,22 +68,7 @@ func (v *ConfigValidator) validateAuthConfig(c *v1.AuthClientConfig) (Warning, e
|
||||
errs = AppendError(errs, fmt.Errorf("invalid auth additional scopes, optional values are %v", SupportedAuthAdditionalScopes))
|
||||
}
|
||||
|
||||
// Validate token/tokenSource mutual exclusivity
|
||||
if c.Token != "" && c.TokenSource != nil {
|
||||
errs = AppendError(errs, fmt.Errorf("cannot specify both auth.token and auth.tokenSource"))
|
||||
}
|
||||
|
||||
// Validate tokenSource if specified
|
||||
if c.TokenSource != nil {
|
||||
if c.TokenSource.Type == "exec" {
|
||||
if err := v.ValidateUnsafeFeature(security.TokenSourceExec); err != nil {
|
||||
errs = AppendError(errs, err)
|
||||
}
|
||||
}
|
||||
if err := c.TokenSource.Validate(); err != nil {
|
||||
errs = AppendError(errs, fmt.Errorf("invalid auth.tokenSource: %v", err))
|
||||
}
|
||||
}
|
||||
errs = AppendError(errs, v.validateAuthTokenSource(c.Token, c.TokenSource))
|
||||
|
||||
if err := v.validateOIDCConfig(&c.OIDC); err != nil {
|
||||
errs = AppendError(errs, err)
|
||||
@@ -146,6 +131,9 @@ func validateTransportConfig(c *v1.ClientTransportConfig) (Warning, error) {
|
||||
if !slices.Contains(SupportedTransportProtocols, c.Protocol) {
|
||||
errs = AppendError(errs, fmt.Errorf("invalid transport.protocol, optional values are %v", SupportedTransportProtocols))
|
||||
}
|
||||
if !slices.Contains(SupportedWireProtocols, c.WireProtocol) {
|
||||
errs = AppendError(errs, fmt.Errorf("invalid transport.wireProtocol, optional values are %v", SupportedWireProtocols))
|
||||
}
|
||||
return warnings, errs
|
||||
}
|
||||
|
||||
|
||||
@@ -21,7 +21,6 @@ import (
|
||||
"github.com/samber/lo"
|
||||
|
||||
v1 "github.com/fatedier/frp/pkg/config/v1"
|
||||
"github.com/fatedier/frp/pkg/policy/security"
|
||||
)
|
||||
|
||||
func (v *ConfigValidator) ValidateServerConfig(c *v1.ServerConfig) (Warning, error) {
|
||||
@@ -36,22 +35,7 @@ func (v *ConfigValidator) ValidateServerConfig(c *v1.ServerConfig) (Warning, err
|
||||
errs = AppendError(errs, fmt.Errorf("invalid auth additional scopes, optional values are %v", SupportedAuthAdditionalScopes))
|
||||
}
|
||||
|
||||
// Validate token/tokenSource mutual exclusivity
|
||||
if c.Auth.Token != "" && c.Auth.TokenSource != nil {
|
||||
errs = AppendError(errs, fmt.Errorf("cannot specify both auth.token and auth.tokenSource"))
|
||||
}
|
||||
|
||||
// Validate tokenSource if specified
|
||||
if c.Auth.TokenSource != nil {
|
||||
if c.Auth.TokenSource.Type == "exec" {
|
||||
if err := v.ValidateUnsafeFeature(security.TokenSourceExec); err != nil {
|
||||
errs = AppendError(errs, err)
|
||||
}
|
||||
}
|
||||
if err := c.Auth.TokenSource.Validate(); err != nil {
|
||||
errs = AppendError(errs, fmt.Errorf("invalid auth.tokenSource: %v", err))
|
||||
}
|
||||
}
|
||||
errs = AppendError(errs, v.validateAuthTokenSource(c.Auth.Token, c.Auth.TokenSource))
|
||||
|
||||
if err := validateLogConfig(&c.Log); err != nil {
|
||||
errs = AppendError(errs, err)
|
||||
|
||||
@@ -29,6 +29,10 @@ var (
|
||||
"websocket",
|
||||
"wss",
|
||||
}
|
||||
SupportedWireProtocols = []string{
|
||||
"v1",
|
||||
"v2",
|
||||
}
|
||||
|
||||
SupportedAuthMethods = []v1.AuthMethod{
|
||||
"token",
|
||||
|
||||
@@ -18,6 +18,8 @@ import (
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"k8s.io/utils/clock"
|
||||
|
||||
"github.com/fatedier/frp/pkg/util/log"
|
||||
"github.com/fatedier/frp/pkg/util/metric"
|
||||
server "github.com/fatedier/frp/server/metrics"
|
||||
@@ -37,12 +39,21 @@ func init() {
|
||||
}
|
||||
|
||||
type serverMetrics struct {
|
||||
info *ServerStatistics
|
||||
mu sync.Mutex
|
||||
info *ServerStatistics
|
||||
clock clock.WithTicker
|
||||
mu sync.Mutex
|
||||
}
|
||||
|
||||
func newServerMetrics() *serverMetrics {
|
||||
return newServerMetricsWithClock(clock.RealClock{})
|
||||
}
|
||||
|
||||
func newServerMetricsWithClock(clk clock.WithTicker) *serverMetrics {
|
||||
if clk == nil {
|
||||
clk = clock.RealClock{}
|
||||
}
|
||||
return &serverMetrics{
|
||||
clock: clk,
|
||||
info: &ServerStatistics{
|
||||
TotalTrafficIn: metric.NewDateCounter(ReserveDays),
|
||||
TotalTrafficOut: metric.NewDateCounter(ReserveDays),
|
||||
@@ -57,14 +68,23 @@ func newServerMetrics() *serverMetrics {
|
||||
}
|
||||
|
||||
func (m *serverMetrics) run() {
|
||||
go func() {
|
||||
for {
|
||||
time.Sleep(12 * time.Hour)
|
||||
start := time.Now()
|
||||
go m.runUntil(nil)
|
||||
}
|
||||
|
||||
func (m *serverMetrics) runUntil(stopCh <-chan struct{}) {
|
||||
ticker := m.clock.NewTicker(12 * time.Hour)
|
||||
defer ticker.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ticker.C():
|
||||
start := m.clock.Now()
|
||||
count, total := m.clearUselessInfo(time.Duration(7*24) * time.Hour)
|
||||
log.Debugf("clear useless proxy statistics data count %d/%d, cost %v", count, total, time.Since(start))
|
||||
log.Debugf("clear useless proxy statistics data count %d/%d, cost %v", count, total, m.clock.Since(start))
|
||||
case <-stopCh:
|
||||
return
|
||||
}
|
||||
}()
|
||||
}
|
||||
}
|
||||
|
||||
func (m *serverMetrics) clearUselessInfo(continuousOfflineDuration time.Duration) (int, int) {
|
||||
@@ -77,7 +97,7 @@ func (m *serverMetrics) clearUselessInfo(continuousOfflineDuration time.Duration
|
||||
for name, data := range m.info.ProxyStatistics {
|
||||
if !data.LastCloseTime.IsZero() &&
|
||||
data.LastStartTime.Before(data.LastCloseTime) &&
|
||||
time.Since(data.LastCloseTime) > continuousOfflineDuration {
|
||||
m.clock.Since(data.LastCloseTime) > continuousOfflineDuration {
|
||||
delete(m.info.ProxyStatistics, name)
|
||||
count++
|
||||
log.Tracef("clear proxy [%s]'s statistics data, lastCloseTime: [%s]", name, data.LastCloseTime.String())
|
||||
@@ -121,7 +141,7 @@ func (m *serverMetrics) NewProxy(name string, proxyType string, user string, cli
|
||||
}
|
||||
proxyStats.User = user
|
||||
proxyStats.ClientID = clientID
|
||||
proxyStats.LastStartTime = time.Now()
|
||||
proxyStats.LastStartTime = m.clock.Now()
|
||||
}
|
||||
|
||||
func (m *serverMetrics) CloseProxy(name string, proxyType string) {
|
||||
@@ -131,7 +151,7 @@ func (m *serverMetrics) CloseProxy(name string, proxyType string) {
|
||||
counter.Dec(1)
|
||||
}
|
||||
if proxyStats, ok := m.info.ProxyStatistics[name]; ok {
|
||||
proxyStats.LastCloseTime = time.Now()
|
||||
proxyStats.LastCloseTime = m.clock.Now()
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
83
pkg/metrics/mem/server_test.go
Normal file
83
pkg/metrics/mem/server_test.go
Normal file
@@ -0,0 +1,83 @@
|
||||
package mem
|
||||
|
||||
import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
clocktesting "k8s.io/utils/clock/testing"
|
||||
)
|
||||
|
||||
func TestServerMetricsUsesClockForProxyTimestamps(t *testing.T) {
|
||||
require := require.New(t)
|
||||
|
||||
start := time.Date(2026, time.May, 8, 12, 30, 0, 0, time.UTC)
|
||||
clk := clocktesting.NewFakeClock(start)
|
||||
metrics := newServerMetricsWithClock(clk)
|
||||
|
||||
metrics.NewProxy("proxy", "tcp", "user", "client-id")
|
||||
require.Equal(start, metrics.info.ProxyStatistics["proxy"].LastStartTime)
|
||||
|
||||
closedAt := start.Add(time.Minute)
|
||||
clk.SetTime(closedAt)
|
||||
metrics.CloseProxy("proxy", "tcp")
|
||||
require.Equal(closedAt, metrics.info.ProxyStatistics["proxy"].LastCloseTime)
|
||||
}
|
||||
|
||||
func TestServerMetricsClearUselessInfoUsesClock(t *testing.T) {
|
||||
require := require.New(t)
|
||||
|
||||
start := time.Date(2026, time.May, 8, 12, 30, 0, 0, time.UTC)
|
||||
clk := clocktesting.NewFakeClock(start.Add(25 * time.Hour))
|
||||
metrics := newServerMetricsWithClock(clk)
|
||||
metrics.info.ProxyStatistics["proxy"] = &ProxyStatistics{
|
||||
Name: "proxy",
|
||||
LastStartTime: start.Add(-time.Hour),
|
||||
LastCloseTime: start,
|
||||
}
|
||||
|
||||
count, total := metrics.clearUselessInfo(24 * time.Hour)
|
||||
|
||||
require.Equal(1, count)
|
||||
require.Equal(1, total)
|
||||
require.Empty(metrics.info.ProxyStatistics)
|
||||
}
|
||||
|
||||
func TestServerMetricsRunUsesClockTicker(t *testing.T) {
|
||||
require := require.New(t)
|
||||
|
||||
start := time.Date(2026, time.May, 8, 12, 30, 0, 0, time.UTC)
|
||||
clk := clocktesting.NewFakeClock(start)
|
||||
metrics := newServerMetricsWithClock(clk)
|
||||
metrics.info.ProxyStatistics["proxy"] = &ProxyStatistics{
|
||||
Name: "proxy",
|
||||
LastStartTime: start.Add(-time.Hour),
|
||||
LastCloseTime: start,
|
||||
}
|
||||
|
||||
stopCh := make(chan struct{})
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
defer close(done)
|
||||
metrics.runUntil(stopCh)
|
||||
}()
|
||||
t.Cleanup(func() {
|
||||
close(stopCh)
|
||||
<-done
|
||||
})
|
||||
|
||||
require.Eventually(clk.HasWaiters, time.Second, time.Millisecond)
|
||||
clk.Step(8 * 24 * time.Hour)
|
||||
|
||||
require.Eventually(func() bool {
|
||||
return !metrics.hasProxyStatistics("proxy")
|
||||
}, time.Second, time.Millisecond)
|
||||
}
|
||||
|
||||
func (m *serverMetrics) hasProxyStatistics(name string) bool {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
|
||||
_, ok := m.info.ProxyStatistics[name]
|
||||
return ok
|
||||
}
|
||||
56
pkg/msg/conn_test.go
Normal file
56
pkg/msg/conn_test.go
Normal file
@@ -0,0 +1,56 @@
|
||||
// Copyright 2026 The frp Authors
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package msg
|
||||
|
||||
import (
|
||||
"net"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/fatedier/frp/pkg/proto/wire"
|
||||
)
|
||||
|
||||
func TestConnReadWriteMsg(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
protocol string
|
||||
}{
|
||||
{name: "v1", protocol: wire.ProtocolV1},
|
||||
{name: "v2", protocol: wire.ProtocolV2},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
client, server := net.Pipe()
|
||||
defer client.Close()
|
||||
defer server.Close()
|
||||
|
||||
clientConn := NewConn(client, NewReadWriter(client, tt.protocol))
|
||||
serverConn := NewConn(server, NewReadWriter(server, tt.protocol))
|
||||
|
||||
in := &Ping{PrivilegeKey: "key", Timestamp: 123}
|
||||
errCh := make(chan error, 1)
|
||||
go func() {
|
||||
errCh <- clientConn.WriteMsg(in)
|
||||
}()
|
||||
|
||||
out, err := serverConn.ReadMsg()
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, in, out)
|
||||
require.NoError(t, <-errCh)
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -15,10 +15,90 @@
|
||||
package msg
|
||||
|
||||
import (
|
||||
"context"
|
||||
"io"
|
||||
"net"
|
||||
"reflect"
|
||||
|
||||
"github.com/fatedier/frp/pkg/proto/wire"
|
||||
)
|
||||
|
||||
type ReadWriter interface {
|
||||
ReadMsg() (Message, error)
|
||||
ReadMsgInto(Message) error
|
||||
WriteMsg(Message) error
|
||||
}
|
||||
|
||||
type Conn struct {
|
||||
net.Conn
|
||||
rw ReadWriter
|
||||
}
|
||||
|
||||
func NewConn(conn net.Conn, rw ReadWriter) *Conn {
|
||||
return &Conn{
|
||||
Conn: conn,
|
||||
rw: rw,
|
||||
}
|
||||
}
|
||||
|
||||
func (c *Conn) ReadMsg() (Message, error) {
|
||||
return c.rw.ReadMsg()
|
||||
}
|
||||
|
||||
func (c *Conn) ReadMsgInto(m Message) error {
|
||||
return c.rw.ReadMsgInto(m)
|
||||
}
|
||||
|
||||
func (c *Conn) WriteMsg(m Message) error {
|
||||
return c.rw.WriteMsg(m)
|
||||
}
|
||||
|
||||
func (c *Conn) Context() context.Context {
|
||||
if getter, ok := c.Conn.(interface{ Context() context.Context }); ok {
|
||||
return getter.Context()
|
||||
}
|
||||
return context.Background()
|
||||
}
|
||||
|
||||
func (c *Conn) WithContext(ctx context.Context) {
|
||||
if setter, ok := c.Conn.(interface{ WithContext(context.Context) }); ok {
|
||||
setter.WithContext(ctx)
|
||||
}
|
||||
}
|
||||
|
||||
type V1ReadWriter struct {
|
||||
rw io.ReadWriter
|
||||
}
|
||||
|
||||
func NewV1ReadWriter(rw io.ReadWriter) ReadWriter {
|
||||
return &V1ReadWriter{rw: rw}
|
||||
}
|
||||
|
||||
// NewReadWriter wraps rw with the message codec for the selected wire protocol.
|
||||
// An empty protocol keeps the historical v1 behavior for tests and older call sites.
|
||||
func NewReadWriter(rw io.ReadWriter, wireProtocol string) ReadWriter {
|
||||
switch wireProtocol {
|
||||
case wire.ProtocolV2:
|
||||
return NewV2ReadWriter(rw)
|
||||
case "", wire.ProtocolV1:
|
||||
return NewV1ReadWriter(rw)
|
||||
default:
|
||||
return NewV1ReadWriter(rw)
|
||||
}
|
||||
}
|
||||
|
||||
func (rw *V1ReadWriter) ReadMsg() (Message, error) {
|
||||
return ReadMsg(rw.rw)
|
||||
}
|
||||
|
||||
func (rw *V1ReadWriter) ReadMsgInto(m Message) error {
|
||||
return ReadMsgInto(rw.rw, m)
|
||||
}
|
||||
|
||||
func (rw *V1ReadWriter) WriteMsg(m Message) error {
|
||||
return WriteMsg(rw.rw, m)
|
||||
}
|
||||
|
||||
func AsyncHandler(f func(Message)) func(Message) {
|
||||
return func(m Message) {
|
||||
go f(m)
|
||||
@@ -27,15 +107,14 @@ func AsyncHandler(f func(Message)) func(Message) {
|
||||
|
||||
// Dispatcher is used to send messages to net.Conn or register handlers for messages read from net.Conn.
|
||||
type Dispatcher struct {
|
||||
rw io.ReadWriter
|
||||
rw ReadWriter
|
||||
|
||||
sendCh chan Message
|
||||
doneCh chan struct{}
|
||||
msgHandlers map[reflect.Type]func(Message)
|
||||
defaultHandler func(Message)
|
||||
sendCh chan Message
|
||||
doneCh chan struct{}
|
||||
msgHandlers map[reflect.Type]func(Message)
|
||||
}
|
||||
|
||||
func NewDispatcher(rw io.ReadWriter) *Dispatcher {
|
||||
func NewDispatcher(rw ReadWriter) *Dispatcher {
|
||||
return &Dispatcher{
|
||||
rw: rw,
|
||||
sendCh: make(chan Message, 100),
|
||||
@@ -56,14 +135,14 @@ func (d *Dispatcher) sendLoop() {
|
||||
case <-d.doneCh:
|
||||
return
|
||||
case m := <-d.sendCh:
|
||||
_ = WriteMsg(d.rw, m)
|
||||
_ = d.rw.WriteMsg(m)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (d *Dispatcher) readLoop() {
|
||||
for {
|
||||
m, err := ReadMsg(d.rw)
|
||||
m, err := d.rw.ReadMsg()
|
||||
if err != nil {
|
||||
close(d.doneCh)
|
||||
return
|
||||
@@ -71,8 +150,6 @@ func (d *Dispatcher) readLoop() {
|
||||
|
||||
if handler, ok := d.msgHandlers[reflect.TypeOf(m)]; ok {
|
||||
handler(m)
|
||||
} else if d.defaultHandler != nil {
|
||||
d.defaultHandler(m)
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -90,10 +167,6 @@ func (d *Dispatcher) RegisterHandler(msg Message, handler func(Message)) {
|
||||
d.msgHandlers[reflect.TypeOf(msg)] = handler
|
||||
}
|
||||
|
||||
func (d *Dispatcher) RegisterDefaultHandler(handler func(Message)) {
|
||||
d.defaultHandler = handler
|
||||
}
|
||||
|
||||
func (d *Dispatcher) Done() chan struct{} {
|
||||
return d.doneCh
|
||||
}
|
||||
|
||||
@@ -20,24 +20,24 @@ import (
|
||||
)
|
||||
|
||||
const (
|
||||
TypeLogin = 'o'
|
||||
TypeLoginResp = '1'
|
||||
TypeNewProxy = 'p'
|
||||
TypeNewProxyResp = '2'
|
||||
TypeCloseProxy = 'c'
|
||||
TypeNewWorkConn = 'w'
|
||||
TypeReqWorkConn = 'r'
|
||||
TypeStartWorkConn = 's'
|
||||
TypeNewVisitorConn = 'v'
|
||||
TypeNewVisitorConnResp = '3'
|
||||
TypePing = 'h'
|
||||
TypePong = '4'
|
||||
TypeUDPPacket = 'u'
|
||||
TypeNatHoleVisitor = 'i'
|
||||
TypeNatHoleClient = 'n'
|
||||
TypeNatHoleResp = 'm'
|
||||
TypeNatHoleSid = '5'
|
||||
TypeNatHoleReport = '6'
|
||||
TypeLogin byte = 'o'
|
||||
TypeLoginResp byte = '1'
|
||||
TypeNewProxy byte = 'p'
|
||||
TypeNewProxyResp byte = '2'
|
||||
TypeCloseProxy byte = 'c'
|
||||
TypeNewWorkConn byte = 'w'
|
||||
TypeReqWorkConn byte = 'r'
|
||||
TypeStartWorkConn byte = 's'
|
||||
TypeNewVisitorConn byte = 'v'
|
||||
TypeNewVisitorConnResp byte = '3'
|
||||
TypePing byte = 'h'
|
||||
TypePong byte = '4'
|
||||
TypeUDPPacket byte = 'u'
|
||||
TypeNatHoleVisitor byte = 'i'
|
||||
TypeNatHoleClient byte = 'n'
|
||||
TypeNatHoleResp byte = 'm'
|
||||
TypeNatHoleSid byte = '5'
|
||||
TypeNatHoleReport byte = '6'
|
||||
)
|
||||
|
||||
var msgTypeMap = map[byte]any{
|
||||
|
||||
55
pkg/msg/msg_test.go
Normal file
55
pkg/msg/msg_test.go
Normal file
@@ -0,0 +1,55 @@
|
||||
// Copyright 2026 The frp Authors
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package msg
|
||||
|
||||
import (
|
||||
"reflect"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestV1MessageTypeIDsAreStable(t *testing.T) {
|
||||
require.Equal(t, byte('o'), TypeLogin)
|
||||
require.Equal(t, byte('1'), TypeLoginResp)
|
||||
require.Equal(t, byte('p'), TypeNewProxy)
|
||||
require.Equal(t, byte('2'), TypeNewProxyResp)
|
||||
require.Equal(t, byte('c'), TypeCloseProxy)
|
||||
require.Equal(t, byte('w'), TypeNewWorkConn)
|
||||
require.Equal(t, byte('r'), TypeReqWorkConn)
|
||||
require.Equal(t, byte('s'), TypeStartWorkConn)
|
||||
require.Equal(t, byte('v'), TypeNewVisitorConn)
|
||||
require.Equal(t, byte('3'), TypeNewVisitorConnResp)
|
||||
require.Equal(t, byte('h'), TypePing)
|
||||
require.Equal(t, byte('4'), TypePong)
|
||||
require.Equal(t, byte('u'), TypeUDPPacket)
|
||||
require.Equal(t, byte('i'), TypeNatHoleVisitor)
|
||||
require.Equal(t, byte('n'), TypeNatHoleClient)
|
||||
require.Equal(t, byte('m'), TypeNatHoleResp)
|
||||
require.Equal(t, byte('5'), TypeNatHoleSid)
|
||||
require.Equal(t, byte('6'), TypeNatHoleReport)
|
||||
}
|
||||
|
||||
func TestMessageTypeMapIsCompleteAndUnique(t *testing.T) {
|
||||
require.Len(t, msgTypeMap, 18)
|
||||
|
||||
msgTypes := make(map[reflect.Type]struct{}, len(msgTypeMap))
|
||||
|
||||
for _, m := range msgTypeMap {
|
||||
msgType := reflect.TypeOf(m)
|
||||
require.NotContains(t, msgTypes, msgType)
|
||||
msgTypes[msgType] = struct{}{}
|
||||
}
|
||||
}
|
||||
192
pkg/msg/wire_v2.go
Normal file
192
pkg/msg/wire_v2.go
Normal file
@@ -0,0 +1,192 @@
|
||||
// Copyright 2026 The frp Authors
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package msg
|
||||
|
||||
import (
|
||||
"encoding/binary"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"reflect"
|
||||
|
||||
"github.com/fatedier/frp/pkg/proto/wire"
|
||||
)
|
||||
|
||||
const (
|
||||
V2TypeLogin uint16 = 1
|
||||
V2TypeLoginResp uint16 = 2
|
||||
V2TypeNewProxy uint16 = 3
|
||||
V2TypeNewProxyResp uint16 = 4
|
||||
V2TypeCloseProxy uint16 = 5
|
||||
V2TypeNewWorkConn uint16 = 6
|
||||
V2TypeReqWorkConn uint16 = 7
|
||||
V2TypeStartWorkConn uint16 = 8
|
||||
V2TypeNewVisitorConn uint16 = 9
|
||||
V2TypeNewVisitorConnResp uint16 = 10
|
||||
V2TypePing uint16 = 11
|
||||
V2TypePong uint16 = 12
|
||||
V2TypeUDPPacket uint16 = 13
|
||||
V2TypeNatHoleVisitor uint16 = 14
|
||||
V2TypeNatHoleClient uint16 = 15
|
||||
V2TypeNatHoleResp uint16 = 16
|
||||
V2TypeNatHoleSid uint16 = 17
|
||||
V2TypeNatHoleReport uint16 = 18
|
||||
)
|
||||
|
||||
var v2MsgTypeMap = map[uint16]any{
|
||||
V2TypeLogin: Login{},
|
||||
V2TypeLoginResp: LoginResp{},
|
||||
V2TypeNewProxy: NewProxy{},
|
||||
V2TypeNewProxyResp: NewProxyResp{},
|
||||
V2TypeCloseProxy: CloseProxy{},
|
||||
V2TypeNewWorkConn: NewWorkConn{},
|
||||
V2TypeReqWorkConn: ReqWorkConn{},
|
||||
V2TypeStartWorkConn: StartWorkConn{},
|
||||
V2TypeNewVisitorConn: NewVisitorConn{},
|
||||
V2TypeNewVisitorConnResp: NewVisitorConnResp{},
|
||||
V2TypePing: Ping{},
|
||||
V2TypePong: Pong{},
|
||||
V2TypeUDPPacket: UDPPacket{},
|
||||
V2TypeNatHoleVisitor: NatHoleVisitor{},
|
||||
V2TypeNatHoleClient: NatHoleClient{},
|
||||
V2TypeNatHoleResp: NatHoleResp{},
|
||||
V2TypeNatHoleSid: NatHoleSid{},
|
||||
V2TypeNatHoleReport: NatHoleReport{},
|
||||
}
|
||||
|
||||
var v2MsgReflectTypeMap, v2MsgTypeIDMap = buildV2MsgTypeMaps()
|
||||
|
||||
func buildV2MsgTypeMaps() (map[uint16]reflect.Type, map[reflect.Type]uint16) {
|
||||
reflectTypeMap := make(map[uint16]reflect.Type, len(v2MsgTypeMap))
|
||||
typeIDMap := make(map[reflect.Type]uint16, len(v2MsgTypeMap))
|
||||
for typeID, m := range v2MsgTypeMap {
|
||||
t := reflect.TypeOf(m)
|
||||
reflectTypeMap[typeID] = t
|
||||
typeIDMap[t] = typeID
|
||||
}
|
||||
return reflectTypeMap, typeIDMap
|
||||
}
|
||||
|
||||
type V2ReadWriter struct {
|
||||
conn *wire.Conn
|
||||
}
|
||||
|
||||
func NewV2ReadWriter(rw io.ReadWriter) *V2ReadWriter {
|
||||
return NewV2ReadWriterWithConn(wire.NewConn(rw))
|
||||
}
|
||||
|
||||
func NewV2ReadWriterWithConn(conn *wire.Conn) *V2ReadWriter {
|
||||
return &V2ReadWriter{conn: conn}
|
||||
}
|
||||
|
||||
func (rw *V2ReadWriter) WireConn() *wire.Conn {
|
||||
return rw.conn
|
||||
}
|
||||
|
||||
func (rw *V2ReadWriter) ReadMsg() (Message, error) {
|
||||
f, err := rw.conn.ReadFrame()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return DecodeV2MessageFrame(f)
|
||||
}
|
||||
|
||||
func (rw *V2ReadWriter) ReadMsgInto(m Message) error {
|
||||
f, err := rw.conn.ReadFrame()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return DecodeV2MessageFrameInto(f, m)
|
||||
}
|
||||
|
||||
func (rw *V2ReadWriter) WriteMsg(m Message) error {
|
||||
f, err := EncodeV2MessageFrame(m)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return rw.conn.WriteFrame(f)
|
||||
}
|
||||
|
||||
func DecodeV2MessageFrame(f *wire.Frame) (Message, error) {
|
||||
if f.Type != wire.FrameTypeMessage {
|
||||
return nil, fmt.Errorf("unexpected frame type %d, want %d", f.Type, wire.FrameTypeMessage)
|
||||
}
|
||||
if len(f.Payload) < 2 {
|
||||
return nil, fmt.Errorf("message frame payload too short")
|
||||
}
|
||||
typeID := binary.BigEndian.Uint16(f.Payload[:2])
|
||||
t, ok := v2MsgReflectTypeMap[typeID]
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("unknown v2 message type %d", typeID)
|
||||
}
|
||||
m := reflect.New(t).Interface()
|
||||
if err := json.Unmarshal(f.Payload[2:], m); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return m, nil
|
||||
}
|
||||
|
||||
func DecodeV2MessageFrameInto(f *wire.Frame, out Message) error {
|
||||
if f.Type != wire.FrameTypeMessage {
|
||||
return fmt.Errorf("unexpected frame type %d, want %d", f.Type, wire.FrameTypeMessage)
|
||||
}
|
||||
if len(f.Payload) < 2 {
|
||||
return fmt.Errorf("message frame payload too short")
|
||||
}
|
||||
|
||||
typeID := binary.BigEndian.Uint16(f.Payload[:2])
|
||||
outType := reflect.TypeOf(out)
|
||||
if outType == nil || outType.Kind() != reflect.Pointer {
|
||||
return fmt.Errorf("message target must be a pointer")
|
||||
}
|
||||
elemType := outType.Elem()
|
||||
expectedTypeID, ok := v2MsgTypeIDMap[elemType]
|
||||
if !ok {
|
||||
return fmt.Errorf("unknown v2 message type %s", elemType.String())
|
||||
}
|
||||
if typeID != expectedTypeID {
|
||||
actualType, ok := v2MsgReflectTypeMap[typeID]
|
||||
if !ok {
|
||||
return fmt.Errorf("unknown v2 message type %d", typeID)
|
||||
}
|
||||
return fmt.Errorf("unexpected message type %s, want %s", actualType.String(), elemType.String())
|
||||
}
|
||||
return json.Unmarshal(f.Payload[2:], out)
|
||||
}
|
||||
|
||||
func EncodeV2MessageFrame(m Message) (*wire.Frame, error) {
|
||||
t := reflect.TypeOf(m)
|
||||
if t == nil {
|
||||
return nil, fmt.Errorf("nil message")
|
||||
}
|
||||
if t.Kind() == reflect.Pointer {
|
||||
t = t.Elem()
|
||||
}
|
||||
typeID, ok := v2MsgTypeIDMap[t]
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("unknown v2 message type %s", t.String())
|
||||
}
|
||||
content, err := json.Marshal(m)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
payload := make([]byte, 2+len(content))
|
||||
binary.BigEndian.PutUint16(payload[:2], typeID)
|
||||
copy(payload[2:], content)
|
||||
return &wire.Frame{
|
||||
Type: wire.FrameTypeMessage,
|
||||
Payload: payload,
|
||||
}, nil
|
||||
}
|
||||
121
pkg/msg/wire_v2_test.go
Normal file
121
pkg/msg/wire_v2_test.go
Normal file
@@ -0,0 +1,121 @@
|
||||
// Copyright 2026 The frp Authors
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package msg
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/binary"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/fatedier/frp/pkg/proto/wire"
|
||||
)
|
||||
|
||||
func TestV2ReadWriterRoundTrip(t *testing.T) {
|
||||
var buf bytes.Buffer
|
||||
rw := NewV2ReadWriter(&buf)
|
||||
|
||||
in := &Login{
|
||||
Version: "test-version",
|
||||
RunID: "run-id",
|
||||
User: "user",
|
||||
}
|
||||
require.NoError(t, rw.WriteMsg(in))
|
||||
|
||||
out, err := rw.ReadMsg()
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, in, out)
|
||||
}
|
||||
|
||||
func TestNewReadWriter(t *testing.T) {
|
||||
require.IsType(t, &V1ReadWriter{}, NewReadWriter(&bytes.Buffer{}, ""))
|
||||
require.IsType(t, &V1ReadWriter{}, NewReadWriter(&bytes.Buffer{}, wire.ProtocolV1))
|
||||
require.IsType(t, &V2ReadWriter{}, NewReadWriter(&bytes.Buffer{}, wire.ProtocolV2))
|
||||
}
|
||||
|
||||
func TestV2MessageTypeIDsAreStable(t *testing.T) {
|
||||
require.Equal(t, uint16(1), V2TypeLogin)
|
||||
require.Equal(t, uint16(2), V2TypeLoginResp)
|
||||
require.Equal(t, uint16(3), V2TypeNewProxy)
|
||||
require.Equal(t, uint16(4), V2TypeNewProxyResp)
|
||||
require.Equal(t, uint16(5), V2TypeCloseProxy)
|
||||
require.Equal(t, uint16(6), V2TypeNewWorkConn)
|
||||
require.Equal(t, uint16(7), V2TypeReqWorkConn)
|
||||
require.Equal(t, uint16(8), V2TypeStartWorkConn)
|
||||
require.Equal(t, uint16(9), V2TypeNewVisitorConn)
|
||||
require.Equal(t, uint16(10), V2TypeNewVisitorConnResp)
|
||||
require.Equal(t, uint16(11), V2TypePing)
|
||||
require.Equal(t, uint16(12), V2TypePong)
|
||||
require.Equal(t, uint16(13), V2TypeUDPPacket)
|
||||
require.Equal(t, uint16(14), V2TypeNatHoleVisitor)
|
||||
require.Equal(t, uint16(15), V2TypeNatHoleClient)
|
||||
require.Equal(t, uint16(16), V2TypeNatHoleResp)
|
||||
require.Equal(t, uint16(17), V2TypeNatHoleSid)
|
||||
require.Equal(t, uint16(18), V2TypeNatHoleReport)
|
||||
}
|
||||
|
||||
func TestV2MessageFrameEncoding(t *testing.T) {
|
||||
frame, err := EncodeV2MessageFrame(&ReqWorkConn{})
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, wire.FrameTypeMessage, frame.Type)
|
||||
require.Len(t, frame.Payload, 4)
|
||||
require.Equal(t, V2TypeReqWorkConn, binary.BigEndian.Uint16(frame.Payload[:2]))
|
||||
|
||||
out, err := DecodeV2MessageFrame(frame)
|
||||
require.NoError(t, err)
|
||||
require.IsType(t, &ReqWorkConn{}, out)
|
||||
}
|
||||
|
||||
func TestDecodeV2MessageFrameInto(t *testing.T) {
|
||||
in := &StartWorkConn{ProxyName: "tcp", SrcAddr: "127.0.0.1", SrcPort: 1234}
|
||||
frame, err := EncodeV2MessageFrame(in)
|
||||
require.NoError(t, err)
|
||||
|
||||
var out StartWorkConn
|
||||
require.NoError(t, DecodeV2MessageFrameInto(frame, &out))
|
||||
require.Equal(t, *in, out)
|
||||
}
|
||||
|
||||
func TestDecodeV2MessageFrameRejectsInvalidFrame(t *testing.T) {
|
||||
_, err := DecodeV2MessageFrame(&wire.Frame{Type: wire.FrameTypeClientHello})
|
||||
require.ErrorContains(t, err, "unexpected frame type")
|
||||
|
||||
_, err = DecodeV2MessageFrame(&wire.Frame{Type: wire.FrameTypeMessage, Payload: []byte{0}})
|
||||
require.ErrorContains(t, err, "payload too short")
|
||||
|
||||
payload := make([]byte, 4)
|
||||
binary.BigEndian.PutUint16(payload[:2], 65535)
|
||||
copy(payload[2:], []byte("{}"))
|
||||
_, err = DecodeV2MessageFrame(&wire.Frame{Type: wire.FrameTypeMessage, Payload: payload})
|
||||
require.ErrorContains(t, err, "unknown v2 message type")
|
||||
}
|
||||
|
||||
func TestDecodeV2MessageFrameIntoRejectsWrongTarget(t *testing.T) {
|
||||
frame, err := EncodeV2MessageFrame(&ReqWorkConn{})
|
||||
require.NoError(t, err)
|
||||
|
||||
var out StartWorkConn
|
||||
err = DecodeV2MessageFrameInto(frame, &out)
|
||||
require.ErrorContains(t, err, "unexpected message type")
|
||||
|
||||
err = DecodeV2MessageFrameInto(frame, StartWorkConn{})
|
||||
require.ErrorContains(t, err, "must be a pointer")
|
||||
}
|
||||
|
||||
func TestEncodeV2MessageFrameRejectsUnknownMessage(t *testing.T) {
|
||||
_, err := EncodeV2MessageFrame(struct{}{})
|
||||
require.ErrorContains(t, err, "unknown v2 message type")
|
||||
}
|
||||
@@ -21,6 +21,7 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/samber/lo"
|
||||
"k8s.io/utils/clock"
|
||||
)
|
||||
|
||||
var (
|
||||
@@ -144,19 +145,19 @@ func getBehaviorByModeAndIndex(mode int, index int) (RecommandBehavior, Recomman
|
||||
return behaviors[index].A, behaviors[index].B
|
||||
}
|
||||
|
||||
func getBehaviorScoresByMode(mode int, defaultScore int) []*BehaviorScore {
|
||||
func getBehaviorScoresByMode(mode int, defaultScore int) []*behaviorScore {
|
||||
return getBehaviorScoresByMode2(mode, defaultScore, defaultScore)
|
||||
}
|
||||
|
||||
func getBehaviorScoresByMode2(mode int, senderScore, receiverScore int) []*BehaviorScore {
|
||||
func getBehaviorScoresByMode2(mode int, senderScore, receiverScore int) []*behaviorScore {
|
||||
behaviors := getBehaviorByMode(mode)
|
||||
scores := make([]*BehaviorScore, 0, len(behaviors))
|
||||
scores := make([]*behaviorScore, 0, len(behaviors))
|
||||
for i := range behaviors {
|
||||
score := receiverScore
|
||||
if behaviors[i].A.Role == DetectRoleSender {
|
||||
score = senderScore
|
||||
}
|
||||
scores = append(scores, &BehaviorScore{Mode: mode, Index: i, Score: score})
|
||||
scores = append(scores, &behaviorScore{Mode: mode, Index: i, Score: score})
|
||||
}
|
||||
return scores
|
||||
}
|
||||
@@ -170,14 +171,18 @@ type RecommandBehavior struct {
|
||||
ListenRandomPorts int
|
||||
}
|
||||
|
||||
type MakeHoleRecords struct {
|
||||
type makeHoleRecords struct {
|
||||
mu sync.Mutex
|
||||
scores []*BehaviorScore
|
||||
LastUpdateTime time.Time
|
||||
scores []*behaviorScore
|
||||
clock clock.PassiveClock
|
||||
lastUpdateTime time.Time
|
||||
}
|
||||
|
||||
func NewMakeHoleRecords(c, v *NatFeature) *MakeHoleRecords {
|
||||
scores := []*BehaviorScore{}
|
||||
func newMakeHoleRecordsWithClock(c, v *NatFeature, clk clock.PassiveClock) *makeHoleRecords {
|
||||
if clk == nil {
|
||||
clk = clock.RealClock{}
|
||||
}
|
||||
scores := []*behaviorScore{}
|
||||
easyCount, hardCount, portsChangedRegularCount := ClassifyFeatureCount([]*NatFeature{c, v})
|
||||
appendMode0 := func() {
|
||||
switch {
|
||||
@@ -212,13 +217,17 @@ func NewMakeHoleRecords(c, v *NatFeature) *MakeHoleRecords {
|
||||
scores = append(scores, getBehaviorScoresByMode(DetectMode1, 1)...)
|
||||
scores = append(scores, getBehaviorScoresByMode(DetectMode3, 1)...)
|
||||
}
|
||||
return &MakeHoleRecords{scores: scores, LastUpdateTime: time.Now()}
|
||||
return &makeHoleRecords{
|
||||
scores: scores,
|
||||
clock: clk,
|
||||
lastUpdateTime: clk.Now(),
|
||||
}
|
||||
}
|
||||
|
||||
func (mhr *MakeHoleRecords) ReportSuccess(mode int, index int) {
|
||||
func (mhr *makeHoleRecords) reportSuccess(mode int, index int) {
|
||||
mhr.mu.Lock()
|
||||
defer mhr.mu.Unlock()
|
||||
mhr.LastUpdateTime = time.Now()
|
||||
mhr.lastUpdateTime = mhr.clock.Now()
|
||||
for i := range mhr.scores {
|
||||
score := mhr.scores[i]
|
||||
if score.Mode != mode || score.Index != index {
|
||||
@@ -231,22 +240,22 @@ func (mhr *MakeHoleRecords) ReportSuccess(mode int, index int) {
|
||||
}
|
||||
}
|
||||
|
||||
func (mhr *MakeHoleRecords) Recommand() (mode, index int) {
|
||||
func (mhr *makeHoleRecords) recommand() (mode, index int) {
|
||||
mhr.mu.Lock()
|
||||
defer mhr.mu.Unlock()
|
||||
|
||||
if len(mhr.scores) == 0 {
|
||||
return 0, 0
|
||||
}
|
||||
maxScore := slices.MaxFunc(mhr.scores, func(a, b *BehaviorScore) int {
|
||||
maxScore := slices.MaxFunc(mhr.scores, func(a, b *behaviorScore) int {
|
||||
return cmp.Compare(a.Score, b.Score)
|
||||
})
|
||||
maxScore.Score--
|
||||
mhr.LastUpdateTime = time.Now()
|
||||
mhr.lastUpdateTime = mhr.clock.Now()
|
||||
return maxScore.Mode, maxScore.Index
|
||||
}
|
||||
|
||||
type BehaviorScore struct {
|
||||
type behaviorScore struct {
|
||||
Mode int
|
||||
Index int
|
||||
// between -10 and 10
|
||||
@@ -255,16 +264,25 @@ type BehaviorScore struct {
|
||||
|
||||
type Analyzer struct {
|
||||
// key is client ip + visitor ip
|
||||
records map[string]*MakeHoleRecords
|
||||
records map[string]*makeHoleRecords
|
||||
dataReserveDuration time.Duration
|
||||
clock clock.PassiveClock
|
||||
|
||||
mu sync.Mutex
|
||||
}
|
||||
|
||||
func NewAnalyzer(dataReserveDuration time.Duration) *Analyzer {
|
||||
return newAnalyzerWithClock(dataReserveDuration, clock.RealClock{})
|
||||
}
|
||||
|
||||
func newAnalyzerWithClock(dataReserveDuration time.Duration, clk clock.PassiveClock) *Analyzer {
|
||||
if clk == nil {
|
||||
clk = clock.RealClock{}
|
||||
}
|
||||
return &Analyzer{
|
||||
records: make(map[string]*MakeHoleRecords),
|
||||
records: make(map[string]*makeHoleRecords),
|
||||
dataReserveDuration: dataReserveDuration,
|
||||
clock: clk,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -272,12 +290,12 @@ func (a *Analyzer) GetRecommandBehaviors(key string, c, v *NatFeature) (mode, in
|
||||
a.mu.Lock()
|
||||
records, ok := a.records[key]
|
||||
if !ok {
|
||||
records = NewMakeHoleRecords(c, v)
|
||||
records = newMakeHoleRecordsWithClock(c, v, a.clock)
|
||||
a.records[key] = records
|
||||
}
|
||||
a.mu.Unlock()
|
||||
|
||||
mode, index = records.Recommand()
|
||||
mode, index = records.recommand()
|
||||
cBehavior, vBehavior := getBehaviorByModeAndIndex(mode, index)
|
||||
|
||||
switch mode {
|
||||
@@ -307,11 +325,11 @@ func (a *Analyzer) ReportSuccess(key string, mode, index int) {
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
records.ReportSuccess(mode, index)
|
||||
records.reportSuccess(mode, index)
|
||||
}
|
||||
|
||||
func (a *Analyzer) Clean() (int, int) {
|
||||
now := time.Now()
|
||||
now := a.clock.Now()
|
||||
total := 0
|
||||
count := 0
|
||||
|
||||
@@ -321,7 +339,7 @@ func (a *Analyzer) Clean() (int, int) {
|
||||
total = len(a.records)
|
||||
// clean up records that have not been used for a period of time.
|
||||
for key, records := range a.records {
|
||||
if now.Sub(records.LastUpdateTime) > a.dataReserveDuration {
|
||||
if now.Sub(records.lastUpdateTime) > a.dataReserveDuration {
|
||||
delete(a.records, key)
|
||||
count++
|
||||
}
|
||||
|
||||
33
pkg/nathole/analysis_test.go
Normal file
33
pkg/nathole/analysis_test.go
Normal file
@@ -0,0 +1,33 @@
|
||||
package nathole
|
||||
|
||||
import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
clocktesting "k8s.io/utils/clock/testing"
|
||||
)
|
||||
|
||||
func TestAnalyzerUsesClockForRecordTimestamps(t *testing.T) {
|
||||
require := require.New(t)
|
||||
|
||||
start := time.Date(2026, time.May, 8, 12, 30, 0, 0, time.UTC)
|
||||
clk := clocktesting.NewFakeClock(start)
|
||||
analyzer := newAnalyzerWithClock(time.Hour, clk)
|
||||
clientFeature := &NatFeature{NatType: EasyNAT, Behavior: BehaviorNoChange}
|
||||
visitorFeature := &NatFeature{NatType: EasyNAT, Behavior: BehaviorNoChange}
|
||||
|
||||
mode, index, _, _ := analyzer.GetRecommandBehaviors("key", clientFeature, visitorFeature)
|
||||
require.Equal(start, analyzer.records["key"].lastUpdateTime)
|
||||
|
||||
updatedAt := start.Add(time.Minute)
|
||||
clk.SetTime(updatedAt)
|
||||
analyzer.ReportSuccess("key", mode, index)
|
||||
require.Equal(updatedAt, analyzer.records["key"].lastUpdateTime)
|
||||
|
||||
clk.SetTime(start.Add(2 * time.Hour))
|
||||
count, total := analyzer.Clean()
|
||||
require.Equal(1, count)
|
||||
require.Equal(1, total)
|
||||
require.Empty(analyzer.records)
|
||||
}
|
||||
@@ -326,40 +326,16 @@ func (c *Controller) analysis(session *Session) (*msg.NatHoleResp, *msg.NatHoleR
|
||||
}
|
||||
|
||||
protocol := vm.Protocol
|
||||
vResp := &msg.NatHoleResp{
|
||||
TransactionID: vm.TransactionID,
|
||||
Sid: session.sid,
|
||||
Protocol: protocol,
|
||||
CandidateAddrs: slices.Compact(cm.MappedAddrs),
|
||||
AssistedAddrs: slices.Compact(cm.AssistedAddrs),
|
||||
DetectBehavior: msg.NatHoleDetectBehavior{
|
||||
Mode: mode,
|
||||
Role: vBehavior.Role,
|
||||
TTL: vBehavior.TTL,
|
||||
SendDelayMs: vBehavior.SendDelayMs,
|
||||
ReadTimeoutMs: timeoutMs - vBehavior.SendDelayMs,
|
||||
SendRandomPorts: vBehavior.PortsRandomNumber,
|
||||
ListenRandomPorts: vBehavior.ListenRandomPorts,
|
||||
CandidatePorts: getRangePorts(cm.MappedAddrs, cNatFeature.PortsDifference, vBehavior.PortsRangeNumber),
|
||||
},
|
||||
}
|
||||
cResp := &msg.NatHoleResp{
|
||||
TransactionID: cm.TransactionID,
|
||||
Sid: session.sid,
|
||||
Protocol: protocol,
|
||||
CandidateAddrs: slices.Compact(vm.MappedAddrs),
|
||||
AssistedAddrs: slices.Compact(vm.AssistedAddrs),
|
||||
DetectBehavior: msg.NatHoleDetectBehavior{
|
||||
Mode: mode,
|
||||
Role: cBehavior.Role,
|
||||
TTL: cBehavior.TTL,
|
||||
SendDelayMs: cBehavior.SendDelayMs,
|
||||
ReadTimeoutMs: timeoutMs - cBehavior.SendDelayMs,
|
||||
SendRandomPorts: cBehavior.PortsRandomNumber,
|
||||
ListenRandomPorts: cBehavior.ListenRandomPorts,
|
||||
CandidatePorts: getRangePorts(vm.MappedAddrs, vNatFeature.PortsDifference, cBehavior.PortsRangeNumber),
|
||||
},
|
||||
}
|
||||
vResp := newNatHoleResponse(
|
||||
vm.TransactionID, session.sid, protocol, mode,
|
||||
cm.MappedAddrs, cm.AssistedAddrs, vBehavior,
|
||||
timeoutMs-vBehavior.SendDelayMs, cNatFeature.PortsDifference,
|
||||
)
|
||||
cResp := newNatHoleResponse(
|
||||
cm.TransactionID, session.sid, protocol, mode,
|
||||
vm.MappedAddrs, vm.AssistedAddrs, cBehavior,
|
||||
timeoutMs-cBehavior.SendDelayMs, vNatFeature.PortsDifference,
|
||||
)
|
||||
|
||||
log.Debugf("sid [%s] visitor nat: %+v, candidateAddrs: %v; client nat: %+v, candidateAddrs: %v, protocol: %s",
|
||||
session.sid, *vNatFeature, vm.MappedAddrs, *cNatFeature, cm.MappedAddrs, protocol)
|
||||
@@ -368,6 +344,38 @@ func (c *Controller) analysis(session *Session) (*msg.NatHoleResp, *msg.NatHoleR
|
||||
return vResp, cResp, nil
|
||||
}
|
||||
|
||||
func newNatHoleResponse(
|
||||
transactionID string,
|
||||
sid string,
|
||||
protocol string,
|
||||
mode int,
|
||||
candidateAddrs []string,
|
||||
assistedAddrs []string,
|
||||
behavior RecommandBehavior,
|
||||
readTimeoutMs int,
|
||||
portsDifference int,
|
||||
) *msg.NatHoleResp {
|
||||
compactCandidateAddrs := slices.Compact(candidateAddrs)
|
||||
compactAssistedAddrs := slices.Compact(assistedAddrs)
|
||||
return &msg.NatHoleResp{
|
||||
TransactionID: transactionID,
|
||||
Sid: sid,
|
||||
Protocol: protocol,
|
||||
CandidateAddrs: compactCandidateAddrs,
|
||||
AssistedAddrs: compactAssistedAddrs,
|
||||
DetectBehavior: msg.NatHoleDetectBehavior{
|
||||
Mode: mode,
|
||||
Role: behavior.Role,
|
||||
TTL: behavior.TTL,
|
||||
SendDelayMs: behavior.SendDelayMs,
|
||||
ReadTimeoutMs: readTimeoutMs,
|
||||
SendRandomPorts: behavior.PortsRandomNumber,
|
||||
ListenRandomPorts: behavior.ListenRandomPorts,
|
||||
CandidatePorts: getRangePorts(candidateAddrs, portsDifference, behavior.PortsRangeNumber),
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func getRangePorts(addrs []string, difference, maxNumber int) []msg.PortsRange {
|
||||
if maxNumber <= 0 {
|
||||
return nil
|
||||
|
||||
@@ -17,17 +17,9 @@
|
||||
package client
|
||||
|
||||
import (
|
||||
"context"
|
||||
stdlog "log"
|
||||
"net/http"
|
||||
"net/http/httputil"
|
||||
"time"
|
||||
|
||||
"github.com/fatedier/golib/pool"
|
||||
|
||||
v1 "github.com/fatedier/frp/pkg/config/v1"
|
||||
"github.com/fatedier/frp/pkg/util/log"
|
||||
netpkg "github.com/fatedier/frp/pkg/util/net"
|
||||
)
|
||||
|
||||
func init() {
|
||||
@@ -37,57 +29,28 @@ func init() {
|
||||
type HTTP2HTTPPlugin struct {
|
||||
opts *v1.HTTP2HTTPPluginOptions
|
||||
|
||||
l *Listener
|
||||
s *http.Server
|
||||
*httpBridgePlugin
|
||||
}
|
||||
|
||||
func NewHTTP2HTTPPlugin(_ PluginContext, options v1.ClientPluginOptions) (Plugin, error) {
|
||||
opts := options.(*v1.HTTP2HTTPPluginOptions)
|
||||
|
||||
listener := NewProxyListener()
|
||||
|
||||
p := &HTTP2HTTPPlugin{
|
||||
opts: opts,
|
||||
l: listener,
|
||||
}
|
||||
|
||||
rp := &httputil.ReverseProxy{
|
||||
Rewrite: func(r *httputil.ProxyRequest) {
|
||||
rp := newHTTPBridgeReverseProxy(
|
||||
func(r *httputil.ProxyRequest) {
|
||||
req := r.Out
|
||||
req.URL.Scheme = "http"
|
||||
req.URL.Host = p.opts.LocalAddr
|
||||
if p.opts.HostHeaderRewrite != "" {
|
||||
req.Host = p.opts.HostHeaderRewrite
|
||||
}
|
||||
for k, v := range p.opts.RequestHeaders.Set {
|
||||
req.Header.Set(k, v)
|
||||
}
|
||||
rewriteHTTPPluginRequest(req, "http", p.opts.LocalAddr, p.opts.HostHeaderRewrite, p.opts.RequestHeaders)
|
||||
},
|
||||
BufferPool: pool.NewBuffer(32 * 1024),
|
||||
ErrorLog: stdlog.New(log.NewWriteLogger(log.WarnLevel, 2), "", 0),
|
||||
}
|
||||
|
||||
p.s = &http.Server{
|
||||
Handler: rp,
|
||||
ReadHeaderTimeout: 60 * time.Second,
|
||||
}
|
||||
|
||||
go func() {
|
||||
_ = p.s.Serve(listener)
|
||||
}()
|
||||
nil,
|
||||
)
|
||||
p.httpBridgePlugin = newHTTPBridgePluginServer(rp, false)
|
||||
|
||||
return p, nil
|
||||
}
|
||||
|
||||
func (p *HTTP2HTTPPlugin) Handle(_ context.Context, connInfo *ConnectionInfo) {
|
||||
wrapConn := netpkg.WrapReadWriteCloserToConn(connInfo.Conn, connInfo.UnderlyingConn)
|
||||
_ = p.l.PutConn(wrapConn)
|
||||
}
|
||||
|
||||
func (p *HTTP2HTTPPlugin) Name() string {
|
||||
return v1.PluginHTTP2HTTP
|
||||
}
|
||||
|
||||
func (p *HTTP2HTTPPlugin) Close() error {
|
||||
return p.s.Close()
|
||||
}
|
||||
|
||||
@@ -17,18 +17,11 @@
|
||||
package client
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/tls"
|
||||
stdlog "log"
|
||||
"net/http"
|
||||
"net/http/httputil"
|
||||
"time"
|
||||
|
||||
"github.com/fatedier/golib/pool"
|
||||
|
||||
v1 "github.com/fatedier/frp/pkg/config/v1"
|
||||
"github.com/fatedier/frp/pkg/util/log"
|
||||
netpkg "github.com/fatedier/frp/pkg/util/net"
|
||||
)
|
||||
|
||||
func init() {
|
||||
@@ -38,65 +31,35 @@ func init() {
|
||||
type HTTP2HTTPSPlugin struct {
|
||||
opts *v1.HTTP2HTTPSPluginOptions
|
||||
|
||||
l *Listener
|
||||
s *http.Server
|
||||
*httpBridgePlugin
|
||||
}
|
||||
|
||||
func NewHTTP2HTTPSPlugin(_ PluginContext, options v1.ClientPluginOptions) (Plugin, error) {
|
||||
opts := options.(*v1.HTTP2HTTPSPluginOptions)
|
||||
|
||||
listener := NewProxyListener()
|
||||
|
||||
p := &HTTP2HTTPSPlugin{
|
||||
opts: opts,
|
||||
l: listener,
|
||||
}
|
||||
|
||||
tr := &http.Transport{
|
||||
TLSClientConfig: &tls.Config{InsecureSkipVerify: true},
|
||||
}
|
||||
|
||||
rp := &httputil.ReverseProxy{
|
||||
Rewrite: func(r *httputil.ProxyRequest) {
|
||||
rp := newHTTPBridgeReverseProxy(
|
||||
func(r *httputil.ProxyRequest) {
|
||||
r.Out.Header["X-Forwarded-For"] = r.In.Header["X-Forwarded-For"]
|
||||
r.Out.Header["X-Forwarded-Host"] = r.In.Header["X-Forwarded-Host"]
|
||||
r.Out.Header["X-Forwarded-Proto"] = r.In.Header["X-Forwarded-Proto"]
|
||||
req := r.Out
|
||||
req.URL.Scheme = "https"
|
||||
req.URL.Host = p.opts.LocalAddr
|
||||
if p.opts.HostHeaderRewrite != "" {
|
||||
req.Host = p.opts.HostHeaderRewrite
|
||||
}
|
||||
for k, v := range p.opts.RequestHeaders.Set {
|
||||
req.Header.Set(k, v)
|
||||
}
|
||||
rewriteHTTPPluginRequest(req, "https", p.opts.LocalAddr, p.opts.HostHeaderRewrite, p.opts.RequestHeaders)
|
||||
},
|
||||
Transport: tr,
|
||||
BufferPool: pool.NewBuffer(32 * 1024),
|
||||
ErrorLog: stdlog.New(log.NewWriteLogger(log.WarnLevel, 2), "", 0),
|
||||
}
|
||||
|
||||
p.s = &http.Server{
|
||||
Handler: rp,
|
||||
ReadHeaderTimeout: 60 * time.Second,
|
||||
}
|
||||
|
||||
go func() {
|
||||
_ = p.s.Serve(listener)
|
||||
}()
|
||||
tr,
|
||||
)
|
||||
p.httpBridgePlugin = newHTTPBridgePluginServer(rp, false)
|
||||
|
||||
return p, nil
|
||||
}
|
||||
|
||||
func (p *HTTP2HTTPSPlugin) Handle(_ context.Context, connInfo *ConnectionInfo) {
|
||||
wrapConn := netpkg.WrapReadWriteCloserToConn(connInfo.Conn, connInfo.UnderlyingConn)
|
||||
_ = p.l.PutConn(wrapConn)
|
||||
}
|
||||
|
||||
func (p *HTTP2HTTPSPlugin) Name() string {
|
||||
return v1.PluginHTTP2HTTPS
|
||||
}
|
||||
|
||||
func (p *HTTP2HTTPSPlugin) Close() error {
|
||||
return p.s.Close()
|
||||
}
|
||||
|
||||
126
pkg/plugin/client/http_common.go
Normal file
126
pkg/plugin/client/http_common.go
Normal file
@@ -0,0 +1,126 @@
|
||||
// Copyright 2026 The frp Authors
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
//go:build !frps
|
||||
|
||||
package client
|
||||
|
||||
import (
|
||||
"context"
|
||||
stdlog "log"
|
||||
"net/http"
|
||||
"net/http/httputil"
|
||||
"time"
|
||||
|
||||
"github.com/fatedier/golib/pool"
|
||||
|
||||
v1 "github.com/fatedier/frp/pkg/config/v1"
|
||||
"github.com/fatedier/frp/pkg/plugin/client/internal/httpsserver"
|
||||
"github.com/fatedier/frp/pkg/util/log"
|
||||
netpkg "github.com/fatedier/frp/pkg/util/net"
|
||||
)
|
||||
|
||||
const httpBridgeReadHeaderTimeout = 60 * time.Second
|
||||
|
||||
func rewriteHTTPPluginRequest(
|
||||
req *http.Request,
|
||||
scheme string,
|
||||
localAddr string,
|
||||
hostHeaderRewrite string,
|
||||
requestHeaders v1.HeaderOperations,
|
||||
) {
|
||||
req.URL.Scheme = scheme
|
||||
req.URL.Host = localAddr
|
||||
if hostHeaderRewrite != "" {
|
||||
req.Host = hostHeaderRewrite
|
||||
}
|
||||
for k, v := range requestHeaders.Set {
|
||||
req.Header.Set(k, v)
|
||||
}
|
||||
}
|
||||
|
||||
type httpBridgePlugin struct {
|
||||
l *Listener
|
||||
s *http.Server
|
||||
|
||||
useSourceRemoteAddr bool
|
||||
}
|
||||
|
||||
func newHTTPBridgePluginServer(handler http.Handler, useSourceRemoteAddr bool) *httpBridgePlugin {
|
||||
listener := NewProxyListener()
|
||||
p := &httpBridgePlugin{
|
||||
l: listener,
|
||||
useSourceRemoteAddr: useSourceRemoteAddr,
|
||||
}
|
||||
p.s = &http.Server{
|
||||
Handler: handler,
|
||||
ReadHeaderTimeout: httpBridgeReadHeaderTimeout,
|
||||
}
|
||||
go func() {
|
||||
_ = p.s.Serve(listener)
|
||||
}()
|
||||
return p
|
||||
}
|
||||
|
||||
func newHTTPSBridgePluginServer(
|
||||
handler http.Handler,
|
||||
crtPath string,
|
||||
keyPath string,
|
||||
enableHTTP2 *bool,
|
||||
useSourceRemoteAddr bool,
|
||||
) (*httpBridgePlugin, error) {
|
||||
listener := NewProxyListener()
|
||||
server, err := httpsserver.New(handler, crtPath, keyPath, enableHTTP2)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
p := &httpBridgePlugin{
|
||||
l: listener,
|
||||
s: server,
|
||||
useSourceRemoteAddr: useSourceRemoteAddr,
|
||||
}
|
||||
go func() {
|
||||
_ = p.s.ServeTLS(listener, "", "")
|
||||
}()
|
||||
return p, nil
|
||||
}
|
||||
|
||||
func newHTTPBridgeReverseProxy(
|
||||
rewrite func(*httputil.ProxyRequest),
|
||||
transport http.RoundTripper,
|
||||
) *httputil.ReverseProxy {
|
||||
rp := &httputil.ReverseProxy{
|
||||
Rewrite: rewrite,
|
||||
BufferPool: pool.NewBuffer(32 * 1024),
|
||||
ErrorLog: stdlog.New(log.NewWriteLogger(log.WarnLevel, 2), "", 0),
|
||||
}
|
||||
if transport != nil {
|
||||
rp.Transport = transport
|
||||
}
|
||||
return rp
|
||||
}
|
||||
|
||||
func (p *httpBridgePlugin) Handle(_ context.Context, connInfo *ConnectionInfo) {
|
||||
wrapConn := netpkg.WrapReadWriteCloserToConn(connInfo.Conn, connInfo.UnderlyingConn)
|
||||
if p.useSourceRemoteAddr && connInfo.SrcAddr != nil {
|
||||
wrapConn.SetRemoteAddr(connInfo.SrcAddr)
|
||||
}
|
||||
_ = p.l.PutConn(wrapConn)
|
||||
}
|
||||
|
||||
func (p *httpBridgePlugin) Close() error {
|
||||
err := p.s.Close()
|
||||
_ = p.l.Close()
|
||||
return err
|
||||
}
|
||||
@@ -45,6 +45,8 @@ type HTTPProxy struct {
|
||||
s *http.Server
|
||||
}
|
||||
|
||||
const httpProxyReadHeaderTimeout = 60 * time.Second
|
||||
|
||||
func NewHTTPProxyPlugin(_ PluginContext, options v1.ClientPluginOptions) (Plugin, error) {
|
||||
opts := options.(*v1.HTTPProxyPluginOptions)
|
||||
listener := NewProxyListener()
|
||||
@@ -56,7 +58,7 @@ func NewHTTPProxyPlugin(_ PluginContext, options v1.ClientPluginOptions) (Plugin
|
||||
|
||||
hp.s = &http.Server{
|
||||
Handler: hp,
|
||||
ReadHeaderTimeout: 60 * time.Second,
|
||||
ReadHeaderTimeout: httpProxyReadHeaderTimeout,
|
||||
}
|
||||
|
||||
go func() {
|
||||
@@ -73,16 +75,19 @@ func (hp *HTTPProxy) Handle(_ context.Context, connInfo *ConnectionInfo) {
|
||||
wrapConn := netpkg.WrapReadWriteCloserToConn(connInfo.Conn, connInfo.UnderlyingConn)
|
||||
|
||||
sc, rd := libnet.NewSharedConn(wrapConn)
|
||||
firstBytes := make([]byte, 7)
|
||||
_, err := rd.Read(firstBytes)
|
||||
firstBytes := make([]byte, len(http.MethodConnect))
|
||||
_ = wrapConn.SetReadDeadline(time.Now().Add(httpProxyReadHeaderTimeout))
|
||||
_, err := io.ReadFull(rd, firstBytes)
|
||||
if err != nil {
|
||||
_ = wrapConn.SetReadDeadline(time.Time{})
|
||||
wrapConn.Close()
|
||||
return
|
||||
}
|
||||
|
||||
if strings.ToUpper(string(firstBytes)) == "CONNECT" {
|
||||
if strings.EqualFold(string(firstBytes), http.MethodConnect) {
|
||||
bufRd := bufio.NewReader(sc)
|
||||
request, err := http.ReadRequest(bufRd)
|
||||
_ = wrapConn.SetReadDeadline(time.Time{})
|
||||
if err != nil {
|
||||
wrapConn.Close()
|
||||
return
|
||||
@@ -91,6 +96,7 @@ func (hp *HTTPProxy) Handle(_ context.Context, connInfo *ConnectionInfo) {
|
||||
return
|
||||
}
|
||||
|
||||
_ = wrapConn.SetReadDeadline(time.Time{})
|
||||
_ = hp.l.PutConn(sc)
|
||||
}
|
||||
|
||||
@@ -107,13 +113,7 @@ func (hp *HTTPProxy) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
|
||||
return
|
||||
}
|
||||
|
||||
if req.Method == http.MethodConnect {
|
||||
// deprecated
|
||||
// Connect request is handled in Handle function.
|
||||
hp.ConnectHandler(rw, req)
|
||||
} else {
|
||||
hp.HTTPHandler(rw, req)
|
||||
}
|
||||
hp.HTTPHandler(rw, req)
|
||||
}
|
||||
|
||||
func (hp *HTTPProxy) HTTPHandler(rw http.ResponseWriter, req *http.Request) {
|
||||
@@ -135,33 +135,6 @@ func (hp *HTTPProxy) HTTPHandler(rw http.ResponseWriter, req *http.Request) {
|
||||
}
|
||||
}
|
||||
|
||||
// deprecated
|
||||
// Hijack needs to SetReadDeadline on the Conn of the request, but if we use stream compression here,
|
||||
// we may always get i/o timeout error.
|
||||
func (hp *HTTPProxy) ConnectHandler(rw http.ResponseWriter, req *http.Request) {
|
||||
hj, ok := rw.(http.Hijacker)
|
||||
if !ok {
|
||||
rw.WriteHeader(http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
client, _, err := hj.Hijack()
|
||||
if err != nil {
|
||||
rw.WriteHeader(http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
remote, err := net.Dial("tcp", req.URL.Host)
|
||||
if err != nil {
|
||||
http.Error(rw, "Failed", http.StatusBadRequest)
|
||||
client.Close()
|
||||
return
|
||||
}
|
||||
_, _ = client.Write([]byte("HTTP/1.1 200 OK\r\n\r\n"))
|
||||
|
||||
go libio.Join(remote, client)
|
||||
}
|
||||
|
||||
func (hp *HTTPProxy) Auth(req *http.Request) bool {
|
||||
if hp.opts.HTTPUser == "" && hp.opts.HTTPPassword == "" {
|
||||
return true
|
||||
|
||||
107
pkg/plugin/client/http_proxy_test.go
Normal file
107
pkg/plugin/client/http_proxy_test.go
Normal file
@@ -0,0 +1,107 @@
|
||||
// Copyright 2026 The frp Authors
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
//go:build !frps
|
||||
|
||||
package client
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"context"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
v1 "github.com/fatedier/frp/pkg/config/v1"
|
||||
)
|
||||
|
||||
func TestHTTPProxyHandleFragmentedConnectMethod(t *testing.T) {
|
||||
require := require.New(t)
|
||||
|
||||
ln, err := net.Listen("tcp", "127.0.0.1:0")
|
||||
require.NoError(err)
|
||||
defer ln.Close()
|
||||
|
||||
const payload = "ping"
|
||||
echoErr := make(chan error, 1)
|
||||
go func() {
|
||||
conn, err := ln.Accept()
|
||||
if err != nil {
|
||||
echoErr <- err
|
||||
return
|
||||
}
|
||||
defer conn.Close()
|
||||
|
||||
buf := make([]byte, len(payload))
|
||||
if _, err = io.ReadFull(conn, buf); err != nil {
|
||||
echoErr <- err
|
||||
return
|
||||
}
|
||||
if string(buf) != payload {
|
||||
echoErr <- fmt.Errorf("unexpected payload %q", string(buf))
|
||||
return
|
||||
}
|
||||
_, err = conn.Write([]byte("echo:" + payload))
|
||||
echoErr <- err
|
||||
}()
|
||||
|
||||
hp := &HTTPProxy{
|
||||
opts: &v1.HTTPProxyPluginOptions{},
|
||||
l: NewProxyListener(),
|
||||
}
|
||||
|
||||
clientConn, serverConn := net.Pipe()
|
||||
defer clientConn.Close()
|
||||
|
||||
go hp.Handle(context.Background(), &ConnectionInfo{
|
||||
Conn: serverConn,
|
||||
UnderlyingConn: serverConn,
|
||||
})
|
||||
|
||||
require.NoError(clientConn.SetDeadline(time.Now().Add(5 * time.Second)))
|
||||
|
||||
targetAddr := ln.Addr().String()
|
||||
req := "CONNECT " + targetAddr + " HTTP/1.1\r\nHost: " + targetAddr + "\r\n\r\n"
|
||||
_, err = clientConn.Write([]byte("CON"))
|
||||
require.NoError(err)
|
||||
_, err = clientConn.Write([]byte(req[len("CON"):]))
|
||||
require.NoError(err)
|
||||
|
||||
rd := bufio.NewReader(clientConn)
|
||||
status, err := rd.ReadString('\n')
|
||||
require.NoError(err)
|
||||
require.Equal("HTTP/1.1 200 OK\r\n", status)
|
||||
line, err := rd.ReadString('\n')
|
||||
require.NoError(err)
|
||||
require.Equal("\r\n", line)
|
||||
|
||||
_, err = clientConn.Write([]byte(payload))
|
||||
require.NoError(err)
|
||||
|
||||
got := make([]byte, len("echo:"+payload))
|
||||
_, err = io.ReadFull(rd, got)
|
||||
require.NoError(err)
|
||||
require.Equal("echo:"+payload, string(got))
|
||||
|
||||
select {
|
||||
case err := <-echoErr:
|
||||
require.NoError(err)
|
||||
case <-time.After(time.Second):
|
||||
t.Fatal("timed out waiting for echo server")
|
||||
}
|
||||
}
|
||||
@@ -17,22 +17,9 @@
|
||||
package client
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"fmt"
|
||||
stdlog "log"
|
||||
"net/http"
|
||||
"net/http/httputil"
|
||||
"time"
|
||||
|
||||
"github.com/fatedier/golib/pool"
|
||||
"github.com/samber/lo"
|
||||
|
||||
v1 "github.com/fatedier/frp/pkg/config/v1"
|
||||
"github.com/fatedier/frp/pkg/transport"
|
||||
httppkg "github.com/fatedier/frp/pkg/util/http"
|
||||
"github.com/fatedier/frp/pkg/util/log"
|
||||
netpkg "github.com/fatedier/frp/pkg/util/net"
|
||||
)
|
||||
|
||||
func init() {
|
||||
@@ -42,80 +29,35 @@ func init() {
|
||||
type HTTPS2HTTPPlugin struct {
|
||||
opts *v1.HTTPS2HTTPPluginOptions
|
||||
|
||||
l *Listener
|
||||
s *http.Server
|
||||
*httpBridgePlugin
|
||||
}
|
||||
|
||||
func NewHTTPS2HTTPPlugin(_ PluginContext, options v1.ClientPluginOptions) (Plugin, error) {
|
||||
opts := options.(*v1.HTTPS2HTTPPluginOptions)
|
||||
listener := NewProxyListener()
|
||||
|
||||
p := &HTTPS2HTTPPlugin{
|
||||
opts: opts,
|
||||
l: listener,
|
||||
}
|
||||
|
||||
rp := &httputil.ReverseProxy{
|
||||
Rewrite: func(r *httputil.ProxyRequest) {
|
||||
rp := newHTTPBridgeReverseProxy(
|
||||
func(r *httputil.ProxyRequest) {
|
||||
r.Out.Header["X-Forwarded-For"] = r.In.Header["X-Forwarded-For"]
|
||||
r.SetXForwarded()
|
||||
req := r.Out
|
||||
req.URL.Scheme = "http"
|
||||
req.URL.Host = p.opts.LocalAddr
|
||||
if p.opts.HostHeaderRewrite != "" {
|
||||
req.Host = p.opts.HostHeaderRewrite
|
||||
}
|
||||
for k, v := range p.opts.RequestHeaders.Set {
|
||||
req.Header.Set(k, v)
|
||||
}
|
||||
rewriteHTTPPluginRequest(req, "http", p.opts.LocalAddr, p.opts.HostHeaderRewrite, p.opts.RequestHeaders)
|
||||
},
|
||||
BufferPool: pool.NewBuffer(32 * 1024),
|
||||
ErrorLog: stdlog.New(log.NewWriteLogger(log.WarnLevel, 2), "", 0),
|
||||
}
|
||||
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.TLS != nil {
|
||||
tlsServerName, _ := httppkg.CanonicalHost(r.TLS.ServerName)
|
||||
host, _ := httppkg.CanonicalHost(r.Host)
|
||||
if tlsServerName != "" && tlsServerName != host {
|
||||
w.WriteHeader(http.StatusMisdirectedRequest)
|
||||
return
|
||||
}
|
||||
}
|
||||
rp.ServeHTTP(w, r)
|
||||
})
|
||||
nil,
|
||||
)
|
||||
|
||||
tlsConfig, err := transport.NewServerTLSConfig(p.opts.CrtPath, p.opts.KeyPath, "")
|
||||
server, err := newHTTPSBridgePluginServer(rp, p.opts.CrtPath, p.opts.KeyPath, opts.EnableHTTP2, true)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("gen TLS config error: %v", err)
|
||||
return nil, err
|
||||
}
|
||||
p.httpBridgePlugin = server
|
||||
|
||||
p.s = &http.Server{
|
||||
Handler: handler,
|
||||
ReadHeaderTimeout: 60 * time.Second,
|
||||
TLSConfig: tlsConfig,
|
||||
}
|
||||
if !lo.FromPtr(opts.EnableHTTP2) {
|
||||
p.s.TLSNextProto = make(map[string]func(*http.Server, *tls.Conn, http.Handler))
|
||||
}
|
||||
|
||||
go func() {
|
||||
_ = p.s.ServeTLS(listener, "", "")
|
||||
}()
|
||||
return p, nil
|
||||
}
|
||||
|
||||
func (p *HTTPS2HTTPPlugin) Handle(_ context.Context, connInfo *ConnectionInfo) {
|
||||
wrapConn := netpkg.WrapReadWriteCloserToConn(connInfo.Conn, connInfo.UnderlyingConn)
|
||||
if connInfo.SrcAddr != nil {
|
||||
wrapConn.SetRemoteAddr(connInfo.SrcAddr)
|
||||
}
|
||||
_ = p.l.PutConn(wrapConn)
|
||||
}
|
||||
|
||||
func (p *HTTPS2HTTPPlugin) Name() string {
|
||||
return v1.PluginHTTPS2HTTP
|
||||
}
|
||||
|
||||
func (p *HTTPS2HTTPPlugin) Close() error {
|
||||
return p.s.Close()
|
||||
}
|
||||
|
||||
@@ -17,22 +17,11 @@
|
||||
package client
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"fmt"
|
||||
stdlog "log"
|
||||
"net/http"
|
||||
"net/http/httputil"
|
||||
"time"
|
||||
|
||||
"github.com/fatedier/golib/pool"
|
||||
"github.com/samber/lo"
|
||||
|
||||
v1 "github.com/fatedier/frp/pkg/config/v1"
|
||||
"github.com/fatedier/frp/pkg/transport"
|
||||
httppkg "github.com/fatedier/frp/pkg/util/http"
|
||||
"github.com/fatedier/frp/pkg/util/log"
|
||||
netpkg "github.com/fatedier/frp/pkg/util/net"
|
||||
)
|
||||
|
||||
func init() {
|
||||
@@ -42,86 +31,39 @@ func init() {
|
||||
type HTTPS2HTTPSPlugin struct {
|
||||
opts *v1.HTTPS2HTTPSPluginOptions
|
||||
|
||||
l *Listener
|
||||
s *http.Server
|
||||
*httpBridgePlugin
|
||||
}
|
||||
|
||||
func NewHTTPS2HTTPSPlugin(_ PluginContext, options v1.ClientPluginOptions) (Plugin, error) {
|
||||
opts := options.(*v1.HTTPS2HTTPSPluginOptions)
|
||||
|
||||
listener := NewProxyListener()
|
||||
|
||||
p := &HTTPS2HTTPSPlugin{
|
||||
opts: opts,
|
||||
l: listener,
|
||||
}
|
||||
|
||||
tr := &http.Transport{
|
||||
TLSClientConfig: &tls.Config{InsecureSkipVerify: true},
|
||||
}
|
||||
|
||||
rp := &httputil.ReverseProxy{
|
||||
Rewrite: func(r *httputil.ProxyRequest) {
|
||||
rp := newHTTPBridgeReverseProxy(
|
||||
func(r *httputil.ProxyRequest) {
|
||||
r.Out.Header["X-Forwarded-For"] = r.In.Header["X-Forwarded-For"]
|
||||
r.SetXForwarded()
|
||||
req := r.Out
|
||||
req.URL.Scheme = "https"
|
||||
req.URL.Host = p.opts.LocalAddr
|
||||
if p.opts.HostHeaderRewrite != "" {
|
||||
req.Host = p.opts.HostHeaderRewrite
|
||||
}
|
||||
for k, v := range p.opts.RequestHeaders.Set {
|
||||
req.Header.Set(k, v)
|
||||
}
|
||||
rewriteHTTPPluginRequest(req, "https", p.opts.LocalAddr, p.opts.HostHeaderRewrite, p.opts.RequestHeaders)
|
||||
},
|
||||
Transport: tr,
|
||||
BufferPool: pool.NewBuffer(32 * 1024),
|
||||
ErrorLog: stdlog.New(log.NewWriteLogger(log.WarnLevel, 2), "", 0),
|
||||
}
|
||||
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.TLS != nil {
|
||||
tlsServerName, _ := httppkg.CanonicalHost(r.TLS.ServerName)
|
||||
host, _ := httppkg.CanonicalHost(r.Host)
|
||||
if tlsServerName != "" && tlsServerName != host {
|
||||
w.WriteHeader(http.StatusMisdirectedRequest)
|
||||
return
|
||||
}
|
||||
}
|
||||
rp.ServeHTTP(w, r)
|
||||
})
|
||||
tr,
|
||||
)
|
||||
|
||||
tlsConfig, err := transport.NewServerTLSConfig(p.opts.CrtPath, p.opts.KeyPath, "")
|
||||
server, err := newHTTPSBridgePluginServer(rp, p.opts.CrtPath, p.opts.KeyPath, opts.EnableHTTP2, true)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("gen TLS config error: %v", err)
|
||||
return nil, err
|
||||
}
|
||||
p.httpBridgePlugin = server
|
||||
|
||||
p.s = &http.Server{
|
||||
Handler: handler,
|
||||
ReadHeaderTimeout: 60 * time.Second,
|
||||
TLSConfig: tlsConfig,
|
||||
}
|
||||
if !lo.FromPtr(opts.EnableHTTP2) {
|
||||
p.s.TLSNextProto = make(map[string]func(*http.Server, *tls.Conn, http.Handler))
|
||||
}
|
||||
|
||||
go func() {
|
||||
_ = p.s.ServeTLS(listener, "", "")
|
||||
}()
|
||||
return p, nil
|
||||
}
|
||||
|
||||
func (p *HTTPS2HTTPSPlugin) Handle(_ context.Context, connInfo *ConnectionInfo) {
|
||||
wrapConn := netpkg.WrapReadWriteCloserToConn(connInfo.Conn, connInfo.UnderlyingConn)
|
||||
if connInfo.SrcAddr != nil {
|
||||
wrapConn.SetRemoteAddr(connInfo.SrcAddr)
|
||||
}
|
||||
_ = p.l.PutConn(wrapConn)
|
||||
}
|
||||
|
||||
func (p *HTTPS2HTTPSPlugin) Name() string {
|
||||
return v1.PluginHTTPS2HTTPS
|
||||
}
|
||||
|
||||
func (p *HTTPS2HTTPSPlugin) Close() error {
|
||||
return p.s.Close()
|
||||
}
|
||||
|
||||
60
pkg/plugin/client/internal/httpsserver/server.go
Normal file
60
pkg/plugin/client/internal/httpsserver/server.go
Normal file
@@ -0,0 +1,60 @@
|
||||
// Copyright 2026 The frp Authors
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
//go:build !frps
|
||||
|
||||
package httpsserver
|
||||
|
||||
import (
|
||||
"crypto/tls"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
"github.com/samber/lo"
|
||||
|
||||
"github.com/fatedier/frp/pkg/transport"
|
||||
httppkg "github.com/fatedier/frp/pkg/util/http"
|
||||
)
|
||||
|
||||
func New(handler http.Handler, crtPath, keyPath string, enableHTTP2 *bool) (*http.Server, error) {
|
||||
tlsConfig, err := transport.NewServerTLSConfig(crtPath, keyPath, "")
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("gen TLS config error: %v", err)
|
||||
}
|
||||
|
||||
server := &http.Server{
|
||||
Handler: withMisdirectedRequestCheck(handler),
|
||||
ReadHeaderTimeout: 60 * time.Second,
|
||||
TLSConfig: tlsConfig,
|
||||
}
|
||||
if !lo.FromPtr(enableHTTP2) {
|
||||
server.TLSNextProto = make(map[string]func(*http.Server, *tls.Conn, http.Handler))
|
||||
}
|
||||
return server, nil
|
||||
}
|
||||
|
||||
func withMisdirectedRequestCheck(handler http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.TLS != nil {
|
||||
tlsServerName, _ := httppkg.CanonicalHost(r.TLS.ServerName)
|
||||
host, _ := httppkg.CanonicalHost(r.Host)
|
||||
if tlsServerName != "" && tlsServerName != host {
|
||||
w.WriteHeader(http.StatusMisdirectedRequest)
|
||||
return
|
||||
}
|
||||
}
|
||||
handler.ServeHTTP(w, r)
|
||||
})
|
||||
}
|
||||
@@ -44,6 +44,67 @@ func NewManager() *Manager {
|
||||
}
|
||||
}
|
||||
|
||||
func newPluginRequestContext() (context.Context, *xlog.Logger) {
|
||||
reqid, _ := util.RandID()
|
||||
xl := xlog.New().AppendPrefix("reqid: " + reqid)
|
||||
ctx := xlog.NewContext(context.Background(), xl)
|
||||
return NewReqidContext(ctx, reqid), xl
|
||||
}
|
||||
|
||||
type pluginErrorLogMode bool
|
||||
|
||||
const (
|
||||
// Warn is the zero value because it is the default for mutable plugin operations.
|
||||
pluginErrorLogWarn pluginErrorLogMode = false
|
||||
pluginErrorLogInfo pluginErrorLogMode = true
|
||||
)
|
||||
|
||||
func logPluginError(xl *xlog.Logger, p Plugin, op string, err error, mode pluginErrorLogMode) {
|
||||
if mode == pluginErrorLogInfo {
|
||||
xl.Infof("send %s request to plugin [%s] error: %v", op, p.Name(), err)
|
||||
return
|
||||
}
|
||||
xl.Warnf("send %s request to plugin [%s] error: %v", op, p.Name(), err)
|
||||
}
|
||||
|
||||
func handleMutableContent[T any](
|
||||
plugins []Plugin,
|
||||
op string,
|
||||
content *T,
|
||||
logMode pluginErrorLogMode,
|
||||
) (*T, error) {
|
||||
if len(plugins) == 0 {
|
||||
return content, nil
|
||||
}
|
||||
|
||||
var (
|
||||
res = &Response{
|
||||
Reject: false,
|
||||
Unchange: true,
|
||||
}
|
||||
retContent any
|
||||
err error
|
||||
)
|
||||
ctx, xl := newPluginRequestContext()
|
||||
|
||||
for _, p := range plugins {
|
||||
res, retContent, err = p.Handle(ctx, op, *content)
|
||||
if err != nil {
|
||||
logPluginError(xl, p, op, err, logMode)
|
||||
return nil, errors.New("send " + op + " request to plugin error")
|
||||
}
|
||||
if res.Reject {
|
||||
return nil, fmt.Errorf("%s", res.RejectReason)
|
||||
}
|
||||
if !res.Unchange {
|
||||
// Preserve the existing Plugin contract: changed content must be *T.
|
||||
// Buggy Plugin implementations still panic here, by design.
|
||||
content = retContent.(*T)
|
||||
}
|
||||
}
|
||||
return content, nil
|
||||
}
|
||||
|
||||
func (m *Manager) Register(p Plugin) {
|
||||
if p.IsSupport(OpLogin) {
|
||||
m.loginPlugins = append(m.loginPlugins, p)
|
||||
@@ -66,71 +127,11 @@ func (m *Manager) Register(p Plugin) {
|
||||
}
|
||||
|
||||
func (m *Manager) Login(content *LoginContent) (*LoginContent, error) {
|
||||
if len(m.loginPlugins) == 0 {
|
||||
return content, nil
|
||||
}
|
||||
|
||||
var (
|
||||
res = &Response{
|
||||
Reject: false,
|
||||
Unchange: true,
|
||||
}
|
||||
retContent any
|
||||
err error
|
||||
)
|
||||
reqid, _ := util.RandID()
|
||||
xl := xlog.New().AppendPrefix("reqid: " + reqid)
|
||||
ctx := xlog.NewContext(context.Background(), xl)
|
||||
ctx = NewReqidContext(ctx, reqid)
|
||||
|
||||
for _, p := range m.loginPlugins {
|
||||
res, retContent, err = p.Handle(ctx, OpLogin, *content)
|
||||
if err != nil {
|
||||
xl.Warnf("send Login request to plugin [%s] error: %v", p.Name(), err)
|
||||
return nil, errors.New("send Login request to plugin error")
|
||||
}
|
||||
if res.Reject {
|
||||
return nil, fmt.Errorf("%s", res.RejectReason)
|
||||
}
|
||||
if !res.Unchange {
|
||||
content = retContent.(*LoginContent)
|
||||
}
|
||||
}
|
||||
return content, nil
|
||||
return handleMutableContent(m.loginPlugins, OpLogin, content, pluginErrorLogWarn)
|
||||
}
|
||||
|
||||
func (m *Manager) NewProxy(content *NewProxyContent) (*NewProxyContent, error) {
|
||||
if len(m.newProxyPlugins) == 0 {
|
||||
return content, nil
|
||||
}
|
||||
|
||||
var (
|
||||
res = &Response{
|
||||
Reject: false,
|
||||
Unchange: true,
|
||||
}
|
||||
retContent any
|
||||
err error
|
||||
)
|
||||
reqid, _ := util.RandID()
|
||||
xl := xlog.New().AppendPrefix("reqid: " + reqid)
|
||||
ctx := xlog.NewContext(context.Background(), xl)
|
||||
ctx = NewReqidContext(ctx, reqid)
|
||||
|
||||
for _, p := range m.newProxyPlugins {
|
||||
res, retContent, err = p.Handle(ctx, OpNewProxy, *content)
|
||||
if err != nil {
|
||||
xl.Warnf("send NewProxy request to plugin [%s] error: %v", p.Name(), err)
|
||||
return nil, errors.New("send NewProxy request to plugin error")
|
||||
}
|
||||
if res.Reject {
|
||||
return nil, fmt.Errorf("%s", res.RejectReason)
|
||||
}
|
||||
if !res.Unchange {
|
||||
content = retContent.(*NewProxyContent)
|
||||
}
|
||||
}
|
||||
return content, nil
|
||||
return handleMutableContent(m.newProxyPlugins, OpNewProxy, content, pluginErrorLogWarn)
|
||||
}
|
||||
|
||||
func (m *Manager) CloseProxy(content *CloseProxyContent) error {
|
||||
@@ -139,10 +140,7 @@ func (m *Manager) CloseProxy(content *CloseProxyContent) error {
|
||||
}
|
||||
|
||||
errs := make([]string, 0)
|
||||
reqid, _ := util.RandID()
|
||||
xl := xlog.New().AppendPrefix("reqid: " + reqid)
|
||||
ctx := xlog.NewContext(context.Background(), xl)
|
||||
ctx = NewReqidContext(ctx, reqid)
|
||||
ctx, xl := newPluginRequestContext()
|
||||
|
||||
for _, p := range m.closeProxyPlugins {
|
||||
_, _, err := p.Handle(ctx, OpCloseProxy, *content)
|
||||
@@ -159,103 +157,14 @@ func (m *Manager) CloseProxy(content *CloseProxyContent) error {
|
||||
}
|
||||
|
||||
func (m *Manager) Ping(content *PingContent) (*PingContent, error) {
|
||||
if len(m.pingPlugins) == 0 {
|
||||
return content, nil
|
||||
}
|
||||
|
||||
var (
|
||||
res = &Response{
|
||||
Reject: false,
|
||||
Unchange: true,
|
||||
}
|
||||
retContent any
|
||||
err error
|
||||
)
|
||||
reqid, _ := util.RandID()
|
||||
xl := xlog.New().AppendPrefix("reqid: " + reqid)
|
||||
ctx := xlog.NewContext(context.Background(), xl)
|
||||
ctx = NewReqidContext(ctx, reqid)
|
||||
|
||||
for _, p := range m.pingPlugins {
|
||||
res, retContent, err = p.Handle(ctx, OpPing, *content)
|
||||
if err != nil {
|
||||
xl.Warnf("send Ping request to plugin [%s] error: %v", p.Name(), err)
|
||||
return nil, errors.New("send Ping request to plugin error")
|
||||
}
|
||||
if res.Reject {
|
||||
return nil, fmt.Errorf("%s", res.RejectReason)
|
||||
}
|
||||
if !res.Unchange {
|
||||
content = retContent.(*PingContent)
|
||||
}
|
||||
}
|
||||
return content, nil
|
||||
return handleMutableContent(m.pingPlugins, OpPing, content, pluginErrorLogWarn)
|
||||
}
|
||||
|
||||
func (m *Manager) NewWorkConn(content *NewWorkConnContent) (*NewWorkConnContent, error) {
|
||||
if len(m.newWorkConnPlugins) == 0 {
|
||||
return content, nil
|
||||
}
|
||||
|
||||
var (
|
||||
res = &Response{
|
||||
Reject: false,
|
||||
Unchange: true,
|
||||
}
|
||||
retContent any
|
||||
err error
|
||||
)
|
||||
reqid, _ := util.RandID()
|
||||
xl := xlog.New().AppendPrefix("reqid: " + reqid)
|
||||
ctx := xlog.NewContext(context.Background(), xl)
|
||||
ctx = NewReqidContext(ctx, reqid)
|
||||
|
||||
for _, p := range m.newWorkConnPlugins {
|
||||
res, retContent, err = p.Handle(ctx, OpNewWorkConn, *content)
|
||||
if err != nil {
|
||||
xl.Warnf("send NewWorkConn request to plugin [%s] error: %v", p.Name(), err)
|
||||
return nil, errors.New("send NewWorkConn request to plugin error")
|
||||
}
|
||||
if res.Reject {
|
||||
return nil, fmt.Errorf("%s", res.RejectReason)
|
||||
}
|
||||
if !res.Unchange {
|
||||
content = retContent.(*NewWorkConnContent)
|
||||
}
|
||||
}
|
||||
return content, nil
|
||||
return handleMutableContent(m.newWorkConnPlugins, OpNewWorkConn, content, pluginErrorLogWarn)
|
||||
}
|
||||
|
||||
func (m *Manager) NewUserConn(content *NewUserConnContent) (*NewUserConnContent, error) {
|
||||
if len(m.newUserConnPlugins) == 0 {
|
||||
return content, nil
|
||||
}
|
||||
|
||||
var (
|
||||
res = &Response{
|
||||
Reject: false,
|
||||
Unchange: true,
|
||||
}
|
||||
retContent any
|
||||
err error
|
||||
)
|
||||
reqid, _ := util.RandID()
|
||||
xl := xlog.New().AppendPrefix("reqid: " + reqid)
|
||||
ctx := xlog.NewContext(context.Background(), xl)
|
||||
ctx = NewReqidContext(ctx, reqid)
|
||||
|
||||
for _, p := range m.newUserConnPlugins {
|
||||
res, retContent, err = p.Handle(ctx, OpNewUserConn, *content)
|
||||
if err != nil {
|
||||
xl.Infof("send NewUserConn request to plugin [%s] error: %v", p.Name(), err)
|
||||
return nil, errors.New("send NewUserConn request to plugin error")
|
||||
}
|
||||
if res.Reject {
|
||||
return nil, fmt.Errorf("%s", res.RejectReason)
|
||||
}
|
||||
if !res.Unchange {
|
||||
content = retContent.(*NewUserConnContent)
|
||||
}
|
||||
}
|
||||
return content, nil
|
||||
// Preserve the pre-refactor log level for NewUserConn plugin errors.
|
||||
return handleMutableContent(m.newUserConnPlugins, OpNewUserConn, content, pluginErrorLogInfo)
|
||||
}
|
||||
|
||||
336
pkg/plugin/server/manager_test.go
Normal file
336
pkg/plugin/server/manager_test.go
Normal file
@@ -0,0 +1,336 @@
|
||||
// Copyright 2019 fatedier, fatedier@gmail.com
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package server
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"errors"
|
||||
"strings"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
goliblog "github.com/fatedier/golib/log"
|
||||
|
||||
"github.com/fatedier/frp/pkg/msg"
|
||||
frplog "github.com/fatedier/frp/pkg/util/log"
|
||||
)
|
||||
|
||||
type testPlugin struct {
|
||||
name string
|
||||
ops map[string]bool
|
||||
handler func(context.Context, string, any) (*Response, any, error)
|
||||
}
|
||||
|
||||
// Log-capturing subtests serialize global logger swaps; do not use t.Parallel.
|
||||
var logCaptureMu sync.Mutex
|
||||
|
||||
type logCapture struct {
|
||||
bytes.Buffer
|
||||
levels []goliblog.Level
|
||||
}
|
||||
|
||||
func (p testPlugin) Name() string {
|
||||
return p.name
|
||||
}
|
||||
|
||||
func (p testPlugin) IsSupport(op string) bool {
|
||||
return p.ops[op]
|
||||
}
|
||||
|
||||
func (p testPlugin) Handle(ctx context.Context, op string, content any) (*Response, any, error) {
|
||||
return p.handler(ctx, op, content)
|
||||
}
|
||||
|
||||
func (w *logCapture) WriteLog(p []byte, level goliblog.Level, _ time.Time) (int, error) {
|
||||
w.levels = append(w.levels, level)
|
||||
return w.Write(p)
|
||||
}
|
||||
|
||||
func captureLogOutput(t *testing.T) *logCapture {
|
||||
t.Helper()
|
||||
|
||||
logCaptureMu.Lock()
|
||||
logOutput := &logCapture{}
|
||||
oldLogger := frplog.Logger
|
||||
frplog.Logger = goliblog.New(
|
||||
goliblog.WithOutput(logOutput),
|
||||
goliblog.WithLevel(goliblog.TraceLevel),
|
||||
goliblog.WithCaller(false),
|
||||
)
|
||||
t.Cleanup(func() {
|
||||
frplog.Logger = oldLogger
|
||||
logCaptureMu.Unlock()
|
||||
})
|
||||
return logOutput
|
||||
}
|
||||
|
||||
var mutablePluginOps = []struct {
|
||||
name string
|
||||
op string
|
||||
}{
|
||||
{name: "login", op: OpLogin},
|
||||
{name: "new proxy", op: OpNewProxy},
|
||||
{name: "ping", op: OpPing},
|
||||
{name: "new work conn", op: OpNewWorkConn},
|
||||
{name: "new user conn", op: OpNewUserConn},
|
||||
}
|
||||
|
||||
func callMutableWithUser(m *Manager, op string, user string) (string, error) {
|
||||
switch op {
|
||||
case OpLogin:
|
||||
got, err := m.Login(&LoginContent{Login: msg.Login{User: user}})
|
||||
if got == nil {
|
||||
return "", err
|
||||
}
|
||||
return got.User, err
|
||||
case OpNewProxy:
|
||||
got, err := m.NewProxy(&NewProxyContent{User: UserInfo{User: user}})
|
||||
if got == nil {
|
||||
return "", err
|
||||
}
|
||||
return got.User.User, err
|
||||
case OpPing:
|
||||
got, err := m.Ping(&PingContent{User: UserInfo{User: user}})
|
||||
if got == nil {
|
||||
return "", err
|
||||
}
|
||||
return got.User.User, err
|
||||
case OpNewWorkConn:
|
||||
got, err := m.NewWorkConn(&NewWorkConnContent{User: UserInfo{User: user}})
|
||||
if got == nil {
|
||||
return "", err
|
||||
}
|
||||
return got.User.User, err
|
||||
case OpNewUserConn:
|
||||
got, err := m.NewUserConn(&NewUserConnContent{User: UserInfo{User: user}})
|
||||
if got == nil {
|
||||
return "", err
|
||||
}
|
||||
return got.User.User, err
|
||||
default:
|
||||
panic("unsupported mutable op: " + op)
|
||||
}
|
||||
}
|
||||
|
||||
func mutableUser(t *testing.T, op string, content any) string {
|
||||
t.Helper()
|
||||
|
||||
switch op {
|
||||
case OpLogin:
|
||||
return content.(LoginContent).User
|
||||
case OpNewProxy:
|
||||
return content.(NewProxyContent).User.User
|
||||
case OpPing:
|
||||
return content.(PingContent).User.User
|
||||
case OpNewWorkConn:
|
||||
return content.(NewWorkConnContent).User.User
|
||||
case OpNewUserConn:
|
||||
return content.(NewUserConnContent).User.User
|
||||
default:
|
||||
t.Fatalf("unsupported mutable op: %s", op)
|
||||
return ""
|
||||
}
|
||||
}
|
||||
|
||||
func mutateMutableContent(t *testing.T, op string, content any, user string) any {
|
||||
t.Helper()
|
||||
|
||||
switch op {
|
||||
case OpLogin:
|
||||
got := content.(LoginContent)
|
||||
got.User = user
|
||||
return &got
|
||||
case OpNewProxy:
|
||||
got := content.(NewProxyContent)
|
||||
got.User.User = user
|
||||
return &got
|
||||
case OpPing:
|
||||
got := content.(PingContent)
|
||||
got.User.User = user
|
||||
return &got
|
||||
case OpNewWorkConn:
|
||||
got := content.(NewWorkConnContent)
|
||||
got.User.User = user
|
||||
return &got
|
||||
case OpNewUserConn:
|
||||
got := content.(NewUserConnContent)
|
||||
got.User.User = user
|
||||
return &got
|
||||
default:
|
||||
t.Fatalf("unsupported mutable op: %s", op)
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
func TestManagerMutableContentAcrossOps(t *testing.T) {
|
||||
for _, tt := range mutablePluginOps {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
m := NewManager()
|
||||
m.Register(testPlugin{
|
||||
name: "mutate",
|
||||
ops: map[string]bool{tt.op: true},
|
||||
handler: func(ctx context.Context, op string, content any) (*Response, any, error) {
|
||||
if op != tt.op {
|
||||
t.Fatalf("unexpected op: %s", op)
|
||||
}
|
||||
if GetReqidFromContext(ctx) == "" {
|
||||
t.Fatal("expected request id in context")
|
||||
}
|
||||
if got := mutableUser(t, tt.op, content); got != "initial" {
|
||||
t.Fatalf("expected initial user, got %q", got)
|
||||
}
|
||||
return &Response{Unchange: false}, mutateMutableContent(t, tt.op, content, "mutated"), nil
|
||||
},
|
||||
})
|
||||
m.Register(testPlugin{
|
||||
name: "observe",
|
||||
ops: map[string]bool{tt.op: true},
|
||||
handler: func(ctx context.Context, op string, content any) (*Response, any, error) {
|
||||
if op != tt.op {
|
||||
t.Fatalf("unexpected op: %s", op)
|
||||
}
|
||||
if GetReqidFromContext(ctx) == "" {
|
||||
t.Fatal("expected request id in context")
|
||||
}
|
||||
if got := mutableUser(t, tt.op, content); got != "mutated" {
|
||||
t.Fatalf("expected mutated user, got %q", got)
|
||||
}
|
||||
return &Response{Unchange: true}, mutateMutableContent(t, tt.op, content, "ignored"), nil
|
||||
},
|
||||
})
|
||||
|
||||
got, err := callMutableWithUser(m, tt.op, "initial")
|
||||
if err != nil {
|
||||
t.Fatalf("mutable op failed: %v", err)
|
||||
}
|
||||
if got != "mutated" {
|
||||
t.Fatalf("expected mutated user, got %q", got)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestManagerMutableContentRejectStopsChain(t *testing.T) {
|
||||
m := NewManager()
|
||||
|
||||
var called bool
|
||||
m.Register(testPlugin{
|
||||
name: "reject",
|
||||
ops: map[string]bool{OpPing: true},
|
||||
handler: func(context.Context, string, any) (*Response, any, error) {
|
||||
return &Response{Reject: true, RejectReason: "blocked"}, nil, nil
|
||||
},
|
||||
})
|
||||
m.Register(testPlugin{
|
||||
name: "unused",
|
||||
ops: map[string]bool{OpPing: true},
|
||||
handler: func(context.Context, string, any) (*Response, any, error) {
|
||||
called = true
|
||||
return &Response{Unchange: true}, nil, nil
|
||||
},
|
||||
})
|
||||
|
||||
got, err := m.Ping(&PingContent{})
|
||||
if err == nil {
|
||||
t.Fatal("expected reject error")
|
||||
}
|
||||
if got != nil {
|
||||
t.Fatalf("expected no returned content, got %#v", got)
|
||||
}
|
||||
if err.Error() != "blocked" {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if called {
|
||||
t.Fatal("expected plugin chain to stop after reject")
|
||||
}
|
||||
}
|
||||
|
||||
func TestManagerMutableContentPluginErrorLogLevel(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
op string
|
||||
level goliblog.Level
|
||||
}{
|
||||
{name: "default warning", op: OpLogin, level: goliblog.WarnLevel},
|
||||
{name: "new user conn info", op: OpNewUserConn, level: goliblog.InfoLevel},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
logOutput := captureLogOutput(t)
|
||||
m := NewManager()
|
||||
m.Register(testPlugin{
|
||||
name: "error",
|
||||
ops: map[string]bool{tt.op: true},
|
||||
handler: func(context.Context, string, any) (*Response, any, error) {
|
||||
return nil, nil, errors.New("boom")
|
||||
},
|
||||
})
|
||||
|
||||
_, err := callMutableWithUser(m, tt.op, "initial")
|
||||
if err == nil {
|
||||
t.Fatal("expected plugin error")
|
||||
}
|
||||
if want := "send " + tt.op + " request to plugin error"; err.Error() != want {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if len(logOutput.levels) != 1 || logOutput.levels[0] != tt.level {
|
||||
t.Fatalf("expected log level %v, got %v in %q", tt.level, logOutput.levels, logOutput.String())
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestManagerCloseProxyAggregatesErrors(t *testing.T) {
|
||||
logOutput := captureLogOutput(t)
|
||||
m := NewManager()
|
||||
|
||||
for _, name := range []string{"first", "second"} {
|
||||
m.Register(testPlugin{
|
||||
name: name,
|
||||
ops: map[string]bool{OpCloseProxy: true},
|
||||
handler: func(ctx context.Context, op string, content any) (*Response, any, error) {
|
||||
if GetReqidFromContext(ctx) == "" {
|
||||
t.Fatal("expected request id in context")
|
||||
}
|
||||
if op != OpCloseProxy {
|
||||
t.Fatalf("unexpected op: %s", op)
|
||||
}
|
||||
return nil, nil, errors.New(name + " error")
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
err := m.CloseProxy(&CloseProxyContent{})
|
||||
if err == nil {
|
||||
t.Fatal("expected close proxy error")
|
||||
}
|
||||
if !strings.HasPrefix(err.Error(), "send CloseProxy request to plugin errors: ") {
|
||||
t.Fatalf("unexpected close proxy error prefix: %v", err)
|
||||
}
|
||||
if !strings.Contains(err.Error(), "[first]: first error") || !strings.Contains(err.Error(), "[second]: second error") {
|
||||
t.Fatalf("missing aggregated errors: %v", err)
|
||||
}
|
||||
if len(logOutput.levels) != 2 {
|
||||
t.Fatalf("expected two warning logs, got %v", logOutput.levels)
|
||||
}
|
||||
for _, level := range logOutput.levels {
|
||||
if level != goliblog.WarnLevel {
|
||||
t.Fatalf("expected warning log level, got %v", logOutput.levels)
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -140,6 +140,8 @@ func Forwarder(dstAddr *net.UDPAddr, readCh <-chan *msg.UDPPacket, sendCh chan<-
|
||||
_, err = udpConn.Write(buf)
|
||||
if err != nil {
|
||||
udpConn.Close()
|
||||
} else {
|
||||
_ = udpConn.SetReadDeadline(time.Now().Add(30 * time.Second))
|
||||
}
|
||||
|
||||
if !ok {
|
||||
|
||||
197
pkg/proto/wire/crypto.go
Normal file
197
pkg/proto/wire/crypto.go
Normal file
@@ -0,0 +1,197 @@
|
||||
// Copyright 2026 The frp Authors
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package wire
|
||||
|
||||
import (
|
||||
"crypto/rand"
|
||||
"crypto/sha256"
|
||||
"encoding/binary"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"hash"
|
||||
"runtime"
|
||||
|
||||
"golang.org/x/sys/cpu"
|
||||
)
|
||||
|
||||
const (
|
||||
AEADAlgorithmAES256GCM = "aes-256-gcm"
|
||||
AEADAlgorithmXChaCha20Poly1305 = "xchacha20-poly1305"
|
||||
|
||||
CryptoRandomSize = 32
|
||||
|
||||
cryptoTranscriptLabel = "frp wire v2 crypto transcript"
|
||||
)
|
||||
|
||||
var supportedAEADAlgorithms = []string{
|
||||
AEADAlgorithmAES256GCM,
|
||||
AEADAlgorithmXChaCha20Poly1305,
|
||||
}
|
||||
|
||||
type CryptoContext struct {
|
||||
Algorithm string
|
||||
TranscriptHash []byte
|
||||
}
|
||||
|
||||
func NewClientHello(bootstrap BootstrapInfo) (ClientHello, error) {
|
||||
clientRandom, err := newCryptoRandom()
|
||||
if err != nil {
|
||||
return ClientHello{}, err
|
||||
}
|
||||
return clientHelloWithCryptoRandom(bootstrap, clientRandom), nil
|
||||
}
|
||||
|
||||
func NewServerHello(clientHello ClientHello) (ServerHello, error) {
|
||||
if err := ValidateClientHello(clientHello); err != nil {
|
||||
return ServerHello{}, err
|
||||
}
|
||||
algorithm, ok := SelectAEADAlgorithm(clientHello.Capabilities.Crypto.Algorithms)
|
||||
if !ok {
|
||||
return ServerHello{}, fmt.Errorf("no supported crypto algorithm")
|
||||
}
|
||||
serverRandom, err := newCryptoRandom()
|
||||
if err != nil {
|
||||
return ServerHello{}, err
|
||||
}
|
||||
return ServerHello{
|
||||
Selected: ServerSelection{
|
||||
Message: MessageSelection{
|
||||
Codec: MessageCodecJSON,
|
||||
},
|
||||
Crypto: CryptoSelection{
|
||||
Algorithm: algorithm,
|
||||
ServerRandom: serverRandom,
|
||||
},
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
|
||||
func ValidateCryptoCapabilities(c CryptoCapabilities) error {
|
||||
if len(c.ClientRandom) != CryptoRandomSize {
|
||||
return fmt.Errorf("invalid crypto client random length %d, want %d", len(c.ClientRandom), CryptoRandomSize)
|
||||
}
|
||||
if _, ok := SelectAEADAlgorithm(c.Algorithms); !ok {
|
||||
return fmt.Errorf("no supported crypto algorithm")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func ValidateServerHelloForClient(clientHello ClientHello, serverHello ServerHello) error {
|
||||
if serverHello.Selected.Message.Codec != MessageCodecJSON {
|
||||
return fmt.Errorf("unsupported selected message codec: %s", serverHello.Selected.Message.Codec)
|
||||
}
|
||||
cryptoSelection := serverHello.Selected.Crypto
|
||||
if !IsSupportedAEADAlgorithm(cryptoSelection.Algorithm) {
|
||||
return fmt.Errorf("unknown selected crypto algorithm: %s", cryptoSelection.Algorithm)
|
||||
}
|
||||
if !Supports(clientHello.Capabilities.Crypto.Algorithms, cryptoSelection.Algorithm) {
|
||||
return fmt.Errorf("selected crypto algorithm was not advertised by client: %s", cryptoSelection.Algorithm)
|
||||
}
|
||||
if len(cryptoSelection.ServerRandom) != CryptoRandomSize {
|
||||
return fmt.Errorf("invalid crypto server random length %d, want %d", len(cryptoSelection.ServerRandom), CryptoRandomSize)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func NewCryptoContext(algorithm string, clientHelloPayload, serverHelloPayload []byte) *CryptoContext {
|
||||
return &CryptoContext{
|
||||
Algorithm: algorithm,
|
||||
TranscriptHash: HashCryptoTranscript(clientHelloPayload, serverHelloPayload),
|
||||
}
|
||||
}
|
||||
|
||||
func NewClientCryptoContext(clientHelloPayload, serverHelloPayload []byte) (*CryptoContext, error) {
|
||||
var clientHello ClientHello
|
||||
if err := json.Unmarshal(clientHelloPayload, &clientHello); err != nil {
|
||||
return nil, fmt.Errorf("decode ClientHello transcript: %w", err)
|
||||
}
|
||||
var serverHello ServerHello
|
||||
if err := json.Unmarshal(serverHelloPayload, &serverHello); err != nil {
|
||||
return nil, fmt.Errorf("decode ServerHello transcript: %w", err)
|
||||
}
|
||||
if err := ValidateServerHelloForClient(clientHello, serverHello); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return NewCryptoContext(serverHello.Selected.Crypto.Algorithm, clientHelloPayload, serverHelloPayload), nil
|
||||
}
|
||||
|
||||
func HashCryptoTranscript(clientHelloPayload, serverHelloPayload []byte) []byte {
|
||||
h := sha256.New()
|
||||
_, _ = h.Write([]byte(cryptoTranscriptLabel))
|
||||
writeCryptoTranscriptPart(h, "client hello", clientHelloPayload)
|
||||
writeCryptoTranscriptPart(h, "server hello", serverHelloPayload)
|
||||
return h.Sum(nil)
|
||||
}
|
||||
|
||||
func writeCryptoTranscriptPart(h hash.Hash, label string, payload []byte) {
|
||||
var length [8]byte
|
||||
binary.BigEndian.PutUint64(length[:], uint64(len(payload)))
|
||||
_, _ = h.Write([]byte{0})
|
||||
_, _ = h.Write([]byte(label))
|
||||
_, _ = h.Write([]byte{0})
|
||||
_, _ = h.Write(length[:])
|
||||
_, _ = h.Write(payload)
|
||||
}
|
||||
|
||||
func PreferredAEADAlgorithms() []string {
|
||||
if hasFastAESGCM() {
|
||||
return []string{AEADAlgorithmAES256GCM, AEADAlgorithmXChaCha20Poly1305}
|
||||
}
|
||||
return []string{AEADAlgorithmXChaCha20Poly1305, AEADAlgorithmAES256GCM}
|
||||
}
|
||||
|
||||
func SelectAEADAlgorithm(clientAlgorithms []string) (string, bool) {
|
||||
for _, algorithm := range clientAlgorithms {
|
||||
if IsSupportedAEADAlgorithm(algorithm) {
|
||||
return algorithm, true
|
||||
}
|
||||
}
|
||||
return "", false
|
||||
}
|
||||
|
||||
func IsSupportedAEADAlgorithm(algorithm string) bool {
|
||||
return Supports(supportedAEADAlgorithms, algorithm)
|
||||
}
|
||||
|
||||
func newCryptoRandom() ([]byte, error) {
|
||||
b := make([]byte, CryptoRandomSize)
|
||||
if _, err := rand.Read(b); err != nil {
|
||||
return nil, fmt.Errorf("generate crypto random: %w", err)
|
||||
}
|
||||
return b, nil
|
||||
}
|
||||
|
||||
func hasFastAESGCM() bool {
|
||||
switch runtime.GOARCH {
|
||||
case "amd64":
|
||||
return cpu.X86.HasAES &&
|
||||
cpu.X86.HasPCLMULQDQ &&
|
||||
cpu.X86.HasSSE41 &&
|
||||
cpu.X86.HasSSSE3
|
||||
case "arm64":
|
||||
return cpu.ARM64.HasAES && cpu.ARM64.HasPMULL
|
||||
case "s390x":
|
||||
return cpu.S390X.HasAES &&
|
||||
cpu.S390X.HasAESCTR &&
|
||||
cpu.S390X.HasGHASH
|
||||
case "ppc64", "ppc64le":
|
||||
// Go's ppc64/ppc64le port targets POWER8+, which has AES instructions;
|
||||
// x/sys/cpu does not expose a PPC64 AES feature flag.
|
||||
return true
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
246
pkg/proto/wire/wire.go
Normal file
246
pkg/proto/wire/wire.go
Normal file
@@ -0,0 +1,246 @@
|
||||
// Copyright 2026 The frp Authors
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package wire
|
||||
|
||||
import (
|
||||
"encoding/binary"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"slices"
|
||||
|
||||
libnet "github.com/fatedier/golib/net"
|
||||
)
|
||||
|
||||
const (
|
||||
ProtocolV1 = "v1"
|
||||
ProtocolV2 = "v2"
|
||||
|
||||
WireVersionV2 = 2
|
||||
|
||||
FrameTypeClientHello uint16 = 1
|
||||
FrameTypeServerHello uint16 = 2
|
||||
FrameTypeMessage uint16 = 16
|
||||
|
||||
MessageCodecJSON = "json"
|
||||
DefaultMaxFramePayloadSize = 64 * 1024
|
||||
|
||||
MagicV2 = "FRP\x00\x02\r\n"
|
||||
)
|
||||
|
||||
type Frame struct {
|
||||
Type uint16
|
||||
Flags uint16
|
||||
Payload []byte
|
||||
}
|
||||
|
||||
type Conn struct {
|
||||
rw io.ReadWriter
|
||||
maxFramePayloadSize uint32
|
||||
}
|
||||
|
||||
func NewConn(rw io.ReadWriter) *Conn {
|
||||
return &Conn{
|
||||
rw: rw,
|
||||
maxFramePayloadSize: DefaultMaxFramePayloadSize,
|
||||
}
|
||||
}
|
||||
|
||||
func (c *Conn) ReadFrame() (*Frame, error) {
|
||||
header := make([]byte, 8)
|
||||
if _, err := io.ReadFull(c.rw, header); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
frameType := binary.BigEndian.Uint16(header[0:2])
|
||||
flags := binary.BigEndian.Uint16(header[2:4])
|
||||
length := binary.BigEndian.Uint32(header[4:8])
|
||||
if flags != 0 {
|
||||
return nil, fmt.Errorf("unsupported frame flags: %d", flags)
|
||||
}
|
||||
if length > c.maxFramePayloadSize {
|
||||
return nil, fmt.Errorf("frame payload length %d exceeds limit %d", length, c.maxFramePayloadSize)
|
||||
}
|
||||
|
||||
payload := make([]byte, length)
|
||||
if _, err := io.ReadFull(c.rw, payload); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &Frame{
|
||||
Type: frameType,
|
||||
Flags: flags,
|
||||
Payload: payload,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (c *Conn) WriteFrame(f *Frame) error {
|
||||
if f.Flags != 0 {
|
||||
return fmt.Errorf("unsupported frame flags: %d", f.Flags)
|
||||
}
|
||||
if len(f.Payload) > int(c.maxFramePayloadSize) {
|
||||
return fmt.Errorf("frame payload length %d exceeds limit %d", len(f.Payload), c.maxFramePayloadSize)
|
||||
}
|
||||
|
||||
header := make([]byte, 8)
|
||||
binary.BigEndian.PutUint16(header[0:2], f.Type)
|
||||
binary.BigEndian.PutUint16(header[2:4], f.Flags)
|
||||
binary.BigEndian.PutUint32(header[4:8], uint32(len(f.Payload)))
|
||||
if _, err := c.rw.Write(header); err != nil {
|
||||
return err
|
||||
}
|
||||
_, err := c.rw.Write(f.Payload)
|
||||
return err
|
||||
}
|
||||
|
||||
func (c *Conn) ReadJSONFrame(frameType uint16, out any) error {
|
||||
f, err := c.ReadFrame()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if f.Type != frameType {
|
||||
return fmt.Errorf("unexpected frame type %d, want %d", f.Type, frameType)
|
||||
}
|
||||
return c.UnmarshalFrame(f, out)
|
||||
}
|
||||
|
||||
func (c *Conn) UnmarshalFrame(f *Frame, out any) error {
|
||||
return json.Unmarshal(f.Payload, out)
|
||||
}
|
||||
|
||||
func NewJSONFrame(frameType uint16, in any) (*Frame, error) {
|
||||
payload, err := json.Marshal(in)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &Frame{
|
||||
Type: frameType,
|
||||
Payload: payload,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (c *Conn) WriteJSONFrame(frameType uint16, in any) error {
|
||||
f, err := NewJSONFrame(frameType, in)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return c.WriteFrame(f)
|
||||
}
|
||||
|
||||
func WriteMagic(w io.Writer) error {
|
||||
_, err := io.WriteString(w, MagicV2)
|
||||
return err
|
||||
}
|
||||
|
||||
func WriteMagicIfV2(w io.Writer, wireProtocol string) error {
|
||||
if wireProtocol != ProtocolV2 {
|
||||
return nil
|
||||
}
|
||||
return WriteMagic(w)
|
||||
}
|
||||
|
||||
func CheckMagic(conn net.Conn) (out net.Conn, isV2 bool, err error) {
|
||||
sharedConn, r := libnet.NewSharedConnSize(conn, len(MagicV2))
|
||||
buf := make([]byte, len(MagicV2))
|
||||
if _, err = io.ReadFull(r, buf); err != nil {
|
||||
return nil, false, err
|
||||
}
|
||||
for i := range MagicV2 {
|
||||
if buf[i] != MagicV2[i] {
|
||||
return sharedConn, false, nil
|
||||
}
|
||||
}
|
||||
return conn, true, nil
|
||||
}
|
||||
|
||||
type BootstrapInfo struct {
|
||||
Transport string `json:"transport,omitempty"`
|
||||
TLS bool `json:"tls,omitempty"`
|
||||
TCPMux bool `json:"tcpMux,omitempty"`
|
||||
}
|
||||
|
||||
type ClientHello struct {
|
||||
Bootstrap BootstrapInfo `json:"bootstrap,omitempty"`
|
||||
Capabilities ClientCapabilities `json:"capabilities,omitempty"`
|
||||
}
|
||||
|
||||
type ClientCapabilities struct {
|
||||
Message MessageCapabilities `json:"message,omitempty"`
|
||||
Crypto CryptoCapabilities `json:"crypto,omitempty"`
|
||||
}
|
||||
|
||||
type MessageCapabilities struct {
|
||||
Codecs []string `json:"codecs,omitempty"`
|
||||
}
|
||||
|
||||
type CryptoCapabilities struct {
|
||||
Algorithms []string `json:"algorithms,omitempty"`
|
||||
ClientRandom []byte `json:"clientRandom,omitempty"`
|
||||
}
|
||||
|
||||
type ServerHello struct {
|
||||
Selected ServerSelection `json:"selected,omitempty"`
|
||||
Error string `json:"error,omitempty"`
|
||||
}
|
||||
|
||||
type ServerSelection struct {
|
||||
Message MessageSelection `json:"message,omitempty"`
|
||||
Crypto CryptoSelection `json:"crypto,omitempty"`
|
||||
}
|
||||
|
||||
type MessageSelection struct {
|
||||
Codec string `json:"codec,omitempty"`
|
||||
}
|
||||
|
||||
type CryptoSelection struct {
|
||||
Algorithm string `json:"algorithm,omitempty"`
|
||||
ServerRandom []byte `json:"serverRandom,omitempty"`
|
||||
}
|
||||
|
||||
func clientHelloWithCryptoRandom(bootstrap BootstrapInfo, clientRandom []byte) ClientHello {
|
||||
return ClientHello{
|
||||
Bootstrap: bootstrap,
|
||||
Capabilities: ClientCapabilities{
|
||||
Message: MessageCapabilities{
|
||||
Codecs: []string{MessageCodecJSON},
|
||||
},
|
||||
Crypto: CryptoCapabilities{
|
||||
Algorithms: PreferredAEADAlgorithms(),
|
||||
ClientRandom: clientRandom,
|
||||
},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func DefaultServerHello() ServerHello {
|
||||
return ServerHello{
|
||||
Selected: ServerSelection{
|
||||
Message: MessageSelection{
|
||||
Codec: MessageCodecJSON,
|
||||
},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func Supports(list []string, value string) bool {
|
||||
return slices.Contains(list, value)
|
||||
}
|
||||
|
||||
func ValidateClientHello(h ClientHello) error {
|
||||
if !Supports(h.Capabilities.Message.Codecs, MessageCodecJSON) {
|
||||
return fmt.Errorf("unsupported message codec")
|
||||
}
|
||||
return ValidateCryptoCapabilities(h.Capabilities.Crypto)
|
||||
}
|
||||
233
pkg/proto/wire/wire_test.go
Normal file
233
pkg/proto/wire/wire_test.go
Normal file
@@ -0,0 +1,233 @@
|
||||
// Copyright 2026 The frp Authors
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package wire
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/binary"
|
||||
"io"
|
||||
"net"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestFrameRoundTrip(t *testing.T) {
|
||||
var buf bytes.Buffer
|
||||
conn := NewConn(&buf)
|
||||
|
||||
in := mustClientHello(t, BootstrapInfo{
|
||||
Transport: "tcp",
|
||||
TLS: true,
|
||||
TCPMux: true,
|
||||
})
|
||||
require.NoError(t, conn.WriteJSONFrame(FrameTypeClientHello, in))
|
||||
|
||||
var out ClientHello
|
||||
require.NoError(t, conn.ReadJSONFrame(FrameTypeClientHello, &out))
|
||||
require.Equal(t, in, out)
|
||||
}
|
||||
|
||||
func TestReadFrameRejectsUnsupportedFlags(t *testing.T) {
|
||||
var buf bytes.Buffer
|
||||
header := make([]byte, 8)
|
||||
binary.BigEndian.PutUint16(header[0:2], FrameTypeMessage)
|
||||
binary.BigEndian.PutUint16(header[2:4], 1)
|
||||
binary.BigEndian.PutUint32(header[4:8], 0)
|
||||
buf.Write(header)
|
||||
|
||||
_, err := NewConn(&buf).ReadFrame()
|
||||
require.ErrorContains(t, err, "unsupported frame flags")
|
||||
}
|
||||
|
||||
func TestReadFrameRejectsOversizedPayload(t *testing.T) {
|
||||
var buf bytes.Buffer
|
||||
header := make([]byte, 8)
|
||||
binary.BigEndian.PutUint16(header[0:2], FrameTypeMessage)
|
||||
binary.BigEndian.PutUint32(header[4:8], DefaultMaxFramePayloadSize+1)
|
||||
buf.Write(header)
|
||||
|
||||
_, err := NewConn(&buf).ReadFrame()
|
||||
require.ErrorContains(t, err, "exceeds limit")
|
||||
}
|
||||
|
||||
func TestCheckMagicV2ConsumesMagic(t *testing.T) {
|
||||
client, server := net.Pipe()
|
||||
defer server.Close()
|
||||
|
||||
want := []byte("payload")
|
||||
go func() {
|
||||
defer client.Close()
|
||||
_, _ = client.Write(append([]byte(MagicV2), want...))
|
||||
}()
|
||||
|
||||
out, isV2, err := CheckMagic(server)
|
||||
require.NoError(t, err)
|
||||
require.True(t, isV2)
|
||||
|
||||
got := make([]byte, len(want))
|
||||
_, err = io.ReadFull(out, got)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, want, got)
|
||||
}
|
||||
|
||||
func TestWriteMagicIfV2(t *testing.T) {
|
||||
var buf bytes.Buffer
|
||||
require.NoError(t, WriteMagicIfV2(&buf, ProtocolV1))
|
||||
require.Empty(t, buf.Bytes())
|
||||
|
||||
require.NoError(t, WriteMagicIfV2(&buf, ProtocolV2))
|
||||
require.Equal(t, []byte(MagicV2), buf.Bytes())
|
||||
}
|
||||
|
||||
func TestCheckMagicV1PreservesReadBytes(t *testing.T) {
|
||||
client, server := net.Pipe()
|
||||
defer server.Close()
|
||||
|
||||
want := []byte("legacy payload")
|
||||
go func() {
|
||||
defer client.Close()
|
||||
_, _ = client.Write(want)
|
||||
}()
|
||||
|
||||
out, isV2, err := CheckMagic(server)
|
||||
require.NoError(t, err)
|
||||
require.False(t, isV2)
|
||||
|
||||
got, err := io.ReadAll(out)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, want, got)
|
||||
}
|
||||
|
||||
func TestValidateClientHello(t *testing.T) {
|
||||
hello := mustClientHello(t, BootstrapInfo{})
|
||||
require.NoError(t, ValidateClientHello(hello))
|
||||
require.Len(t, hello.Capabilities.Crypto.ClientRandom, CryptoRandomSize)
|
||||
require.ElementsMatch(t, []string{
|
||||
AEADAlgorithmAES256GCM,
|
||||
AEADAlgorithmXChaCha20Poly1305,
|
||||
}, hello.Capabilities.Crypto.Algorithms)
|
||||
|
||||
hello.Capabilities.Message.Codecs = []string{"unknown"}
|
||||
require.ErrorContains(t, ValidateClientHello(hello), "unsupported message codec")
|
||||
}
|
||||
|
||||
func TestValidateClientHelloRejectsInvalidCrypto(t *testing.T) {
|
||||
hello := mustClientHello(t, BootstrapInfo{})
|
||||
hello.Capabilities.Crypto.ClientRandom = hello.Capabilities.Crypto.ClientRandom[:CryptoRandomSize-1]
|
||||
require.ErrorContains(t, ValidateClientHello(hello), "invalid crypto client random length")
|
||||
|
||||
hello = mustClientHello(t, BootstrapInfo{})
|
||||
hello.Capabilities.Crypto.Algorithms = []string{"unknown"}
|
||||
require.ErrorContains(t, ValidateClientHello(hello), "no supported crypto algorithm")
|
||||
}
|
||||
|
||||
func TestPreferredAEADAlgorithms(t *testing.T) {
|
||||
require.ElementsMatch(t, []string{
|
||||
AEADAlgorithmAES256GCM,
|
||||
AEADAlgorithmXChaCha20Poly1305,
|
||||
}, PreferredAEADAlgorithms())
|
||||
}
|
||||
|
||||
func TestNewServerHelloSelectsFirstSupportedAEADAlgorithm(t *testing.T) {
|
||||
hello := mustClientHello(t, BootstrapInfo{})
|
||||
hello.Capabilities.Crypto.Algorithms = []string{"future-aead", AEADAlgorithmXChaCha20Poly1305, AEADAlgorithmAES256GCM}
|
||||
|
||||
serverHello, err := NewServerHello(hello)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, MessageCodecJSON, serverHello.Selected.Message.Codec)
|
||||
require.Equal(t, AEADAlgorithmXChaCha20Poly1305, serverHello.Selected.Crypto.Algorithm)
|
||||
require.Len(t, serverHello.Selected.Crypto.ServerRandom, CryptoRandomSize)
|
||||
}
|
||||
|
||||
func TestNewClientCryptoContextValidatesServerHello(t *testing.T) {
|
||||
hello := mustClientHello(t, BootstrapInfo{})
|
||||
serverHello, err := NewServerHello(hello)
|
||||
require.NoError(t, err)
|
||||
clientHelloPayload, serverHelloPayload := mustCryptoTranscriptPayloads(t, hello, serverHello)
|
||||
|
||||
ctx, err := NewClientCryptoContext(clientHelloPayload, serverHelloPayload)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, serverHello.Selected.Crypto.Algorithm, ctx.Algorithm)
|
||||
require.Len(t, ctx.TranscriptHash, 32)
|
||||
|
||||
tampered := serverHello
|
||||
tampered.Selected.Crypto.ServerRandom = append([]byte(nil), serverHello.Selected.Crypto.ServerRandom...)
|
||||
tampered.Selected.Crypto.ServerRandom[0] ^= 0xff
|
||||
_, tamperedServerHelloPayload := mustCryptoTranscriptPayloads(t, hello, tampered)
|
||||
tamperedCtx, err := NewClientCryptoContext(clientHelloPayload, tamperedServerHelloPayload)
|
||||
require.NoError(t, err)
|
||||
require.NotEqual(t, ctx.TranscriptHash, tamperedCtx.TranscriptHash)
|
||||
}
|
||||
|
||||
func TestNewCryptoContextBindsFullClientHelloPayload(t *testing.T) {
|
||||
hello := mustClientHello(t, BootstrapInfo{
|
||||
Transport: "tcp",
|
||||
TLS: true,
|
||||
TCPMux: true,
|
||||
})
|
||||
serverHello, err := NewServerHello(hello)
|
||||
require.NoError(t, err)
|
||||
clientHelloPayload, serverHelloPayload := mustCryptoTranscriptPayloads(t, hello, serverHello)
|
||||
|
||||
ctx := NewCryptoContext(serverHello.Selected.Crypto.Algorithm, clientHelloPayload, serverHelloPayload)
|
||||
|
||||
tamperedHello := hello
|
||||
tamperedHello.Bootstrap.TLS = false
|
||||
tamperedClientHelloPayload, _ := mustCryptoTranscriptPayloads(t, tamperedHello, serverHello)
|
||||
tamperedCtx := NewCryptoContext(serverHello.Selected.Crypto.Algorithm, tamperedClientHelloPayload, serverHelloPayload)
|
||||
require.NotEqual(t, ctx.TranscriptHash, tamperedCtx.TranscriptHash)
|
||||
}
|
||||
|
||||
func TestNewClientCryptoContextRejectsUnknownServerSelection(t *testing.T) {
|
||||
hello := mustClientHello(t, BootstrapInfo{})
|
||||
serverHello, err := NewServerHello(hello)
|
||||
require.NoError(t, err)
|
||||
|
||||
serverHello.Selected.Crypto.Algorithm = "unknown"
|
||||
clientHelloPayload, serverHelloPayload := mustCryptoTranscriptPayloads(t, hello, serverHello)
|
||||
_, err = NewClientCryptoContext(clientHelloPayload, serverHelloPayload)
|
||||
require.ErrorContains(t, err, "unknown selected crypto algorithm")
|
||||
}
|
||||
|
||||
func TestNewClientCryptoContextRejectsUnadvertisedServerSelection(t *testing.T) {
|
||||
hello := mustClientHello(t, BootstrapInfo{})
|
||||
hello.Capabilities.Crypto.Algorithms = []string{AEADAlgorithmAES256GCM}
|
||||
serverHello, err := NewServerHello(hello)
|
||||
require.NoError(t, err)
|
||||
|
||||
serverHello.Selected.Crypto.Algorithm = AEADAlgorithmXChaCha20Poly1305
|
||||
clientHelloPayload, serverHelloPayload := mustCryptoTranscriptPayloads(t, hello, serverHello)
|
||||
_, err = NewClientCryptoContext(clientHelloPayload, serverHelloPayload)
|
||||
require.ErrorContains(t, err, "selected crypto algorithm was not advertised by client")
|
||||
}
|
||||
|
||||
func mustClientHello(t *testing.T, bootstrap BootstrapInfo) ClientHello {
|
||||
t.Helper()
|
||||
|
||||
hello, err := NewClientHello(bootstrap)
|
||||
require.NoError(t, err)
|
||||
return hello
|
||||
}
|
||||
|
||||
func mustCryptoTranscriptPayloads(t *testing.T, hello ClientHello, serverHello ServerHello) ([]byte, []byte) {
|
||||
t.Helper()
|
||||
|
||||
clientHelloFrame, err := NewJSONFrame(FrameTypeClientHello, hello)
|
||||
require.NoError(t, err)
|
||||
serverHelloFrame, err := NewJSONFrame(FrameTypeServerHello, serverHello)
|
||||
require.NoError(t, err)
|
||||
return clientHelloFrame.Payload, serverHelloFrame.Payload
|
||||
}
|
||||
@@ -17,6 +17,8 @@ package metric
|
||||
import (
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"k8s.io/utils/clock"
|
||||
)
|
||||
|
||||
type DateCounter interface {
|
||||
@@ -38,27 +40,33 @@ func NewDateCounter(reserveDays int64) DateCounter {
|
||||
type StandardDateCounter struct {
|
||||
reserveDays int64
|
||||
counts []int64
|
||||
clock clock.PassiveClock
|
||||
|
||||
lastUpdateDate time.Time
|
||||
mu sync.Mutex
|
||||
}
|
||||
|
||||
func newStandardDateCounter(reserveDays int64) *StandardDateCounter {
|
||||
now := time.Now()
|
||||
now = time.Date(now.Year(), now.Month(), now.Day(), 0, 0, 0, 0, now.Location())
|
||||
s := &StandardDateCounter{
|
||||
return newStandardDateCounterWithClock(reserveDays, clock.RealClock{})
|
||||
}
|
||||
|
||||
func newStandardDateCounterWithClock(reserveDays int64, clk clock.PassiveClock) *StandardDateCounter {
|
||||
if clk == nil {
|
||||
clk = clock.RealClock{}
|
||||
}
|
||||
return &StandardDateCounter{
|
||||
reserveDays: reserveDays,
|
||||
counts: make([]int64, reserveDays),
|
||||
lastUpdateDate: now,
|
||||
clock: clk,
|
||||
lastUpdateDate: startOfDay(clk.Now()),
|
||||
}
|
||||
return s
|
||||
}
|
||||
|
||||
func (c *StandardDateCounter) TodayCount() int64 {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
|
||||
c.rotate(time.Now())
|
||||
c.rotate(c.clock.Now())
|
||||
return c.counts[0]
|
||||
}
|
||||
|
||||
@@ -70,65 +78,61 @@ func (c *StandardDateCounter) GetLastDaysCount(lastdays int64) []int64 {
|
||||
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
c.rotate(time.Now())
|
||||
for i := 0; i < int(lastdays); i++ {
|
||||
counts[i] = c.counts[i]
|
||||
}
|
||||
c.rotate(c.clock.Now())
|
||||
copy(counts, c.counts)
|
||||
return counts
|
||||
}
|
||||
|
||||
func (c *StandardDateCounter) Inc(count int64) {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
c.rotate(time.Now())
|
||||
c.rotate(c.clock.Now())
|
||||
c.counts[0] += count
|
||||
}
|
||||
|
||||
func (c *StandardDateCounter) Dec(count int64) {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
c.rotate(time.Now())
|
||||
c.rotate(c.clock.Now())
|
||||
c.counts[0] -= count
|
||||
}
|
||||
|
||||
func (c *StandardDateCounter) Snapshot() DateCounter {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
tmp := newStandardDateCounter(c.reserveDays)
|
||||
for i := 0; i < int(c.reserveDays); i++ {
|
||||
tmp.counts[i] = c.counts[i]
|
||||
}
|
||||
tmp := newStandardDateCounterWithClock(c.reserveDays, c.clock)
|
||||
copy(tmp.counts, c.counts)
|
||||
return tmp
|
||||
}
|
||||
|
||||
func (c *StandardDateCounter) Clear() {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
for i := 0; i < int(c.reserveDays); i++ {
|
||||
c.counts[i] = 0
|
||||
}
|
||||
clear(c.counts)
|
||||
}
|
||||
|
||||
// rotate
|
||||
// Must hold the lock before calling this function.
|
||||
func (c *StandardDateCounter) rotate(now time.Time) {
|
||||
now = time.Date(now.Year(), now.Month(), now.Day(), 0, 0, 0, 0, now.Location())
|
||||
now = startOfDay(now)
|
||||
days := int(now.Sub(c.lastUpdateDate).Hours() / 24)
|
||||
|
||||
defer func() {
|
||||
c.lastUpdateDate = now
|
||||
}()
|
||||
reserveDays := int(c.reserveDays)
|
||||
|
||||
if days <= 0 {
|
||||
return
|
||||
} else if days >= int(c.reserveDays) {
|
||||
} else if days >= reserveDays {
|
||||
c.counts = make([]int64, c.reserveDays)
|
||||
c.lastUpdateDate = now
|
||||
return
|
||||
}
|
||||
newCounts := make([]int64, c.reserveDays)
|
||||
|
||||
for i := days; i < int(c.reserveDays); i++ {
|
||||
newCounts[i] = c.counts[i-days]
|
||||
}
|
||||
copy(newCounts[days:], c.counts[:reserveDays-days])
|
||||
c.counts = newCounts
|
||||
c.lastUpdateDate = now
|
||||
}
|
||||
|
||||
// startOfDay returns midnight in t's location.
|
||||
func startOfDay(t time.Time) time.Time {
|
||||
return time.Date(t.Year(), t.Month(), t.Day(), 0, 0, 0, 0, t.Location())
|
||||
}
|
||||
|
||||
@@ -1,9 +1,12 @@
|
||||
package metric
|
||||
|
||||
import (
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
clocktesting "k8s.io/utils/clock/testing"
|
||||
)
|
||||
|
||||
func TestDateCounter(t *testing.T) {
|
||||
@@ -25,3 +28,107 @@ func TestDateCounter(t *testing.T) {
|
||||
dcTmp := dc.Snapshot()
|
||||
require.EqualValues(5, dcTmp.TodayCount())
|
||||
}
|
||||
|
||||
func TestDateCounterRotate(t *testing.T) {
|
||||
loc := time.FixedZone("test", 8*60*60)
|
||||
lastUpdateDate := time.Date(2026, time.May, 8, 0, 0, 0, 0, loc)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
now time.Time
|
||||
want []int64
|
||||
wantLastUpdateDate time.Time
|
||||
}{
|
||||
{
|
||||
name: "same day",
|
||||
now: time.Date(2026, time.May, 8, 12, 30, 0, 0, loc),
|
||||
want: []int64{10, 7, 3},
|
||||
wantLastUpdateDate: lastUpdateDate,
|
||||
},
|
||||
{
|
||||
name: "clock skew",
|
||||
now: time.Date(2026, time.May, 7, 12, 30, 0, 0, loc),
|
||||
want: []int64{10, 7, 3},
|
||||
wantLastUpdateDate: lastUpdateDate,
|
||||
},
|
||||
{
|
||||
name: "one day",
|
||||
now: time.Date(2026, time.May, 9, 12, 30, 0, 0, loc),
|
||||
want: []int64{0, 10, 7},
|
||||
wantLastUpdateDate: time.Date(2026, time.May, 9, 0, 0, 0, 0, loc),
|
||||
},
|
||||
{
|
||||
name: "two days",
|
||||
now: time.Date(2026, time.May, 10, 12, 30, 0, 0, loc),
|
||||
want: []int64{0, 0, 10},
|
||||
wantLastUpdateDate: time.Date(2026, time.May, 10, 0, 0, 0, 0, loc),
|
||||
},
|
||||
{
|
||||
name: "all reserved days elapsed",
|
||||
now: time.Date(2026, time.May, 11, 12, 30, 0, 0, loc),
|
||||
want: []int64{0, 0, 0},
|
||||
wantLastUpdateDate: time.Date(2026, time.May, 11, 0, 0, 0, 0, loc),
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
require := require.New(t)
|
||||
|
||||
dc := newStandardDateCounter(3)
|
||||
dc.counts = []int64{10, 7, 3}
|
||||
dc.lastUpdateDate = lastUpdateDate
|
||||
|
||||
dc.mu.Lock()
|
||||
dc.rotate(tt.now)
|
||||
dc.mu.Unlock()
|
||||
|
||||
require.Equal(tt.want, dc.counts)
|
||||
require.Equal(tt.wantLastUpdateDate, dc.lastUpdateDate)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestDateCounterGetLastDaysCountReturnsCopy(t *testing.T) {
|
||||
require := require.New(t)
|
||||
|
||||
clk := clocktesting.NewFakeClock(time.Date(2026, time.May, 8, 12, 30, 0, 0, time.Local))
|
||||
dc := newStandardDateCounterWithClock(3, clk)
|
||||
dc.counts = []int64{10, 7, 3}
|
||||
|
||||
counts := dc.GetLastDaysCount(2)
|
||||
require.Equal([]int64{10, 7}, counts)
|
||||
|
||||
counts[0] = 100
|
||||
require.Equal([]int64{10, 7}, dc.GetLastDaysCount(2))
|
||||
}
|
||||
|
||||
func TestDateCounterClear(t *testing.T) {
|
||||
require := require.New(t)
|
||||
|
||||
dc := newStandardDateCounter(3)
|
||||
dc.counts = []int64{10, 7, 3}
|
||||
|
||||
dc.Clear()
|
||||
|
||||
require.Equal([]int64{0, 0, 0}, dc.counts)
|
||||
}
|
||||
|
||||
func TestDateCounterConcurrentAccess(t *testing.T) {
|
||||
clk := clocktesting.NewFakeClock(time.Date(2026, time.May, 8, 12, 30, 0, 0, time.Local))
|
||||
dc := newStandardDateCounterWithClock(3, clk)
|
||||
|
||||
var wg sync.WaitGroup
|
||||
for range 8 {
|
||||
wg.Go(func() {
|
||||
for range 100 {
|
||||
dc.Inc(1)
|
||||
dc.Dec(1)
|
||||
_ = dc.TodayCount()
|
||||
_ = dc.GetLastDaysCount(3)
|
||||
_ = dc.Snapshot()
|
||||
}
|
||||
})
|
||||
}
|
||||
wg.Wait()
|
||||
}
|
||||
|
||||
@@ -16,14 +16,16 @@ package net
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/sha256"
|
||||
"errors"
|
||||
"io"
|
||||
"net"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/fatedier/golib/crypto"
|
||||
libcrypto "github.com/fatedier/golib/crypto"
|
||||
quic "github.com/quic-go/quic-go"
|
||||
"golang.org/x/crypto/hkdf"
|
||||
|
||||
"github.com/fatedier/frp/pkg/util/xlog"
|
||||
)
|
||||
@@ -133,7 +135,7 @@ type CloseNotifyConn struct {
|
||||
net.Conn
|
||||
|
||||
// 1 means closed
|
||||
closeFlag int32
|
||||
closeFlag atomic.Int32
|
||||
|
||||
closeFn func(error)
|
||||
}
|
||||
@@ -147,7 +149,7 @@ func WrapCloseNotifyConn(c net.Conn, closeFn func(error)) *CloseNotifyConn {
|
||||
}
|
||||
|
||||
func (cc *CloseNotifyConn) Close() (err error) {
|
||||
pflag := atomic.SwapInt32(&cc.closeFlag, 1)
|
||||
pflag := cc.closeFlag.Swap(1)
|
||||
if pflag == 0 {
|
||||
err = cc.Conn.Close()
|
||||
if cc.closeFn != nil {
|
||||
@@ -159,7 +161,7 @@ func (cc *CloseNotifyConn) Close() (err error) {
|
||||
|
||||
// CloseWithError closes the connection and passes the error to the close callback.
|
||||
func (cc *CloseNotifyConn) CloseWithError(err error) error {
|
||||
pflag := atomic.SwapInt32(&cc.closeFlag, 1)
|
||||
pflag := cc.closeFlag.Swap(1)
|
||||
if pflag == 0 {
|
||||
closeErr := cc.Conn.Close()
|
||||
if cc.closeFn != nil {
|
||||
@@ -173,7 +175,7 @@ func (cc *CloseNotifyConn) CloseWithError(err error) error {
|
||||
type StatsConn struct {
|
||||
net.Conn
|
||||
|
||||
closed int64 // 1 means closed
|
||||
closed atomic.Int64 // 1 means closed
|
||||
totalRead int64
|
||||
totalWrite int64
|
||||
statsFunc func(totalRead, totalWrite int64)
|
||||
@@ -199,7 +201,7 @@ func (statsConn *StatsConn) Write(p []byte) (n int, err error) {
|
||||
}
|
||||
|
||||
func (statsConn *StatsConn) Close() (err error) {
|
||||
old := atomic.SwapInt64(&statsConn.closed, 1)
|
||||
old := statsConn.closed.Swap(1)
|
||||
if old != 1 {
|
||||
err = statsConn.Conn.Close()
|
||||
if statsConn.statsFunc != nil {
|
||||
@@ -241,8 +243,8 @@ func (conn *wrapQuicStream) Close() error {
|
||||
}
|
||||
|
||||
func NewCryptoReadWriter(rw io.ReadWriter, key []byte) (io.ReadWriter, error) {
|
||||
encReader := crypto.NewReader(rw, key)
|
||||
encWriter, err := crypto.NewWriter(rw, key)
|
||||
encReader := libcrypto.NewReader(rw, key)
|
||||
encWriter, err := libcrypto.NewWriter(rw, key)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -254,3 +256,90 @@ func NewCryptoReadWriter(rw io.ReadWriter, key []byte) (io.ReadWriter, error) {
|
||||
Writer: encWriter,
|
||||
}, nil
|
||||
}
|
||||
|
||||
type AEADCryptoRole int
|
||||
|
||||
const (
|
||||
AEADCryptoRoleClient AEADCryptoRole = iota + 1
|
||||
AEADCryptoRoleServer
|
||||
)
|
||||
|
||||
const (
|
||||
aeadControlHKDFInfoPrefix = "frp wire v2 control aead"
|
||||
aeadDirectionClientToServer = "client-to-server"
|
||||
aeadDirectionServerToClient = "server-to-client"
|
||||
)
|
||||
|
||||
// NewAEADCryptoReadWriter wraps rw with framed AEAD encryption for the v2
|
||||
// control channel. Frames and their order are authenticated, but end-of-stream
|
||||
// is not: a clean EOF at a frame boundary is returned as normal EOF by the
|
||||
// underlying AEAD stream. Protocols that need truncation detection for finite
|
||||
// objects must add their own authenticated final message.
|
||||
func NewAEADCryptoReadWriter(
|
||||
rw io.ReadWriter,
|
||||
key []byte,
|
||||
role AEADCryptoRole,
|
||||
algorithm string,
|
||||
transcriptHash []byte,
|
||||
) (io.ReadWriter, error) {
|
||||
clientToServerKey, serverToClientKey, err := deriveAEADControlKeys(key, algorithm, transcriptHash)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var readKey, writeKey []byte
|
||||
switch role {
|
||||
case AEADCryptoRoleClient:
|
||||
readKey = serverToClientKey
|
||||
writeKey = clientToServerKey
|
||||
case AEADCryptoRoleServer:
|
||||
readKey = clientToServerKey
|
||||
writeKey = serverToClientKey
|
||||
default:
|
||||
return nil, errors.New("invalid aead crypto role")
|
||||
}
|
||||
|
||||
encReader, err := libcrypto.NewAEADStreamReader(rw, libcrypto.AEADStreamOptions{
|
||||
Algorithm: libcrypto.AEADAlgorithm(algorithm),
|
||||
Key: readKey,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
encWriter, err := libcrypto.NewAEADStreamWriter(rw, libcrypto.AEADStreamOptions{
|
||||
Algorithm: libcrypto.AEADAlgorithm(algorithm),
|
||||
Key: writeKey,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return struct {
|
||||
io.Reader
|
||||
io.Writer
|
||||
}{
|
||||
Reader: encReader,
|
||||
Writer: encWriter,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func deriveAEADControlKeys(key []byte, algorithm string, transcriptHash []byte) (clientToServerKey, serverToClientKey []byte, err error) {
|
||||
clientToServerKey, err = deriveAEADControlKey(key, algorithm, transcriptHash, aeadDirectionClientToServer)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
serverToClientKey, err = deriveAEADControlKey(key, algorithm, transcriptHash, aeadDirectionServerToClient)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
return clientToServerKey, serverToClientKey, nil
|
||||
}
|
||||
|
||||
func deriveAEADControlKey(key []byte, algorithm string, transcriptHash []byte, direction string) ([]byte, error) {
|
||||
info := []byte(aeadControlHKDFInfoPrefix + " " + algorithm + " " + direction)
|
||||
reader := hkdf.New(sha256.New, key, transcriptHash, info)
|
||||
out := make([]byte, libcrypto.AEADKeySize)
|
||||
if _, err := io.ReadFull(reader, out); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return out, nil
|
||||
}
|
||||
|
||||
118
pkg/util/net/conn_test.go
Normal file
118
pkg/util/net/conn_test.go
Normal file
@@ -0,0 +1,118 @@
|
||||
// Copyright 2026 The frp Authors
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package net
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"io"
|
||||
stdnet "net"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/fatedier/frp/pkg/proto/wire"
|
||||
)
|
||||
|
||||
func TestNewAEADCryptoReadWriterRoundTrip(t *testing.T) {
|
||||
clientConn, serverConn := stdnet.Pipe()
|
||||
defer clientConn.Close()
|
||||
defer serverConn.Close()
|
||||
|
||||
key := []byte("token")
|
||||
transcriptHash := bytes.Repeat([]byte{0x11}, 32)
|
||||
clientRW, err := NewAEADCryptoReadWriter(
|
||||
clientConn,
|
||||
key,
|
||||
AEADCryptoRoleClient,
|
||||
wire.AEADAlgorithmXChaCha20Poly1305,
|
||||
transcriptHash,
|
||||
)
|
||||
require.NoError(t, err)
|
||||
serverRW, err := NewAEADCryptoReadWriter(
|
||||
serverConn,
|
||||
key,
|
||||
AEADCryptoRoleServer,
|
||||
wire.AEADAlgorithmXChaCha20Poly1305,
|
||||
transcriptHash,
|
||||
)
|
||||
require.NoError(t, err)
|
||||
|
||||
clientErrCh := make(chan error, 1)
|
||||
go func() {
|
||||
if _, err := clientRW.Write([]byte("ping")); err != nil {
|
||||
clientErrCh <- err
|
||||
return
|
||||
}
|
||||
buf := make([]byte, len("pong"))
|
||||
_, err := io.ReadFull(clientRW, buf)
|
||||
clientErrCh <- err
|
||||
}()
|
||||
|
||||
buf := make([]byte, len("ping"))
|
||||
_, err = io.ReadFull(serverRW, buf)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "ping", string(buf))
|
||||
_, err = serverRW.Write([]byte("pong"))
|
||||
require.NoError(t, err)
|
||||
require.NoError(t, <-clientErrCh)
|
||||
}
|
||||
|
||||
func TestNewAEADCryptoReadWriterRejectsDifferentTranscript(t *testing.T) {
|
||||
clientConn, serverConn := stdnet.Pipe()
|
||||
defer clientConn.Close()
|
||||
defer serverConn.Close()
|
||||
require.NoError(t, clientConn.SetDeadline(time.Now().Add(time.Second)))
|
||||
require.NoError(t, serverConn.SetDeadline(time.Now().Add(time.Second)))
|
||||
|
||||
key := []byte("token")
|
||||
clientRW, err := NewAEADCryptoReadWriter(
|
||||
clientConn,
|
||||
key,
|
||||
AEADCryptoRoleClient,
|
||||
wire.AEADAlgorithmAES256GCM,
|
||||
bytes.Repeat([]byte{0x22}, 32),
|
||||
)
|
||||
require.NoError(t, err)
|
||||
serverRW, err := NewAEADCryptoReadWriter(
|
||||
serverConn,
|
||||
key,
|
||||
AEADCryptoRoleServer,
|
||||
wire.AEADAlgorithmAES256GCM,
|
||||
bytes.Repeat([]byte{0x33}, 32),
|
||||
)
|
||||
require.NoError(t, err)
|
||||
|
||||
writeErrCh := make(chan error, 1)
|
||||
go func() {
|
||||
_, err := clientRW.Write([]byte("ping"))
|
||||
writeErrCh <- err
|
||||
}()
|
||||
|
||||
buf := make([]byte, len("ping"))
|
||||
_, err = io.ReadFull(serverRW, buf)
|
||||
require.Error(t, err)
|
||||
require.NoError(t, <-writeErrCh)
|
||||
}
|
||||
|
||||
func TestDeriveAEADControlKeysUsesDistinctDirections(t *testing.T) {
|
||||
clientToServerKey, serverToClientKey, err := deriveAEADControlKeys(
|
||||
[]byte("token"),
|
||||
wire.AEADAlgorithmXChaCha20Poly1305,
|
||||
bytes.Repeat([]byte{0x44}, 32),
|
||||
)
|
||||
require.NoError(t, err)
|
||||
require.NotEqual(t, clientToServerKey, serverToClientKey)
|
||||
}
|
||||
@@ -39,11 +39,14 @@ type HTTPConnectTCPMuxer struct {
|
||||
func NewHTTPConnectTCPMuxer(listener net.Listener, passthrough bool, timeout time.Duration) (*HTTPConnectTCPMuxer, error) {
|
||||
ret := &HTTPConnectTCPMuxer{passthrough: passthrough}
|
||||
mux, err := vhost.NewMuxer(listener, ret.getHostFromHTTPConnect, timeout)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
mux.SetCheckAuthFunc(ret.auth).
|
||||
SetSuccessHookFunc(ret.sendConnectResponse).
|
||||
SetFailHookFunc(vhostFailed)
|
||||
ret.Muxer = mux
|
||||
return ret, err
|
||||
return ret, nil
|
||||
}
|
||||
|
||||
func (muxer *HTTPConnectTCPMuxer) readHTTPConnectRequest(rd io.Reader) (host, httpUser, httpPwd string, err error) {
|
||||
|
||||
@@ -14,7 +14,7 @@
|
||||
|
||||
package version
|
||||
|
||||
var version = "0.68.1"
|
||||
var version = "0.69.0"
|
||||
|
||||
func Full() string {
|
||||
return version
|
||||
|
||||
@@ -24,7 +24,6 @@ import (
|
||||
"net/http"
|
||||
"net/http/httputil"
|
||||
"net/url"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
libio "github.com/fatedier/golib/io"
|
||||
@@ -160,7 +159,7 @@ func (rp *HTTPReverseProxy) UnRegister(routeCfg RouteConfig) {
|
||||
}
|
||||
|
||||
func (rp *HTTPReverseProxy) GetRouteConfig(domain, location, routeByHTTPUser string) *RouteConfig {
|
||||
vr, ok := rp.getVhost(domain, location, routeByHTTPUser)
|
||||
vr, ok := rp.vhostRouter.getByRoute(domain, location, routeByHTTPUser)
|
||||
if ok {
|
||||
log.Debugf("get new http request host [%s] path [%s] httpuser [%s]", domain, location, routeByHTTPUser)
|
||||
return vr.payload.(*RouteConfig)
|
||||
@@ -171,7 +170,7 @@ func (rp *HTTPReverseProxy) GetRouteConfig(domain, location, routeByHTTPUser str
|
||||
// CreateConnection create a new connection by route config
|
||||
func (rp *HTTPReverseProxy) CreateConnection(reqRouteInfo *RequestRouteInfo, byEndpoint bool) (net.Conn, error) {
|
||||
host, _ := httppkg.CanonicalHost(reqRouteInfo.Host)
|
||||
vr, ok := rp.getVhost(host, reqRouteInfo.URL, reqRouteInfo.HTTPUser)
|
||||
vr, ok := rp.vhostRouter.getByRoute(host, reqRouteInfo.URL, reqRouteInfo.HTTPUser)
|
||||
if ok {
|
||||
if byEndpoint {
|
||||
fn := vr.payload.(*RouteConfig).CreateConnByEndpointFn
|
||||
@@ -208,50 +207,6 @@ func checkRouteAuthByRequest(req *http.Request, rc *RouteConfig) bool {
|
||||
return ok && user == rc.Username && passwd == rc.Password
|
||||
}
|
||||
|
||||
// getVhost tries to get vhost router by route policy.
|
||||
func (rp *HTTPReverseProxy) getVhost(domain, location, routeByHTTPUser string) (*Router, bool) {
|
||||
findRouter := func(inDomain, inLocation, inRouteByHTTPUser string) (*Router, bool) {
|
||||
vr, ok := rp.vhostRouter.Get(inDomain, inLocation, inRouteByHTTPUser)
|
||||
if ok {
|
||||
return vr, ok
|
||||
}
|
||||
// Try to check if there is one proxy that doesn't specify routerByHTTPUser, it means match all.
|
||||
vr, ok = rp.vhostRouter.Get(inDomain, inLocation, "")
|
||||
if ok {
|
||||
return vr, ok
|
||||
}
|
||||
return nil, false
|
||||
}
|
||||
|
||||
// First we check the full hostname
|
||||
// if not exist, then check the wildcard_domain such as *.example.com
|
||||
vr, ok := findRouter(domain, location, routeByHTTPUser)
|
||||
if ok {
|
||||
return vr, ok
|
||||
}
|
||||
|
||||
// e.g. domain = test.example.com, try to match wildcard domains.
|
||||
// *.example.com
|
||||
// *.com
|
||||
domainSplit := strings.Split(domain, ".")
|
||||
for len(domainSplit) >= 3 {
|
||||
domainSplit[0] = "*"
|
||||
domain = strings.Join(domainSplit, ".")
|
||||
vr, ok = findRouter(domain, location, routeByHTTPUser)
|
||||
if ok {
|
||||
return vr, true
|
||||
}
|
||||
domainSplit = domainSplit[1:]
|
||||
}
|
||||
|
||||
// Finally, try to check if there is one proxy that domain is "*" means match all domains.
|
||||
vr, ok = findRouter("*", location, routeByHTTPUser)
|
||||
if ok {
|
||||
return vr, true
|
||||
}
|
||||
return nil, false
|
||||
}
|
||||
|
||||
func (rp *HTTPReverseProxy) connectHandler(rw http.ResponseWriter, req *http.Request) {
|
||||
hj, ok := rw.(http.Hijacker)
|
||||
if !ok {
|
||||
|
||||
@@ -29,11 +29,11 @@ type HTTPSMuxer struct {
|
||||
|
||||
func NewHTTPSMuxer(listener net.Listener, timeout time.Duration) (*HTTPSMuxer, error) {
|
||||
mux, err := NewMuxer(listener, GetHTTPSHostname, timeout)
|
||||
mux.SetFailHookFunc(vhostFailed)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &HTTPSMuxer{mux}, err
|
||||
mux.SetFailHookFunc(vhostFailed)
|
||||
return &HTTPSMuxer{mux}, nil
|
||||
}
|
||||
|
||||
func GetHTTPSHostname(c net.Conn) (_ net.Conn, _ map[string]string, err error) {
|
||||
|
||||
@@ -16,23 +16,53 @@ func TestGetHTTPSHostname(t *testing.T) {
|
||||
require.NoError(err)
|
||||
defer l.Close()
|
||||
|
||||
var conn net.Conn
|
||||
connCh := make(chan net.Conn, 1)
|
||||
acceptErrCh := make(chan error, 1)
|
||||
go func() {
|
||||
conn, _ = l.Accept()
|
||||
require.NotNil(conn)
|
||||
conn, err := l.Accept()
|
||||
if err != nil {
|
||||
acceptErrCh <- err
|
||||
return
|
||||
}
|
||||
connCh <- conn
|
||||
}()
|
||||
|
||||
clientErrCh := make(chan error, 1)
|
||||
go func() {
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
tls.Dial("tcp", l.Addr().String(), &tls.Config{
|
||||
conn, err := tls.Dial("tcp", l.Addr().String(), &tls.Config{
|
||||
InsecureSkipVerify: true,
|
||||
ServerName: "example.com",
|
||||
})
|
||||
if conn != nil {
|
||||
_ = conn.Close()
|
||||
}
|
||||
clientErrCh <- err
|
||||
}()
|
||||
|
||||
time.Sleep(200 * time.Millisecond)
|
||||
_, infos, err := GetHTTPSHostname(conn)
|
||||
var conn net.Conn
|
||||
select {
|
||||
case conn = <-connCh:
|
||||
case err := <-acceptErrCh:
|
||||
require.NoError(err)
|
||||
case <-time.After(time.Second):
|
||||
t.Fatal("timed out waiting for accepted connection")
|
||||
}
|
||||
require.NotNil(conn)
|
||||
|
||||
serverConn, infos, err := GetHTTPSHostname(conn)
|
||||
if serverConn != nil {
|
||||
_ = serverConn.Close()
|
||||
} else {
|
||||
_ = conn.Close()
|
||||
}
|
||||
require.NoError(err)
|
||||
require.Equal("example.com", infos["Host"])
|
||||
require.Equal("https", infos["Scheme"])
|
||||
|
||||
select {
|
||||
case <-clientErrCh:
|
||||
case <-time.After(time.Second):
|
||||
t.Fatal("timed out waiting for TLS client")
|
||||
}
|
||||
}
|
||||
|
||||
@@ -93,12 +93,19 @@ func (r *Routers) Del(domain, location, httpUser string) {
|
||||
routersByHTTPUser[httpUser] = newVrs
|
||||
}
|
||||
|
||||
// Get returns the best location match for an exact host and exact HTTP user.
|
||||
// It does not apply all-users, wildcard-domain, or catch-all-domain fallback.
|
||||
func (r *Routers) Get(host, path, httpUser string) (vr *Router, exist bool) {
|
||||
host = strings.ToLower(host)
|
||||
|
||||
r.mutex.RLock()
|
||||
defer r.mutex.RUnlock()
|
||||
|
||||
return r.getLocked(host, path, httpUser)
|
||||
}
|
||||
|
||||
// getLocked performs an exact-host lookup; host must already be lower-cased.
|
||||
func (r *Routers) getLocked(host, path, httpUser string) (vr *Router, exist bool) {
|
||||
routersByHTTPUser, found := r.indexByDomain[host]
|
||||
if !found {
|
||||
return
|
||||
@@ -117,6 +124,49 @@ func (r *Routers) Get(host, path, httpUser string) (vr *Router, exist bool) {
|
||||
return
|
||||
}
|
||||
|
||||
func (r *Routers) getByRoute(host, path, httpUser string) (*Router, bool) {
|
||||
host = strings.ToLower(host)
|
||||
|
||||
r.mutex.RLock()
|
||||
defer r.mutex.RUnlock()
|
||||
|
||||
// First we check the full hostname; if it doesn't exist, then check wildcard domains.
|
||||
// For example, test.example.com checks *.example.com before falling back to "*".
|
||||
vr, ok := r.getExactOrAllUsersLocked(host, path, httpUser)
|
||||
if ok {
|
||||
return vr, true
|
||||
}
|
||||
|
||||
hostSplit := strings.Split(host, ".")
|
||||
// Keep two-label hosts out of the wildcard walk, so example.com does not match *.com.
|
||||
for len(hostSplit) >= 3 {
|
||||
// Replace the leftmost remaining label with the wildcard marker.
|
||||
hostSplit[0] = "*"
|
||||
host = strings.Join(hostSplit, ".")
|
||||
vr, ok = r.getExactOrAllUsersLocked(host, path, httpUser)
|
||||
if ok {
|
||||
return vr, true
|
||||
}
|
||||
hostSplit = hostSplit[1:]
|
||||
}
|
||||
|
||||
// Finally, try to check if there is one proxy whose domain is "*", which means match all domains.
|
||||
return r.getExactOrAllUsersLocked("*", path, httpUser)
|
||||
}
|
||||
|
||||
func (r *Routers) getExactOrAllUsersLocked(host, path, httpUser string) (*Router, bool) {
|
||||
vr, ok := r.getLocked(host, path, httpUser)
|
||||
if ok {
|
||||
return vr, true
|
||||
}
|
||||
// Try to check if there is one proxy that doesn't specify routeByHTTPUser, it means match all.
|
||||
vr, ok = r.getLocked(host, path, "")
|
||||
if ok {
|
||||
return vr, true
|
||||
}
|
||||
return nil, false
|
||||
}
|
||||
|
||||
func (r *Routers) exist(host, path, httpUser string) (route *Router, exist bool) {
|
||||
routersByHTTPUser, found := r.indexByDomain[host]
|
||||
if !found {
|
||||
|
||||
257
pkg/util/vhost/router_test.go
Normal file
257
pkg/util/vhost/router_test.go
Normal file
@@ -0,0 +1,257 @@
|
||||
// Copyright 2017 fatedier, fatedier@gmail.com
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package vhost
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestRoutersGet(t *testing.T) {
|
||||
routers := NewRouters()
|
||||
require.NoError(t, routers.Add("example.com", "/api", "alice", "exact-user"))
|
||||
require.NoError(t, routers.Add("example.com", "/public", "", "exact-all-users"))
|
||||
require.NoError(t, routers.Add("*.example.com", "/api", "", "wildcard-all-users"))
|
||||
require.NoError(t, routers.Add("*", "/api", "", "all-domains"))
|
||||
|
||||
t.Run("exact host and user match with normalized domain", func(t *testing.T) {
|
||||
router, ok := routers.Get("EXAMPLE.COM", "/api/users", "alice")
|
||||
require.True(t, ok)
|
||||
require.Equal(t, "exact-user", router.payload)
|
||||
})
|
||||
|
||||
t.Run("exact host and empty user match", func(t *testing.T) {
|
||||
router, ok := routers.Get("EXAMPLE.COM", "/public/docs", "")
|
||||
require.True(t, ok)
|
||||
require.Equal(t, "exact-all-users", router.payload)
|
||||
})
|
||||
|
||||
t.Run("does not fall back from named user to empty user", func(t *testing.T) {
|
||||
// Get intentionally requires an exact HTTP user; route-level fallbacks live in getByRoute.
|
||||
_, ok := routers.Get("EXAMPLE.COM", "/public/docs", "alice")
|
||||
require.False(t, ok)
|
||||
})
|
||||
|
||||
t.Run("does not fall back to wildcard domains", func(t *testing.T) {
|
||||
_, ok := routers.Get("foo.example.com", "/api/users", "")
|
||||
require.False(t, ok)
|
||||
})
|
||||
|
||||
t.Run("does not fall back to catch-all domain", func(t *testing.T) {
|
||||
_, ok := routers.Get("missing.test", "/api/users", "")
|
||||
require.False(t, ok)
|
||||
})
|
||||
}
|
||||
|
||||
func TestRoutersGetByRoute(t *testing.T) {
|
||||
routers := NewRouters()
|
||||
require.NoError(t, routers.Add("example.com", "/api", "alice", "exact-user"))
|
||||
require.NoError(t, routers.Add("example.com", "/api", "", "exact-all-users"))
|
||||
require.NoError(t, routers.Add("exact.example.com", "/api", "", "exact-subdomain"))
|
||||
require.NoError(t, routers.Add("*.example.com", "/api", "", "wildcard-all-users"))
|
||||
require.NoError(t, routers.Add("*.foo.example.com", "/api", "", "specific-wildcard"))
|
||||
require.NoError(t, routers.Add("*.bar.com", "/api", "", "wildcard-parent-domain"))
|
||||
require.NoError(t, routers.Add("*", "/admin", "root", "all-domains-user"))
|
||||
require.NoError(t, routers.Add("*", "/", "", "all-domains"))
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
domain string
|
||||
location string
|
||||
httpUser string
|
||||
want string
|
||||
}{
|
||||
{
|
||||
name: "exact domain and http user",
|
||||
domain: "example.com",
|
||||
location: "/api/users",
|
||||
httpUser: "alice",
|
||||
want: "exact-user",
|
||||
},
|
||||
{
|
||||
name: "exact domain falls back to all users",
|
||||
domain: "example.com",
|
||||
location: "/api/users",
|
||||
httpUser: "bob",
|
||||
want: "exact-all-users",
|
||||
},
|
||||
{
|
||||
name: "wildcard domain uses all users fallback",
|
||||
domain: "foo.example.com",
|
||||
location: "/api/users",
|
||||
httpUser: "bob",
|
||||
want: "wildcard-all-users",
|
||||
},
|
||||
{
|
||||
name: "mixed-case domain is normalized",
|
||||
domain: "Foo.Example.Com",
|
||||
location: "/api/users",
|
||||
httpUser: "bob",
|
||||
want: "wildcard-all-users",
|
||||
},
|
||||
{
|
||||
name: "exact domain wins over wildcard domain",
|
||||
domain: "exact.example.com",
|
||||
location: "/api/users",
|
||||
httpUser: "bob",
|
||||
want: "exact-subdomain",
|
||||
},
|
||||
{
|
||||
name: "more specific wildcard wins over broader wildcard",
|
||||
domain: "bar.foo.example.com",
|
||||
location: "/api/users",
|
||||
httpUser: "bob",
|
||||
want: "specific-wildcard",
|
||||
},
|
||||
{
|
||||
name: "wildcard walk checks parent domains",
|
||||
domain: "a.b.bar.com",
|
||||
location: "/api/users",
|
||||
httpUser: "bob",
|
||||
want: "wildcard-parent-domain",
|
||||
},
|
||||
{
|
||||
name: "catch-all domain fallback",
|
||||
domain: "foo.test.com",
|
||||
location: "/other",
|
||||
httpUser: "bob",
|
||||
want: "all-domains",
|
||||
},
|
||||
{
|
||||
name: "catch-all domain honors http user",
|
||||
domain: "foo.test.com",
|
||||
location: "/admin/panel",
|
||||
httpUser: "root",
|
||||
want: "all-domains-user",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
router, ok := routers.getByRoute(tt.domain, tt.location, tt.httpUser)
|
||||
require.True(t, ok)
|
||||
require.Equal(t, tt.want, router.payload)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestRoutersGetByRouteNoMatch(t *testing.T) {
|
||||
routers := NewRouters()
|
||||
require.NoError(t, routers.Add("*.example.com", "/api", "", "wildcard-all-users"))
|
||||
require.NoError(t, routers.Add("*.com", "/api", "", "top-level-wildcard"))
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
domain string
|
||||
location string
|
||||
}{
|
||||
{
|
||||
name: "two-label domain does not enter wildcard walk",
|
||||
domain: "example.com",
|
||||
location: "/api/users",
|
||||
},
|
||||
{
|
||||
name: "missing catch-all remains no match",
|
||||
domain: "foo.test.com",
|
||||
location: "/api/users",
|
||||
},
|
||||
{
|
||||
name: "wrong path remains no match",
|
||||
domain: "foo.example.com",
|
||||
location: "/other",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
router, ok := routers.getByRoute(tt.domain, tt.location, "bob")
|
||||
require.False(t, ok)
|
||||
require.Nil(t, router)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestRoutersConcurrentGetByRouteAndAdd(t *testing.T) {
|
||||
routers := NewRouters()
|
||||
require.NoError(t, routers.Add("*.example.com", "/api", "", "wildcard"))
|
||||
|
||||
const readers = 8
|
||||
const iterations = 200
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
start := make(chan struct{})
|
||||
errCh := make(chan error, readers+1)
|
||||
var wg sync.WaitGroup
|
||||
|
||||
for id := range readers {
|
||||
wg.Go(func() {
|
||||
<-start
|
||||
for j := range iterations {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
default:
|
||||
}
|
||||
router, ok := routers.getByRoute("foo.example.com", "/api/users", "")
|
||||
if !ok || router == nil || router.payload != "wildcard" {
|
||||
errCh <- fmt.Errorf("reader %d iteration %d got router=%v ok=%v", id, j, router, ok)
|
||||
return
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
wg.Go(func() {
|
||||
<-start
|
||||
for i := range iterations {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
default:
|
||||
}
|
||||
err := routers.Add(fmt.Sprintf("host-%d.example.com", i), "/api", "", i)
|
||||
if err != nil {
|
||||
errCh <- err
|
||||
return
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
close(start)
|
||||
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
wg.Wait()
|
||||
close(done)
|
||||
}()
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
case <-time.After(5 * time.Second):
|
||||
cancel()
|
||||
t.Fatal("concurrent route lookup and add timed out")
|
||||
}
|
||||
|
||||
close(errCh)
|
||||
for err := range errCh {
|
||||
require.NoError(t, err)
|
||||
}
|
||||
}
|
||||
@@ -148,41 +148,9 @@ func (v *Muxer) Listen(ctx context.Context, cfg *RouteConfig) (l *Listener, err
|
||||
}
|
||||
|
||||
func (v *Muxer) getListener(name, path, httpUser string) (*Listener, bool) {
|
||||
findRouter := func(inName, inPath, inHTTPUser string) (*Listener, bool) {
|
||||
vr, ok := v.registryRouter.Get(inName, inPath, inHTTPUser)
|
||||
if ok {
|
||||
return vr.payload.(*Listener), true
|
||||
}
|
||||
// Try to check if there is one proxy that doesn't specify routerByHTTPUser, it means match all.
|
||||
vr, ok = v.registryRouter.Get(inName, inPath, "")
|
||||
if ok {
|
||||
return vr.payload.(*Listener), true
|
||||
}
|
||||
return nil, false
|
||||
}
|
||||
|
||||
// first we check the full hostname
|
||||
// if not exist, then check the wildcard_domain such as *.example.com
|
||||
l, ok := findRouter(name, path, httpUser)
|
||||
vr, ok := v.registryRouter.getByRoute(name, path, httpUser)
|
||||
if ok {
|
||||
return l, true
|
||||
}
|
||||
|
||||
domainSplit := strings.Split(name, ".")
|
||||
for len(domainSplit) >= 3 {
|
||||
domainSplit[0] = "*"
|
||||
name = strings.Join(domainSplit, ".")
|
||||
|
||||
l, ok = findRouter(name, path, httpUser)
|
||||
if ok {
|
||||
return l, true
|
||||
}
|
||||
domainSplit = domainSplit[1:]
|
||||
}
|
||||
// Finally, try to check if there is one proxy that domain is "*" means match all domains.
|
||||
l, ok = findRouter("*", path, httpUser)
|
||||
if ok {
|
||||
return l, true
|
||||
return vr.payload.(*Listener), true
|
||||
}
|
||||
return nil, false
|
||||
}
|
||||
|
||||
@@ -18,6 +18,8 @@ import (
|
||||
"math/rand/v2"
|
||||
"time"
|
||||
|
||||
"k8s.io/utils/clock"
|
||||
|
||||
"github.com/fatedier/frp/pkg/util/util"
|
||||
)
|
||||
|
||||
@@ -48,6 +50,7 @@ type FastBackoffOptions struct {
|
||||
|
||||
type fastBackoffImpl struct {
|
||||
options FastBackoffOptions
|
||||
clock clock.PassiveClock
|
||||
|
||||
lastCalledTime time.Time
|
||||
consecutiveErrCount int
|
||||
@@ -57,18 +60,26 @@ type fastBackoffImpl struct {
|
||||
}
|
||||
|
||||
func NewFastBackoffManager(options FastBackoffOptions) BackoffManager {
|
||||
return newFastBackoffManagerWithClock(options, clock.RealClock{})
|
||||
}
|
||||
|
||||
func newFastBackoffManagerWithClock(options FastBackoffOptions, clk clock.PassiveClock) BackoffManager {
|
||||
if clk == nil {
|
||||
clk = clock.RealClock{}
|
||||
}
|
||||
return &fastBackoffImpl{
|
||||
options: options,
|
||||
clock: clk,
|
||||
countsInFastRetryWindow: 1,
|
||||
}
|
||||
}
|
||||
|
||||
func (f *fastBackoffImpl) Backoff(previousDuration time.Duration, previousConditionError bool) time.Duration {
|
||||
if f.lastCalledTime.IsZero() {
|
||||
f.lastCalledTime = time.Now()
|
||||
f.lastCalledTime = f.clock.Now()
|
||||
return f.options.Duration
|
||||
}
|
||||
now := time.Now()
|
||||
now := f.clock.Now()
|
||||
f.lastCalledTime = now
|
||||
|
||||
if previousConditionError {
|
||||
|
||||
27
pkg/util/wait/backoff_test.go
Normal file
27
pkg/util/wait/backoff_test.go
Normal file
@@ -0,0 +1,27 @@
|
||||
package wait
|
||||
|
||||
import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
clocktesting "k8s.io/utils/clock/testing"
|
||||
)
|
||||
|
||||
func TestFastBackoffManagerUsesClock(t *testing.T) {
|
||||
require := require.New(t)
|
||||
|
||||
start := time.Date(2026, time.May, 8, 12, 30, 0, 0, time.UTC)
|
||||
clk := clocktesting.NewFakeClock(start)
|
||||
backoff := newFastBackoffManagerWithClock(FastBackoffOptions{
|
||||
Duration: time.Second,
|
||||
}, clk).(*fastBackoffImpl)
|
||||
|
||||
require.Equal(time.Second, backoff.Backoff(0, false))
|
||||
require.Equal(start, backoff.lastCalledTime)
|
||||
|
||||
next := start.Add(time.Minute)
|
||||
clk.SetTime(next)
|
||||
require.Equal(time.Second, backoff.Backoff(time.Second, false))
|
||||
require.Equal(next, backoff.lastCalledTime)
|
||||
}
|
||||
@@ -19,7 +19,6 @@ import "strings"
|
||||
// LogWriter forwards writes to frp's logger at configurable level.
|
||||
// It is safe for concurrent use as long as the underlying Logger is thread-safe.
|
||||
type LogWriter struct {
|
||||
xl *Logger
|
||||
logFunc func(string)
|
||||
}
|
||||
|
||||
@@ -31,35 +30,30 @@ func (w LogWriter) Write(p []byte) (n int, err error) {
|
||||
|
||||
func NewTraceWriter(xl *Logger) LogWriter {
|
||||
return LogWriter{
|
||||
xl: xl,
|
||||
logFunc: func(msg string) { xl.Tracef("%s", msg) },
|
||||
}
|
||||
}
|
||||
|
||||
func NewDebugWriter(xl *Logger) LogWriter {
|
||||
return LogWriter{
|
||||
xl: xl,
|
||||
logFunc: func(msg string) { xl.Debugf("%s", msg) },
|
||||
}
|
||||
}
|
||||
|
||||
func NewInfoWriter(xl *Logger) LogWriter {
|
||||
return LogWriter{
|
||||
xl: xl,
|
||||
logFunc: func(msg string) { xl.Infof("%s", msg) },
|
||||
}
|
||||
}
|
||||
|
||||
func NewWarnWriter(xl *Logger) LogWriter {
|
||||
return LogWriter{
|
||||
xl: xl,
|
||||
logFunc: func(msg string) { xl.Warnf("%s", msg) },
|
||||
}
|
||||
}
|
||||
|
||||
func NewErrorWriter(xl *Logger) LogWriter {
|
||||
return LogWriter{
|
||||
xl: xl,
|
||||
logFunc: func(msg string) { xl.Errorf("%s", msg) },
|
||||
}
|
||||
}
|
||||
|
||||
@@ -286,7 +286,6 @@ func (r *clientRouter) addRoute(name string, routes []net.IPNet, conn io.ReadWri
|
||||
r.mu.Lock()
|
||||
defer r.mu.Unlock()
|
||||
r.routes[name] = &routeElement{
|
||||
name: name,
|
||||
routes: routes,
|
||||
conn: conn,
|
||||
}
|
||||
@@ -383,7 +382,6 @@ func (r *serverRouter) cleanupConnIPs(conn io.Writer) {
|
||||
}
|
||||
|
||||
type routeElement struct {
|
||||
name string
|
||||
routes []net.IPNet
|
||||
conn io.ReadWriteCloser
|
||||
}
|
||||
|
||||
@@ -17,7 +17,6 @@ package server
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net"
|
||||
"runtime/debug"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
@@ -32,9 +31,7 @@ import (
|
||||
"github.com/fatedier/frp/pkg/msg"
|
||||
plugin "github.com/fatedier/frp/pkg/plugin/server"
|
||||
"github.com/fatedier/frp/pkg/transport"
|
||||
netpkg "github.com/fatedier/frp/pkg/util/net"
|
||||
"github.com/fatedier/frp/pkg/util/util"
|
||||
"github.com/fatedier/frp/pkg/util/version"
|
||||
"github.com/fatedier/frp/pkg/util/wait"
|
||||
"github.com/fatedier/frp/pkg/util/xlog"
|
||||
"github.com/fatedier/frp/server/controller"
|
||||
@@ -108,9 +105,7 @@ type SessionContext struct {
|
||||
// key used for connection encryption
|
||||
EncryptionKey []byte
|
||||
// control connection
|
||||
Conn net.Conn
|
||||
// indicates whether the connection is encrypted
|
||||
ConnEncrypted bool
|
||||
Conn *msg.Conn
|
||||
// login message
|
||||
LoginMsg *msg.Login
|
||||
// server configuration
|
||||
@@ -131,7 +126,7 @@ type Control struct {
|
||||
msgDispatcher *msg.Dispatcher
|
||||
|
||||
// work connections
|
||||
workConnCh chan net.Conn
|
||||
workConnCh chan *proxy.WorkConn
|
||||
|
||||
// proxies in one client
|
||||
proxies map[string]proxy.Proxy
|
||||
@@ -161,7 +156,7 @@ func NewControl(ctx context.Context, sessionCtx *SessionContext) (*Control, erro
|
||||
poolCount := min(sessionCtx.LoginMsg.PoolCount, int(sessionCtx.ServerCfg.Transport.MaxPoolCount))
|
||||
ctl := &Control{
|
||||
sessionCtx: sessionCtx,
|
||||
workConnCh: make(chan net.Conn, poolCount+10),
|
||||
workConnCh: make(chan *proxy.WorkConn, poolCount+10),
|
||||
proxies: make(map[string]proxy.Proxy),
|
||||
poolCount: poolCount,
|
||||
portsUsedNum: 0,
|
||||
@@ -172,29 +167,14 @@ func NewControl(ctx context.Context, sessionCtx *SessionContext) (*Control, erro
|
||||
}
|
||||
ctl.lastPing.Store(time.Now())
|
||||
|
||||
if sessionCtx.ConnEncrypted {
|
||||
cryptoRW, err := netpkg.NewCryptoReadWriter(sessionCtx.Conn, sessionCtx.EncryptionKey)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
ctl.msgDispatcher = msg.NewDispatcher(cryptoRW)
|
||||
} else {
|
||||
ctl.msgDispatcher = msg.NewDispatcher(sessionCtx.Conn)
|
||||
}
|
||||
ctl.msgDispatcher = msg.NewDispatcher(sessionCtx.Conn)
|
||||
ctl.registerMsgHandlers()
|
||||
ctl.msgTransporter = transport.NewMessageTransporter(ctl.msgDispatcher)
|
||||
return ctl, nil
|
||||
}
|
||||
|
||||
// Start send a login success message to client and start working.
|
||||
// Start starts the control session workers after login succeeds.
|
||||
func (ctl *Control) Start() {
|
||||
loginRespMsg := &msg.LoginResp{
|
||||
Version: version.Full(),
|
||||
RunID: ctl.runID,
|
||||
Error: "",
|
||||
}
|
||||
_ = msg.WriteMsg(ctl.sessionCtx.Conn, loginRespMsg)
|
||||
|
||||
go func() {
|
||||
for i := 0; i < ctl.poolCount; i++ {
|
||||
// ignore error here, that means that this control is closed
|
||||
@@ -216,7 +196,7 @@ func (ctl *Control) Replaced(newCtl *Control) {
|
||||
ctl.sessionCtx.Conn.Close()
|
||||
}
|
||||
|
||||
func (ctl *Control) RegisterWorkConn(conn net.Conn) error {
|
||||
func (ctl *Control) RegisterWorkConn(conn *proxy.WorkConn) error {
|
||||
xl := ctl.xl
|
||||
defer func() {
|
||||
if err := recover(); err != nil {
|
||||
@@ -239,7 +219,7 @@ func (ctl *Control) RegisterWorkConn(conn net.Conn) error {
|
||||
// If no workConn available in the pool, send message to frpc to get one or more
|
||||
// and wait until it is available.
|
||||
// return an error if wait timeout
|
||||
func (ctl *Control) GetWorkConn() (workConn net.Conn, err error) {
|
||||
func (ctl *Control) GetWorkConn() (workConn *proxy.WorkConn, err error) {
|
||||
xl := ctl.xl
|
||||
defer func() {
|
||||
if err := recover(); err != nil {
|
||||
|
||||
@@ -57,7 +57,7 @@ type HTTPGroup struct {
|
||||
// CreateConnFuncs indexed by proxy name
|
||||
createFuncs map[string]vhost.CreateConnFunc
|
||||
pxyNames []string
|
||||
index uint64
|
||||
index atomic.Uint64
|
||||
ctl *HTTPGroupController
|
||||
mu sync.RWMutex
|
||||
}
|
||||
@@ -136,7 +136,7 @@ func (g *HTTPGroup) UnRegister(proxyName string) {
|
||||
|
||||
func (g *HTTPGroup) createConn(remoteAddr string) (net.Conn, error) {
|
||||
var f vhost.CreateConnFunc
|
||||
newIndex := atomic.AddUint64(&g.index, 1)
|
||||
newIndex := g.index.Add(1)
|
||||
|
||||
g.mu.RLock()
|
||||
group := g.group
|
||||
@@ -158,7 +158,7 @@ func (g *HTTPGroup) createConn(remoteAddr string) (net.Conn, error) {
|
||||
}
|
||||
|
||||
func (g *HTTPGroup) chooseEndpoint() (string, error) {
|
||||
newIndex := atomic.AddUint64(&g.index, 1)
|
||||
newIndex := g.index.Add(1)
|
||||
name := ""
|
||||
|
||||
g.mu.RLock()
|
||||
|
||||
@@ -287,6 +287,7 @@ func buildClientInfoResp(info registry.ClientInfo) model.ClientInfoResp {
|
||||
ClientID: info.ClientID(),
|
||||
RunID: info.RunID,
|
||||
Version: info.Version,
|
||||
WireProtocol: info.WireProtocol,
|
||||
Hostname: info.Hostname,
|
||||
ClientIP: info.IP,
|
||||
FirstConnectedAt: toUnix(info.FirstConnectedAt),
|
||||
|
||||
@@ -17,8 +17,11 @@ package http
|
||||
import (
|
||||
"encoding/json"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
v1 "github.com/fatedier/frp/pkg/config/v1"
|
||||
"github.com/fatedier/frp/pkg/proto/wire"
|
||||
"github.com/fatedier/frp/server/registry"
|
||||
)
|
||||
|
||||
func TestGetConfFromConfigurerKeepsPluginFields(t *testing.T) {
|
||||
@@ -69,3 +72,24 @@ func TestGetConfFromConfigurerKeepsPluginFields(t *testing.T) {
|
||||
t.Fatalf("plugin httpPassword mismatch, want %q got %#v", "password", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildClientInfoRespIncludesWireProtocol(t *testing.T) {
|
||||
info := registry.ClientInfo{
|
||||
Key: "user.client",
|
||||
User: "user",
|
||||
RawClientID: "client",
|
||||
RunID: "run-id",
|
||||
Version: "1.0.0",
|
||||
WireProtocol: wire.ProtocolV2,
|
||||
Hostname: "host",
|
||||
IP: "127.0.0.1",
|
||||
FirstConnectedAt: time.Unix(1, 0),
|
||||
LastConnectedAt: time.Unix(2, 0),
|
||||
Online: true,
|
||||
}
|
||||
|
||||
resp := buildClientInfoResp(info)
|
||||
if resp.WireProtocol != wire.ProtocolV2 {
|
||||
t.Fatalf("wire protocol mismatch, want %q got %q", wire.ProtocolV2, resp.WireProtocol)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -46,6 +46,7 @@ type ClientInfoResp struct {
|
||||
ClientID string `json:"clientID"`
|
||||
RunID string `json:"runID"`
|
||||
Version string `json:"version,omitempty"`
|
||||
WireProtocol string `json:"wireProtocol,omitempty"`
|
||||
Hostname string `json:"hostname"`
|
||||
ClientIP string `json:"clientIP,omitempty"`
|
||||
FirstConnectedAt int64 `json:"firstConnectedAt"`
|
||||
|
||||
@@ -7,6 +7,8 @@ import (
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"k8s.io/utils/clock"
|
||||
|
||||
"github.com/fatedier/frp/pkg/config/types"
|
||||
)
|
||||
|
||||
@@ -38,16 +40,25 @@ type Manager struct {
|
||||
|
||||
bindAddr string
|
||||
netType string
|
||||
clock clock.WithTicker
|
||||
mu sync.Mutex
|
||||
}
|
||||
|
||||
func NewManager(netType string, bindAddr string, allowPorts []types.PortsRange) *Manager {
|
||||
return newManagerWithClock(netType, bindAddr, allowPorts, clock.RealClock{})
|
||||
}
|
||||
|
||||
func newManagerWithClock(netType string, bindAddr string, allowPorts []types.PortsRange, clk clock.WithTicker) *Manager {
|
||||
if clk == nil {
|
||||
clk = clock.RealClock{}
|
||||
}
|
||||
pm := &Manager{
|
||||
reservedPorts: make(map[string]*PortCtx),
|
||||
usedPorts: make(map[int]*PortCtx),
|
||||
freePorts: make(map[int]struct{}),
|
||||
bindAddr: bindAddr,
|
||||
netType: netType,
|
||||
clock: clk,
|
||||
}
|
||||
if len(allowPorts) > 0 {
|
||||
for _, pair := range allowPorts {
|
||||
@@ -72,7 +83,7 @@ func (pm *Manager) Acquire(name string, port int) (realPort int, err error) {
|
||||
portCtx := &PortCtx{
|
||||
ProxyName: name,
|
||||
Closed: false,
|
||||
UpdateTime: time.Now(),
|
||||
UpdateTime: pm.clock.Now(),
|
||||
}
|
||||
|
||||
var ok bool
|
||||
@@ -90,9 +101,7 @@ func (pm *Manager) Acquire(name string, port int) (realPort int, err error) {
|
||||
if ctx, ok := pm.reservedPorts[name]; ok {
|
||||
if pm.isPortAvailable(ctx.Port) {
|
||||
realPort = ctx.Port
|
||||
pm.usedPorts[realPort] = portCtx
|
||||
pm.reservedPorts[name] = portCtx
|
||||
delete(pm.freePorts, realPort)
|
||||
pm.markPortAcquiredLocked(name, realPort, portCtx)
|
||||
return
|
||||
}
|
||||
}
|
||||
@@ -109,9 +118,7 @@ func (pm *Manager) Acquire(name string, port int) (realPort int, err error) {
|
||||
}
|
||||
if pm.isPortAvailable(k) {
|
||||
realPort = k
|
||||
pm.usedPorts[realPort] = portCtx
|
||||
pm.reservedPorts[name] = portCtx
|
||||
delete(pm.freePorts, realPort)
|
||||
pm.markPortAcquiredLocked(name, realPort, portCtx)
|
||||
break
|
||||
}
|
||||
}
|
||||
@@ -123,9 +130,7 @@ func (pm *Manager) Acquire(name string, port int) (realPort int, err error) {
|
||||
if _, ok = pm.freePorts[port]; ok {
|
||||
if pm.isPortAvailable(port) {
|
||||
realPort = port
|
||||
pm.usedPorts[realPort] = portCtx
|
||||
pm.reservedPorts[name] = portCtx
|
||||
delete(pm.freePorts, realPort)
|
||||
pm.markPortAcquiredLocked(name, realPort, portCtx)
|
||||
} else {
|
||||
err = ErrPortUnAvailable
|
||||
}
|
||||
@@ -140,6 +145,13 @@ func (pm *Manager) Acquire(name string, port int) (realPort int, err error) {
|
||||
return
|
||||
}
|
||||
|
||||
// markPortAcquiredLocked records a successful acquisition. pm.mu must be held.
|
||||
func (pm *Manager) markPortAcquiredLocked(name string, port int, portCtx *PortCtx) {
|
||||
pm.usedPorts[port] = portCtx
|
||||
pm.reservedPorts[name] = portCtx
|
||||
delete(pm.freePorts, port)
|
||||
}
|
||||
|
||||
func (pm *Manager) isPortAvailable(port int) bool {
|
||||
if pm.netType == "udp" {
|
||||
addr, err := net.ResolveUDPAddr("udp", net.JoinHostPort(pm.bindAddr, strconv.Itoa(port)))
|
||||
@@ -169,20 +181,36 @@ func (pm *Manager) Release(port int) {
|
||||
pm.freePorts[port] = struct{}{}
|
||||
delete(pm.usedPorts, port)
|
||||
ctx.Closed = true
|
||||
ctx.UpdateTime = time.Now()
|
||||
ctx.UpdateTime = pm.clock.Now()
|
||||
}
|
||||
}
|
||||
|
||||
// Release reserved port if it isn't used in last 24 hours.
|
||||
func (pm *Manager) cleanReservedPortsWorker() {
|
||||
pm.cleanReservedPortsWorkerUntil(nil)
|
||||
}
|
||||
|
||||
func (pm *Manager) cleanReservedPortsWorkerUntil(stopCh <-chan struct{}) {
|
||||
ticker := pm.clock.NewTicker(CleanReservedPortsInterval)
|
||||
defer ticker.Stop()
|
||||
|
||||
for {
|
||||
time.Sleep(CleanReservedPortsInterval)
|
||||
pm.mu.Lock()
|
||||
for name, ctx := range pm.reservedPorts {
|
||||
if ctx.Closed && time.Since(ctx.UpdateTime) > MaxPortReservedDuration {
|
||||
delete(pm.reservedPorts, name)
|
||||
}
|
||||
select {
|
||||
case <-ticker.C():
|
||||
pm.cleanReservedPortsOnce()
|
||||
case <-stopCh:
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (pm *Manager) cleanReservedPortsOnce() {
|
||||
pm.mu.Lock()
|
||||
defer pm.mu.Unlock()
|
||||
|
||||
for name, ctx := range pm.reservedPorts {
|
||||
if ctx.Closed && pm.clock.Since(ctx.UpdateTime) > MaxPortReservedDuration {
|
||||
delete(pm.reservedPorts, name)
|
||||
}
|
||||
pm.mu.Unlock()
|
||||
}
|
||||
}
|
||||
|
||||
72
server/ports/ports_test.go
Normal file
72
server/ports/ports_test.go
Normal file
@@ -0,0 +1,72 @@
|
||||
package ports
|
||||
|
||||
import (
|
||||
"net"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
clocktesting "k8s.io/utils/clock/testing"
|
||||
|
||||
"github.com/fatedier/frp/pkg/config/types"
|
||||
)
|
||||
|
||||
func TestManagerUsesClockForPortTimestamps(t *testing.T) {
|
||||
require := require.New(t)
|
||||
|
||||
port := freeTCPPort(t)
|
||||
start := time.Date(2026, time.May, 8, 12, 30, 0, 0, time.UTC)
|
||||
clk := clocktesting.NewFakeClock(start)
|
||||
pm := newManagerWithClock("tcp", "127.0.0.1", []types.PortsRange{{Single: port}}, clk)
|
||||
|
||||
realPort, err := pm.Acquire("proxy", port)
|
||||
require.NoError(err)
|
||||
require.Equal(port, realPort)
|
||||
require.Equal(start, pm.usedPorts[port].UpdateTime)
|
||||
|
||||
releasedAt := start.Add(time.Minute)
|
||||
clk.SetTime(releasedAt)
|
||||
pm.Release(port)
|
||||
|
||||
require.Equal(releasedAt, pm.reservedPorts["proxy"].UpdateTime)
|
||||
}
|
||||
|
||||
func TestManagerCleanReservedPortsWorkerUsesClockTicker(t *testing.T) {
|
||||
require := require.New(t)
|
||||
|
||||
port := freeTCPPort(t)
|
||||
start := time.Date(2026, time.May, 8, 12, 30, 0, 0, time.UTC)
|
||||
clk := clocktesting.NewFakeClock(start)
|
||||
pm := newManagerWithClock("tcp", "127.0.0.1", []types.PortsRange{{Single: port}}, clk)
|
||||
|
||||
realPort, err := pm.Acquire("proxy", port)
|
||||
require.NoError(err)
|
||||
require.Equal(port, realPort)
|
||||
pm.Release(port)
|
||||
require.True(pm.hasReservedPort("proxy"))
|
||||
|
||||
require.Eventually(clk.HasWaiters, time.Second, time.Millisecond)
|
||||
clk.Step(MaxPortReservedDuration + CleanReservedPortsInterval + time.Minute)
|
||||
|
||||
require.Eventually(func() bool {
|
||||
return !pm.hasReservedPort("proxy")
|
||||
}, time.Second, time.Millisecond)
|
||||
}
|
||||
|
||||
func (pm *Manager) hasReservedPort(name string) bool {
|
||||
pm.mu.Lock()
|
||||
defer pm.mu.Unlock()
|
||||
|
||||
_, ok := pm.reservedPorts[name]
|
||||
return ok
|
||||
}
|
||||
|
||||
func freeTCPPort(t *testing.T) int {
|
||||
t.Helper()
|
||||
|
||||
listener, err := net.Listen("tcp", "127.0.0.1:0")
|
||||
require.NoError(t, err)
|
||||
defer listener.Close()
|
||||
|
||||
return listener.Addr().(*net.TCPAddr).Port
|
||||
}
|
||||
@@ -44,7 +44,26 @@ func RegisterProxyFactory(proxyConfType reflect.Type, factory func(*BaseProxy) P
|
||||
proxyFactoryRegistry[proxyConfType] = factory
|
||||
}
|
||||
|
||||
type GetWorkConnFn func() (net.Conn, error)
|
||||
type WorkConn struct {
|
||||
conn *msg.Conn
|
||||
}
|
||||
|
||||
func NewWorkConn(conn *msg.Conn) *WorkConn {
|
||||
return &WorkConn{conn: conn}
|
||||
}
|
||||
|
||||
func (c *WorkConn) Start(m *msg.StartWorkConn) (net.Conn, error) {
|
||||
if err := c.conn.WriteMsg(m); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return c.conn, nil
|
||||
}
|
||||
|
||||
func (c *WorkConn) Close() error {
|
||||
return c.conn.Close()
|
||||
}
|
||||
|
||||
type GetWorkConnFn func() (*WorkConn, error)
|
||||
|
||||
type Proxy interface {
|
||||
Context() context.Context
|
||||
@@ -125,13 +144,13 @@ func (pxy *BaseProxy) GetWorkConnFromPool(src, dst net.Addr) (workConn net.Conn,
|
||||
xl := xlog.FromContextSafe(pxy.ctx)
|
||||
// try all connections from the pool
|
||||
for i := 0; i < pxy.poolCount+1; i++ {
|
||||
if workConn, err = pxy.getWorkConnFn(); err != nil {
|
||||
var pxyWorkConn *WorkConn
|
||||
if pxyWorkConn, err = pxy.getWorkConnFn(); err != nil {
|
||||
xl.Warnf("failed to get work connection: %v", err)
|
||||
return
|
||||
}
|
||||
xl.Debugf("get a new work connection: [%s]", workConn.RemoteAddr().String())
|
||||
xl.Debugf("get a new work connection: [%s]", pxyWorkConn.conn.RemoteAddr().String())
|
||||
xl.Spawn().AppendPrefix(pxy.GetName())
|
||||
workConn = netpkg.NewContextConn(pxy.ctx, workConn)
|
||||
|
||||
var (
|
||||
srcAddr string
|
||||
@@ -150,7 +169,7 @@ func (pxy *BaseProxy) GetWorkConnFromPool(src, dst net.Addr) (workConn net.Conn,
|
||||
dstAddr, dstPortStr, _ = net.SplitHostPort(dst.String())
|
||||
dstPort, _ = strconv.ParseUint(dstPortStr, 10, 16)
|
||||
}
|
||||
err = msg.WriteMsg(workConn, &msg.StartWorkConn{
|
||||
workConn, err = pxyWorkConn.Start(&msg.StartWorkConn{
|
||||
ProxyName: pxy.GetName(),
|
||||
SrcAddr: srcAddr,
|
||||
SrcPort: uint16(srcPort),
|
||||
@@ -160,9 +179,10 @@ func (pxy *BaseProxy) GetWorkConnFromPool(src, dst net.Addr) (workConn net.Conn,
|
||||
})
|
||||
if err != nil {
|
||||
xl.Warnf("failed to send message to work connection from pool: %v, times: %d", err, i)
|
||||
workConn.Close()
|
||||
pxyWorkConn.Close()
|
||||
workConn = nil
|
||||
} else {
|
||||
workConn = netpkg.NewContextConn(pxy.ctx, workConn)
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
53
server/proxy/proxy_test.go
Normal file
53
server/proxy/proxy_test.go
Normal file
@@ -0,0 +1,53 @@
|
||||
// Copyright 2026 The frp Authors
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package proxy
|
||||
|
||||
import (
|
||||
"net"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/fatedier/frp/pkg/msg"
|
||||
)
|
||||
|
||||
func TestWorkConnStartWritesStartWorkConn(t *testing.T) {
|
||||
client, server := net.Pipe()
|
||||
defer client.Close()
|
||||
defer server.Close()
|
||||
|
||||
serverMsgConn := msg.NewConn(server, msg.NewV2ReadWriter(server))
|
||||
clientMsgConn := msg.NewConn(client, msg.NewV2ReadWriter(client))
|
||||
workConn := NewWorkConn(serverMsgConn)
|
||||
|
||||
in := &msg.StartWorkConn{ProxyName: "tcp", SrcAddr: "127.0.0.1", SrcPort: 1234}
|
||||
type startResult struct {
|
||||
conn net.Conn
|
||||
err error
|
||||
}
|
||||
resultCh := make(chan startResult, 1)
|
||||
go func() {
|
||||
conn, err := workConn.Start(in)
|
||||
resultCh <- startResult{conn: conn, err: err}
|
||||
}()
|
||||
|
||||
out, err := clientMsgConn.ReadMsg()
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, in, out)
|
||||
|
||||
result := <-resultCh
|
||||
require.NoError(t, result.err)
|
||||
require.Same(t, serverMsgConn, result.conn)
|
||||
}
|
||||
@@ -18,6 +18,8 @@ import (
|
||||
"fmt"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"k8s.io/utils/clock"
|
||||
)
|
||||
|
||||
// ClientInfo captures metadata about a connected frpc instance.
|
||||
@@ -29,6 +31,7 @@ type ClientInfo struct {
|
||||
Hostname string
|
||||
IP string
|
||||
Version string
|
||||
WireProtocol string
|
||||
FirstConnectedAt time.Time
|
||||
LastConnectedAt time.Time
|
||||
DisconnectedAt time.Time
|
||||
@@ -41,17 +44,26 @@ type ClientRegistry struct {
|
||||
mu sync.RWMutex
|
||||
clients map[string]*ClientInfo
|
||||
runIndex map[string]string
|
||||
clock clock.PassiveClock
|
||||
}
|
||||
|
||||
func NewClientRegistry() *ClientRegistry {
|
||||
return newClientRegistryWithClock(clock.RealClock{})
|
||||
}
|
||||
|
||||
func newClientRegistryWithClock(clk clock.PassiveClock) *ClientRegistry {
|
||||
if clk == nil {
|
||||
clk = clock.RealClock{}
|
||||
}
|
||||
return &ClientRegistry{
|
||||
clients: make(map[string]*ClientInfo),
|
||||
runIndex: make(map[string]string),
|
||||
clock: clk,
|
||||
}
|
||||
}
|
||||
|
||||
// Register stores/updates metadata for a client and returns the registry key plus whether it conflicts with an online client.
|
||||
func (cr *ClientRegistry) Register(user, rawClientID, runID, hostname, version, remoteAddr string) (key string, conflict bool) {
|
||||
func (cr *ClientRegistry) Register(user, rawClientID, runID, hostname, version, remoteAddr, wireProtocol string) (key string, conflict bool) {
|
||||
if runID == "" {
|
||||
return "", false
|
||||
}
|
||||
@@ -63,7 +75,7 @@ func (cr *ClientRegistry) Register(user, rawClientID, runID, hostname, version,
|
||||
key = cr.composeClientKey(user, effectiveID)
|
||||
enforceUnique := rawClientID != ""
|
||||
|
||||
now := time.Now()
|
||||
now := cr.clock.Now()
|
||||
cr.mu.Lock()
|
||||
defer cr.mu.Unlock()
|
||||
|
||||
@@ -88,6 +100,7 @@ func (cr *ClientRegistry) Register(user, rawClientID, runID, hostname, version,
|
||||
info.Hostname = hostname
|
||||
info.IP = remoteAddr
|
||||
info.Version = version
|
||||
info.WireProtocol = wireProtocol
|
||||
if info.FirstConnectedAt.IsZero() {
|
||||
info.FirstConnectedAt = now
|
||||
}
|
||||
@@ -114,7 +127,7 @@ func (cr *ClientRegistry) MarkOfflineByRunID(runID string) {
|
||||
} else {
|
||||
info.RunID = ""
|
||||
info.Online = false
|
||||
now := time.Now()
|
||||
now := cr.clock.Now()
|
||||
info.DisconnectedAt = now
|
||||
}
|
||||
}
|
||||
|
||||
74
server/registry/registry_test.go
Normal file
74
server/registry/registry_test.go
Normal file
@@ -0,0 +1,74 @@
|
||||
// Copyright 2026 The frp Authors
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package registry
|
||||
|
||||
import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
clocktesting "k8s.io/utils/clock/testing"
|
||||
|
||||
"github.com/fatedier/frp/pkg/proto/wire"
|
||||
)
|
||||
|
||||
func TestClientRegistryRegisterStoresWireProtocol(t *testing.T) {
|
||||
registry := NewClientRegistry()
|
||||
key, conflict := registry.Register("user", "client-id", "run-id", "host", "1.0.0", "127.0.0.1", wire.ProtocolV2)
|
||||
if conflict {
|
||||
t.Fatal("unexpected client conflict")
|
||||
}
|
||||
|
||||
info, ok := registry.GetByKey(key)
|
||||
if !ok {
|
||||
t.Fatalf("client %q not found", key)
|
||||
}
|
||||
if info.WireProtocol != wire.ProtocolV2 {
|
||||
t.Fatalf("wire protocol mismatch, want %q got %q", wire.ProtocolV2, info.WireProtocol)
|
||||
}
|
||||
}
|
||||
|
||||
func TestClientRegistryUsesClockForTimestamps(t *testing.T) {
|
||||
start := time.Date(2026, time.May, 8, 12, 30, 0, 0, time.UTC)
|
||||
clk := clocktesting.NewFakeClock(start)
|
||||
registry := newClientRegistryWithClock(clk)
|
||||
|
||||
key, conflict := registry.Register("user", "client-id", "run-id", "host", "1.0.0", "127.0.0.1", wire.ProtocolV2)
|
||||
if conflict {
|
||||
t.Fatal("unexpected client conflict")
|
||||
}
|
||||
|
||||
info, ok := registry.GetByKey(key)
|
||||
if !ok {
|
||||
t.Fatalf("client %q not found", key)
|
||||
}
|
||||
if !info.FirstConnectedAt.Equal(start) {
|
||||
t.Fatalf("first connected time mismatch, want %s got %s", start, info.FirstConnectedAt)
|
||||
}
|
||||
if !info.LastConnectedAt.Equal(start) {
|
||||
t.Fatalf("last connected time mismatch, want %s got %s", start, info.LastConnectedAt)
|
||||
}
|
||||
|
||||
disconnectedAt := start.Add(time.Minute)
|
||||
clk.SetTime(disconnectedAt)
|
||||
registry.MarkOfflineByRunID("run-id")
|
||||
|
||||
info, ok = registry.GetByKey(key)
|
||||
if !ok {
|
||||
t.Fatalf("client %q not found after disconnect", key)
|
||||
}
|
||||
if !info.DisconnectedAt.Equal(disconnectedAt) {
|
||||
t.Fatalf("disconnected time mismatch, want %s got %s", disconnectedAt, info.DisconnectedAt)
|
||||
}
|
||||
}
|
||||
@@ -19,6 +19,7 @@ import (
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"net/http"
|
||||
"os"
|
||||
@@ -37,6 +38,7 @@ import (
|
||||
"github.com/fatedier/frp/pkg/msg"
|
||||
"github.com/fatedier/frp/pkg/nathole"
|
||||
plugin "github.com/fatedier/frp/pkg/plugin/server"
|
||||
"github.com/fatedier/frp/pkg/proto/wire"
|
||||
"github.com/fatedier/frp/pkg/ssh"
|
||||
"github.com/fatedier/frp/pkg/transport"
|
||||
httppkg "github.com/fatedier/frp/pkg/util/http"
|
||||
@@ -58,6 +60,7 @@ import (
|
||||
|
||||
const (
|
||||
connReadTimeout time.Duration = 10 * time.Second
|
||||
connWriteTimeout time.Duration = 5 * time.Second
|
||||
vhostReadWriteTimeout time.Duration = 30 * time.Second
|
||||
)
|
||||
|
||||
@@ -432,20 +435,15 @@ func (svr *Service) Close() error {
|
||||
func (svr *Service) handleConnection(ctx context.Context, conn net.Conn, internal bool) {
|
||||
xl := xlog.FromContextSafe(ctx)
|
||||
|
||||
var (
|
||||
rawMsg msg.Message
|
||||
err error
|
||||
)
|
||||
|
||||
_ = conn.SetReadDeadline(time.Now().Add(connReadTimeout))
|
||||
if rawMsg, err = msg.ReadMsg(conn); err != nil {
|
||||
log.Tracef("failed to read message: %v", err)
|
||||
acceptedConn, err := svr.acceptConnection(ctx, conn)
|
||||
if err != nil {
|
||||
log.Tracef("failed to accept frp connection: %v", err)
|
||||
conn.Close()
|
||||
return
|
||||
}
|
||||
_ = conn.SetReadDeadline(time.Time{})
|
||||
conn = acceptedConn.conn
|
||||
|
||||
switch m := rawMsg.(type) {
|
||||
switch m := acceptedConn.firstMsg.(type) {
|
||||
case *msg.Login:
|
||||
// server plugin hook
|
||||
content := &plugin.LoginContent{
|
||||
@@ -453,35 +451,72 @@ func (svr *Service) handleConnection(ctx context.Context, conn net.Conn, interna
|
||||
ClientAddress: conn.RemoteAddr().String(),
|
||||
}
|
||||
retContent, err := svr.pluginManager.Login(content)
|
||||
var ctl *Control
|
||||
if err == nil {
|
||||
m = &retContent.Login
|
||||
err = svr.RegisterControl(conn, m, internal)
|
||||
controlConn := acceptedConn.conn
|
||||
if !internal {
|
||||
var controlRW io.ReadWriter
|
||||
controlRW, err = acceptedConn.newControlReadWriter(conn, svr.auth.EncryptionKey())
|
||||
if err == nil {
|
||||
controlConn = acceptedConn.messageConnFor(controlRW)
|
||||
}
|
||||
}
|
||||
if err == nil {
|
||||
ctl, err = svr.RegisterControl(controlConn, m, internal, acceptedConn.wireProtocol)
|
||||
}
|
||||
}
|
||||
|
||||
// If login failed, send error message there.
|
||||
// Otherwise send success message in control's work goroutine.
|
||||
if err != nil {
|
||||
xl.Warnf("register control error: %v", err)
|
||||
_ = msg.WriteMsg(conn, &msg.LoginResp{
|
||||
Version: version.Full(),
|
||||
Error: util.GenerateResponseErrorString("register control error", err, lo.FromPtr(svr.cfg.DetailedErrorsToClient)),
|
||||
})
|
||||
if writeErr := writeWithDeadline(conn, connWriteTimeout, func() error {
|
||||
return acceptedConn.conn.WriteMsg(&msg.LoginResp{
|
||||
Version: version.Full(),
|
||||
Error: util.GenerateResponseErrorString("register control error", err, lo.FromPtr(svr.cfg.DetailedErrorsToClient)),
|
||||
})
|
||||
}); writeErr != nil {
|
||||
xl.Warnf("write login error response error: %v", writeErr)
|
||||
}
|
||||
conn.Close()
|
||||
return
|
||||
}
|
||||
if err = writeWithDeadline(conn, connWriteTimeout, func() error {
|
||||
return acceptedConn.conn.WriteMsg(&msg.LoginResp{
|
||||
Version: version.Full(),
|
||||
RunID: ctl.runID,
|
||||
Error: "",
|
||||
})
|
||||
}); err != nil {
|
||||
xl.Warnf("write login response error: %v", err)
|
||||
svr.ctlManager.Del(m.RunID, ctl)
|
||||
svr.clientRegistry.MarkOfflineByRunID(m.RunID)
|
||||
conn.Close()
|
||||
return
|
||||
}
|
||||
ctl.Start()
|
||||
metrics.Server.NewClient()
|
||||
go func() {
|
||||
// block until control closed
|
||||
ctl.WaitClosed()
|
||||
svr.ctlManager.Del(m.RunID, ctl)
|
||||
}()
|
||||
case *msg.NewWorkConn:
|
||||
if err := svr.RegisterWorkConn(conn, m); err != nil {
|
||||
if err := svr.RegisterWorkConn(acceptedConn.conn, m); err != nil {
|
||||
_ = acceptedConn.conn.WriteMsg(&msg.StartWorkConn{
|
||||
Error: util.GenerateResponseErrorString("invalid NewWorkConn", err, lo.FromPtr(svr.cfg.DetailedErrorsToClient)),
|
||||
})
|
||||
conn.Close()
|
||||
}
|
||||
case *msg.NewVisitorConn:
|
||||
if err = svr.RegisterVisitorConn(conn, m); err != nil {
|
||||
xl.Warnf("register visitor conn error: %v", err)
|
||||
_ = msg.WriteMsg(conn, &msg.NewVisitorConnResp{
|
||||
_ = acceptedConn.conn.WriteMsg(&msg.NewVisitorConnResp{
|
||||
ProxyName: m.ProxyName,
|
||||
Error: util.GenerateResponseErrorString("register visitor conn error", err, lo.FromPtr(svr.cfg.DetailedErrorsToClient)),
|
||||
})
|
||||
conn.Close()
|
||||
} else {
|
||||
_ = msg.WriteMsg(conn, &msg.NewVisitorConnResp{
|
||||
_ = acceptedConn.conn.WriteMsg(&msg.NewVisitorConnResp{
|
||||
ProxyName: m.ProxyName,
|
||||
Error: "",
|
||||
})
|
||||
@@ -492,6 +527,129 @@ func (svr *Service) handleConnection(ctx context.Context, conn net.Conn, interna
|
||||
}
|
||||
}
|
||||
|
||||
type acceptedConnection struct {
|
||||
conn *msg.Conn
|
||||
wireProtocol string
|
||||
cryptoContext *wire.CryptoContext
|
||||
firstMsg msg.Message
|
||||
}
|
||||
|
||||
func (svr *Service) acceptConnection(ctx context.Context, conn net.Conn) (*acceptedConnection, error) {
|
||||
_ = conn.SetReadDeadline(time.Now().Add(connReadTimeout))
|
||||
checkedConn, isV2, err := wire.CheckMagic(conn)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("read wire protocol magic: %w", err)
|
||||
}
|
||||
|
||||
wireProtocol := wire.ProtocolV1
|
||||
if isV2 {
|
||||
wireProtocol = wire.ProtocolV2
|
||||
}
|
||||
|
||||
conn = netpkg.NewContextConn(ctx, checkedConn)
|
||||
acceptedConn := &acceptedConnection{wireProtocol: wireProtocol}
|
||||
if isV2 {
|
||||
wireConn := wire.NewConn(conn)
|
||||
rw := msg.NewV2ReadWriterWithConn(wireConn)
|
||||
acceptedConn.conn = msg.NewConn(conn, rw)
|
||||
acceptedConn.firstMsg, err = acceptedConn.readFirstV2Msg(conn, wireConn)
|
||||
} else {
|
||||
rw := msg.NewV1ReadWriter(conn)
|
||||
acceptedConn.conn = msg.NewConn(conn, rw)
|
||||
acceptedConn.firstMsg, err = acceptedConn.conn.ReadMsg()
|
||||
}
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
_ = conn.SetReadDeadline(time.Time{})
|
||||
return acceptedConn, nil
|
||||
}
|
||||
|
||||
func writeWithDeadline(conn net.Conn, timeout time.Duration, writeFn func() error) error {
|
||||
_ = conn.SetWriteDeadline(time.Now().Add(timeout))
|
||||
defer func() {
|
||||
_ = conn.SetWriteDeadline(time.Time{})
|
||||
}()
|
||||
return writeFn()
|
||||
}
|
||||
|
||||
func (ac *acceptedConnection) messageConnFor(rw io.ReadWriter) *msg.Conn {
|
||||
return msg.NewConn(ac.conn, msg.NewReadWriter(rw, ac.wireProtocol))
|
||||
}
|
||||
|
||||
func (ac *acceptedConnection) newControlReadWriter(rw io.ReadWriter, key []byte) (io.ReadWriter, error) {
|
||||
if ac.wireProtocol == wire.ProtocolV2 {
|
||||
if ac.cryptoContext == nil {
|
||||
return nil, fmt.Errorf("missing v2 crypto negotiation")
|
||||
}
|
||||
return netpkg.NewAEADCryptoReadWriter(
|
||||
rw,
|
||||
key,
|
||||
netpkg.AEADCryptoRoleServer,
|
||||
ac.cryptoContext.Algorithm,
|
||||
ac.cryptoContext.TranscriptHash,
|
||||
)
|
||||
}
|
||||
return netpkg.NewCryptoReadWriter(rw, key)
|
||||
}
|
||||
|
||||
func (ac *acceptedConnection) readFirstV2Msg(conn net.Conn, wireConn *wire.Conn) (msg.Message, error) {
|
||||
frame, err := wireConn.ReadFrame()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("read v2 frame: %w", err)
|
||||
}
|
||||
if frame.Type == wire.FrameTypeClientHello {
|
||||
if err := ac.handleClientHello(conn, wireConn, frame); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
frame, err = wireConn.ReadFrame()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("read first v2 message frame: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
m, err := msg.DecodeV2MessageFrame(frame)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("decode v2 message: %w", err)
|
||||
}
|
||||
return m, nil
|
||||
}
|
||||
|
||||
func (ac *acceptedConnection) handleClientHello(conn net.Conn, wireConn *wire.Conn, frame *wire.Frame) error {
|
||||
var hello wire.ClientHello
|
||||
if err := wireConn.UnmarshalFrame(frame, &hello); err != nil {
|
||||
return fmt.Errorf("decode ClientHello: %w", err)
|
||||
}
|
||||
|
||||
serverHello, err := wire.NewServerHello(hello)
|
||||
if err != nil {
|
||||
serverHello = wire.DefaultServerHello()
|
||||
serverHello.Error = err.Error()
|
||||
if writeErr := writeWithDeadline(conn, connWriteTimeout, func() error {
|
||||
return wireConn.WriteJSONFrame(wire.FrameTypeServerHello, serverHello)
|
||||
}); writeErr != nil {
|
||||
return fmt.Errorf("%w; write ServerHello error: %v", err, writeErr)
|
||||
}
|
||||
return err
|
||||
}
|
||||
serverHelloFrame, err := wire.NewJSONFrame(wire.FrameTypeServerHello, serverHello)
|
||||
if err != nil {
|
||||
return fmt.Errorf("encode ServerHello: %w", err)
|
||||
}
|
||||
cryptoContext := wire.NewCryptoContext(
|
||||
serverHello.Selected.Crypto.Algorithm,
|
||||
frame.Payload,
|
||||
serverHelloFrame.Payload,
|
||||
)
|
||||
if err := writeWithDeadline(conn, connWriteTimeout, func() error {
|
||||
return wireConn.WriteFrame(serverHelloFrame)
|
||||
}); err != nil {
|
||||
return fmt.Errorf("write ServerHello: %w", err)
|
||||
}
|
||||
ac.cryptoContext = cryptoContext
|
||||
return nil
|
||||
}
|
||||
|
||||
// HandleListener accepts connections from client and call handleConnection to handle them.
|
||||
// If internal is true, it means that this listener is used for internal communication like ssh tunnel gateway.
|
||||
// TODO(fatedier): Pass some parameters of listener/connection through context to avoid passing too many parameters.
|
||||
@@ -577,14 +735,19 @@ func (svr *Service) HandleQUICListener(l *quic.Listener) {
|
||||
}
|
||||
}
|
||||
|
||||
func (svr *Service) RegisterControl(ctlConn net.Conn, loginMsg *msg.Login, internal bool) error {
|
||||
func (svr *Service) RegisterControl(
|
||||
ctlConn *msg.Conn,
|
||||
loginMsg *msg.Login,
|
||||
internal bool,
|
||||
wireProtocol string,
|
||||
) (*Control, error) {
|
||||
// If client's RunID is empty, it's a new client, we just create a new controller.
|
||||
// Otherwise, we check if there is one controller has the same run id. If so, we release previous controller and start new one.
|
||||
var err error
|
||||
if loginMsg.RunID == "" {
|
||||
loginMsg.RunID, err = util.RandID()
|
||||
if err != nil {
|
||||
return err
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
@@ -601,7 +764,7 @@ func (svr *Service) RegisterControl(ctlConn net.Conn, loginMsg *msg.Login, inter
|
||||
authVerifier = auth.AlwaysPassVerifier
|
||||
}
|
||||
if err := authVerifier.VerifyLogin(loginMsg); err != nil {
|
||||
return err
|
||||
return nil, err
|
||||
}
|
||||
|
||||
ctl, err := NewControl(ctx, &SessionContext{
|
||||
@@ -611,7 +774,6 @@ func (svr *Service) RegisterControl(ctlConn net.Conn, loginMsg *msg.Login, inter
|
||||
AuthVerifier: authVerifier,
|
||||
EncryptionKey: svr.auth.EncryptionKey(),
|
||||
Conn: ctlConn,
|
||||
ConnEncrypted: !internal,
|
||||
LoginMsg: loginMsg,
|
||||
ServerCfg: svr.cfg,
|
||||
ClientRegistry: svr.clientRegistry,
|
||||
@@ -619,7 +781,7 @@ func (svr *Service) RegisterControl(ctlConn net.Conn, loginMsg *msg.Login, inter
|
||||
if err != nil {
|
||||
xl.Warnf("create new controller error: %v", err)
|
||||
// don't return detailed errors to client
|
||||
return fmt.Errorf("unexpected error when creating new controller")
|
||||
return nil, fmt.Errorf("unexpected error when creating new controller")
|
||||
}
|
||||
|
||||
if oldCtl := svr.ctlManager.Add(loginMsg.RunID, ctl); oldCtl != nil {
|
||||
@@ -630,34 +792,24 @@ func (svr *Service) RegisterControl(ctlConn net.Conn, loginMsg *msg.Login, inter
|
||||
if host, _, err := net.SplitHostPort(remoteAddr); err == nil {
|
||||
remoteAddr = host
|
||||
}
|
||||
_, conflict := svr.clientRegistry.Register(loginMsg.User, loginMsg.ClientID, loginMsg.RunID, loginMsg.Hostname, loginMsg.Version, remoteAddr)
|
||||
_, conflict := svr.clientRegistry.Register(loginMsg.User, loginMsg.ClientID, loginMsg.RunID, loginMsg.Hostname, loginMsg.Version, remoteAddr, wireProtocol)
|
||||
if conflict {
|
||||
svr.ctlManager.Del(loginMsg.RunID, ctl)
|
||||
ctl.Close()
|
||||
return fmt.Errorf("client_id [%s] for user [%s] is already online", loginMsg.ClientID, loginMsg.User)
|
||||
return nil, fmt.Errorf("client_id [%s] for user [%s] is already online", loginMsg.ClientID, loginMsg.User)
|
||||
}
|
||||
|
||||
ctl.Start()
|
||||
|
||||
// for statistics
|
||||
metrics.Server.NewClient()
|
||||
|
||||
go func() {
|
||||
// block until control closed
|
||||
ctl.WaitClosed()
|
||||
svr.ctlManager.Del(loginMsg.RunID, ctl)
|
||||
}()
|
||||
return nil
|
||||
return ctl, nil
|
||||
}
|
||||
|
||||
// RegisterWorkConn register a new work connection to control and proxies need it.
|
||||
func (svr *Service) RegisterWorkConn(workConn net.Conn, newMsg *msg.NewWorkConn) error {
|
||||
func (svr *Service) RegisterWorkConn(workConn *msg.Conn, newMsg *msg.NewWorkConn) error {
|
||||
xl := netpkg.NewLogFromConn(workConn)
|
||||
ctl, exist := svr.ctlManager.GetByID(newMsg.RunID)
|
||||
if !exist {
|
||||
xl.Warnf("no client control found for run id [%s]", newMsg.RunID)
|
||||
return fmt.Errorf("no client control found for run id [%s]", newMsg.RunID)
|
||||
}
|
||||
|
||||
// server plugin hook
|
||||
content := &plugin.NewWorkConnContent{
|
||||
User: plugin.UserInfo{
|
||||
@@ -675,12 +827,9 @@ func (svr *Service) RegisterWorkConn(workConn net.Conn, newMsg *msg.NewWorkConn)
|
||||
}
|
||||
if err != nil {
|
||||
xl.Warnf("invalid NewWorkConn with run id [%s]", newMsg.RunID)
|
||||
_ = msg.WriteMsg(workConn, &msg.StartWorkConn{
|
||||
Error: util.GenerateResponseErrorString("invalid NewWorkConn", err, lo.FromPtr(svr.cfg.DetailedErrorsToClient)),
|
||||
})
|
||||
return fmt.Errorf("invalid NewWorkConn with run id [%s]", newMsg.RunID)
|
||||
return err
|
||||
}
|
||||
return ctl.RegisterWorkConn(workConn)
|
||||
return ctl.RegisterWorkConn(proxy.NewWorkConn(workConn))
|
||||
}
|
||||
|
||||
func (svr *Service) RegisterVisitorConn(visitorConn net.Conn, newMsg *msg.NewVisitorConn) error {
|
||||
|
||||
63
server/service_test.go
Normal file
63
server/service_test.go
Normal file
@@ -0,0 +1,63 @@
|
||||
// Copyright 2026 The frp Authors
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package server
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"net"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestWriteWithDeadlineTimesOutAndClearsDeadline(t *testing.T) {
|
||||
serverConn, clientConn := net.Pipe()
|
||||
defer serverConn.Close()
|
||||
defer clientConn.Close()
|
||||
|
||||
err := writeWithDeadline(serverConn, 50*time.Millisecond, func() error {
|
||||
_, writeErr := serverConn.Write([]byte("x"))
|
||||
return writeErr
|
||||
})
|
||||
require.Error(t, err)
|
||||
|
||||
var netErr net.Error
|
||||
require.True(t, errors.As(err, &netErr))
|
||||
require.True(t, netErr.Timeout())
|
||||
|
||||
readCh := make(chan byte, 1)
|
||||
errCh := make(chan error, 1)
|
||||
go func() {
|
||||
buf := make([]byte, 1)
|
||||
if _, readErr := clientConn.Read(buf); readErr != nil {
|
||||
errCh <- readErr
|
||||
return
|
||||
}
|
||||
readCh <- buf[0]
|
||||
}()
|
||||
|
||||
_, err = serverConn.Write([]byte("y"))
|
||||
require.NoError(t, err)
|
||||
|
||||
select {
|
||||
case b := <-readCh:
|
||||
require.Equal(t, byte('y'), b)
|
||||
case err := <-errCh:
|
||||
require.NoError(t, err)
|
||||
case <-time.After(time.Second):
|
||||
t.Fatal("timed out waiting for write after deadline reset")
|
||||
}
|
||||
}
|
||||
236
test/e2e/compatibility/compatibility_test.go
Normal file
236
test/e2e/compatibility/compatibility_test.go
Normal file
@@ -0,0 +1,236 @@
|
||||
package compatibility
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"flag"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"os"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/onsi/ginkgo/v2"
|
||||
"github.com/onsi/gomega"
|
||||
|
||||
"github.com/fatedier/frp/pkg/util/log"
|
||||
"github.com/fatedier/frp/test/e2e/framework"
|
||||
"github.com/fatedier/frp/test/e2e/framework/consts"
|
||||
"github.com/fatedier/frp/test/e2e/pkg/port"
|
||||
"github.com/fatedier/frp/test/e2e/pkg/process"
|
||||
)
|
||||
|
||||
type compatTestContext struct {
|
||||
CurrentFRPSPath string
|
||||
CurrentFRPCPath string
|
||||
BaselineFRPSPath string
|
||||
BaselineFRPCPath string
|
||||
BaselineVersion string
|
||||
LogLevel string
|
||||
Debug bool
|
||||
RunCurrentCurrent bool
|
||||
}
|
||||
|
||||
var compatCtx compatTestContext
|
||||
|
||||
func registerFlags(flags *flag.FlagSet) {
|
||||
flags.StringVar(&compatCtx.CurrentFRPSPath, "current-frps-path", "../../bin/frps", "The current frps binary to use.")
|
||||
flags.StringVar(&compatCtx.CurrentFRPCPath, "current-frpc-path", "../../bin/frpc", "The current frpc binary to use.")
|
||||
flags.StringVar(&compatCtx.BaselineFRPSPath, "baseline-frps-path", "", "The baseline frps binary to use.")
|
||||
flags.StringVar(&compatCtx.BaselineFRPCPath, "baseline-frpc-path", "", "The baseline frpc binary to use.")
|
||||
flags.StringVar(&compatCtx.BaselineVersion, "baseline-version", "custom", "The baseline version label for reporting.")
|
||||
flags.StringVar(&compatCtx.LogLevel, "log-level", "debug", "Log level.")
|
||||
flags.BoolVar(&compatCtx.Debug, "debug", false, "Enable debug mode to print detailed info.")
|
||||
flags.BoolVar(&compatCtx.RunCurrentCurrent, "run-current-current", true, "Run current frps/current frpc sanity checks.")
|
||||
}
|
||||
|
||||
func validateCompatContext(t *compatTestContext) error {
|
||||
paths := map[string]string{
|
||||
"current-frps-path": t.CurrentFRPSPath,
|
||||
"current-frpc-path": t.CurrentFRPCPath,
|
||||
"baseline-frps-path": t.BaselineFRPSPath,
|
||||
"baseline-frpc-path": t.BaselineFRPCPath,
|
||||
}
|
||||
for name, path := range paths {
|
||||
if path == "" {
|
||||
return fmt.Errorf("%s can't be empty", name)
|
||||
}
|
||||
if _, err := os.Stat(path); err != nil {
|
||||
return fmt.Errorf("load %s error: %v", name, err)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func TestMain(m *testing.M) {
|
||||
registerFlags(flag.CommandLine)
|
||||
flag.Parse()
|
||||
|
||||
if err := validateCompatContext(&compatCtx); err != nil {
|
||||
fmt.Println(err)
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
framework.TestContext.Debug = compatCtx.Debug
|
||||
framework.TestContext.LogLevel = compatCtx.LogLevel
|
||||
log.InitLogger("console", compatCtx.LogLevel, 0, true)
|
||||
os.Exit(m.Run())
|
||||
}
|
||||
|
||||
func TestCompatibilityE2E(t *testing.T) {
|
||||
gomega.RegisterFailHandler(framework.Fail)
|
||||
|
||||
suiteConfig, reporterConfig := ginkgo.GinkgoConfiguration()
|
||||
suiteConfig.EmitSpecProgress = true
|
||||
suiteConfig.RandomizeAllSpecs = true
|
||||
ginkgo.RunSpecs(t, "frp compatibility e2e suite", suiteConfig, reporterConfig)
|
||||
}
|
||||
|
||||
var _ = ginkgo.Describe("[Compatibility: WireProtocol]", func() {
|
||||
f := framework.NewDefaultFramework()
|
||||
|
||||
ginkgo.It("current frps and current frpc support v1 and v2", func() {
|
||||
if !compatCtx.RunCurrentCurrent {
|
||||
ginkgo.Skip("current/current sanity checks already ran")
|
||||
}
|
||||
|
||||
webPort := f.AllocPort()
|
||||
serverConf := consts.DefaultServerConfig + fmt.Sprintf(`
|
||||
webServer.port = %d
|
||||
`, webPort)
|
||||
|
||||
v1PortName := port.GenName("CompatCurrentV1")
|
||||
v1ClientConf := tcpClientConfig("compat-current-v1", v1PortName, `
|
||||
clientID = "compat-current-v1"
|
||||
transport.wireProtocol = "v1"
|
||||
`)
|
||||
|
||||
v2PortName := port.GenName("CompatCurrentV2")
|
||||
v2ClientConf := tcpClientConfig("compat-current-v2", v2PortName, `
|
||||
clientID = "compat-current-v2"
|
||||
transport.wireProtocol = "v2"
|
||||
`)
|
||||
|
||||
f.RunProcessesWithBinaries(
|
||||
compatCtx.CurrentFRPSPath,
|
||||
compatCtx.CurrentFRPCPath,
|
||||
serverConf,
|
||||
[]string{v1ClientConf, v2ClientConf},
|
||||
)
|
||||
|
||||
framework.NewRequestExpect(f).PortName(v1PortName).Ensure()
|
||||
framework.NewRequestExpect(f).PortName(v2PortName).Ensure()
|
||||
expectClientWireProtocol(webPort, "compat-current-v1", "v1")
|
||||
expectClientWireProtocol(webPort, "compat-current-v2", "v2")
|
||||
})
|
||||
|
||||
ginkgo.It("current frps accepts baseline frpc using v1", func() {
|
||||
webPort := f.AllocPort()
|
||||
serverConf := consts.DefaultServerConfig + fmt.Sprintf(`
|
||||
webServer.port = %d
|
||||
`, webPort)
|
||||
|
||||
portName := port.GenName("CompatBaselineFRPC")
|
||||
clientConf := tcpClientConfig("tcp", portName, "")
|
||||
|
||||
f.RunProcessesWithBinaries(
|
||||
compatCtx.CurrentFRPSPath,
|
||||
compatCtx.BaselineFRPCPath,
|
||||
serverConf,
|
||||
[]string{clientConf},
|
||||
)
|
||||
|
||||
framework.NewRequestExpect(f).PortName(portName).Ensure()
|
||||
expectSingleClientWireProtocol(webPort, "v1")
|
||||
})
|
||||
|
||||
ginkgo.It("baseline frps accepts current frpc defaulting to v1", func() {
|
||||
portName := port.GenName("CompatBaselineFRPS")
|
||||
clientConf := tcpClientConfig("tcp", portName, "")
|
||||
|
||||
f.RunProcessesWithBinaries(
|
||||
compatCtx.BaselineFRPSPath,
|
||||
compatCtx.CurrentFRPCPath,
|
||||
consts.DefaultServerConfig,
|
||||
[]string{clientConf},
|
||||
)
|
||||
|
||||
framework.NewRequestExpect(f).PortName(portName).Ensure()
|
||||
})
|
||||
|
||||
ginkgo.It("baseline frps rejects current frpc forced to v2", func() {
|
||||
portName := port.GenName("CompatBaselineFRPSForcedV2")
|
||||
clientConf := tcpClientConfig("tcp", portName, `
|
||||
transport.wireProtocol = "v2"
|
||||
`)
|
||||
|
||||
_, clientProcesses := f.RunProcessesWithBinaries(
|
||||
compatCtx.BaselineFRPSPath,
|
||||
compatCtx.CurrentFRPCPath,
|
||||
consts.DefaultServerConfig,
|
||||
[]string{clientConf},
|
||||
)
|
||||
expectProcessExit(clientProcesses[0], 5*time.Second)
|
||||
framework.NewRequestExpect(f).PortName(portName).ExpectError(true).Ensure()
|
||||
})
|
||||
})
|
||||
|
||||
func tcpClientConfig(proxyName string, remotePortName string, extra string) string {
|
||||
return fmt.Sprintf(`
|
||||
serverAddr = "127.0.0.1"
|
||||
serverPort = {{ .%s }}
|
||||
loginFailExit = true
|
||||
log.level = "trace"
|
||||
%s
|
||||
|
||||
[[proxies]]
|
||||
name = "%s"
|
||||
type = "tcp"
|
||||
localPort = {{ .%s }}
|
||||
remotePort = {{ .%s }}
|
||||
`, consts.PortServerName, extra, proxyName, framework.TCPEchoServerPort, remotePortName)
|
||||
}
|
||||
|
||||
func expectProcessExit(p *process.Process, timeout time.Duration) {
|
||||
select {
|
||||
case <-p.Done():
|
||||
case <-time.After(timeout):
|
||||
framework.Failf("process did not exit within %s; output:\n%s", timeout, p.Output())
|
||||
}
|
||||
}
|
||||
|
||||
type wireClientInfo struct {
|
||||
ClientID string `json:"clientID"`
|
||||
WireProtocol string `json:"wireProtocol"`
|
||||
}
|
||||
|
||||
func expectSingleClientWireProtocol(webPort int, wireProtocol string) {
|
||||
clients := getWireClientInfos(webPort)
|
||||
framework.ExpectEqual(len(clients), 1)
|
||||
framework.ExpectEqual(clients[0].WireProtocol, wireProtocol)
|
||||
}
|
||||
|
||||
func expectClientWireProtocol(webPort int, clientID string, wireProtocol string) {
|
||||
for _, client := range getWireClientInfos(webPort) {
|
||||
if client.ClientID == clientID {
|
||||
framework.ExpectEqual(client.WireProtocol, wireProtocol)
|
||||
return
|
||||
}
|
||||
}
|
||||
framework.Failf("client %q not found in /api/clients response", clientID)
|
||||
}
|
||||
|
||||
func getWireClientInfos(webPort int) []wireClientInfo {
|
||||
client := http.Client{Timeout: consts.DefaultTimeout}
|
||||
resp, err := client.Get(fmt.Sprintf("http://127.0.0.1:%d/api/clients", webPort))
|
||||
framework.ExpectNoError(err)
|
||||
defer resp.Body.Close()
|
||||
framework.ExpectEqual(resp.StatusCode, http.StatusOK)
|
||||
|
||||
content, err := io.ReadAll(resp.Body)
|
||||
framework.ExpectNoError(err)
|
||||
|
||||
var clients []wireClientInfo
|
||||
framework.ExpectNoError(json.Unmarshal(content, &clients))
|
||||
return clients
|
||||
}
|
||||
@@ -1,62 +0,0 @@
|
||||
package framework
|
||||
|
||||
import (
|
||||
"sync"
|
||||
)
|
||||
|
||||
// CleanupActionHandle is an integer pointer type for handling cleanup action
|
||||
type CleanupActionHandle *int
|
||||
|
||||
type cleanupFuncHandle struct {
|
||||
actionHandle CleanupActionHandle
|
||||
actionHook func()
|
||||
}
|
||||
|
||||
var (
|
||||
cleanupActionsLock sync.Mutex
|
||||
cleanupHookList = []cleanupFuncHandle{}
|
||||
)
|
||||
|
||||
// AddCleanupAction installs a function that will be called in the event of the
|
||||
// whole test being terminated. This allows arbitrary pieces of the overall
|
||||
// test to hook into SynchronizedAfterSuite().
|
||||
// The hooks are called in last-in-first-out order.
|
||||
func AddCleanupAction(fn func()) CleanupActionHandle {
|
||||
p := CleanupActionHandle(new(int))
|
||||
cleanupActionsLock.Lock()
|
||||
defer cleanupActionsLock.Unlock()
|
||||
c := cleanupFuncHandle{actionHandle: p, actionHook: fn}
|
||||
cleanupHookList = append([]cleanupFuncHandle{c}, cleanupHookList...)
|
||||
return p
|
||||
}
|
||||
|
||||
// RemoveCleanupAction removes a function that was installed by
|
||||
// AddCleanupAction.
|
||||
func RemoveCleanupAction(p CleanupActionHandle) {
|
||||
cleanupActionsLock.Lock()
|
||||
defer cleanupActionsLock.Unlock()
|
||||
for i, item := range cleanupHookList {
|
||||
if item.actionHandle == p {
|
||||
cleanupHookList = append(cleanupHookList[:i], cleanupHookList[i+1:]...)
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// RunCleanupActions runs all functions installed by AddCleanupAction. It does
|
||||
// not remove them (see RemoveCleanupAction) but it does run unlocked, so they
|
||||
// may remove themselves.
|
||||
func RunCleanupActions() {
|
||||
list := []func(){}
|
||||
func() {
|
||||
cleanupActionsLock.Lock()
|
||||
defer cleanupActionsLock.Unlock()
|
||||
for _, p := range cleanupHookList {
|
||||
list = append(list, p.actionHook)
|
||||
}
|
||||
}()
|
||||
// Run unlocked.
|
||||
for _, fn := range list {
|
||||
fn()
|
||||
}
|
||||
}
|
||||
@@ -23,11 +23,6 @@ func ExpectNotEqual(actual any, extra any, explain ...any) {
|
||||
gomega.ExpectWithOffset(1, actual).NotTo(gomega.Equal(extra), explain...)
|
||||
}
|
||||
|
||||
// ExpectError expects an error happens, otherwise an exception raises
|
||||
func ExpectError(err error, explain ...any) {
|
||||
gomega.ExpectWithOffset(1, err).To(gomega.HaveOccurred(), explain...)
|
||||
}
|
||||
|
||||
func ExpectErrorWithOffset(offset int, err error, explain ...any) {
|
||||
gomega.ExpectWithOffset(1+offset, err).To(gomega.HaveOccurred(), explain...)
|
||||
}
|
||||
@@ -47,11 +42,6 @@ func ExpectContainSubstring(actual, substr string, explain ...any) {
|
||||
gomega.ExpectWithOffset(1, actual).To(gomega.ContainSubstring(substr), explain...)
|
||||
}
|
||||
|
||||
// ExpectConsistOf expects actual contains precisely the extra elements. The ordering of the elements does not matter.
|
||||
func ExpectConsistOf(actual any, extra any, explain ...any) {
|
||||
gomega.ExpectWithOffset(1, actual).To(gomega.ConsistOf(extra), explain...)
|
||||
}
|
||||
|
||||
func ExpectContainElements(actual any, extra any, explain ...any) {
|
||||
gomega.ExpectWithOffset(1, actual).To(gomega.ContainElements(extra), explain...)
|
||||
}
|
||||
@@ -60,16 +50,6 @@ func ExpectNotContainElements(actual any, extra any, explain ...any) {
|
||||
gomega.ExpectWithOffset(1, actual).NotTo(gomega.ContainElements(extra), explain...)
|
||||
}
|
||||
|
||||
// ExpectHaveKey expects the actual map has the key in the keyset
|
||||
func ExpectHaveKey(actual any, key any, explain ...any) {
|
||||
gomega.ExpectWithOffset(1, actual).To(gomega.HaveKey(key), explain...)
|
||||
}
|
||||
|
||||
// ExpectEmpty expects actual is empty
|
||||
func ExpectEmpty(actual any, explain ...any) {
|
||||
gomega.ExpectWithOffset(1, actual).To(gomega.BeEmpty(), explain...)
|
||||
}
|
||||
|
||||
func ExpectTrue(actual any, explain ...any) {
|
||||
gomega.ExpectWithOffset(1, actual).Should(gomega.BeTrue(), explain...)
|
||||
}
|
||||
|
||||
@@ -38,11 +38,6 @@ type Framework struct {
|
||||
// Multiple default mock servers used for e2e testing.
|
||||
mockServers *MockServers
|
||||
|
||||
// To make sure that this framework cleans up after itself, no matter what,
|
||||
// we install a Cleanup action before each test and clear it after. If we
|
||||
// should abort, the AfterSuite hook should run all Cleanup actions.
|
||||
cleanupHandle CleanupActionHandle
|
||||
|
||||
// beforeEachStarted indicates that BeforeEach has started
|
||||
beforeEachStarted bool
|
||||
|
||||
@@ -87,8 +82,6 @@ func NewFramework(opt Options) *Framework {
|
||||
func (f *Framework) BeforeEach() {
|
||||
f.beforeEachStarted = true
|
||||
|
||||
f.cleanupHandle = AddCleanupAction(f.AfterEach)
|
||||
|
||||
dir, err := os.MkdirTemp(os.TempDir(), "frp-e2e-test-*")
|
||||
ExpectNoError(err)
|
||||
f.TempDirectory = dir
|
||||
@@ -113,8 +106,6 @@ func (f *Framework) AfterEach() {
|
||||
return
|
||||
}
|
||||
|
||||
RemoveCleanupAction(f.cleanupHandle)
|
||||
|
||||
// stop processor
|
||||
for _, p := range f.serverProcesses {
|
||||
_ = p.Stop()
|
||||
@@ -266,10 +257,6 @@ func (f *Framework) AllocPortExcludingRanges(ranges ...[2]int) int {
|
||||
return 0
|
||||
}
|
||||
|
||||
func (f *Framework) ReleasePort(port int) {
|
||||
f.portAllocator.Release(port)
|
||||
}
|
||||
|
||||
func (f *Framework) RunServer(portName string, s server.Server) {
|
||||
f.servers = append(f.servers, s)
|
||||
if s.BindPort() > 0 && portName != "" {
|
||||
|
||||
@@ -75,11 +75,3 @@ func (m *MockServers) GetTemplateParams() map[string]any {
|
||||
ret[HTTPSimpleServerPort] = m.httpSimpleServer.BindPort()
|
||||
return ret
|
||||
}
|
||||
|
||||
func (m *MockServers) GetParam(key string) any {
|
||||
params := m.GetTemplateParams()
|
||||
if v, ok := params[key]; ok {
|
||||
return v
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -17,6 +17,16 @@ import (
|
||||
|
||||
// RunProcesses starts one frps and zero or more frpc processes from templates.
|
||||
func (f *Framework) RunProcesses(serverTemplate string, clientTemplates []string) (*process.Process, []*process.Process) {
|
||||
return f.RunProcessesWithBinaries(TestContext.FRPServerPath, TestContext.FRPClientPath, serverTemplate, clientTemplates)
|
||||
}
|
||||
|
||||
// RunProcessesWithBinaries starts one frps and zero or more frpc processes with explicit binary paths.
|
||||
func (f *Framework) RunProcessesWithBinaries(
|
||||
serverBinaryPath string,
|
||||
clientBinaryPath string,
|
||||
serverTemplate string,
|
||||
clientTemplates []string,
|
||||
) (*process.Process, []*process.Process) {
|
||||
templates := append([]string{serverTemplate}, clientTemplates...)
|
||||
outs, ports, err := f.RenderTemplates(templates)
|
||||
ExpectNoError(err)
|
||||
@@ -32,7 +42,7 @@ func (f *Framework) RunProcesses(serverTemplate string, clientTemplates []string
|
||||
flog.Debugf("[%s] %s", serverPath, outs[0])
|
||||
}
|
||||
|
||||
serverProcess := process.NewWithEnvs(TestContext.FRPServerPath, []string{"-c", serverPath}, f.osEnvs)
|
||||
serverProcess := process.NewWithEnvs(serverBinaryPath, []string{"-c", serverPath}, f.osEnvs)
|
||||
f.serverConfPaths = append(f.serverConfPaths, serverPath)
|
||||
f.serverProcesses = append(f.serverProcesses, serverProcess)
|
||||
err = serverProcess.Start()
|
||||
@@ -55,7 +65,7 @@ func (f *Framework) RunProcesses(serverTemplate string, clientTemplates []string
|
||||
flog.Debugf("[%s] %s", path, outs[1+i])
|
||||
}
|
||||
|
||||
p := process.NewWithEnvs(TestContext.FRPClientPath, []string{"-c", path}, f.osEnvs)
|
||||
p := process.NewWithEnvs(clientBinaryPath, []string{"-c", path}, f.osEnvs)
|
||||
f.clientConfPaths = append(f.clientConfPaths, path)
|
||||
f.clientProcesses = append(f.clientProcesses, p)
|
||||
clientProcesses = append(clientProcesses, p)
|
||||
|
||||
@@ -124,7 +124,3 @@ func (e *RequestExpect) Ensure(fns ...EnsureFunc) {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (e *RequestExpect) Do() (*request.Response, error) {
|
||||
return e.req.Do()
|
||||
}
|
||||
|
||||
@@ -31,13 +31,6 @@ func New(options ...Option) *Server {
|
||||
return s
|
||||
}
|
||||
|
||||
func WithBindAddr(addr string) Option {
|
||||
return func(s *Server) *Server {
|
||||
s.bindAddr = addr
|
||||
return s
|
||||
}
|
||||
}
|
||||
|
||||
func WithBindPort(port int) Option {
|
||||
return func(s *Server) *Server {
|
||||
s.bindPort = port
|
||||
|
||||
@@ -53,27 +53,14 @@ type Server struct {
|
||||
tokenRequestCount atomic.Int64
|
||||
}
|
||||
|
||||
const maxTokenRequestBodySize = 1 << 20
|
||||
|
||||
type Option func(*Server)
|
||||
|
||||
func WithBindPort(port int) Option {
|
||||
return func(s *Server) { s.bindPort = port }
|
||||
}
|
||||
|
||||
func WithClientCredentials(id, secret string) Option {
|
||||
return func(s *Server) {
|
||||
s.clientID = id
|
||||
s.clientSecret = secret
|
||||
}
|
||||
}
|
||||
|
||||
func WithAudience(aud string) Option {
|
||||
return func(s *Server) { s.audience = aud }
|
||||
}
|
||||
|
||||
func WithSubject(sub string) Option {
|
||||
return func(s *Server) { s.subject = sub }
|
||||
}
|
||||
|
||||
func WithExpiresIn(seconds int) Option {
|
||||
return func(s *Server) { s.expiresIn = seconds }
|
||||
}
|
||||
@@ -178,6 +165,7 @@ func (s *Server) handleToken(w http.ResponseWriter, r *http.Request) {
|
||||
http.Error(w, "method not allowed", http.StatusMethodNotAllowed)
|
||||
return
|
||||
}
|
||||
r.Body = http.MaxBytesReader(w, r.Body, maxTokenRequestBodySize)
|
||||
if err := r.ParseForm(); err != nil {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(http.StatusBadRequest)
|
||||
@@ -187,7 +175,7 @@ func (s *Server) handleToken(w http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
|
||||
if r.FormValue("grant_type") != "client_credentials" {
|
||||
if r.Form.Get("grant_type") != "client_credentials" {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(http.StatusBadRequest)
|
||||
_ = json.NewEncoder(w).Encode(map[string]any{
|
||||
@@ -199,8 +187,8 @@ func (s *Server) handleToken(w http.ResponseWriter, r *http.Request) {
|
||||
// Accept credentials from Basic Auth or form body.
|
||||
clientID, clientSecret, ok := r.BasicAuth()
|
||||
if !ok {
|
||||
clientID = r.FormValue("client_id")
|
||||
clientSecret = r.FormValue("client_secret")
|
||||
clientID = r.Form.Get("client_id")
|
||||
clientSecret = r.Form.Get("client_secret")
|
||||
}
|
||||
if clientID != s.clientID || clientSecret != s.clientSecret {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
|
||||
@@ -40,13 +40,8 @@ type Process struct {
|
||||
closeOne sync.Once
|
||||
waitErr error
|
||||
|
||||
started bool
|
||||
beforeStopHandler func()
|
||||
stopped bool
|
||||
}
|
||||
|
||||
func New(path string, params []string) *Process {
|
||||
return NewWithEnvs(path, params, nil)
|
||||
started bool
|
||||
stopped bool
|
||||
}
|
||||
|
||||
func NewWithEnvs(path string, params []string, envs []string) *Process {
|
||||
@@ -100,9 +95,6 @@ func (p *Process) Stop() error {
|
||||
defer func() {
|
||||
p.stopped = true
|
||||
}()
|
||||
if p.beforeStopHandler != nil {
|
||||
p.beforeStopHandler()
|
||||
}
|
||||
p.cancel()
|
||||
<-p.done
|
||||
return p.waitErr
|
||||
@@ -125,10 +117,6 @@ func (p *Process) CountOutput(pattern string) int {
|
||||
return strings.Count(p.Output(), pattern)
|
||||
}
|
||||
|
||||
func (p *Process) SetBeforeStopHandler(fn func()) {
|
||||
p.beforeStopHandler = fn
|
||||
}
|
||||
|
||||
// WaitForOutput polls the combined process output until the pattern is found
|
||||
// count time(s) or the timeout is reached. It also returns early if the process exits.
|
||||
func (p *Process) WaitForOutput(pattern string, count int, timeout time.Duration) error {
|
||||
|
||||
@@ -22,11 +22,10 @@ type Request struct {
|
||||
protocol string
|
||||
|
||||
// for all protocol
|
||||
addr string
|
||||
port int
|
||||
body []byte
|
||||
timeout time.Duration
|
||||
resolver *net.Resolver
|
||||
addr string
|
||||
port int
|
||||
body []byte
|
||||
timeout time.Duration
|
||||
|
||||
// for http or https
|
||||
method string
|
||||
@@ -134,11 +133,6 @@ func (r *Request) Body(content []byte) *Request {
|
||||
return r
|
||||
}
|
||||
|
||||
func (r *Request) Resolver(resolver *net.Resolver) *Request {
|
||||
r.resolver = resolver
|
||||
return r
|
||||
}
|
||||
|
||||
func (r *Request) Do() (*Response, error) {
|
||||
var (
|
||||
conn net.Conn
|
||||
@@ -169,7 +163,7 @@ func (r *Request) Do() (*Response, error) {
|
||||
return nil, err
|
||||
}
|
||||
} else {
|
||||
dialer := &net.Dialer{Resolver: r.resolver}
|
||||
dialer := &net.Dialer{}
|
||||
switch r.protocol {
|
||||
case "tcp":
|
||||
conn, err = dialer.Dial("tcp", addr)
|
||||
@@ -225,7 +219,6 @@ func (r *Request) sendHTTPRequest(method, urlstr string, host string, headers ma
|
||||
Timeout: time.Second,
|
||||
KeepAlive: 30 * time.Second,
|
||||
DualStack: true,
|
||||
Resolver: r.resolver,
|
||||
}).DialContext,
|
||||
MaxIdleConns: 100,
|
||||
IdleConnTimeout: 90 * time.Second,
|
||||
|
||||
163
test/e2e/v1/basic/wire.go
Normal file
163
test/e2e/v1/basic/wire.go
Normal file
@@ -0,0 +1,163 @@
|
||||
package basic
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
|
||||
"github.com/onsi/ginkgo/v2"
|
||||
|
||||
"github.com/fatedier/frp/test/e2e/framework"
|
||||
"github.com/fatedier/frp/test/e2e/framework/consts"
|
||||
"github.com/fatedier/frp/test/e2e/pkg/port"
|
||||
)
|
||||
|
||||
var _ = ginkgo.Describe("[Feature: WireProtocol]", func() {
|
||||
f := framework.NewDefaultFramework()
|
||||
|
||||
ginkgo.It("v1 tcp and udp proxies", func() {
|
||||
serverConf := consts.DefaultServerConfig
|
||||
tcpPortName := port.GenName("WireV1TCP")
|
||||
udpPortName := port.GenName("WireV1UDP")
|
||||
clientConf := consts.DefaultClientConfig + fmt.Sprintf(`
|
||||
transport.wireProtocol = "v1"
|
||||
|
||||
[[proxies]]
|
||||
name = "tcp"
|
||||
type = "tcp"
|
||||
localPort = {{ .%s }}
|
||||
remotePort = {{ .%s }}
|
||||
|
||||
[[proxies]]
|
||||
name = "udp"
|
||||
type = "udp"
|
||||
localPort = {{ .%s }}
|
||||
remotePort = {{ .%s }}
|
||||
`, framework.TCPEchoServerPort, tcpPortName, framework.UDPEchoServerPort, udpPortName)
|
||||
|
||||
f.RunProcesses(serverConf, []string{clientConf})
|
||||
|
||||
framework.NewRequestExpect(f).PortName(tcpPortName).Ensure()
|
||||
framework.NewRequestExpect(f).Protocol("udp").PortName(udpPortName).Ensure()
|
||||
})
|
||||
|
||||
ginkgo.It("v2 tcp and udp proxies", func() {
|
||||
serverConf := consts.DefaultServerConfig
|
||||
tcpPortName := port.GenName("WireV2TCP")
|
||||
udpPortName := port.GenName("WireV2UDP")
|
||||
clientConf := consts.DefaultClientConfig + fmt.Sprintf(`
|
||||
transport.wireProtocol = "v2"
|
||||
|
||||
[[proxies]]
|
||||
name = "tcp"
|
||||
type = "tcp"
|
||||
localPort = {{ .%s }}
|
||||
remotePort = {{ .%s }}
|
||||
|
||||
[[proxies]]
|
||||
name = "udp"
|
||||
type = "udp"
|
||||
localPort = {{ .%s }}
|
||||
remotePort = {{ .%s }}
|
||||
`, framework.TCPEchoServerPort, tcpPortName, framework.UDPEchoServerPort, udpPortName)
|
||||
|
||||
f.RunProcesses(serverConf, []string{clientConf})
|
||||
|
||||
framework.NewRequestExpect(f).PortName(tcpPortName).Ensure()
|
||||
framework.NewRequestExpect(f).Protocol("udp").PortName(udpPortName).Ensure()
|
||||
})
|
||||
|
||||
ginkgo.It("v2 stcp visitor", func() {
|
||||
serverConf := consts.DefaultServerConfig
|
||||
bindPortName := port.GenName("WireV2STCP")
|
||||
clientServerConf := consts.DefaultClientConfig + fmt.Sprintf(`
|
||||
user = "user1"
|
||||
transport.wireProtocol = "v2"
|
||||
|
||||
[[proxies]]
|
||||
name = "stcp"
|
||||
type = "stcp"
|
||||
secretKey = "abc"
|
||||
localPort = {{ .%s }}
|
||||
`, framework.TCPEchoServerPort)
|
||||
clientVisitorConf := consts.DefaultClientConfig + fmt.Sprintf(`
|
||||
user = "user1"
|
||||
transport.wireProtocol = "v2"
|
||||
|
||||
[[visitors]]
|
||||
name = "stcp-visitor"
|
||||
type = "stcp"
|
||||
serverName = "stcp"
|
||||
secretKey = "abc"
|
||||
bindPort = {{ .%s }}
|
||||
`, bindPortName)
|
||||
|
||||
f.RunProcesses(serverConf, []string{clientServerConf, clientVisitorConf})
|
||||
|
||||
framework.NewRequestExpect(f).PortName(bindPortName).Ensure()
|
||||
})
|
||||
|
||||
ginkgo.It("reports client wire protocol", func() {
|
||||
webPort := f.AllocPort()
|
||||
serverConf := consts.DefaultServerConfig + fmt.Sprintf(`
|
||||
webServer.port = %d
|
||||
`, webPort)
|
||||
|
||||
v1PortName := port.GenName("WireReportV1")
|
||||
v1ClientConf := consts.DefaultClientConfig + fmt.Sprintf(`
|
||||
clientID = "wire-v1"
|
||||
transport.wireProtocol = "v1"
|
||||
|
||||
[[proxies]]
|
||||
name = "v1"
|
||||
type = "tcp"
|
||||
localPort = {{ .%s }}
|
||||
remotePort = {{ .%s }}
|
||||
`, framework.TCPEchoServerPort, v1PortName)
|
||||
|
||||
v2PortName := port.GenName("WireReportV2")
|
||||
v2ClientConf := consts.DefaultClientConfig + fmt.Sprintf(`
|
||||
clientID = "wire-v2"
|
||||
transport.wireProtocol = "v2"
|
||||
|
||||
[[proxies]]
|
||||
name = "v2"
|
||||
type = "tcp"
|
||||
localPort = {{ .%s }}
|
||||
remotePort = {{ .%s }}
|
||||
`, framework.TCPEchoServerPort, v2PortName)
|
||||
|
||||
f.RunProcesses(serverConf, []string{v1ClientConf, v2ClientConf})
|
||||
|
||||
framework.NewRequestExpect(f).PortName(v1PortName).Ensure()
|
||||
framework.NewRequestExpect(f).PortName(v2PortName).Ensure()
|
||||
expectClientWireProtocol(webPort, "wire-v1", "v1")
|
||||
expectClientWireProtocol(webPort, "wire-v2", "v2")
|
||||
})
|
||||
})
|
||||
|
||||
type wireClientInfo struct {
|
||||
ClientID string `json:"clientID"`
|
||||
WireProtocol string `json:"wireProtocol"`
|
||||
}
|
||||
|
||||
func expectClientWireProtocol(webPort int, clientID string, wireProtocol string) {
|
||||
resp, err := http.Get(fmt.Sprintf("http://127.0.0.1:%d/api/clients", webPort))
|
||||
framework.ExpectNoError(err)
|
||||
defer resp.Body.Close()
|
||||
framework.ExpectEqual(resp.StatusCode, 200)
|
||||
|
||||
content, err := io.ReadAll(resp.Body)
|
||||
framework.ExpectNoError(err)
|
||||
|
||||
var clients []wireClientInfo
|
||||
framework.ExpectNoError(json.Unmarshal(content, &clients))
|
||||
for _, client := range clients {
|
||||
if client.ClientID == clientID {
|
||||
framework.ExpectEqual(client.WireProtocol, wireProtocol)
|
||||
return
|
||||
}
|
||||
}
|
||||
framework.Failf("client %q not found in /api/clients response: %s", clientID, string(content))
|
||||
}
|
||||
@@ -16,6 +16,9 @@
|
||||
<el-tag v-if="client.version" size="small" type="success"
|
||||
>v{{ client.version }}</el-tag
|
||||
>
|
||||
<el-tag v-if="client.wireProtocolLabel" size="small" type="info">
|
||||
{{ client.wireProtocolLabel }}
|
||||
</el-tag>
|
||||
</div>
|
||||
|
||||
<div class="card-meta">
|
||||
|
||||
@@ -4,6 +4,7 @@ export interface ClientInfoData {
|
||||
clientID: string
|
||||
runID: string
|
||||
version?: string
|
||||
wireProtocol?: string
|
||||
hostname: string
|
||||
clientIP?: string
|
||||
metas?: Record<string, string>
|
||||
|
||||
@@ -7,6 +7,7 @@ export class Client {
|
||||
clientID: string
|
||||
runID: string
|
||||
version: string
|
||||
wireProtocol: string
|
||||
hostname: string
|
||||
ip: string
|
||||
metas: Map<string, string>
|
||||
@@ -21,6 +22,7 @@ export class Client {
|
||||
this.clientID = data.clientID
|
||||
this.runID = data.runID
|
||||
this.version = data.version || ''
|
||||
this.wireProtocol = data.wireProtocol || ''
|
||||
this.hostname = data.hostname
|
||||
this.ip = data.clientIP || ''
|
||||
this.metas = new Map<string, string>()
|
||||
@@ -48,6 +50,11 @@ export class Client {
|
||||
return this.runID.substring(0, 8)
|
||||
}
|
||||
|
||||
get wireProtocolLabel(): string {
|
||||
if (!this.wireProtocol) return ''
|
||||
return `Protocol ${this.wireProtocol}`
|
||||
}
|
||||
|
||||
get firstConnectedAgo(): string {
|
||||
return formatDistanceToNow(this.firstConnectedAt)
|
||||
}
|
||||
@@ -80,6 +87,7 @@ export class Client {
|
||||
this.user.toLowerCase().includes(search) ||
|
||||
this.clientID.toLowerCase().includes(search) ||
|
||||
this.runID.toLowerCase().includes(search) ||
|
||||
this.wireProtocol.toLowerCase().includes(search) ||
|
||||
this.hostname.toLowerCase().includes(search)
|
||||
)
|
||||
}
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user