From ca6d9a1b32984d27e32728a0781d1ef592cda0ba Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E4=B8=96=E7=95=8C?= Date: Sun, 5 Apr 2026 02:17:12 +0800 Subject: [PATCH 1/7] Refactor memory --- common/memory/memory.go | 21 ++++--- common/memory/memory_available_darwin.go | 36 ----------- common/memory/memory_available_stub.go | 11 ---- common/memory/memory_darwin.go | 57 ++++++++++++++--- ...ory_available_linux.go => memory_linux.go} | 50 ++++++++++++++- common/memory/memory_stub.go | 16 ++++- common/memory/memory_windows.go | 35 +++++++++++ common/memory/syscall_windows.go | 34 +++++++++++ common/memory/zsyscall_windows.go | 61 +++++++++++++++++++ 9 files changed, 251 insertions(+), 70 deletions(-) delete mode 100644 common/memory/memory_available_darwin.go delete mode 100644 common/memory/memory_available_stub.go rename common/memory/{memory_available_linux.go => memory_linux.go} (66%) create mode 100644 common/memory/memory_windows.go create mode 100644 common/memory/syscall_windows.go create mode 100644 common/memory/zsyscall_windows.go diff --git a/common/memory/memory.go b/common/memory/memory.go index 6fd5a25ec..54ca2492d 100644 --- a/common/memory/memory.go +++ b/common/memory/memory.go @@ -3,22 +3,23 @@ package memory import "runtime" func Total() uint64 { - if nativeAvailable { - return usageNative() - } - return Inuse() + return totalNative() } -func Inuse() uint64 { - var memStats runtime.MemStats - runtime.ReadMemStats(&memStats) - return memStats.StackInuse + memStats.HeapInuse + memStats.HeapIdle - memStats.HeapReleased +func TotalAvailable() bool { + return totalAvailable() } func Available() uint64 { return availableNative() } -func AvailableSupported() bool { - return availableNativeSupported() +func AvailableAvailable() bool { + return availableAvailable() +} + +func Inuse() uint64 { + var memStats runtime.MemStats + runtime.ReadMemStats(&memStats) + return memStats.StackInuse + memStats.HeapInuse + memStats.HeapIdle - memStats.HeapReleased } diff --git a/common/memory/memory_available_darwin.go b/common/memory/memory_available_darwin.go deleted file mode 100644 index 0e42c2f19..000000000 --- a/common/memory/memory_available_darwin.go +++ /dev/null @@ -1,36 +0,0 @@ -package memory - -// #include -// #include -// -// static size_t get_available_memory(int *supported) { -// typedef size_t (*proc_available_memory_func)(void); -// static int resolved = 0; -// static proc_available_memory_func fn = NULL; -// if (!resolved) { -// fn = (proc_available_memory_func)dlsym(RTLD_DEFAULT, "os_proc_available_memory"); -// resolved = 1; -// } -// if (fn) { -// *supported = 1; -// return fn(); -// } -// *supported = 0; -// return 0; -// } -import "C" - -func availableNative() uint64 { - var supported C.int - result := C.get_available_memory(&supported) - if supported == 0 { - return 0 - } - return uint64(result) -} - -func availableNativeSupported() bool { - var supported C.int - C.get_available_memory(&supported) - return supported != 0 -} diff --git a/common/memory/memory_available_stub.go b/common/memory/memory_available_stub.go deleted file mode 100644 index 27ff4a1c8..000000000 --- a/common/memory/memory_available_stub.go +++ /dev/null @@ -1,11 +0,0 @@ -//go:build !(darwin && cgo) && !linux - -package memory - -func availableNativeSupported() bool { - return false -} - -func availableNative() uint64 { - return 0 -} diff --git a/common/memory/memory_darwin.go b/common/memory/memory_darwin.go index 4fe2e50a3..86b0be2aa 100644 --- a/common/memory/memory_darwin.go +++ b/common/memory/memory_darwin.go @@ -1,18 +1,59 @@ package memory // #include +// #include +// #include +// +// typedef size_t (*proc_available_memory_func)(void); +// static int resolved = 0; +// static proc_available_memory_func fn = NULL; +// +// static void resolve_available_memory() { +// if (!resolved) { +// fn = (proc_available_memory_func)dlsym(RTLD_DEFAULT, "os_proc_available_memory"); +// resolved = 1; +// } +// } +// +// static size_t get_available_memory(int *supported) { +// resolve_available_memory(); +// if (fn) { +// *supported = 1; +// return fn(); +// } +// *supported = 0; +// return 0; +// } +// +// static int is_available_memory_supported() { +// resolve_available_memory(); +// return fn != NULL; +// } import "C" import "unsafe" -const nativeAvailable = true - -func usageNative() uint64 { - var memoryUsageInByte uint64 +func totalNative() uint64 { var vmInfo C.task_vm_info_data_t var count C.mach_msg_type_number_t = C.TASK_VM_INFO_COUNT - var kernelReturn C.kern_return_t = C.task_info(C.vm_map_t(C.mach_task_self_), C.TASK_VM_INFO, (*C.integer_t)(unsafe.Pointer(&vmInfo)), &count) - if kernelReturn == C.KERN_SUCCESS { - memoryUsageInByte = uint64(vmInfo.phys_footprint) + if C.task_info(C.vm_map_t(C.mach_task_self_), C.TASK_VM_INFO, (*C.integer_t)(unsafe.Pointer(&vmInfo)), &count) == C.KERN_SUCCESS { + return uint64(vmInfo.phys_footprint) + } + return 0 +} + +func totalAvailable() bool { + return true +} + +func availableNative() uint64 { + var supported C.int + result := C.get_available_memory(&supported) + if supported == 0 { + return 0 } - return memoryUsageInByte + return uint64(result) +} + +func availableAvailable() bool { + return C.is_available_memory_supported() != 0 } diff --git a/common/memory/memory_available_linux.go b/common/memory/memory_linux.go similarity index 66% rename from common/memory/memory_available_linux.go rename to common/memory/memory_linux.go index d4dd239aa..64e65861c 100644 --- a/common/memory/memory_available_linux.go +++ b/common/memory/memory_linux.go @@ -8,8 +8,41 @@ import ( "strings" ) -func availableNativeSupported() bool { - return true +var pageSize = uint64(os.Getpagesize()) + +func totalNative() uint64 { + fd, err := os.Open("/proc/self/statm") + if err != nil { + return 0 + } + defer fd.Close() + var buf [128]byte + n, _ := fd.Read(buf[:]) + if n == 0 { + return 0 + } + i := 0 + for i < n && buf[i] != ' ' { + i++ + } + i++ + var rss uint64 + for i < n && buf[i] >= '0' && buf[i] <= '9' { + rss = rss*10 + uint64(buf[i]-'0') + i++ + } + return rss * pageSize +} + +func totalAvailable() bool { + fd, err := os.Open("/proc/self/statm") + if err != nil { + return false + } + defer fd.Close() + var buf [1]byte + n, _ := fd.Read(buf[:]) + return n > 0 } func availableNative() uint64 { @@ -20,6 +53,19 @@ func availableNative() uint64 { return procMemAvailable() } +func availableAvailable() bool { + _, ok := cgroupAvailable() + if ok { + return true + } + fd, err := os.Open("/proc/meminfo") + if err != nil { + return false + } + fd.Close() + return true +} + func cgroupAvailable() (uint64, bool) { max, err := readCgroupUint("/sys/fs/cgroup/memory.max") if err == nil && max != math.MaxUint64 { diff --git a/common/memory/memory_stub.go b/common/memory/memory_stub.go index 3781f587c..24a4b0509 100644 --- a/common/memory/memory_stub.go +++ b/common/memory/memory_stub.go @@ -1,9 +1,19 @@ -//go:build (darwin && !cgo) || !darwin +//go:build (darwin && !cgo) || (!darwin && !linux && !windows) package memory -const nativeAvailable = false +func totalNative() uint64 { + return 0 +} + +func totalAvailable() bool { + return false +} -func usageNative() uint64 { +func availableNative() uint64 { return 0 } + +func availableAvailable() bool { + return false +} diff --git a/common/memory/memory_windows.go b/common/memory/memory_windows.go new file mode 100644 index 000000000..c0f3b0f00 --- /dev/null +++ b/common/memory/memory_windows.go @@ -0,0 +1,35 @@ +package memory + +import ( + "unsafe" + + "golang.org/x/sys/windows" +) + +func totalNative() uint64 { + var mem processMemoryCounters + mem.cb = uint32(unsafe.Sizeof(mem)) + err := getProcessMemoryInfo(windows.CurrentProcess(), &mem, mem.cb) + if err != nil { + return 0 + } + return uint64(mem.workingSetSize) +} + +func totalAvailable() bool { + return true +} + +func availableNative() uint64 { + var mem memoryStatusEx + mem.dwLength = uint32(unsafe.Sizeof(mem)) + err := globalMemoryStatusEx(&mem) + if err != nil { + return 0 + } + return mem.ullAvailPhys +} + +func availableAvailable() bool { + return true +} diff --git a/common/memory/syscall_windows.go b/common/memory/syscall_windows.go new file mode 100644 index 000000000..069ce0118 --- /dev/null +++ b/common/memory/syscall_windows.go @@ -0,0 +1,34 @@ +package memory + +//go:generate go run golang.org/x/sys/windows/mkwinsyscall -output zsyscall_windows.go syscall_windows.go + +type processMemoryCounters struct { + cb uint32 + pageFaultCount uint32 + peakWorkingSetSize uintptr + workingSetSize uintptr + quotaPeakPagedPoolUsage uintptr + quotaPagedPoolUsage uintptr + quotaPeakNonPagedPoolUsage uintptr + quotaNonPagedPoolUsage uintptr + pagefileUsage uintptr + peakPagefileUsage uintptr +} + +type memoryStatusEx struct { + dwLength uint32 + dwMemoryLoad uint32 + ullTotalPhys uint64 + ullAvailPhys uint64 + ullTotalPageFile uint64 + ullAvailPageFile uint64 + ullTotalVirtual uint64 + ullAvailVirtual uint64 + ullAvailExtendedVirtual uint64 +} + +// https://learn.microsoft.com/en-us/windows/win32/api/psapi/nf-psapi-getprocessmemoryinfo +//sys getProcessMemoryInfo(process windows.Handle, ppsmemCounters *processMemoryCounters, cb uint32) (err error) = kernel32.K32GetProcessMemoryInfo + +// https://learn.microsoft.com/en-us/windows/win32/api/sysinfoapi/nf-sysinfoapi-globalmemorystatusex +//sys globalMemoryStatusEx(lpBuffer *memoryStatusEx) (err error) = kernel32.GlobalMemoryStatusEx diff --git a/common/memory/zsyscall_windows.go b/common/memory/zsyscall_windows.go new file mode 100644 index 000000000..33334a374 --- /dev/null +++ b/common/memory/zsyscall_windows.go @@ -0,0 +1,61 @@ +// Code generated by 'go generate'; DO NOT EDIT. + +package memory + +import ( + "syscall" + "unsafe" + + "golang.org/x/sys/windows" +) + +var _ unsafe.Pointer + +// Do the interface allocations only once for common +// Errno values. +const ( + errnoERROR_IO_PENDING = 997 +) + +var ( + errERROR_IO_PENDING error = syscall.Errno(errnoERROR_IO_PENDING) + errERROR_EINVAL error = syscall.EINVAL +) + +// errnoErr returns common boxed Errno values, to prevent +// allocations at runtime. +func errnoErr(e syscall.Errno) error { + switch e { + case 0: + return errERROR_EINVAL + case errnoERROR_IO_PENDING: + return errERROR_IO_PENDING + } + // TODO: add more here, after collecting data on the common + // error values see on Windows. (perhaps when running + // all.bat?) + return e +} + +var ( + modkernel32 = windows.NewLazySystemDLL("kernel32.dll") + + procGlobalMemoryStatusEx = modkernel32.NewProc("GlobalMemoryStatusEx") + procK32GetProcessMemoryInfo = modkernel32.NewProc("K32GetProcessMemoryInfo") +) + +func globalMemoryStatusEx(lpBuffer *memoryStatusEx) (err error) { + r1, _, e1 := syscall.Syscall(procGlobalMemoryStatusEx.Addr(), 1, uintptr(unsafe.Pointer(lpBuffer)), 0, 0) + if r1 == 0 { + err = errnoErr(e1) + } + return +} + +func getProcessMemoryInfo(process windows.Handle, ppsmemCounters *processMemoryCounters, cb uint32) (err error) { + r1, _, e1 := syscall.Syscall(procK32GetProcessMemoryInfo.Addr(), 3, uintptr(process), uintptr(unsafe.Pointer(ppsmemCounters)), uintptr(cb)) + if r1 == 0 { + err = errnoErr(e1) + } + return +} From e8eacaf9abe395e5bab86592981f01dbf6282a37 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E4=B8=96=E7=95=8C?= Date: Sat, 11 Apr 2026 20:37:32 +0800 Subject: [PATCH 2/7] tls: Add handshake timeout to interface --- common/tls/config.go | 13 +++++++++++++ 1 file changed, 13 insertions(+) diff --git a/common/tls/config.go b/common/tls/config.go index 3bb16416f..587b42c6c 100644 --- a/common/tls/config.go +++ b/common/tls/config.go @@ -4,6 +4,7 @@ import ( "context" "crypto/tls" "net" + "time" ) type ( @@ -17,6 +18,8 @@ type Config interface { SetServerName(serverName string) NextProtos() []string SetNextProtos(nextProto []string) + HandshakeTimeout() time.Duration + SetHandshakeTimeout(timeout time.Duration) STDConfig() (*STDConfig, error) Client(conn net.Conn) (Conn, error) Clone() Config @@ -51,6 +54,11 @@ type Conn interface { } func ClientHandshake(ctx context.Context, conn net.Conn, config Config) (Conn, error) { + if handshakeTimeout := config.HandshakeTimeout(); handshakeTimeout > 0 { + var cancel context.CancelFunc + ctx, cancel = context.WithTimeout(ctx, handshakeTimeout) + defer cancel() + } if compatServer, isCompat := config.(ConfigCompat); isCompat { return compatServer.ClientHandshake(ctx, conn) } @@ -66,6 +74,11 @@ func ClientHandshake(ctx context.Context, conn net.Conn, config Config) (Conn, e } func ServerHandshake(ctx context.Context, conn net.Conn, config ServerConfig) (Conn, error) { + if handshakeTimeout := config.HandshakeTimeout(); handshakeTimeout > 0 { + var cancel context.CancelFunc + ctx, cancel = context.WithTimeout(ctx, handshakeTimeout) + defer cancel() + } if compatServer, isCompat := config.(ServerConfigCompat); isCompat { return compatServer.ServerHandshake(ctx, conn) } From 1bd153a749b8678facf25a3612cba228ffee7bdf Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E4=B8=96=E7=95=8C?= Date: Fri, 24 Apr 2026 08:52:54 +0800 Subject: [PATCH 3/7] bufio: add UDP packet batch support --- common/bufio/bind.go | 26 +- common/bufio/copy.go | 195 ++++++++- common/bufio/copy_direct.go | 168 +++++++- common/bufio/counter_packet_conn.go | 137 ++++++ common/bufio/nat.go | 99 +++++ common/bufio/nat_wait.go | 98 +++++ common/bufio/packet_batch.go | 107 +++++ common/bufio/packet_batch_mmsg.go | 392 +++++++++++++++++ common/bufio/packet_batch_mmsg_generic.go | 7 + common/bufio/packet_batch_mmsg_linux32.go | 7 + common/bufio/packet_batch_msgx_darwin.go | 306 +++++++++++++ common/bufio/packet_batch_stub.go | 24 ++ common/bufio/packet_batch_syscall.go | 58 +++ common/bufio/packet_batch_test.go | 504 ++++++++++++++++++++++ common/bufio/wait.go | 55 ++- common/metadata/addr_unix.go | 22 +- common/metadata/addr_windows.go | 6 +- common/network/counter.go | 4 +- common/network/direct.go | 21 +- common/network/early.go | 29 +- common/network/vectorised.go | 16 + common/udpnat2/conn.go | 42 ++ common/udpnat2/conn_test.go | 98 +++++ common/udpnat2/service.go | 12 + 24 files changed, 2390 insertions(+), 43 deletions(-) create mode 100644 common/bufio/packet_batch.go create mode 100644 common/bufio/packet_batch_mmsg.go create mode 100644 common/bufio/packet_batch_mmsg_generic.go create mode 100644 common/bufio/packet_batch_mmsg_linux32.go create mode 100644 common/bufio/packet_batch_msgx_darwin.go create mode 100644 common/bufio/packet_batch_stub.go create mode 100644 common/bufio/packet_batch_syscall.go create mode 100644 common/bufio/packet_batch_test.go create mode 100644 common/udpnat2/conn_test.go diff --git a/common/bufio/bind.go b/common/bufio/bind.go index 9788b4d3a..dd46fb836 100644 --- a/common/bufio/bind.go +++ b/common/bufio/bind.go @@ -42,6 +42,14 @@ func (c *bindPacketConn) CreateReadWaiter() (N.ReadWaiter, bool) { return &bindPacketReadWaiter{readWaiter}, true } +func (c *bindPacketConn) CreatePacketBatchReadWaiter() (N.PacketBatchReadWaiter, bool) { + return CreatePacketBatchReadWaiter(c.NetPacketConn) +} + +func (c *bindPacketConn) CreatePacketBatchWriter() (N.PacketBatchWriter, bool) { + return CreatePacketBatchWriter(c.NetPacketConn) +} + func (c *bindPacketConn) RemoteAddr() net.Addr { return c.addr } @@ -51,8 +59,10 @@ func (c *bindPacketConn) Upstream() any { } var ( - _ N.NetPacketConn = (*UnbindPacketConn)(nil) - _ N.PacketReadWaitCreator = (*UnbindPacketConn)(nil) + _ N.NetPacketConn = (*UnbindPacketConn)(nil) + _ N.PacketReadWaitCreator = (*UnbindPacketConn)(nil) + _ N.ConnectedPacketBatchReadWaitCreator = (*UnbindPacketConn)(nil) + _ N.ConnectedPacketBatchWriteCreator = (*UnbindPacketConn)(nil) ) type UnbindPacketConn struct { @@ -107,6 +117,14 @@ func (c *UnbindPacketConn) CreateReadWaiter() (N.PacketReadWaiter, bool) { return &unbindPacketReadWaiter{readWaiter, c.addr}, true } +func (c *UnbindPacketConn) CreateConnectedPacketBatchReadWaiter() (N.ConnectedPacketBatchReadWaiter, bool) { + return createSyscallConnectedPacketBatchReadWaiter(c.ExtendedConn, c.addr) +} + +func (c *UnbindPacketConn) CreateConnectedPacketBatchWriter() (N.ConnectedPacketBatchWriter, bool) { + return createSyscallConnectedPacketBatchWriter(c.ExtendedConn) +} + func (c *UnbindPacketConn) Upstream() any { return c.ExtendedConn } @@ -163,3 +181,7 @@ func (c *serverPacketConn) CreateReadWaiter() (N.ReadWaiter, bool) { } return &serverPacketReadWaiter{c, readWaiter}, true } + +func (c *serverPacketConn) CreatePacketBatchReadWaiter() (N.PacketBatchReadWaiter, bool) { + return CreatePacketBatchReadWaiter(c.NetPacketConn) +} diff --git a/common/bufio/copy.go b/common/bufio/copy.go index 164ee296e..fe34fdbb7 100644 --- a/common/bufio/copy.go +++ b/common/bufio/copy.go @@ -17,6 +17,7 @@ import ( const ( DefaultIncreaseBufferAfter = 512 * 1000 DefaultBatchSize = 8 + DefaultPacketReadBatchSize = 64 ) func Copy(destination io.Writer, source io.Reader) (n int64, err error) { @@ -259,7 +260,6 @@ func CopyConn(ctx context.Context, source net.Conn, destination net.Conn) error func CopyPacket(destinationConn N.PacketWriter, source N.PacketReader) (n int64, err error) { var readCounters, writeCounters []N.CountFunc - var cachedPackets []*N.PacketBuffer originSource := source for { source, readCounters = N.UnwrapCountPacketReader(source, readCounters) @@ -267,45 +267,194 @@ func CopyPacket(destinationConn N.PacketWriter, source N.PacketReader) (n int64, if cachedReader, isCached := source.(N.CachedPacketReader); isCached { packet := cachedReader.ReadCachedPacket() if packet != nil { - cachedPackets = append(cachedPackets, packet) + var cachedN int64 + cachedN, err = writePacketWithPool(originSource, destinationConn, []*N.PacketBuffer{packet}, readCounters, writeCounters, n > 0) + n += cachedN + if err != nil { + return + } continue } } break } - if cachedPackets != nil { - n, err = WritePacketWithPool(originSource, destinationConn, cachedPackets, readCounters, writeCounters) - if err != nil { - return - } - } - copeN, err := CopyPacketWithCounters(destinationConn, source, originSource, readCounters, writeCounters) + copeN, err := copyPacketWithCounters(destinationConn, source, originSource, readCounters, writeCounters, n > 0) n += copeN return } +type packetCopySession struct { + destination N.PacketWriter + source N.PacketReader + originSource N.PacketReader + readCounters []N.CountFunc + writeCounters []N.CountFunc + handshake N.HandshakeState + upgradable bool +} + +func newPacketCopySession(destination N.PacketWriter, source N.PacketReader, originSource N.PacketReader, readCounters []N.CountFunc, writeCounters []N.CountFunc) *packetCopySession { + session := &packetCopySession{ + destination: destination, + source: source, + originSource: originSource, + readCounters: readCounters, + writeCounters: writeCounters, + } + session.ResetHandshake() + return session +} + +func (s *packetCopySession) ResetHandshake() { + s.handshake = N.NewPacketHandshakeState(s.source, s.destination) + s.upgradable = s.handshake.Upgradable() +} + +func (s *packetCopySession) Transfer(n int64) error { + for _, counter := range s.readCounters { + counter(n) + } + for _, counter := range s.writeCounters { + counter(n) + } + if s.upgradable { + return s.handshake.Check() + } + return nil +} + +func (s *packetCopySession) TransferBatch(dataLens []int) error { + for _, dataLen := range dataLens { + n := int64(dataLen) + for _, counter := range s.readCounters { + counter(n) + } + for _, counter := range s.writeCounters { + counter(n) + } + } + if s.upgradable { + return s.handshake.Check() + } + return nil +} + func CopyPacketWithCounters(destinationConn N.PacketWriter, source N.PacketReader, originSource N.PacketReader, readCounters []N.CountFunc, writeCounters []N.CountFunc) (n int64, err error) { + return copyPacketWithCounters(destinationConn, source, originSource, readCounters, writeCounters, false) +} + +func copyPacketWithCounters(destinationConn N.PacketWriter, source N.PacketReader, originSource N.PacketReader, readCounters []N.CountFunc, writeCounters []N.CountFunc, notFirstTime bool) (n int64, err error) { + sourceReader := source + destinationWriter := destinationConn + session := newPacketCopySession(destinationWriter, sourceReader, originSource, readCounters, writeCounters) + refreshUnwrap := func() { + sourceReader, readCounters = N.UnwrapCountPacketReader(sourceReader, readCounters) + destinationWriter, writeCounters = N.UnwrapCountPacketWriter(destinationWriter, writeCounters) + session.source = sourceReader + session.destination = destinationWriter + session.readCounters = readCounters + session.writeCounters = writeCounters + session.ResetHandshake() + } + for { + var copyN int64 + copyN, err = copyPacketWithCountersOnce(session, notFirstTime || n > 0) + n += copyN + if errors.Is(err, N.ErrHandshakeCompleted) { + refreshUnwrap() + continue + } + return + } +} + +func copyPacketWithCountersOnce(session *packetCopySession, notFirstTime bool) (n int64, err error) { var ( handled bool copeN int64 ) + source := session.source + destinationConn := session.destination + batchReadWaiter, isBatchReadWaiter := CreatePacketBatchReadWaiter(source) + if isBatchReadWaiter { + batchWriter, isBatchWriter := CreatePacketBatchWriter(destinationConn) + if isBatchWriter { + readWaitOptions := N.NewReadWaitOptions(source, destinationConn) + readWaitOptions.BatchSize = DefaultPacketReadBatchSize + needCopy := batchReadWaiter.InitializeReadWaiter(readWaitOptions) + if !needCopy || common.LowMemory { + handled, copeN, err = copyPacketBatchWaitWithPool(session, batchWriter, batchReadWaiter, notFirstTime) + if handled { + n += copeN + return + } + } + } + connectedBatchWriter, isConnectedBatchWriter := CreateConnectedPacketBatchWriter(destinationConn) + if isConnectedBatchWriter { + readWaitOptions := N.NewReadWaitOptions(source, destinationConn) + readWaitOptions.BatchSize = DefaultPacketReadBatchSize + needCopy := batchReadWaiter.InitializeReadWaiter(readWaitOptions) + if !needCopy || common.LowMemory { + handled, copeN, err = copyPacketBatchToConnectedWaitWithPool(session, connectedBatchWriter, batchReadWaiter, notFirstTime) + if handled { + n += copeN + return + } + } + } + } + connectedBatchReadWaiter, isConnectedBatchReadWaiter := CreateConnectedPacketBatchReadWaiter(source) + if isConnectedBatchReadWaiter { + batchWriter, isBatchWriter := CreatePacketBatchWriter(destinationConn) + if isBatchWriter { + readWaitOptions := N.NewReadWaitOptions(source, destinationConn) + readWaitOptions.BatchSize = DefaultPacketReadBatchSize + needCopy := connectedBatchReadWaiter.InitializeReadWaiter(readWaitOptions) + if !needCopy || common.LowMemory { + handled, copeN, err = copyConnectedPacketBatchWaitWithPool(session, batchWriter, connectedBatchReadWaiter, notFirstTime) + if handled { + n += copeN + return + } + } + } + connectedBatchWriter, isConnectedBatchWriter := CreateConnectedPacketBatchWriter(destinationConn) + if isConnectedBatchWriter { + readWaitOptions := N.NewReadWaitOptions(source, destinationConn) + readWaitOptions.BatchSize = DefaultPacketReadBatchSize + needCopy := connectedBatchReadWaiter.InitializeReadWaiter(readWaitOptions) + if !needCopy || common.LowMemory { + handled, copeN, err = copyConnectedPacketBatchToConnectedWaitWithPool(session, connectedBatchWriter, connectedBatchReadWaiter, notFirstTime) + if handled { + n += copeN + return + } + } + } + } readWaiter, isReadWaiter := CreatePacketReadWaiter(source) if isReadWaiter { needCopy := readWaiter.InitializeReadWaiter(N.NewReadWaitOptions(source, destinationConn)) if !needCopy || common.LowMemory { - handled, copeN, err = copyPacketWaitWithPool(originSource, destinationConn, readWaiter, readCounters, writeCounters, n > 0) + handled, copeN, err = copyPacketWaitWithPool(session, destinationConn, readWaiter, notFirstTime) if handled { n += copeN return } } } - copeN, err = CopyPacketWithPool(originSource, destinationConn, source, readCounters, writeCounters, n > 0) + copeN, err = copyPacketWithPool(session, destinationConn, source, notFirstTime) n += copeN return } func CopyPacketWithPool(originSource N.PacketReader, destination N.PacketWriter, source N.PacketReader, readCounters []N.CountFunc, writeCounters []N.CountFunc, notFirstTime bool) (n int64, err error) { + session := newPacketCopySession(destination, source, originSource, readCounters, writeCounters) + return copyPacketWithPool(session, destination, source, notFirstTime) +} + +func copyPacketWithPool(session *packetCopySession, destination N.PacketWriter, source N.PacketReader, notFirstTime bool) (n int64, err error) { options := N.NewReadWaitOptions(source, destination) var destinationAddress M.Socksaddr for { @@ -321,24 +470,27 @@ func CopyPacketWithPool(originSource N.PacketReader, destination N.PacketWriter, if err != nil { buffer.Leak() if !notFirstTime { - err = N.ReportHandshakeFailure(originSource, err) + handshakeErr := N.ReportHandshakeFailure(session.originSource, err) + if handshakeErr != nil { + err = handshakeErr + } } return } - for _, counter := range readCounters { - counter(int64(dataLen)) - } - for _, counter := range writeCounters { - counter(int64(dataLen)) - } n += int64(dataLen) + if err = session.Transfer(int64(dataLen)); err != nil { + return + } notFirstTime = true } } func WritePacketWithPool(originSource N.PacketReader, destination N.PacketWriter, packetBuffers []*N.PacketBuffer, readCounters []N.CountFunc, writeCounters []N.CountFunc) (n int64, err error) { + return writePacketWithPool(originSource, destination, packetBuffers, readCounters, writeCounters, false) +} + +func writePacketWithPool(originSource N.PacketReader, destination N.PacketWriter, packetBuffers []*N.PacketBuffer, readCounters []N.CountFunc, writeCounters []N.CountFunc, notFirstTime bool) (n int64, err error) { options := N.NewReadWaitOptions(nil, destination) - var notFirstTime bool for _, packetBuffer := range packetBuffers { buffer := options.Copy(packetBuffer.Buffer) dataLen := buffer.Len() @@ -347,7 +499,10 @@ func WritePacketWithPool(originSource N.PacketReader, destination N.PacketWriter if err != nil { buffer.Leak() if !notFirstTime { - err = N.ReportHandshakeFailure(originSource, err) + handshakeErr := N.ReportHandshakeFailure(originSource, err) + if handshakeErr != nil { + err = handshakeErr + } } return } diff --git a/common/bufio/copy_direct.go b/common/bufio/copy_direct.go index 8f601f101..a460c7edd 100644 --- a/common/bufio/copy_direct.go +++ b/common/bufio/copy_direct.go @@ -3,6 +3,7 @@ package bufio import ( "errors" "io" + "os" "github.com/sagernet/sing/common/buf" M "github.com/sagernet/sing/common/metadata" @@ -97,7 +98,7 @@ func copyWaitVectorisedWithPool(session *CopySession, vectorisedWriter N.Vectori } } -func copyPacketWaitWithPool(originSource N.PacketReader, destinationConn N.PacketWriter, source N.PacketReadWaiter, readCounters []N.CountFunc, writeCounters []N.CountFunc, notFirstTime bool) (handled bool, n int64, err error) { +func copyPacketWaitWithPool(session *packetCopySession, destinationConn N.PacketWriter, source N.PacketReadWaiter, notFirstTime bool) (handled bool, n int64, err error) { handled = true var ( buffer *buf.Buffer @@ -113,16 +114,171 @@ func copyPacketWaitWithPool(originSource N.PacketReader, destinationConn N.Packe if err != nil { buffer.Leak() if !notFirstTime { - err = N.ReportHandshakeFailure(originSource, err) + handshakeErr := N.ReportHandshakeFailure(session.originSource, err) + if handshakeErr != nil { + err = handshakeErr + } } return } n += int64(dataLen) - for _, counter := range readCounters { - counter(int64(dataLen)) + if err = session.Transfer(int64(dataLen)); err != nil { + return + } + notFirstTime = true + } +} + +func copyPacketBatchWaitWithPool(session *packetCopySession, destinationConn N.PacketBatchWriter, source N.PacketBatchReadWaiter, notFirstTime bool) (handled bool, n int64, err error) { + handled = true + for { + var ( + buffers []*buf.Buffer + destinations []M.Socksaddr + ) + buffers, destinations, err = source.WaitReadPackets() + if err != nil { + return handled, n, err } - for _, counter := range writeCounters { - counter(int64(dataLen)) + if len(buffers) == 0 || len(buffers) != len(destinations) { + buf.ReleaseMulti(buffers) + return handled, n, os.ErrInvalid + } + dataLens := make([]int, len(buffers)) + for index, buffer := range buffers { + dataLens[index] = buffer.Len() + } + err = destinationConn.WritePacketBatch(buffers, destinations) + if err != nil { + if !notFirstTime { + handshakeErr := N.ReportHandshakeFailure(session.originSource, err) + if handshakeErr != nil { + err = handshakeErr + } + } + return + } + for _, dataLen := range dataLens { + n += int64(dataLen) + } + if err = session.TransferBatch(dataLens); err != nil { + return + } + notFirstTime = true + } +} + +func copyPacketBatchToConnectedWaitWithPool(session *packetCopySession, destinationConn N.ConnectedPacketBatchWriter, source N.PacketBatchReadWaiter, notFirstTime bool) (handled bool, n int64, err error) { + handled = true + for { + var ( + buffers []*buf.Buffer + destinations []M.Socksaddr + ) + buffers, destinations, err = source.WaitReadPackets() + if err != nil { + return handled, n, err + } + if len(buffers) == 0 || len(buffers) != len(destinations) { + buf.ReleaseMulti(buffers) + return handled, n, os.ErrInvalid + } + dataLens := make([]int, len(buffers)) + for index, buffer := range buffers { + dataLens[index] = buffer.Len() + } + err = destinationConn.WriteConnectedPacketBatch(buffers) + if err != nil { + if !notFirstTime { + handshakeErr := N.ReportHandshakeFailure(session.originSource, err) + if handshakeErr != nil { + err = handshakeErr + } + } + return + } + for _, dataLen := range dataLens { + n += int64(dataLen) + } + if err = session.TransferBatch(dataLens); err != nil { + return + } + notFirstTime = true + } +} + +func copyConnectedPacketBatchWaitWithPool(session *packetCopySession, destinationConn N.PacketBatchWriter, source N.ConnectedPacketBatchReadWaiter, notFirstTime bool) (handled bool, n int64, err error) { + handled = true + for { + var ( + buffers []*buf.Buffer + destination M.Socksaddr + ) + buffers, destination, err = source.WaitReadConnectedPackets() + if err != nil { + return handled, n, err + } + if len(buffers) == 0 { + buf.ReleaseMulti(buffers) + return handled, n, os.ErrInvalid + } + destinations := make([]M.Socksaddr, len(buffers)) + dataLens := make([]int, len(buffers)) + for index, buffer := range buffers { + destinations[index] = destination + dataLens[index] = buffer.Len() + } + err = destinationConn.WritePacketBatch(buffers, destinations) + if err != nil { + if !notFirstTime { + handshakeErr := N.ReportHandshakeFailure(session.originSource, err) + if handshakeErr != nil { + err = handshakeErr + } + } + return + } + for _, dataLen := range dataLens { + n += int64(dataLen) + } + if err = session.TransferBatch(dataLens); err != nil { + return + } + notFirstTime = true + } +} + +func copyConnectedPacketBatchToConnectedWaitWithPool(session *packetCopySession, destinationConn N.ConnectedPacketBatchWriter, source N.ConnectedPacketBatchReadWaiter, notFirstTime bool) (handled bool, n int64, err error) { + handled = true + for { + var buffers []*buf.Buffer + buffers, _, err = source.WaitReadConnectedPackets() + if err != nil { + return handled, n, err + } + if len(buffers) == 0 { + buf.ReleaseMulti(buffers) + return handled, n, os.ErrInvalid + } + dataLens := make([]int, len(buffers)) + for index, buffer := range buffers { + dataLens[index] = buffer.Len() + } + err = destinationConn.WriteConnectedPacketBatch(buffers) + if err != nil { + if !notFirstTime { + handshakeErr := N.ReportHandshakeFailure(session.originSource, err) + if handshakeErr != nil { + err = handshakeErr + } + } + return + } + for _, dataLen := range dataLens { + n += int64(dataLen) + } + if err = session.TransferBatch(dataLens); err != nil { + return } notFirstTime = true } diff --git a/common/bufio/counter_packet_conn.go b/common/bufio/counter_packet_conn.go index c98d17e06..2208d1d03 100644 --- a/common/bufio/counter_packet_conn.go +++ b/common/bufio/counter_packet_conn.go @@ -1,6 +1,7 @@ package bufio import ( + "os" "sync/atomic" "github.com/sagernet/sing/common" @@ -69,6 +70,94 @@ func (c *CounterPacketConn) WritePacket(buffer *buf.Buffer, destination M.Socksa return nil } +func (c *CounterPacketConn) CreatePacketBatchWriter() (N.PacketBatchWriter, bool) { + writer, created := CreatePacketBatchWriter(c.PacketConn) + if !created { + return nil, false + } + return &counterPacketBatchWriter{writer, c.writeCounter}, true +} + +type counterPacketBatchWriter struct { + writer N.PacketBatchWriter + writeCounts []N.CountFunc +} + +func (w *counterPacketBatchWriter) WritePacketBatch(buffers []*buf.Buffer, destinations []M.Socksaddr) error { + if len(buffers) == 0 || len(buffers) != len(destinations) { + buf.ReleaseMulti(buffers) + return os.ErrInvalid + } + dataLens := make([]int64, len(buffers)) + for index, buffer := range buffers { + dataLens[index] = int64(buffer.Len()) + } + err := w.writer.WritePacketBatch(buffers, destinations) + if err != nil { + return err + } + for _, dataLen := range dataLens { + if dataLen > 0 { + for _, counter := range w.writeCounts { + counter(dataLen) + } + } + } + return nil +} + +func (c *CounterPacketConn) CreateConnectedPacketBatchWriter() (N.ConnectedPacketBatchWriter, bool) { + writer, created := CreateConnectedPacketBatchWriter(c.PacketConn) + if !created { + return nil, false + } + return &counterConnectedPacketBatchWriter{writer, c.writeCounter}, true +} + +type counterConnectedPacketBatchWriter struct { + writer N.ConnectedPacketBatchWriter + writeCounts []N.CountFunc +} + +func (w *counterConnectedPacketBatchWriter) WriteConnectedPacketBatch(buffers []*buf.Buffer) error { + if len(buffers) == 0 { + buf.ReleaseMulti(buffers) + return os.ErrInvalid + } + dataLens := make([]int64, len(buffers)) + for index, buffer := range buffers { + dataLens[index] = int64(buffer.Len()) + } + err := w.writer.WriteConnectedPacketBatch(buffers) + if err != nil { + return err + } + for _, dataLen := range dataLens { + if dataLen > 0 { + for _, counter := range w.writeCounts { + counter(dataLen) + } + } + } + return nil +} + +func (c *CounterPacketConn) CreatePacketBatchReadWaiter() (N.PacketBatchReadWaiter, bool) { + readWaiter, isReadWaiter := CreatePacketBatchReadWaiter(c.PacketConn) + if !isReadWaiter { + return nil, false + } + return &counterPacketBatchReadWaiter{readWaiter, c.readCounter}, true +} + +func (c *CounterPacketConn) CreateConnectedPacketBatchReadWaiter() (N.ConnectedPacketBatchReadWaiter, bool) { + readWaiter, isReadWaiter := CreateConnectedPacketBatchReadWaiter(c.PacketConn) + if !isReadWaiter { + return nil, false + } + return &counterConnectedPacketBatchReadWaiter{readWaiter, c.readCounter}, true +} + func (c *CounterPacketConn) UnwrapPacketReader() (N.PacketReader, []N.CountFunc) { return c.PacketConn, c.readCounter } @@ -80,3 +169,51 @@ func (c *CounterPacketConn) UnwrapPacketWriter() (N.PacketWriter, []N.CountFunc) func (c *CounterPacketConn) Upstream() any { return c.PacketConn } + +type counterPacketBatchReadWaiter struct { + readWaiter N.PacketBatchReadWaiter + readCounts []N.CountFunc +} + +func (w *counterPacketBatchReadWaiter) InitializeReadWaiter(options N.ReadWaitOptions) (needCopy bool) { + return w.readWaiter.InitializeReadWaiter(options) +} + +func (w *counterPacketBatchReadWaiter) WaitReadPackets() (buffers []*buf.Buffer, destinations []M.Socksaddr, err error) { + buffers, destinations, err = w.readWaiter.WaitReadPackets() + if err != nil { + return + } + for _, buffer := range buffers { + if buffer.Len() > 0 { + for _, counter := range w.readCounts { + counter(int64(buffer.Len())) + } + } + } + return +} + +type counterConnectedPacketBatchReadWaiter struct { + readWaiter N.ConnectedPacketBatchReadWaiter + readCounts []N.CountFunc +} + +func (w *counterConnectedPacketBatchReadWaiter) InitializeReadWaiter(options N.ReadWaitOptions) (needCopy bool) { + return w.readWaiter.InitializeReadWaiter(options) +} + +func (w *counterConnectedPacketBatchReadWaiter) WaitReadConnectedPackets() (buffers []*buf.Buffer, destination M.Socksaddr, err error) { + buffers, destination, err = w.readWaiter.WaitReadConnectedPackets() + if err != nil { + return + } + for _, buffer := range buffers { + if buffer.Len() > 0 { + for _, counter := range w.readCounts { + counter(int64(buffer.Len())) + } + } + } + return +} diff --git a/common/bufio/nat.go b/common/bufio/nat.go index 6e5ab6494..e825ef46b 100644 --- a/common/bufio/nat.go +++ b/common/bufio/nat.go @@ -3,6 +3,7 @@ package bufio import ( "net" "net/netip" + "os" "github.com/sagernet/sing/common/buf" M "github.com/sagernet/sing/common/metadata" @@ -67,6 +68,40 @@ func (c *unidirectionalNATPacketConn) WritePacket(buffer *buf.Buffer, destinatio return c.NetPacketConn.WritePacket(buffer, destination) } +func (c *unidirectionalNATPacketConn) CreatePacketBatchWriter() (N.PacketBatchWriter, bool) { + writer, created := CreatePacketBatchWriter(c.NetPacketConn) + if !created { + return nil, false + } + return &unidirectionalNATPacketBatchWriter{c, writer}, true +} + +func (c *unidirectionalNATPacketConn) CreateConnectedPacketBatchWriter() (N.ConnectedPacketBatchWriter, bool) { + return CreateConnectedPacketBatchWriter(c.NetPacketConn) +} + +type unidirectionalNATPacketBatchWriter struct { + *unidirectionalNATPacketConn + writer N.PacketBatchWriter +} + +func (w *unidirectionalNATPacketBatchWriter) WritePacketBatch(buffers []*buf.Buffer, destinations []M.Socksaddr) error { + if len(buffers) == 0 || len(buffers) != len(destinations) { + buf.ReleaseMulti(buffers) + return os.ErrInvalid + } + for index, destination := range destinations { + if socksaddrWithoutPort(destination) == w.destination { + destinations[index] = M.Socksaddr{ + Addr: w.origin.Addr, + Fqdn: w.origin.Fqdn, + Port: destination.Port, + } + } + } + return w.writer.WritePacketBatch(buffers, destinations) +} + func (c *unidirectionalNATPacketConn) UpdateDestination(destinationAddress netip.Addr) { c.destination = M.SocksaddrFrom(destinationAddress, c.destination.Port) } @@ -140,6 +175,40 @@ func (c *bidirectionalNATPacketConn) WritePacket(buffer *buf.Buffer, destination return c.NetPacketConn.WritePacket(buffer, destination) } +func (c *bidirectionalNATPacketConn) CreatePacketBatchWriter() (N.PacketBatchWriter, bool) { + writer, created := CreatePacketBatchWriter(c.NetPacketConn) + if !created { + return nil, false + } + return &bidirectionalNATPacketBatchWriter{c, writer}, true +} + +func (c *bidirectionalNATPacketConn) CreateConnectedPacketBatchWriter() (N.ConnectedPacketBatchWriter, bool) { + return CreateConnectedPacketBatchWriter(c.NetPacketConn) +} + +type bidirectionalNATPacketBatchWriter struct { + *bidirectionalNATPacketConn + writer N.PacketBatchWriter +} + +func (w *bidirectionalNATPacketBatchWriter) WritePacketBatch(buffers []*buf.Buffer, destinations []M.Socksaddr) error { + if len(buffers) == 0 || len(buffers) != len(destinations) { + buf.ReleaseMulti(buffers) + return os.ErrInvalid + } + for index, destination := range destinations { + if socksaddrWithoutPort(destination) == w.destination { + destinations[index] = M.Socksaddr{ + Addr: w.origin.Addr, + Fqdn: w.origin.Fqdn, + Port: destination.Port, + } + } + } + return w.writer.WritePacketBatch(buffers, destinations) +} + func (c *bidirectionalNATPacketConn) UpdateDestination(destinationAddress netip.Addr) { c.destination = M.SocksaddrFrom(destinationAddress, c.destination.Port) } @@ -194,6 +263,36 @@ func (c *destinationNATPacketConn) WritePacket(buffer *buf.Buffer, destination M return c.NetPacketConn.WritePacket(buffer, destination) } +func (c *destinationNATPacketConn) CreatePacketBatchWriter() (N.PacketBatchWriter, bool) { + writer, created := CreatePacketBatchWriter(c.NetPacketConn) + if !created { + return nil, false + } + return &destinationNATPacketBatchWriter{c, writer}, true +} + +func (c *destinationNATPacketConn) CreateConnectedPacketBatchWriter() (N.ConnectedPacketBatchWriter, bool) { + return CreateConnectedPacketBatchWriter(c.NetPacketConn) +} + +type destinationNATPacketBatchWriter struct { + *destinationNATPacketConn + writer N.PacketBatchWriter +} + +func (w *destinationNATPacketBatchWriter) WritePacketBatch(buffers []*buf.Buffer, destinations []M.Socksaddr) error { + if len(buffers) == 0 || len(buffers) != len(destinations) { + buf.ReleaseMulti(buffers) + return os.ErrInvalid + } + for index, destination := range destinations { + if destination == w.destination { + destinations[index] = w.origin + } + } + return w.writer.WritePacketBatch(buffers, destinations) +} + func (c *destinationNATPacketConn) UpdateDestination(destinationAddress netip.Addr) { c.destination = M.SocksaddrFrom(destinationAddress, c.destination.Port) } diff --git a/common/bufio/nat_wait.go b/common/bufio/nat_wait.go index dbb370a08..d6a6e2bd0 100644 --- a/common/bufio/nat_wait.go +++ b/common/bufio/nat_wait.go @@ -14,6 +14,34 @@ func (c *bidirectionalNATPacketConn) CreatePacketReadWaiter() (N.PacketReadWaite return &waitBidirectionalNATPacketConn{c, waiter}, true } +func (c *bidirectionalNATPacketConn) CreatePacketBatchReadWaiter() (N.PacketBatchReadWaiter, bool) { + waiter, created := CreatePacketBatchReadWaiter(c.NetPacketConn) + if !created { + return nil, false + } + return &batchWaitBidirectionalNATPacketConn{c, waiter}, true +} + +func (c *unidirectionalNATPacketConn) CreateConnectedPacketBatchReadWaiter() (N.ConnectedPacketBatchReadWaiter, bool) { + return CreateConnectedPacketBatchReadWaiter(c.NetPacketConn) +} + +func (c *bidirectionalNATPacketConn) CreateConnectedPacketBatchReadWaiter() (N.ConnectedPacketBatchReadWaiter, bool) { + waiter, created := CreateConnectedPacketBatchReadWaiter(c.NetPacketConn) + if !created { + return nil, false + } + return &connectedBatchWaitBidirectionalNATPacketConn{c, waiter}, true +} + +func (c *destinationNATPacketConn) CreateConnectedPacketBatchReadWaiter() (N.ConnectedPacketBatchReadWaiter, bool) { + waiter, created := CreateConnectedPacketBatchReadWaiter(c.NetPacketConn) + if !created { + return nil, false + } + return &connectedBatchWaitDestinationNATPacketConn{c, waiter}, true +} + type waitBidirectionalNATPacketConn struct { *bidirectionalNATPacketConn readWaiter N.PacketReadWaiter @@ -37,3 +65,73 @@ func (c *waitBidirectionalNATPacketConn) WaitReadPacket() (buffer *buf.Buffer, d } return } + +type connectedBatchWaitBidirectionalNATPacketConn struct { + *bidirectionalNATPacketConn + readWaiter N.ConnectedPacketBatchReadWaiter +} + +func (c *connectedBatchWaitBidirectionalNATPacketConn) InitializeReadWaiter(options N.ReadWaitOptions) (needCopy bool) { + return c.readWaiter.InitializeReadWaiter(options) +} + +func (c *connectedBatchWaitBidirectionalNATPacketConn) WaitReadConnectedPackets() (buffers []*buf.Buffer, destination M.Socksaddr, err error) { + buffers, destination, err = c.readWaiter.WaitReadConnectedPackets() + if err != nil { + return + } + if socksaddrWithoutPort(destination) == c.origin { + destination = M.Socksaddr{ + Addr: c.destination.Addr, + Fqdn: c.destination.Fqdn, + Port: destination.Port, + } + } + return +} + +type connectedBatchWaitDestinationNATPacketConn struct { + *destinationNATPacketConn + readWaiter N.ConnectedPacketBatchReadWaiter +} + +func (c *connectedBatchWaitDestinationNATPacketConn) InitializeReadWaiter(options N.ReadWaitOptions) (needCopy bool) { + return c.readWaiter.InitializeReadWaiter(options) +} + +func (c *connectedBatchWaitDestinationNATPacketConn) WaitReadConnectedPackets() (buffers []*buf.Buffer, destination M.Socksaddr, err error) { + buffers, destination, err = c.readWaiter.WaitReadConnectedPackets() + if err != nil { + return + } + if destination == c.origin { + destination = c.destination + } + return +} + +type batchWaitBidirectionalNATPacketConn struct { + *bidirectionalNATPacketConn + readWaiter N.PacketBatchReadWaiter +} + +func (c *batchWaitBidirectionalNATPacketConn) InitializeReadWaiter(options N.ReadWaitOptions) (needCopy bool) { + return c.readWaiter.InitializeReadWaiter(options) +} + +func (c *batchWaitBidirectionalNATPacketConn) WaitReadPackets() (buffers []*buf.Buffer, destinations []M.Socksaddr, err error) { + buffers, destinations, err = c.readWaiter.WaitReadPackets() + if err != nil { + return + } + for index, destination := range destinations { + if socksaddrWithoutPort(destination) == c.origin { + destinations[index] = M.Socksaddr{ + Addr: c.destination.Addr, + Fqdn: c.destination.Fqdn, + Port: destination.Port, + } + } + } + return +} diff --git a/common/bufio/packet_batch.go b/common/bufio/packet_batch.go new file mode 100644 index 000000000..a716e33d9 --- /dev/null +++ b/common/bufio/packet_batch.go @@ -0,0 +1,107 @@ +package bufio + +import ( + "os" + + "github.com/sagernet/sing/common" + "github.com/sagernet/sing/common/buf" + M "github.com/sagernet/sing/common/metadata" + N "github.com/sagernet/sing/common/network" +) + +func CreatePacketBatchWriter(writer N.PacketWriter) (N.PacketBatchWriter, bool) { + if batchCreator, isBatchCreator := writer.(N.PacketBatchWriteCreator); isBatchCreator { + return batchCreator.CreatePacketBatchWriter() + } + if batchWriter, isBatchWriter := writer.(N.PacketBatchWriter); isBatchWriter { + return batchWriter, true + } + if batchWriter, created := createSyscallPacketBatchWriter(writer); created { + return batchWriter, true + } + if u, ok := writer.(N.WriterWithUpstream); ok && u.WriterReplaceable() { + if u, ok := writer.(N.WithUpstreamWriter); ok { + return CreatePacketBatchWriter(u.UpstreamWriter().(N.PacketWriter)) + } + if u, ok := writer.(common.WithUpstream); ok { + return CreatePacketBatchWriter(u.Upstream().(N.PacketWriter)) + } + } + return nil, false +} + +func CreateConnectedPacketBatchWriter(writer N.PacketWriter) (N.ConnectedPacketBatchWriter, bool) { + if batchCreator, isBatchCreator := writer.(N.ConnectedPacketBatchWriteCreator); isBatchCreator { + return batchCreator.CreateConnectedPacketBatchWriter() + } + if batchWriter, isBatchWriter := writer.(N.ConnectedPacketBatchWriter); isBatchWriter { + return batchWriter, true + } + if batchWriter, created := createSyscallConnectedPacketBatchWriter(writer); created { + return batchWriter, true + } + if u, ok := writer.(N.WriterWithUpstream); !ok || !u.WriterReplaceable() { + return nil, false + } + if u, ok := writer.(N.WithUpstreamWriter); ok { + return CreateConnectedPacketBatchWriter(u.UpstreamWriter().(N.PacketWriter)) + } + if u, ok := writer.(common.WithUpstream); ok { + return CreateConnectedPacketBatchWriter(u.Upstream().(N.PacketWriter)) + } + return nil, false +} + +func NewPacketBatchWriter(writer N.PacketWriter) N.PacketBatchWriter { + if batchWriter, created := CreatePacketBatchWriter(writer); created { + return batchWriter + } + return &fallbackPacketBatchWriter{writer} +} + +func NewConnectedPacketBatchWriter(writer N.PacketWriter) N.ConnectedPacketBatchWriter { + if batchWriter, created := CreateConnectedPacketBatchWriter(writer); created { + return batchWriter + } + return &fallbackConnectedPacketBatchWriter{writer} +} + +type fallbackPacketBatchWriter struct { + writer N.PacketWriter +} + +func (w *fallbackPacketBatchWriter) WritePacketBatch(buffers []*buf.Buffer, destinations []M.Socksaddr) error { + if len(buffers) == 0 || len(buffers) != len(destinations) { + buf.ReleaseMulti(buffers) + return os.ErrInvalid + } + for index, buffer := range buffers { + err := w.writer.WritePacket(buffer, destinations[index]) + if err != nil { + buffer.Release() + buf.ReleaseMulti(buffers[index+1:]) + return err + } + } + return nil +} + +type fallbackConnectedPacketBatchWriter struct { + writer N.PacketWriter +} + +func (w *fallbackConnectedPacketBatchWriter) WriteConnectedPacketBatch(buffers []*buf.Buffer) error { + if len(buffers) == 0 { + buf.ReleaseMulti(buffers) + return os.ErrInvalid + } + for index, buffer := range buffers { + err := w.writer.WritePacket(buffer, M.Socksaddr{}) + if err != nil { + buffer.Release() + buf.ReleaseMulti(buffers[index+1:]) + return err + } + } + return nil +} diff --git a/common/bufio/packet_batch_mmsg.go b/common/bufio/packet_batch_mmsg.go new file mode 100644 index 000000000..9cedba7ff --- /dev/null +++ b/common/bufio/packet_batch_mmsg.go @@ -0,0 +1,392 @@ +//go:build linux || netbsd + +package bufio + +import ( + "io" + "net/netip" + "os" + "sync" + "syscall" + "unsafe" + + "github.com/sagernet/sing/common/buf" + "github.com/sagernet/sing/common/control" + E "github.com/sagernet/sing/common/exceptions" + M "github.com/sagernet/sing/common/metadata" + N "github.com/sagernet/sing/common/network" + + "golang.org/x/sys/unix" +) + +var ( + _ N.PacketBatchReadWaiter = (*syscallPacketBatchReadWaiter)(nil) + _ N.ConnectedPacketBatchReadWaiter = (*syscallPacketBatchReadWaiter)(nil) +) + +type mmsghdr struct { + msgHdr unix.Msghdr + msgLen uint32 +} + +type syscallPacketBatchReadWaiter struct { + rawConn syscall.RawConn + connected bool + destination M.Socksaddr + readErr error + readN int + readFunc func(fd uintptr) (done bool) + buffers []*buf.Buffer + destinations []M.Socksaddr + names []unix.RawSockaddrAny + iovecs []unix.Iovec + msgvec []mmsghdr + options N.ReadWaitOptions +} + +func createSyscallPacketBatchReadWaiter(reader any) (N.PacketBatchReadWaiter, bool) { + rawConn := syscallPacketBatchRawConnForRead(reader) + if rawConn == nil { + return nil, false + } + if _, isConnected := syscallPacketBatchPeerDestination(rawConn); isConnected { + return nil, false + } + return &syscallPacketBatchReadWaiter{rawConn: rawConn}, true +} + +func createSyscallConnectedPacketBatchReadWaiter(reader any, destination M.Socksaddr) (N.ConnectedPacketBatchReadWaiter, bool) { + rawConn := syscallPacketBatchRawConnForRead(reader) + if rawConn == nil { + return nil, false + } + peerDestination, isConnected := syscallPacketBatchPeerDestination(rawConn) + if !isConnected { + return nil, false + } + if !destination.IsValid() { + destination = peerDestination + } + return &syscallPacketBatchReadWaiter{rawConn: rawConn, connected: true, destination: destination}, true +} + +func (w *syscallPacketBatchReadWaiter) InitializeReadWaiter(options N.ReadWaitOptions) (needCopy bool) { + if options.BatchSize <= 0 { + options.BatchSize = DefaultPacketReadBatchSize + } + w.options = options + w.buffers = make([]*buf.Buffer, options.BatchSize) + if !w.connected { + w.destinations = make([]M.Socksaddr, options.BatchSize) + w.names = make([]unix.RawSockaddrAny, options.BatchSize) + } + w.iovecs = make([]unix.Iovec, options.BatchSize) + w.msgvec = make([]mmsghdr, options.BatchSize) + w.readFunc = func(fd uintptr) (done bool) { + for i := range w.msgvec { + buffer := w.buffers[i] + if buffer == nil { + buffer = w.options.NewPacketBuffer() + w.buffers[i] = buffer + } + w.iovecs[i] = buffer.Iovec(buffer.FreeLen()) + w.msgvec[i] = mmsghdr{} + if !w.connected { + w.names[i] = unix.RawSockaddrAny{} + w.msgvec[i].msgHdr.Name = (*byte)(unsafe.Pointer(&w.names[i])) + w.msgvec[i].msgHdr.Namelen = unix.SizeofSockaddrAny + } + w.msgvec[i].msgHdr.Iov = &w.iovecs[i] + w.msgvec[i].msgHdr.SetIovlen(1) + } + for { + var errno syscall.Errno + w.readN, errno = recvmmsg(int(fd), w.msgvec, 0) + switch errno { + case 0: + w.readErr = nil + case syscall.EINTR: + continue + case syscall.EAGAIN: + return false + default: + if errno == syscall.EWOULDBLOCK { + return false + } + w.readErr = os.NewSyscallError("recvmmsg", errno) + } + break + } + if w.readN == 0 && w.readErr == nil { + w.readErr = io.EOF + } + for i := 0; i < w.readN; i++ { + buffer := w.buffers[i] + buffer.Truncate(int(w.msgvec[i].msgLen)) + w.options.PostReturn(buffer) + if !w.connected { + w.destinations[i] = M.SocksaddrFromRawSockaddrAny(&w.names[i]) + } + } + return true + } + return false +} + +func (w *syscallPacketBatchReadWaiter) WaitReadPackets() (buffers []*buf.Buffer, destinations []M.Socksaddr, err error) { + if w.connected { + return nil, nil, os.ErrInvalid + } + if w.readFunc == nil { + return nil, nil, os.ErrInvalid + } + err = w.rawConn.Read(w.readFunc) + if err != nil { + return + } + if w.readErr != nil { + if w.readErr == io.EOF { + return nil, nil, io.EOF + } + return nil, nil, E.Cause(w.readErr, "raw read") + } + buffers = make([]*buf.Buffer, w.readN) + destinations = make([]M.Socksaddr, w.readN) + for i := 0; i < w.readN; i++ { + buffers[i] = w.buffers[i] + w.buffers[i] = nil + destinations[i] = w.destinations[i] + } + w.readN = 0 + return +} + +func (w *syscallPacketBatchReadWaiter) WaitReadConnectedPackets() (buffers []*buf.Buffer, destination M.Socksaddr, err error) { + if !w.connected { + return nil, M.Socksaddr{}, os.ErrInvalid + } + if w.readFunc == nil { + return nil, M.Socksaddr{}, os.ErrInvalid + } + err = w.rawConn.Read(w.readFunc) + if err != nil { + return + } + if w.readErr != nil { + if w.readErr == io.EOF { + return nil, M.Socksaddr{}, io.EOF + } + return nil, M.Socksaddr{}, E.Cause(w.readErr, "raw read") + } + buffers = make([]*buf.Buffer, w.readN) + for i := 0; i < w.readN; i++ { + buffers[i] = w.buffers[i] + w.buffers[i] = nil + } + w.readN = 0 + destination = w.destination + return +} + +var ( + _ N.PacketBatchWriter = (*syscallPacketBatchWriter)(nil) + _ N.ConnectedPacketBatchWriter = (*syscallPacketBatchWriter)(nil) +) + +type syscallPacketBatchWriter struct { + upstream any + rawConn syscall.RawConn + connected bool + access sync.Mutex + localAddr netip.AddrPort + names []unix.RawSockaddrAny + iovecs []unix.Iovec + msgvec []mmsghdr +} + +func createSyscallPacketBatchWriter(writer any) (N.PacketBatchWriter, bool) { + rawConn := syscallPacketBatchRawConnForWrite(writer) + if rawConn == nil { + return nil, false + } + if _, isConnected := syscallPacketBatchPeerDestination(rawConn); isConnected { + return nil, false + } + return &syscallPacketBatchWriter{upstream: writer, rawConn: rawConn}, true +} + +func createSyscallConnectedPacketBatchWriter(writer any) (N.ConnectedPacketBatchWriter, bool) { + rawConn := syscallPacketBatchRawConnForWrite(writer) + if rawConn == nil { + return nil, false + } + if _, isConnected := syscallPacketBatchPeerDestination(rawConn); !isConnected { + return nil, false + } + return &syscallPacketBatchWriter{upstream: writer, rawConn: rawConn, connected: true}, true +} + +func (w *syscallPacketBatchWriter) WritePacketBatch(buffers []*buf.Buffer, destinations []M.Socksaddr) error { + if w.connected { + buf.ReleaseMulti(buffers) + return os.ErrInvalid + } + w.access.Lock() + defer w.access.Unlock() + defer buf.ReleaseMulti(buffers) + if len(buffers) == 0 || len(buffers) != len(destinations) { + return os.ErrInvalid + } + if !w.localAddr.IsValid() { + err := control.Raw(w.rawConn, func(fd uintptr) error { + name, err := unix.Getsockname(int(fd)) + if err != nil { + return err + } + w.localAddr = M.AddrPortFromSockaddr(name) + return nil + }) + if err != nil { + return err + } + } + names := growSlice(w.names, len(buffers)) + iovecs := growSlice(w.iovecs, len(buffers)) + msgvec := growSlice(w.msgvec, len(buffers)) + defer func() { + clear(iovecs) + clear(msgvec) + w.names = names[:0] + w.iovecs = iovecs[:0] + w.msgvec = msgvec[:0] + }() + for i, buffer := range buffers { + names[i] = unix.RawSockaddrAny{} + iovecs[i] = unix.Iovec{} + msgvec[i] = mmsghdr{} + msgvec[i].msgHdr.Name = (*byte)(unsafe.Pointer(&names[i])) + msgvec[i].msgHdr.Namelen = M.AddrPortToRawSockaddrAny(&names[i], destinations[i].AddrPort(), w.localAddr.Addr().Is6()) + if !buffer.IsEmpty() { + iovecs[i] = buffer.Iovec(buffer.Len()) + msgvec[i].msgHdr.Iov = &iovecs[i] + msgvec[i].msgHdr.SetIovlen(1) + } + } + writeMsgvec := msgvec + var innerErr syscall.Errno + err := w.rawConn.Write(func(fd uintptr) (done bool) { + for len(writeMsgvec) > 0 { + n, errno := sendmmsg(int(fd), writeMsgvec, 0) + switch errno { + case 0: + case syscall.EINTR: + continue + case syscall.EAGAIN: + return false + default: + if errno == syscall.EWOULDBLOCK { + return false + } + innerErr = errno + return true + } + if n == 0 { + innerErr = syscall.EIO + return true + } + writeMsgvec = writeMsgvec[n:] + } + return true + }) + if innerErr != 0 { + err = os.NewSyscallError("sendmmsg", innerErr) + } + return err +} + +func (w *syscallPacketBatchWriter) WriteConnectedPacketBatch(buffers []*buf.Buffer) error { + if !w.connected { + buf.ReleaseMulti(buffers) + return os.ErrInvalid + } + w.access.Lock() + defer w.access.Unlock() + defer buf.ReleaseMulti(buffers) + if len(buffers) == 0 { + return os.ErrInvalid + } + iovecs := growSlice(w.iovecs, len(buffers)) + msgvec := growSlice(w.msgvec, len(buffers)) + defer func() { + clear(iovecs) + clear(msgvec) + w.iovecs = iovecs[:0] + w.msgvec = msgvec[:0] + }() + for i, buffer := range buffers { + iovecs[i] = unix.Iovec{} + msgvec[i] = mmsghdr{} + if !buffer.IsEmpty() { + iovecs[i] = buffer.Iovec(buffer.Len()) + msgvec[i].msgHdr.Iov = &iovecs[i] + msgvec[i].msgHdr.SetIovlen(1) + } + } + writeMsgvec := msgvec + var innerErr syscall.Errno + err := w.rawConn.Write(func(fd uintptr) (done bool) { + for len(writeMsgvec) > 0 { + n, errno := sendmmsg(int(fd), writeMsgvec, 0) + switch errno { + case 0: + case syscall.EINTR: + continue + case syscall.EAGAIN: + return false + default: + if errno == syscall.EWOULDBLOCK { + return false + } + innerErr = errno + return true + } + if n == 0 { + innerErr = syscall.EIO + return true + } + writeMsgvec = writeMsgvec[n:] + } + return true + }) + if innerErr != 0 { + err = os.NewSyscallError("sendmmsg", innerErr) + } + return err +} + +func (w *syscallPacketBatchWriter) Upstream() any { + return w.upstream +} + +func growSlice[T any](values []T, size int) []T { + if cap(values) < size { + return make([]T, size) + } + return values[:size] +} + +func recvmmsg(fd int, msgvec []mmsghdr, flags int) (int, syscall.Errno) { + return mmsgSyscall(sysRecvmmsg, fd, msgvec, flags) +} + +func sendmmsg(fd int, msgvec []mmsghdr, flags int) (int, syscall.Errno) { + return mmsgSyscall(unix.SYS_SENDMMSG, fd, msgvec, flags) +} + +func mmsgSyscall(trap uintptr, fd int, msgvec []mmsghdr, flags int) (int, syscall.Errno) { + r0, _, errno := unix.Syscall6(trap, uintptr(fd), uintptr(unsafe.Pointer(&msgvec[0])), uintptr(len(msgvec)), uintptr(flags), 0, 0) + if errno != 0 { + return 0, errno + } + return int(r0), 0 +} diff --git a/common/bufio/packet_batch_mmsg_generic.go b/common/bufio/packet_batch_mmsg_generic.go new file mode 100644 index 000000000..b5e66dcfa --- /dev/null +++ b/common/bufio/packet_batch_mmsg_generic.go @@ -0,0 +1,7 @@ +//go:build (linux && !386 && !arm && !mips && !mipsle && !ppc) || netbsd + +package bufio + +import "golang.org/x/sys/unix" + +const sysRecvmmsg = unix.SYS_RECVMMSG diff --git a/common/bufio/packet_batch_mmsg_linux32.go b/common/bufio/packet_batch_mmsg_linux32.go new file mode 100644 index 000000000..0faa2b0ce --- /dev/null +++ b/common/bufio/packet_batch_mmsg_linux32.go @@ -0,0 +1,7 @@ +//go:build linux && (386 || arm || mips || mipsle || ppc) + +package bufio + +import "golang.org/x/sys/unix" + +const sysRecvmmsg = unix.SYS_RECVMMSG_TIME64 diff --git a/common/bufio/packet_batch_msgx_darwin.go b/common/bufio/packet_batch_msgx_darwin.go new file mode 100644 index 000000000..0fb23e824 --- /dev/null +++ b/common/bufio/packet_batch_msgx_darwin.go @@ -0,0 +1,306 @@ +//go:build darwin + +package bufio + +import ( + "io" + "os" + "sync" + "syscall" + "unsafe" + + "github.com/sagernet/sing/common/buf" + E "github.com/sagernet/sing/common/exceptions" + M "github.com/sagernet/sing/common/metadata" + N "github.com/sagernet/sing/common/network" + + "golang.org/x/sys/unix" +) + +var ( + _ N.PacketBatchReadWaiter = (*syscallPacketBatchReadWaiter)(nil) + _ N.ConnectedPacketBatchReadWaiter = (*syscallPacketBatchReadWaiter)(nil) +) + +type msghdrX struct { + name *byte + namelen uint32 + iov *unix.Iovec + iovlen int32 + control *byte + controllen uint32 + flags int32 + datalen uint64 +} + +type syscallPacketBatchReadWaiter struct { + rawConn syscall.RawConn + connected bool + destination M.Socksaddr + readErr error + readN int + readFunc func(fd uintptr) (done bool) + buffers []*buf.Buffer + destinations []M.Socksaddr + names []unix.RawSockaddrAny + iovecs []unix.Iovec + msgvec []msghdrX + options N.ReadWaitOptions +} + +func createSyscallPacketBatchReadWaiter(reader any) (N.PacketBatchReadWaiter, bool) { + rawConn := syscallPacketBatchRawConnForRead(reader) + if rawConn == nil { + return nil, false + } + if _, isConnected := syscallPacketBatchPeerDestination(rawConn); isConnected { + return nil, false + } + return &syscallPacketBatchReadWaiter{rawConn: rawConn}, true +} + +func createSyscallConnectedPacketBatchReadWaiter(reader any, destination M.Socksaddr) (N.ConnectedPacketBatchReadWaiter, bool) { + rawConn := syscallPacketBatchRawConnForRead(reader) + if rawConn == nil { + return nil, false + } + peerDestination, isConnected := syscallPacketBatchPeerDestination(rawConn) + if !isConnected { + return nil, false + } + if !destination.IsValid() { + destination = peerDestination + } + return &syscallPacketBatchReadWaiter{rawConn: rawConn, connected: true, destination: destination}, true +} + +func (w *syscallPacketBatchReadWaiter) InitializeReadWaiter(options N.ReadWaitOptions) (needCopy bool) { + if options.BatchSize <= 0 { + options.BatchSize = DefaultPacketReadBatchSize + } + w.options = options + w.buffers = make([]*buf.Buffer, options.BatchSize) + if !w.connected { + w.destinations = make([]M.Socksaddr, options.BatchSize) + w.names = make([]unix.RawSockaddrAny, options.BatchSize) + } + w.iovecs = make([]unix.Iovec, options.BatchSize) + w.msgvec = make([]msghdrX, options.BatchSize) + w.readFunc = func(fd uintptr) (done bool) { + for i := range w.msgvec { + buffer := w.buffers[i] + if buffer == nil { + buffer = w.options.NewPacketBuffer() + w.buffers[i] = buffer + } + w.iovecs[i] = buffer.Iovec(buffer.FreeLen()) + w.msgvec[i] = msghdrX{ + iov: &w.iovecs[i], + iovlen: 1, + } + if !w.connected { + w.names[i] = unix.RawSockaddrAny{} + w.msgvec[i].name = (*byte)(unsafe.Pointer(&w.names[i])) + w.msgvec[i].namelen = unix.SizeofSockaddrAny + } + } + for { + var errno syscall.Errno + w.readN, errno = recvmsgX(int(fd), w.msgvec, 0) + switch errno { + case 0: + w.readErr = nil + case syscall.EINTR: + continue + case syscall.EAGAIN: + return false + default: + if errno == syscall.EWOULDBLOCK { + return false + } + w.readErr = os.NewSyscallError("recvmsg_x", errno) + } + break + } + if w.readN == 0 && w.readErr == nil { + w.readErr = io.EOF + } + for i := 0; i < w.readN; i++ { + buffer := w.buffers[i] + buffer.Truncate(int(w.msgvec[i].datalen)) + w.options.PostReturn(buffer) + if !w.connected { + w.destinations[i] = M.SocksaddrFromRawSockaddrAny(&w.names[i]) + } + } + return true + } + return false +} + +func (w *syscallPacketBatchReadWaiter) WaitReadPackets() (buffers []*buf.Buffer, destinations []M.Socksaddr, err error) { + if w.connected { + return nil, nil, os.ErrInvalid + } + if w.readFunc == nil { + return nil, nil, os.ErrInvalid + } + err = w.rawConn.Read(w.readFunc) + if err != nil { + return + } + if w.readErr != nil { + if w.readErr == io.EOF { + return nil, nil, io.EOF + } + return nil, nil, E.Cause(w.readErr, "raw read") + } + buffers = make([]*buf.Buffer, w.readN) + destinations = make([]M.Socksaddr, w.readN) + for i := 0; i < w.readN; i++ { + buffers[i] = w.buffers[i] + w.buffers[i] = nil + destinations[i] = w.destinations[i] + } + w.readN = 0 + return +} + +func (w *syscallPacketBatchReadWaiter) WaitReadConnectedPackets() (buffers []*buf.Buffer, destination M.Socksaddr, err error) { + if !w.connected { + return nil, M.Socksaddr{}, os.ErrInvalid + } + if w.readFunc == nil { + return nil, M.Socksaddr{}, os.ErrInvalid + } + err = w.rawConn.Read(w.readFunc) + if err != nil { + return + } + if w.readErr != nil { + if w.readErr == io.EOF { + return nil, M.Socksaddr{}, io.EOF + } + return nil, M.Socksaddr{}, E.Cause(w.readErr, "raw read") + } + buffers = make([]*buf.Buffer, w.readN) + for i := 0; i < w.readN; i++ { + buffers[i] = w.buffers[i] + w.buffers[i] = nil + } + w.readN = 0 + destination = w.destination + return +} + +var _ N.ConnectedPacketBatchWriter = (*syscallConnectedPacketBatchWriter)(nil) + +type syscallConnectedPacketBatchWriter struct { + upstream any + rawConn syscall.RawConn + access sync.Mutex + iovecs []unix.Iovec + msgvec []msghdrX +} + +func createSyscallPacketBatchWriter(writer any) (N.PacketBatchWriter, bool) { + return nil, false +} + +func createSyscallConnectedPacketBatchWriter(writer any) (N.ConnectedPacketBatchWriter, bool) { + rawConn := syscallPacketBatchRawConnForWrite(writer) + if rawConn == nil { + return nil, false + } + if _, isConnected := syscallPacketBatchPeerDestination(rawConn); !isConnected { + return nil, false + } + return &syscallConnectedPacketBatchWriter{upstream: writer, rawConn: rawConn}, true +} + +func (w *syscallConnectedPacketBatchWriter) WriteConnectedPacketBatch(buffers []*buf.Buffer) error { + w.access.Lock() + defer w.access.Unlock() + defer buf.ReleaseMulti(buffers) + if len(buffers) == 0 { + return os.ErrInvalid + } + iovecs := growSlice(w.iovecs, len(buffers)) + msgvec := growSlice(w.msgvec, len(buffers)) + defer func() { + clear(iovecs) + clear(msgvec) + w.iovecs = iovecs[:0] + w.msgvec = msgvec[:0] + }() + for i, buffer := range buffers { + iovecs[i] = unix.Iovec{} + msgvec[i] = msghdrX{} + if !buffer.IsEmpty() { + iovecs[i] = buffer.Iovec(buffer.Len()) + msgvec[i].iov = &iovecs[i] + msgvec[i].iovlen = 1 + } + } + writeMsgvec := msgvec + maxBatchSize := len(writeMsgvec) + var innerErr syscall.Errno + err := w.rawConn.Write(func(fd uintptr) (done bool) { + for len(writeMsgvec) > 0 { + batchSize := min(maxBatchSize, len(writeMsgvec)) + n, errno := sendmsgX(int(fd), writeMsgvec[:batchSize], 0) + switch { + case errno == 0: + case errno == syscall.EINTR: + continue + case errno == syscall.EMSGSIZE && batchSize > 1: + maxBatchSize = (batchSize + 1) / 2 + continue + case errno == syscall.EAGAIN || errno == syscall.EWOULDBLOCK: + return false + default: + innerErr = errno + return true + } + if n == 0 { + innerErr = syscall.EIO + return true + } + writeMsgvec = writeMsgvec[n:] + } + return true + }) + if innerErr != 0 { + err = os.NewSyscallError("sendmsg_x", innerErr) + } + return err +} + +func (w *syscallConnectedPacketBatchWriter) Upstream() any { + return w.upstream +} + +func growSlice[T any](values []T, size int) []T { + if cap(values) < size { + return make([]T, size) + } + return values[:size] +} + +func recvmsgX(fd int, msgvec []msghdrX, flags int) (int, syscall.Errno) { + //nolint:staticcheck + return msgxSyscall(unix.SYS_RECVMSG_X, fd, msgvec, flags) +} + +func sendmsgX(fd int, msgvec []msghdrX, flags int) (int, syscall.Errno) { + //nolint:staticcheck + return msgxSyscall(unix.SYS_SENDMSG_X, fd, msgvec, flags) +} + +func msgxSyscall(trap uintptr, fd int, msgvec []msghdrX, flags int) (int, syscall.Errno) { + r0, _, errno := unix.Syscall6(trap, uintptr(fd), uintptr(unsafe.Pointer(&msgvec[0])), uintptr(len(msgvec)), uintptr(flags), 0, 0) + if errno != 0 { + return 0, errno + } + return int(r0), 0 +} diff --git a/common/bufio/packet_batch_stub.go b/common/bufio/packet_batch_stub.go new file mode 100644 index 000000000..3c0021c4e --- /dev/null +++ b/common/bufio/packet_batch_stub.go @@ -0,0 +1,24 @@ +//go:build !linux && !netbsd && !darwin + +package bufio + +import ( + M "github.com/sagernet/sing/common/metadata" + N "github.com/sagernet/sing/common/network" +) + +func createSyscallPacketBatchReadWaiter(reader any) (N.PacketBatchReadWaiter, bool) { + return nil, false +} + +func createSyscallPacketBatchWriter(writer any) (N.PacketBatchWriter, bool) { + return nil, false +} + +func createSyscallConnectedPacketBatchReadWaiter(reader any, destination M.Socksaddr) (N.ConnectedPacketBatchReadWaiter, bool) { + return nil, false +} + +func createSyscallConnectedPacketBatchWriter(writer any) (N.ConnectedPacketBatchWriter, bool) { + return nil, false +} diff --git a/common/bufio/packet_batch_syscall.go b/common/bufio/packet_batch_syscall.go new file mode 100644 index 000000000..4eaf18c75 --- /dev/null +++ b/common/bufio/packet_batch_syscall.go @@ -0,0 +1,58 @@ +//go:build linux || netbsd || darwin + +package bufio + +import ( + "io" + "syscall" + + "github.com/sagernet/sing/common/control" + M "github.com/sagernet/sing/common/metadata" + N "github.com/sagernet/sing/common/network" + + "golang.org/x/sys/unix" +) + +func syscallPacketBatchRawConnForRead(reader any) syscall.RawConn { + if syscallConn, isSyscallConn := reader.(syscall.Conn); isSyscallConn { + rawConn, err := syscallConn.SyscallConn() + if err == nil { + return rawConn + } + } + if ioReader, isReader := reader.(io.Reader); isReader { + _, rawConn := N.SyscallConnForRead(ioReader) + return rawConn + } + return nil +} + +func syscallPacketBatchRawConnForWrite(writer any) syscall.RawConn { + if syscallConn, isSyscallConn := writer.(syscall.Conn); isSyscallConn { + rawConn, err := syscallConn.SyscallConn() + if err == nil { + return rawConn + } + } + if ioWriter, isWriter := writer.(io.Writer); isWriter { + _, rawConn := N.SyscallConnForWrite(ioWriter) + return rawConn + } + return nil +} + +func syscallPacketBatchPeerDestination(rawConn syscall.RawConn) (M.Socksaddr, bool) { + if rawConn == nil { + return M.Socksaddr{}, false + } + var destination M.Socksaddr + err := control.Raw(rawConn, func(fd uintptr) error { + peer, err := unix.Getpeername(int(fd)) + if err != nil { + return err + } + destination = M.SocksaddrFromNetIP(M.AddrPortFromSockaddr(peer)).Unwrap() + return nil + }) + return destination, err == nil && destination.IsValid() +} diff --git a/common/bufio/packet_batch_test.go b/common/bufio/packet_batch_test.go new file mode 100644 index 000000000..a645186d9 --- /dev/null +++ b/common/bufio/packet_batch_test.go @@ -0,0 +1,504 @@ +package bufio + +import ( + "errors" + "io" + "net" + "net/netip" + "strconv" + "sync/atomic" + "testing" + "time" + + "github.com/sagernet/sing/common" + "github.com/sagernet/sing/common/buf" + M "github.com/sagernet/sing/common/metadata" + N "github.com/sagernet/sing/common/network" + + "github.com/stretchr/testify/require" +) + +func TestCreatePacketVectorisedReadWaiterDeprecated(t *testing.T) { + t.Parallel() + reader := &testPacketBatchReader{} + batchReader, created := CreatePacketBatchReadWaiter(reader) + require.True(t, created) + require.NotNil(t, batchReader) + vectorisedReader, created := CreatePacketVectorisedReadWaiter(reader) + require.True(t, created) + require.Same(t, batchReader, vectorisedReader) +} + +func TestCopyPacketUsesBatchPath(t *testing.T) { + t.Parallel() + destinationA := M.SocksaddrFrom(netip.MustParseAddr("127.0.0.1"), 1000) + destinationB := M.SocksaddrFrom(netip.MustParseAddr("127.0.0.1"), 1001) + reader := &testPacketBatchReader{ + batches: []testPacketBatch{{ + payloads: [][]byte{[]byte("a"), []byte("bc")}, + destinations: []M.Socksaddr{destinationA, destinationB}, + }}, + } + writer := &testPacketBatchWriter{} + var readBytes, readPackets, writeBytes, writePackets atomic.Int64 + n, err := CopyPacketWithCounters(writer, reader, reader, []N.CountFunc{ + func(n int64) { readBytes.Add(n) }, + func(n int64) { readPackets.Add(1) }, + }, []N.CountFunc{ + func(n int64) { writeBytes.Add(n) }, + func(n int64) { writePackets.Add(1) }, + }) + require.ErrorIs(t, err, io.EOF) + require.Equal(t, int64(3), n) + require.True(t, writer.usedBatch) + require.Equal(t, [][]byte{[]byte("a"), []byte("bc")}, writer.payloads) + require.Equal(t, []M.Socksaddr{destinationA, destinationB}, writer.destinations) + require.Equal(t, int64(3), readBytes.Load()) + require.Equal(t, int64(2), readPackets.Load()) + require.Equal(t, int64(3), writeBytes.Load()) + require.Equal(t, int64(2), writePackets.Load()) +} + +func TestCopyPacketUsesConnectedBatchWriter(t *testing.T) { + t.Parallel() + destinationA := M.SocksaddrFrom(netip.MustParseAddr("127.0.0.1"), 1000) + destinationB := M.SocksaddrFrom(netip.MustParseAddr("127.0.0.1"), 1001) + reader := &testPacketBatchReader{ + batches: []testPacketBatch{{ + payloads: [][]byte{[]byte("a"), []byte("bc")}, + destinations: []M.Socksaddr{destinationA, destinationB}, + }}, + } + writer := &testConnectedPacketBatchWriter{} + var readBytes, readPackets, writeBytes, writePackets atomic.Int64 + n, err := CopyPacketWithCounters(writer, reader, reader, []N.CountFunc{ + func(n int64) { readBytes.Add(n) }, + func(n int64) { readPackets.Add(1) }, + }, []N.CountFunc{ + func(n int64) { writeBytes.Add(n) }, + func(n int64) { writePackets.Add(1) }, + }) + require.ErrorIs(t, err, io.EOF) + require.Equal(t, int64(3), n) + require.True(t, writer.usedConnectedBatch) + require.Equal(t, [][]byte{[]byte("a"), []byte("bc")}, writer.payloads) + require.Equal(t, int64(3), readBytes.Load()) + require.Equal(t, int64(2), readPackets.Load()) + require.Equal(t, int64(3), writeBytes.Load()) + require.Equal(t, int64(2), writePackets.Load()) +} + +func TestCopyPacketUsesConnectedBatchReader(t *testing.T) { + t.Parallel() + destination := M.SocksaddrFrom(netip.MustParseAddr("127.0.0.1"), 1000) + reader := &testConnectedPacketBatchReader{ + destination: destination, + batches: [][]byte{ + []byte("a"), + []byte("bc"), + }, + } + writer := &testPacketBatchWriter{} + var readBytes, readPackets, writeBytes, writePackets atomic.Int64 + n, err := CopyPacketWithCounters(writer, reader, reader, []N.CountFunc{ + func(n int64) { readBytes.Add(n) }, + func(n int64) { readPackets.Add(1) }, + }, []N.CountFunc{ + func(n int64) { writeBytes.Add(n) }, + func(n int64) { writePackets.Add(1) }, + }) + require.ErrorIs(t, err, io.EOF) + require.Equal(t, int64(3), n) + require.True(t, writer.usedBatch) + require.Equal(t, [][]byte{[]byte("a"), []byte("bc")}, writer.payloads) + require.Equal(t, []M.Socksaddr{destination, destination}, writer.destinations) + require.Equal(t, int64(3), readBytes.Load()) + require.Equal(t, int64(2), readPackets.Load()) + require.Equal(t, int64(3), writeBytes.Load()) + require.Equal(t, int64(2), writePackets.Load()) +} + +func TestCopyPacketCachedWriteErrorAfterSuccess(t *testing.T) { + t.Parallel() + destination := M.SocksaddrFrom(netip.MustParseAddr("127.0.0.1"), 1000) + writeErr := errors.New("cached write failed") + reader := &testCachedPacketReader{ + packets: []*N.PacketBuffer{ + testPacketBuffer("a", destination), + testPacketBuffer("bc", destination), + }, + readErr: io.EOF, + } + writer := &testFailAfterPacketWriter{ + failAt: 2, + err: writeErr, + } + n, err := CopyPacket(writer, reader) + require.ErrorIs(t, err, writeErr) + require.Equal(t, int64(1), n) +} + +func TestNewPacketBatchWriterFallback(t *testing.T) { + t.Parallel() + destinationA := M.SocksaddrFrom(netip.MustParseAddr("127.0.0.1"), 1000) + destinationB := M.SocksaddrFrom(netip.MustParseAddr("127.0.0.1"), 1001) + writer := &testPacketWriter{} + _, created := CreatePacketBatchWriter(writer) + require.False(t, created) + batchWriter := NewPacketBatchWriter(writer) + require.NoError(t, batchWriter.WritePacketBatch(testBuffers("a", "bc"), []M.Socksaddr{destinationA, destinationB})) + require.Equal(t, [][]byte{[]byte("a"), []byte("bc")}, writer.payloads) + require.Equal(t, []M.Socksaddr{destinationA, destinationB}, writer.destinations) +} + +func TestNATPacketBatchWriter(t *testing.T) { + t.Parallel() + origin := M.SocksaddrFrom(netip.MustParseAddr("10.0.0.1"), 1000) + destination := M.SocksaddrFrom(netip.MustParseAddr("20.0.0.1"), 2000) + other := M.SocksaddrFrom(netip.MustParseAddr("30.0.0.1"), 3000) + conn := &testNetPacketConn{} + natConn := NewNATPacketConn(conn, origin, destination) + writer, created := CreatePacketBatchWriter(natConn) + require.True(t, created) + require.NoError(t, writer.WritePacketBatch(testBuffers("a", "bc"), []M.Socksaddr{ + M.SocksaddrFrom(destination.Addr, 9999), + other, + })) + require.Equal(t, [][]byte{[]byte("a"), []byte("bc")}, conn.payloads) + require.Equal(t, []M.Socksaddr{ + M.SocksaddrFrom(origin.Addr, 9999), + other, + }, conn.destinations) +} + +func TestRemoteAddrDoesNotCreateConnectedPacketBatch(t *testing.T) { + t.Parallel() + conn := &testRemoteAddrPacketConn{} + _, readCreated := CreateConnectedPacketBatchReadWaiter(conn) + require.False(t, readCreated) + _, writeCreated := CreateConnectedPacketBatchWriter(conn) + require.False(t, writeCreated) +} + +func TestUnconnectedUnbindPacketConnDoesNotCreateConnectedPacketBatch(t *testing.T) { + t.Parallel() + inputConn, outputConn, outputAddr := UDPPipe(t) + defer inputConn.Close() + defer outputConn.Close() + packetConn := NewUnbindPacketConnWithAddr(inputConn.(*net.UDPConn), outputAddr) + _, readCreated := CreateConnectedPacketBatchReadWaiter(packetConn) + require.False(t, readCreated) + _, writeCreated := CreateConnectedPacketBatchWriter(packetConn) + require.False(t, writeCreated) +} + +func TestPacketBatchUDP(t *testing.T) { + t.Parallel() + for _, batchSize := range []int{1, 2, DefaultPacketReadBatchSize} { + t.Run(strconv.Itoa(batchSize), func(t *testing.T) { + t.Parallel() + inputConn, outputConn, outputAddr := UDPPipe(t) + defer inputConn.Close() + defer outputConn.Close() + require.NoError(t, inputConn.SetDeadline(time.Now().Add(time.Second))) + require.NoError(t, outputConn.SetDeadline(time.Now().Add(time.Second))) + packetInputConn := NewPacketConn(inputConn) + reader, readCreated := CreatePacketBatchReadWaiter(packetInputConn) + writer, writeCreated := CreatePacketBatchWriter(packetInputConn) + if !readCreated && !writeCreated { + t.Skip("packet batch syscall backend is not available on this platform") + } + if writeCreated { + require.NoError(t, writer.WritePacketBatch(testBuffers("x", "yz"), []M.Socksaddr{outputAddr, outputAddr})) + output := make([]byte, 2) + n, _, err := outputConn.ReadFrom(output) + require.NoError(t, err) + require.Equal(t, []byte("x"), output[:n]) + n, _, err = outputConn.ReadFrom(output) + require.NoError(t, err) + require.Equal(t, []byte("yz"), output[:n]) + } + if readCreated { + reader.InitializeReadWaiter(N.ReadWaitOptions{BatchSize: batchSize}) + _, err := outputConn.WriteTo([]byte("a"), inputConn.LocalAddr()) + require.NoError(t, err) + _, err = outputConn.WriteTo([]byte("bc"), inputConn.LocalAddr()) + require.NoError(t, err) + var payloads [][]byte + var destinations []M.Socksaddr + for len(payloads) < 2 { + buffers, destinationsN, err := reader.WaitReadPackets() + require.NoError(t, err) + require.NotEmpty(t, buffers) + require.Len(t, destinationsN, len(buffers)) + for index, buffer := range buffers { + payloads = append(payloads, append([]byte(nil), buffer.Bytes()...)) + destinations = append(destinations, destinationsN[index]) + buffer.Release() + } + } + require.Equal(t, [][]byte{[]byte("a"), []byte("bc")}, payloads) + require.Equal(t, []M.Socksaddr{outputAddr, outputAddr}, destinations) + } + }) + } +} + +func TestConnectedPacketBatchUDP(t *testing.T) { + t.Parallel() + for _, batchSize := range []int{1, 2, DefaultPacketReadBatchSize} { + t.Run(strconv.Itoa(batchSize), func(t *testing.T) { + t.Parallel() + serverConn, err := net.ListenUDP("udp", &net.UDPAddr{IP: net.ParseIP("127.0.0.1")}) + require.NoError(t, err) + defer serverConn.Close() + clientConn, err := net.DialUDP("udp", nil, serverConn.LocalAddr().(*net.UDPAddr)) + require.NoError(t, err) + defer clientConn.Close() + packetConn := NewUnbindPacketConn(clientConn) + _, ordinaryReadCreated := CreatePacketBatchReadWaiter(packetConn) + require.False(t, ordinaryReadCreated) + _, ordinaryWriteCreated := CreatePacketBatchWriter(packetConn) + require.False(t, ordinaryWriteCreated) + reader, readCreated := CreateConnectedPacketBatchReadWaiter(packetConn) + writer, writeCreated := CreateConnectedPacketBatchWriter(packetConn) + if !readCreated || !writeCreated { + t.Skip("connected packet batch syscall backend is not available on this platform") + } + require.NoError(t, serverConn.SetDeadline(time.Now().Add(time.Second))) + require.NoError(t, clientConn.SetDeadline(time.Now().Add(time.Second))) + reader.InitializeReadWaiter(N.ReadWaitOptions{BatchSize: batchSize}) + require.NoError(t, writer.WriteConnectedPacketBatch(testBuffers("x", "yz"))) + output := make([]byte, 2) + n, addr, err := serverConn.ReadFromUDP(output) + require.NoError(t, err) + require.Equal(t, clientConn.LocalAddr().String(), addr.String()) + require.Equal(t, []byte("x"), output[:n]) + n, addr, err = serverConn.ReadFromUDP(output) + require.NoError(t, err) + require.Equal(t, clientConn.LocalAddr().String(), addr.String()) + require.Equal(t, []byte("yz"), output[:n]) + _, err = serverConn.WriteToUDP([]byte("a"), clientConn.LocalAddr().(*net.UDPAddr)) + require.NoError(t, err) + _, err = serverConn.WriteToUDP([]byte("bc"), clientConn.LocalAddr().(*net.UDPAddr)) + require.NoError(t, err) + var payloads [][]byte + for len(payloads) < 2 { + buffers, destination, err := reader.WaitReadConnectedPackets() + require.NoError(t, err) + require.Equal(t, M.SocksaddrFromNet(clientConn.RemoteAddr()).Unwrap(), destination) + for _, buffer := range buffers { + payloads = append(payloads, append([]byte(nil), buffer.Bytes()...)) + buffer.Release() + } + } + require.Equal(t, [][]byte{[]byte("a"), []byte("bc")}, payloads) + }) + } +} + +type testPacketBatch struct { + payloads [][]byte + destinations []M.Socksaddr +} + +type testPacketBatchReader struct { + batches []testPacketBatch + index int +} + +func (r *testPacketBatchReader) ReadPacket(buffer *buf.Buffer) (M.Socksaddr, error) { + return M.Socksaddr{}, io.ErrUnexpectedEOF +} + +func (r *testPacketBatchReader) InitializeReadWaiter(options N.ReadWaitOptions) (needCopy bool) { + return false +} + +func (r *testPacketBatchReader) WaitReadPackets() ([]*buf.Buffer, []M.Socksaddr, error) { + if r.index >= len(r.batches) { + return nil, nil, io.EOF + } + batch := r.batches[r.index] + r.index++ + return testBuffersBytes(batch.payloads...), append([]M.Socksaddr(nil), batch.destinations...), nil +} + +type testPacketBatchWriter struct { + usedBatch bool + payloads [][]byte + destinations []M.Socksaddr +} + +func (w *testPacketBatchWriter) WritePacket(buffer *buf.Buffer, destination M.Socksaddr) error { + return io.ErrUnexpectedEOF +} + +func (w *testPacketBatchWriter) WritePacketBatch(buffers []*buf.Buffer, destinations []M.Socksaddr) error { + w.usedBatch = true + for index, buffer := range buffers { + w.payloads = append(w.payloads, append([]byte(nil), buffer.Bytes()...)) + w.destinations = append(w.destinations, destinations[index]) + } + buf.ReleaseMulti(buffers) + return nil +} + +type testConnectedPacketBatchWriter struct { + usedConnectedBatch bool + payloads [][]byte +} + +func (w *testConnectedPacketBatchWriter) WritePacket(buffer *buf.Buffer, destination M.Socksaddr) error { + return io.ErrUnexpectedEOF +} + +func (w *testConnectedPacketBatchWriter) WriteConnectedPacketBatch(buffers []*buf.Buffer) error { + w.usedConnectedBatch = true + for _, buffer := range buffers { + w.payloads = append(w.payloads, append([]byte(nil), buffer.Bytes()...)) + } + buf.ReleaseMulti(buffers) + return nil +} + +type testConnectedPacketBatchReader struct { + destination M.Socksaddr + batches [][]byte + index int +} + +func (r *testConnectedPacketBatchReader) ReadPacket(buffer *buf.Buffer) (M.Socksaddr, error) { + return M.Socksaddr{}, io.ErrUnexpectedEOF +} + +func (r *testConnectedPacketBatchReader) InitializeReadWaiter(options N.ReadWaitOptions) (needCopy bool) { + return false +} + +func (r *testConnectedPacketBatchReader) WaitReadConnectedPackets() ([]*buf.Buffer, M.Socksaddr, error) { + if r.index >= len(r.batches) { + return nil, M.Socksaddr{}, io.EOF + } + buffers := testBuffersBytes(r.batches[r.index:]...) + r.index = len(r.batches) + return buffers, r.destination, nil +} + +type testCachedPacketReader struct { + packets []*N.PacketBuffer + readErr error +} + +func (r *testCachedPacketReader) ReadPacket(buffer *buf.Buffer) (M.Socksaddr, error) { + return M.Socksaddr{}, r.readErr +} + +func (r *testCachedPacketReader) ReadCachedPacket() *N.PacketBuffer { + if len(r.packets) == 0 { + return nil + } + packet := r.packets[0] + r.packets = r.packets[1:] + return packet +} + +type testFailAfterPacketWriter struct { + count int + failAt int + err error +} + +func (w *testFailAfterPacketWriter) WritePacket(buffer *buf.Buffer, destination M.Socksaddr) error { + w.count++ + if w.count == w.failAt { + return w.err + } + buffer.Release() + return nil +} + +type testPacketWriter struct { + payloads [][]byte + destinations []M.Socksaddr +} + +func (w *testPacketWriter) WritePacket(buffer *buf.Buffer, destination M.Socksaddr) error { + w.payloads = append(w.payloads, append([]byte(nil), buffer.Bytes()...)) + w.destinations = append(w.destinations, destination) + buffer.Release() + return nil +} + +type testNetPacketConn struct { + testPacketBatchWriter +} + +func (c *testNetPacketConn) ReadPacket(buffer *buf.Buffer) (M.Socksaddr, error) { + return M.Socksaddr{}, io.ErrUnexpectedEOF +} + +func (c *testNetPacketConn) ReadFrom(p []byte) (n int, addr net.Addr, err error) { + return 0, nil, io.ErrUnexpectedEOF +} + +func (c *testNetPacketConn) WriteTo(p []byte, addr net.Addr) (n int, err error) { + return len(p), nil +} + +func (c *testNetPacketConn) Close() error { + return nil +} + +func (c *testNetPacketConn) LocalAddr() net.Addr { + return M.SocksaddrFrom(netip.MustParseAddr("127.0.0.1"), 0).UDPAddr() +} + +func (c *testNetPacketConn) SetDeadline(t time.Time) error { + return nil +} + +func (c *testNetPacketConn) SetReadDeadline(t time.Time) error { + return nil +} + +func (c *testNetPacketConn) SetWriteDeadline(t time.Time) error { + return nil +} + +type testRemoteAddrPacketConn struct { + testPacketWriter +} + +func (c *testRemoteAddrPacketConn) ReadPacket(buffer *buf.Buffer) (M.Socksaddr, error) { + return M.Socksaddr{}, io.ErrUnexpectedEOF +} + +func (c *testRemoteAddrPacketConn) RemoteAddr() net.Addr { + return M.SocksaddrFrom(netip.MustParseAddr("127.0.0.1"), 1000) +} + +func testPacketBuffer(data string, destination M.Socksaddr) *N.PacketBuffer { + packet := N.NewPacketBuffer() + packet.Buffer = buf.As([]byte(data)).ToOwned() + packet.Destination = destination + return packet +} + +func testBuffers(values ...string) []*buf.Buffer { + payloads := make([][]byte, len(values)) + for index, value := range values { + payloads[index] = []byte(value) + } + return testBuffersBytes(payloads...) +} + +func testBuffersBytes(values ...[]byte) []*buf.Buffer { + buffers := make([]*buf.Buffer, len(values)) + for index, value := range values { + buffer := buf.NewSize(len(value)) + common.Must1(buffer.Write(value)) + buffers[index] = buffer + } + return buffers +} diff --git a/common/bufio/wait.go b/common/bufio/wait.go index 6ac0821c7..9fe04d61f 100644 --- a/common/bufio/wait.go +++ b/common/bufio/wait.go @@ -4,9 +4,14 @@ import ( "io" "github.com/sagernet/sing/common" + M "github.com/sagernet/sing/common/metadata" N "github.com/sagernet/sing/common/network" ) +type vectorisedPacketReadWaitCreator interface { + CreateVectorisedPacketReadWaiter() (N.PacketBatchReadWaiter, bool) +} + func CreateReadWaiter(reader io.Reader) (N.ReadWaiter, bool) { if readWaiter, isReadWaiter := reader.(N.ReadWaiter); isReadWaiter { return readWaiter, true @@ -73,6 +78,54 @@ func CreatePacketReadWaiter(reader N.PacketReader) (N.PacketReadWaiter, bool) { return nil, false } +func CreatePacketBatchReadWaiter(reader N.PacketReader) (N.PacketBatchReadWaiter, bool) { + if readWaiter, isReadWaiter := reader.(N.PacketBatchReadWaiter); isReadWaiter { + return readWaiter, true + } + if readWaitCreator, isCreator := reader.(N.PacketBatchReadWaitCreator); isCreator { + return readWaitCreator.CreatePacketBatchReadWaiter() + } + if readWaitCreator, isCreator := reader.(vectorisedPacketReadWaitCreator); isCreator { + return readWaitCreator.CreateVectorisedPacketReadWaiter() + } + if readWaiter, created := createSyscallPacketBatchReadWaiter(reader); created { + return readWaiter, true + } + if u, ok := reader.(N.ReaderWithUpstream); !ok || !u.ReaderReplaceable() { + return nil, false + } + if u, ok := reader.(N.WithUpstreamReader); ok { + return CreatePacketBatchReadWaiter(u.UpstreamReader().(N.PacketReader)) + } + if u, ok := reader.(common.WithUpstream); ok { + return CreatePacketBatchReadWaiter(u.Upstream().(N.PacketReader)) + } + return nil, false +} + +func CreateConnectedPacketBatchReadWaiter(reader N.PacketReader) (N.ConnectedPacketBatchReadWaiter, bool) { + if readWaiter, isReadWaiter := reader.(N.ConnectedPacketBatchReadWaiter); isReadWaiter { + return readWaiter, true + } + if readWaitCreator, isCreator := reader.(N.ConnectedPacketBatchReadWaitCreator); isCreator { + return readWaitCreator.CreateConnectedPacketBatchReadWaiter() + } + if readWaiter, created := createSyscallConnectedPacketBatchReadWaiter(reader, M.Socksaddr{}); created { + return readWaiter, true + } + if u, ok := reader.(N.ReaderWithUpstream); !ok || !u.ReaderReplaceable() { + return nil, false + } + if u, ok := reader.(N.WithUpstreamReader); ok { + return CreateConnectedPacketBatchReadWaiter(u.UpstreamReader().(N.PacketReader)) + } + if u, ok := reader.(common.WithUpstream); ok { + return CreateConnectedPacketBatchReadWaiter(u.Upstream().(N.PacketReader)) + } + return nil, false +} + +// Deprecated: use CreatePacketBatchReadWaiter. func CreatePacketVectorisedReadWaiter(reader N.PacketReader) (N.VectorisedPacketReadWaiter, bool) { - panic("TODO") + return CreatePacketBatchReadWaiter(reader) } diff --git a/common/metadata/addr_unix.go b/common/metadata/addr_unix.go index 222a1c117..f8ab59c19 100644 --- a/common/metadata/addr_unix.go +++ b/common/metadata/addr_unix.go @@ -3,6 +3,7 @@ package metadata import ( + "encoding/binary" "net/netip" "unsafe" @@ -38,11 +39,28 @@ func AddrPortFromRawSockaddr(sa *unix.RawSockaddr) netip.AddrPort { switch sa.Family { case unix.AF_INET: sa4 := (*unix.RawSockaddrInet4)(unsafe.Pointer(sa)) - return netip.AddrPortFrom(netip.AddrFrom4(sa4.Addr), sa4.Port) + port := binary.BigEndian.Uint16((*[2]byte)(unsafe.Pointer(&sa4.Port))[:]) + return netip.AddrPortFrom(netip.AddrFrom4(sa4.Addr), port) case unix.AF_INET6: sa6 := (*unix.RawSockaddrInet6)(unsafe.Pointer(sa)) - return netip.AddrPortFrom(netip.AddrFrom16(sa6.Addr), sa6.Port) + port := binary.BigEndian.Uint16((*[2]byte)(unsafe.Pointer(&sa6.Port))[:]) + return netip.AddrPortFrom(netip.AddrFrom16(sa6.Addr), port) default: return netip.AddrPort{} } } + +func SocksaddrFromRawSockaddrAny(sa *unix.RawSockaddrAny) Socksaddr { + return SocksaddrFromNetIP(AddrPortFromRawSockaddr(&sa.Addr)).Unwrap() +} + +func AddrPortToRawSockaddrAny(name *unix.RawSockaddrAny, addrPort netip.AddrPort, forceInet6 bool) uint32 { + rawName, nameLen := AddrPortToRawSockaddr(addrPort, forceInet6) + *name = unix.RawSockaddrAny{} + if nameLen == unix.SizeofSockaddrInet4 { + *(*unix.RawSockaddrInet4)(unsafe.Pointer(name)) = *(*unix.RawSockaddrInet4)(rawName) + } else { + *(*unix.RawSockaddrInet6)(unsafe.Pointer(name)) = *(*unix.RawSockaddrInet6)(rawName) + } + return nameLen +} diff --git a/common/metadata/addr_windows.go b/common/metadata/addr_windows.go index cd8096107..7eaa148f1 100644 --- a/common/metadata/addr_windows.go +++ b/common/metadata/addr_windows.go @@ -37,10 +37,12 @@ func AddrPortFromRawSockaddr(sa *windows.RawSockaddr) netip.AddrPort { switch sa.Family { case windows.AF_INET: sa4 := (*windows.RawSockaddrInet4)(unsafe.Pointer(sa)) - return netip.AddrPortFrom(netip.AddrFrom4(sa4.Addr), sa4.Port) + port := binary.BigEndian.Uint16((*[2]byte)(unsafe.Pointer(&sa4.Port))[:]) + return netip.AddrPortFrom(netip.AddrFrom4(sa4.Addr), port) case windows.AF_INET6: sa6 := (*windows.RawSockaddrInet6)(unsafe.Pointer(sa)) - return netip.AddrPortFrom(netip.AddrFrom16(sa6.Addr), sa6.Port) + port := binary.BigEndian.Uint16((*[2]byte)(unsafe.Pointer(&sa6.Port))[:]) + return netip.AddrPortFrom(netip.AddrFrom16(sa6.Addr), port) default: return netip.AddrPort{} } diff --git a/common/network/counter.go b/common/network/counter.go index dc15f3204..d8d474d95 100644 --- a/common/network/counter.go +++ b/common/network/counter.go @@ -81,7 +81,7 @@ func UnwrapCountPacketReader(reader PacketReader, countFunc []CountFunc) (Packet return reader, countFunc } switch u := reader.(type) { - case PacketReadWaiter, PacketReadWaitCreator, syscall.Conn: + case PacketReadWaiter, PacketReadWaitCreator, PacketBatchReadWaiter, PacketBatchReadWaitCreator, ConnectedPacketBatchReadWaiter, ConnectedPacketBatchReadWaitCreator, VectorisedPacketReadWaitCreator, syscall.Conn: // In our use cases, counters is always at the top, so we stop when we encounter ReadWaiter return reader, countFunc case WithUpstreamReader: @@ -103,7 +103,7 @@ func UnwrapCountPacketWriter(writer PacketWriter, countFunc []CountFunc) (Packet return writer, countFunc } switch u := writer.(type) { - case syscall.Conn: + case PacketBatchWriter, ConnectedPacketBatchWriter, syscall.Conn: // In our use cases, counters is always at the top, so we stop when we encounter syscall conn return writer, countFunc case WithUpstreamWriter: diff --git a/common/network/direct.go b/common/network/direct.go index 58137b585..e2567c0b3 100644 --- a/common/network/direct.go +++ b/common/network/direct.go @@ -105,13 +105,30 @@ type PacketReadWaitCreator interface { CreateReadWaiter() (PacketReadWaiter, bool) } -type VectorisedPacketReadWaiter interface { +type PacketBatchReadWaiter interface { ReadWaitable WaitReadPackets() (buffers []*buf.Buffer, destinations []M.Socksaddr, err error) } +type PacketBatchReadWaitCreator interface { + CreatePacketBatchReadWaiter() (PacketBatchReadWaiter, bool) +} + +type ConnectedPacketBatchReadWaiter interface { + ReadWaitable + WaitReadConnectedPackets() (buffers []*buf.Buffer, destination M.Socksaddr, err error) +} + +type ConnectedPacketBatchReadWaitCreator interface { + CreateConnectedPacketBatchReadWaiter() (ConnectedPacketBatchReadWaiter, bool) +} + +// Deprecated: use PacketBatchReadWaiter. +type VectorisedPacketReadWaiter = PacketBatchReadWaiter + +// Deprecated: use PacketBatchReadWaitCreator. type VectorisedPacketReadWaitCreator interface { - CreateVectorisedPacketReadWaiter() (VectorisedPacketReadWaiter, bool) + CreateVectorisedPacketReadWaiter() (PacketBatchReadWaiter, bool) } type SyscallReader interface { diff --git a/common/network/early.go b/common/network/early.go index 3ce520ad5..f64475c55 100644 --- a/common/network/early.go +++ b/common/network/early.go @@ -19,6 +19,10 @@ type EarlyReader interface { } func NeedHandshakeForRead(reader io.Reader) bool { + return NeedHandshakeForReadAny(reader) +} + +func NeedHandshakeForReadAny(reader any) bool { if earlyReader, isEarlyReader := common.Cast[EarlyReader](reader); isEarlyReader && earlyReader.NeedHandshakeForRead() { return true } @@ -30,6 +34,10 @@ type EarlyWriter interface { } func NeedHandshakeForWrite(writer io.Writer) bool { + return NeedHandshakeForWriteAny(writer) +} + +func NeedHandshakeForWriteAny(writer any) bool { if //goland:noinspection GoDeprecation earlyConn, isEarlyConn := writer.(EarlyConn); isEarlyConn { return earlyConn.NeedHandshake() @@ -43,14 +51,23 @@ func NeedHandshakeForWrite(writer io.Writer) bool { type HandshakeState struct { readPending bool writePending bool - source io.Reader - destination io.Writer + source any + destination any } func NewHandshakeState(source io.Reader, destination io.Writer) HandshakeState { return HandshakeState{ - readPending: NeedHandshakeForRead(source), - writePending: NeedHandshakeForWrite(destination), + readPending: NeedHandshakeForReadAny(source), + writePending: NeedHandshakeForWriteAny(destination), + source: source, + destination: destination, + } +} + +func NewPacketHandshakeState(source PacketReader, destination PacketWriter) HandshakeState { + return HandshakeState{ + readPending: NeedHandshakeForReadAny(source), + writePending: NeedHandshakeForWriteAny(destination), source: source, destination: destination, } @@ -61,10 +78,10 @@ func (s HandshakeState) Upgradable() bool { } func (s HandshakeState) Check() error { - if s.readPending && !NeedHandshakeForRead(s.source) { + if s.readPending && !NeedHandshakeForReadAny(s.source) { return ErrHandshakeCompleted } - if s.writePending && !NeedHandshakeForWrite(s.destination) { + if s.writePending && !NeedHandshakeForWriteAny(s.destination) { return ErrHandshakeCompleted } return nil diff --git a/common/network/vectorised.go b/common/network/vectorised.go index d6a2354e3..add3ffb1f 100644 --- a/common/network/vectorised.go +++ b/common/network/vectorised.go @@ -12,3 +12,19 @@ type VectorisedWriter interface { type VectorisedPacketWriter interface { WriteVectorisedPacket(buffers []*buf.Buffer, destination M.Socksaddr) error } + +type PacketBatchWriter interface { + WritePacketBatch(buffers []*buf.Buffer, destinations []M.Socksaddr) error +} + +type PacketBatchWriteCreator interface { + CreatePacketBatchWriter() (PacketBatchWriter, bool) +} + +type ConnectedPacketBatchWriter interface { + WriteConnectedPacketBatch(buffers []*buf.Buffer) error +} + +type ConnectedPacketBatchWriteCreator interface { + CreateConnectedPacketBatchWriter() (ConnectedPacketBatchWriter, bool) +} diff --git a/common/udpnat2/conn.go b/common/udpnat2/conn.go index 3c0cda388..e0d9a5567 100644 --- a/common/udpnat2/conn.go +++ b/common/udpnat2/conn.go @@ -24,6 +24,10 @@ type Conn interface { } var _ Conn = (*natConn)(nil) +var ( + _ N.PacketBatchReadWaitCreator = (*natConn)(nil) + _ N.PacketBatchWriteCreator = (*natConn)(nil) +) type natConn struct { cache freelru.Cache[netip.AddrPort, *natConn] @@ -57,6 +61,16 @@ func (c *natConn) WritePacket(buffer *buf.Buffer, destination M.Socksaddr) error return c.writer.WritePacket(buffer, destination) } +func (c *natConn) CreatePacketBatchWriter() (N.PacketBatchWriter, bool) { + if writer, isWriter := c.writer.(N.PacketBatchWriter); isWriter { + return writer, true + } + if creator, isCreator := c.writer.(N.PacketBatchWriteCreator); isCreator { + return creator.CreatePacketBatchWriter() + } + return nil, false +} + func (c *natConn) InitializeReadWaiter(options N.ReadWaitOptions) (needCopy bool) { c.handlerAccess.Lock() defer c.handlerAccess.Unlock() @@ -78,6 +92,34 @@ func (c *natConn) WaitReadPacket() (buffer *buf.Buffer, destination M.Socksaddr, } } +func (c *natConn) CreatePacketBatchReadWaiter() (N.PacketBatchReadWaiter, bool) { + return c, true +} + +func (c *natConn) WaitReadPackets() (buffers []*buf.Buffer, destinations []M.Socksaddr, err error) { + buffer, destination, err := c.WaitReadPacket() + if err != nil { + return nil, nil, err + } + batchSize := c.readWaitOptions.BatchSize + if batchSize <= 0 { + batchSize = 1 + } + buffers = append(buffers, buffer) + destinations = append(destinations, destination) + for len(buffers) < batchSize { + select { + case packet := <-c.packetChan: + buffers = append(buffers, c.readWaitOptions.Copy(packet.Buffer)) + destinations = append(destinations, packet.Destination) + N.PutPacketBuffer(packet) + default: + return + } + } + return +} + func (c *natConn) SetHandler(handler N.UDPHandlerEx) { c.handlerAccess.Lock() c.handler = handler diff --git a/common/udpnat2/conn_test.go b/common/udpnat2/conn_test.go new file mode 100644 index 000000000..94d1fb609 --- /dev/null +++ b/common/udpnat2/conn_test.go @@ -0,0 +1,98 @@ +package udpnat + +import ( + "testing" + + "github.com/sagernet/sing/common/buf" + M "github.com/sagernet/sing/common/metadata" + N "github.com/sagernet/sing/common/network" + "github.com/sagernet/sing/common/pipe" +) + +func TestNatConnPacketBatchReadWaiter(t *testing.T) { + t.Parallel() + + conn := &natConn{ + writer: testPacketWriter{}, + packetChan: make(chan *N.PacketBuffer, 3), + doneChan: make(chan struct{}), + readDeadline: pipe.MakeDeadline(), + } + conn.InitializeReadWaiter(N.ReadWaitOptions{BatchSize: 2}) + conn.packetChan <- testPacketBuffer("a", M.ParseSocksaddr("1.1.1.1:53")) + conn.packetChan <- testPacketBuffer("bb", M.ParseSocksaddr("2.2.2.2:53")) + conn.packetChan <- testPacketBuffer("ccc", M.ParseSocksaddr("3.3.3.3:53")) + + readWaiter, created := conn.CreatePacketBatchReadWaiter() + if !created { + t.Fatal("CreatePacketBatchReadWaiter returned false") + } + buffers, destinations, err := readWaiter.WaitReadPackets() + if err != nil { + t.Fatal(err) + } + defer buf.ReleaseMulti(buffers) + if len(buffers) != 2 { + t.Fatalf("batch size mismatch: %d", len(buffers)) + } + if string(buffers[0].Bytes()) != "a" || string(buffers[1].Bytes()) != "bb" { + t.Fatalf("unexpected buffers: %q %q", buffers[0].Bytes(), buffers[1].Bytes()) + } + if destinations[0] != M.ParseSocksaddr("1.1.1.1:53") || destinations[1] != M.ParseSocksaddr("2.2.2.2:53") { + t.Fatalf("unexpected destinations: %v", destinations) + } +} + +func TestNatConnPacketBatchWriterCreator(t *testing.T) { + t.Parallel() + + writer := &testPacketBatchWriter{} + conn := &natConn{writer: writer} + batchWriter, created := conn.CreatePacketBatchWriter() + if !created { + t.Fatal("CreatePacketBatchWriter returned false") + } + err := batchWriter.WritePacketBatch([]*buf.Buffer{ + buf.As([]byte("a")).ToOwned(), + buf.As([]byte("bb")).ToOwned(), + }, []M.Socksaddr{ + M.ParseSocksaddr("1.1.1.1:53"), + M.ParseSocksaddr("2.2.2.2:53"), + }) + if err != nil { + t.Fatal(err) + } + if writer.count != 2 { + t.Fatalf("unexpected write count: %d", writer.count) + } +} + +func testPacketBuffer(data string, destination M.Socksaddr) *N.PacketBuffer { + packet := N.NewPacketBuffer() + packet.Buffer = buf.As([]byte(data)).ToOwned() + packet.Destination = destination + return packet +} + +type testPacketWriter struct{} + +func (testPacketWriter) WritePacket(buffer *buf.Buffer, destination M.Socksaddr) error { + buffer.Release() + return nil +} + +type testPacketBatchWriter struct { + count int +} + +func (w *testPacketBatchWriter) WritePacket(buffer *buf.Buffer, destination M.Socksaddr) error { + buffer.Release() + w.count++ + return nil +} + +func (w *testPacketBatchWriter) WritePacketBatch(buffers []*buf.Buffer, destinations []M.Socksaddr) error { + defer buf.ReleaseMulti(buffers) + w.count += len(buffers) + return nil +} diff --git a/common/udpnat2/service.go b/common/udpnat2/service.go index 3e3ce7d1f..1a207b02f 100644 --- a/common/udpnat2/service.go +++ b/common/udpnat2/service.go @@ -6,6 +6,7 @@ import ( "time" "github.com/sagernet/sing/common" + "github.com/sagernet/sing/common/buf" M "github.com/sagernet/sing/common/metadata" N "github.com/sagernet/sing/common/network" "github.com/sagernet/sing/common/pipe" @@ -95,6 +96,17 @@ func (s *Service) NewPacket(bufferSlices [][]byte, source M.Socksaddr, destinati } } +func (s *Service) NewPacketBatch(buffers []*buf.Buffer, sources []M.Socksaddr, destination M.Socksaddr, userData any) { + if len(buffers) != len(sources) { + buf.ReleaseMulti(buffers) + return + } + for index, buffer := range buffers { + s.NewPacket([][]byte{buffer.Bytes()}, sources[index], destination, userData) + buffer.Release() + } +} + func (s *Service) Purge() { s.cache.Purge() } From 8ec882b250fb999b8d698e5e0c8d7836c2d37799 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E4=B8=96=E7=95=8C?= Date: Sat, 25 Apr 2026 20:15:46 +0800 Subject: [PATCH 4/7] contextjson: Add comment support --- common/json/comment.go | 1 + common/json/context.go | 2 - common/json/context_ext.go | 26 +- common/json/context_ext_std.go | 36 -- common/json/internal/contextjson/README.md | 2 + common/json/internal/contextjson/comment.go | 564 ++++++++++++++++++ .../json/internal/contextjson/context_test.go | 317 ++++++++++ common/json/internal/contextjson/decode.go | 52 +- common/json/internal/contextjson/encode.go | 60 +- common/json/internal/contextjson/indent.go | 321 +++++++++- common/json/internal/contextjson/scanner.go | 102 ++++ common/json/internal/contextjson/stream.go | 60 +- common/json/internal/contextjson/unmarshal.go | 14 +- common/json/std.go | 21 - common/json/unmarshal.go | 4 +- common/json/unmarshal_context.go | 2 - common/json/unmarshal_std.go | 13 - 17 files changed, 1495 insertions(+), 102 deletions(-) delete mode 100644 common/json/context_ext_std.go create mode 100644 common/json/internal/contextjson/comment.go delete mode 100644 common/json/std.go delete mode 100644 common/json/unmarshal_std.go diff --git a/common/json/comment.go b/common/json/comment.go index 704c01e74..c65809536 100644 --- a/common/json/comment.go +++ b/common/json/comment.go @@ -29,6 +29,7 @@ type CommentFilter struct { pendingN int } +// Deprecated: use new contextjson instead func NewCommentFilter(reader io.Reader) io.Reader { return &CommentFilter{br: bufio.NewReader(reader)} } diff --git a/common/json/context.go b/common/json/context.go index 1c69b4316..b1f3ce1be 100644 --- a/common/json/context.go +++ b/common/json/context.go @@ -1,5 +1,3 @@ -//go:build !without_contextjson - package json import ( diff --git a/common/json/context_ext.go b/common/json/context_ext.go index 96a8263b8..4f80dabfb 100644 --- a/common/json/context_ext.go +++ b/common/json/context_ext.go @@ -1,5 +1,3 @@ -//go:build !without_contextjson - package json import ( @@ -23,3 +21,27 @@ type ContextMarshaler interface { type ContextUnmarshaler interface { UnmarshalJSONContext(ctx context.Context, content []byte) error } + +type ( + CommentKind = json.CommentKind + CommentPlacement = json.CommentPlacement + CommentPathKind = json.CommentPathKind + CommentPathSegment = json.CommentPathSegment + CommentPath = json.CommentPath + CommentPosition = json.CommentPosition + Comment = json.Comment + CommentSet = json.CommentSet + CommentMarshaler = json.CommentMarshaler + CommentUnmarshaler = json.CommentUnmarshaler +) + +const ( + CommentKindLine = json.CommentKindLine + CommentKindHash = json.CommentKindHash + CommentKindBlock = json.CommentKindBlock + CommentPlacementLeading = json.CommentPlacementLeading + CommentPlacementTrailing = json.CommentPlacementTrailing + CommentPlacementInner = json.CommentPlacementInner + CommentPathKey = json.CommentPathKey + CommentPathIndex = json.CommentPathIndex +) diff --git a/common/json/context_ext_std.go b/common/json/context_ext_std.go deleted file mode 100644 index d7e12b1b5..000000000 --- a/common/json/context_ext_std.go +++ /dev/null @@ -1,36 +0,0 @@ -//go:build without_contextjson - -package json - -import ( - "context" - "io" -) - -func MarshalContext(ctx context.Context, value any) ([]byte, error) { - return Marshal(value) -} - -func UnmarshalContext(ctx context.Context, content []byte, value any) error { - return Unmarshal(content, value) -} - -func NewEncoderContext(ctx context.Context, writer io.Writer) *Encoder { - return NewEncoder(writer) -} - -func NewDecoderContext(ctx context.Context, reader io.Reader) *Decoder { - return NewDecoder(reader) -} - -func UnmarshalContextDisallowUnknownFields(ctx context.Context, content []byte, value any) error { - return UnmarshalDisallowUnknownFields(content, value) -} - -type ContextMarshaler interface { - MarshalJSONContext(ctx context.Context) ([]byte, error) -} - -type ContextUnmarshaler interface { - UnmarshalJSONContext(ctx context.Context, content []byte) error -} diff --git a/common/json/internal/contextjson/README.md b/common/json/internal/contextjson/README.md index 7a6688f48..be0951cb5 100644 --- a/common/json/internal/contextjson/README.md +++ b/common/json/internal/contextjson/README.md @@ -5,6 +5,8 @@ Forked from Go 1.25.9 encoding/json. Local changes: - context-aware marshal and unmarshal interfaces +- JSONC comments by default +- comment-aware context marshal and unmarshal interfaces - concrete error path reporting - trailing comma support for arrays and objects - ObjectKeys helper for decoded JSON object fields diff --git a/common/json/internal/contextjson/comment.go b/common/json/internal/contextjson/comment.go new file mode 100644 index 000000000..856512de0 --- /dev/null +++ b/common/json/internal/contextjson/comment.go @@ -0,0 +1,564 @@ +package json + +import ( + "bytes" + "sort" +) + +type CommentKind string + +const ( + CommentKindLine CommentKind = "//" + CommentKindHash CommentKind = "#" + CommentKindBlock CommentKind = "/*" +) + +type CommentPlacement string + +const ( + CommentPlacementLeading CommentPlacement = "leading" + CommentPlacementTrailing CommentPlacement = "trailing" + CommentPlacementInner CommentPlacement = "inner" +) + +type CommentPathKind byte + +const ( + CommentPathKey CommentPathKind = iota + CommentPathIndex +) + +type CommentPathSegment struct { + Kind CommentPathKind + Key string + Index int +} + +type CommentPath []CommentPathSegment + +type CommentPosition struct { + Offset int + Line int + Column int +} + +type Comment struct { + Kind CommentKind + Placement CommentPlacement + Path CommentPath + Text string + Start CommentPosition + End CommentPosition +} + +type CommentSet struct { + Comments []Comment +} + +func (s *CommentSet) Add(comment Comment) { + if s == nil { + return + } + s.Comments = append(s.Comments, comment) +} + +func (s *CommentSet) Empty() bool { + return s == nil || len(s.Comments) == 0 +} + +func (s *CommentSet) ForPath(path CommentPath) *CommentSet { + if s == nil { + return nil + } + var filtered CommentSet + for _, comment := range s.Comments { + if !commentPathHasPrefix(comment.Path, path) { + continue + } + comment.Path = cloneCommentPath(comment.Path[len(path):]) + filtered.Comments = append(filtered.Comments, comment) + } + if len(filtered.Comments) == 0 { + return nil + } + return &filtered +} + +type CommentMarshaler interface { + ContextMarshaler + Comments() *CommentSet +} + +type CommentUnmarshaler interface { + ContextUnmarshaler + Comments() *CommentSet + SetComments(*CommentSet) +} + +type commentNode struct { + path CommentPath + start int + end int +} + +func stripJSONComments(data []byte) ([]byte, *CommentSet, error) { + clean := append([]byte(nil), data...) + lineStarts := buildLineStarts(data) + var comments []Comment + for i := 0; i < len(data); { + switch data[i] { + case '"': + i = skipJSONString(data, i) + case '#': + start := i + i++ + textStart := i + for i < len(data) && data[i] != '\n' && data[i] != '\r' { + i++ + } + comments = append(comments, newComment(CommentKindHash, data[textStart:i], start, i, lineStarts)) + blankComment(clean, start, i) + case '/': + if i+1 >= len(data) { + i++ + continue + } + switch data[i+1] { + case '/': + start := i + i += 2 + textStart := i + for i < len(data) && data[i] != '\n' && data[i] != '\r' { + i++ + } + comments = append(comments, newComment(CommentKindLine, data[textStart:i], start, i, lineStarts)) + blankComment(clean, start, i) + case '*': + start := i + i += 2 + textStart := i + for i+1 < len(data) && (data[i] != '*' || data[i+1] != '/') { + i++ + } + if i+1 >= len(data) { + return nil, nil, &SyntaxError{"unexpected end of JSON comment", int64(start)} + } + textEnd := i + i += 2 + comments = append(comments, newComment(CommentKindBlock, data[textStart:textEnd], start, i, lineStarts)) + blankComment(clean, start, i) + default: + i++ + } + default: + i++ + } + } + if len(comments) == 0 { + return clean, nil, nil + } + commentSet := assignCommentPaths(data, clean, comments, lineStarts) + return clean, commentSet, nil +} + +func blankComment(data []byte, start, end int) { + for i := start; i < end; i++ { + if data[i] != '\n' && data[i] != '\r' { + data[i] = ' ' + } + } +} + +func newComment(kind CommentKind, text []byte, start, end int, lineStarts []int) Comment { + return Comment{ + Kind: kind, + Text: string(text), + Start: commentPosition(start, lineStarts), + End: commentPosition(end, lineStarts), + } +} + +func buildLineStarts(data []byte) []int { + lineStarts := []int{0} + for i, c := range data { + if c == '\n' { + lineStarts = append(lineStarts, i+1) + } + } + return lineStarts +} + +func commentPosition(offset int, lineStarts []int) CommentPosition { + line := sort.Search(len(lineStarts), func(i int) bool { + return lineStarts[i] > offset + }) + if line == 0 { + return CommentPosition{Offset: offset, Line: 1, Column: offset + 1} + } + lineStart := lineStarts[line-1] + return CommentPosition{Offset: offset, Line: line, Column: offset - lineStart + 1} +} + +func skipJSONString(data []byte, start int) int { + i := start + 1 + for i < len(data) { + switch data[i] { + case '\\': + i += 2 + case '"': + return i + 1 + default: + i++ + } + } + return i +} + +func assignCommentPaths(original, clean []byte, comments []Comment, lineStarts []int) *CommentSet { + nodes := collectCommentNodes(clean) + set := &CommentSet{Comments: comments} + for i := range set.Comments { + placement, path := classifyComment(original, nodes, set.Comments[i]) + set.Comments[i].Placement = placement + set.Comments[i].Path = path + set.Comments[i].Start = commentPosition(set.Comments[i].Start.Offset, lineStarts) + set.Comments[i].End = commentPosition(set.Comments[i].End.Offset, lineStarts) + } + return set +} + +func classifyComment(data []byte, nodes []commentNode, comment Comment) (CommentPlacement, CommentPath) { + var containing *commentNode + for i := range nodes { + node := &nodes[i] + if node.start <= comment.Start.Offset && comment.End.Offset <= node.end { + if containing == nil || len(node.path) > len(containing.path) { + containing = node + } + } + } + + var previous *commentNode + for i := range nodes { + node := &nodes[i] + if node.end > comment.Start.Offset { + continue + } + if containing != nil && !commentPathHasPrefix(node.path, containing.path) { + continue + } + if previous == nil || node.end > previous.end || node.end == previous.end && len(node.path) > len(previous.path) { + previous = node + } + } + if previous != nil && isInlineTrailing(data, previous.end, comment.Start.Offset) { + return CommentPlacementTrailing, cloneCommentPath(previous.path) + } + + var next *commentNode + for i := range nodes { + node := &nodes[i] + if node.start < comment.End.Offset { + continue + } + if containing != nil && !commentPathHasPrefix(node.path, containing.path) { + continue + } + if next == nil || node.start < next.start || node.start == next.start && len(node.path) > len(next.path) { + next = node + } + } + if next != nil { + return CommentPlacementLeading, cloneCommentPath(next.path) + } + if containing != nil { + return CommentPlacementInner, cloneCommentPath(containing.path) + } + return CommentPlacementLeading, nil +} + +func isInlineTrailing(data []byte, from, to int) bool { + if from > to || from < 0 || to > len(data) { + return false + } + seenComma := false + for _, c := range data[from:to] { + switch c { + case '\n', '\r': + return false + case ' ', '\t': + case ',': + if seenComma { + return false + } + seenComma = true + default: + return false + } + } + return true +} + +func collectCommentNodes(data []byte) []commentNode { + parser := commentNodeParser{data: data} + start := parser.skipSpaces(0) + if start < len(data) { + parser.parseValue(start, nil, start) + } + return parser.nodes +} + +type commentNodeParser struct { + data []byte + nodes []commentNode +} + +func (p *commentNodeParser) parseValue(i int, path CommentPath, anchor int) (commentNode, int, bool) { + i = p.skipSpaces(i) + if anchor < 0 { + anchor = i + } + if i >= len(p.data) { + return commentNode{}, i, false + } + switch p.data[i] { + case '{': + return p.parseObject(i, path, anchor) + case '[': + return p.parseArray(i, path, anchor) + case '"': + end := skipJSONString(p.data, i) + node := commentNode{path: cloneCommentPath(path), start: anchor, end: end} + p.nodes = append(p.nodes, node) + return node, end, true + default: + end := p.parseLiteralEnd(i) + node := commentNode{path: cloneCommentPath(path), start: anchor, end: end} + p.nodes = append(p.nodes, node) + return node, end, true + } +} + +func (p *commentNodeParser) parseObject(i int, path CommentPath, anchor int) (commentNode, int, bool) { + i++ + for { + i = p.skipSpaces(i) + if i >= len(p.data) { + return commentNode{}, i, false + } + if p.data[i] == '}' { + end := i + 1 + node := commentNode{path: cloneCommentPath(path), start: anchor, end: end} + p.nodes = append(p.nodes, node) + return node, end, true + } + keyStart := i + keyEnd := skipJSONString(p.data, i) + key, ok := unquote(p.data[keyStart:keyEnd]) + if !ok { + return commentNode{}, keyEnd, false + } + i = p.skipSpaces(keyEnd) + if i < len(p.data) && p.data[i] == ':' { + i++ + } + childPath := appendCommentPathKey(path, key) + _, i, _ = p.parseValue(i, childPath, keyStart) + i = p.skipSpaces(i) + if i < len(p.data) && p.data[i] == ',' { + i++ + continue + } + } +} + +func (p *commentNodeParser) parseArray(i int, path CommentPath, anchor int) (commentNode, int, bool) { + i++ + index := 0 + for { + i = p.skipSpaces(i) + if i >= len(p.data) { + return commentNode{}, i, false + } + if p.data[i] == ']' { + end := i + 1 + node := commentNode{path: cloneCommentPath(path), start: anchor, end: end} + p.nodes = append(p.nodes, node) + return node, end, true + } + childPath := appendCommentPathIndex(path, index) + _, i, _ = p.parseValue(i, childPath, i) + index++ + i = p.skipSpaces(i) + if i < len(p.data) && p.data[i] == ',' { + i++ + continue + } + } +} + +func (p *commentNodeParser) parseLiteralEnd(i int) int { + for i < len(p.data) { + c := p.data[i] + if isSpace(c) || c == ',' || c == '}' || c == ']' || c == ':' { + return i + } + i++ + } + return i +} + +func (p *commentNodeParser) skipSpaces(i int) int { + for i < len(p.data) && isSpace(p.data[i]) { + i++ + } + return i +} + +func appendCommentPathKey(path CommentPath, key string) CommentPath { + next := cloneCommentPath(path) + next = append(next, CommentPathSegment{Kind: CommentPathKey, Key: key}) + return next +} + +func appendCommentPathIndex(path CommentPath, index int) CommentPath { + next := cloneCommentPath(path) + next = append(next, CommentPathSegment{Kind: CommentPathIndex, Index: index}) + return next +} + +func cloneCommentPath(path CommentPath) CommentPath { + if len(path) == 0 { + return nil + } + next := make(CommentPath, len(path)) + copy(next, path) + return next +} + +func commentPathHasPrefix(path, prefix CommentPath) bool { + if len(prefix) > len(path) { + return false + } + for i := range prefix { + if path[i] != prefix[i] { + return false + } + } + return true +} + +func decodeContextPathToCommentPath(path []decodePathSegment) CommentPath { + if len(path) == 0 { + return nil + } + commentPath := make(CommentPath, 0, len(path)) + for _, segment := range path { + switch segment.kind { + case decodePathKey: + commentPath = append(commentPath, CommentPathSegment{Kind: CommentPathKey, Key: segment.key}) + case decodePathIndex: + commentPath = append(commentPath, CommentPathSegment{Kind: CommentPathIndex, Index: segment.index}) + } + } + return commentPath +} + +func (d *decodeState) commentsForCurrentValue() *CommentSet { + if d.comments == nil { + return nil + } + return d.comments.ForPath(decodeContextPathToCommentPath(d.context)) +} + +func insertJSONComments(data []byte, comments *CommentSet) ([]byte, error) { + if comments.Empty() { + return data, nil + } + nodes := collectCommentNodes(data) + insertions := make(map[int][][]byte) + for _, comment := range comments.Comments { + offset := commentInsertOffset(data, nodes, comment) + insertions[offset] = append(insertions[offset], formatComment(comment)) + } + var out bytes.Buffer + out.Grow(len(data) + len(comments.Comments)*16) + for i := 0; i <= len(data); i++ { + if values := insertions[i]; len(values) > 0 { + for _, value := range values { + out.Write(value) + } + } + if i < len(data) { + out.WriteByte(data[i]) + } + } + return out.Bytes(), nil +} + +func commentInsertOffset(data []byte, nodes []commentNode, comment Comment) int { + node, ok := findCommentNode(nodes, comment.Path) + if !ok { + if comment.Placement == CommentPlacementTrailing { + return len(data) + } + return 0 + } + switch comment.Placement { + case CommentPlacementTrailing: + return trailingCommentInsertOffset(data, node.end) + case CommentPlacementInner: + if node.end > node.start && (data[node.start] == '{' || data[node.start] == '[') { + return node.start + 1 + } + return node.start + default: + return node.start + } +} + +func trailingCommentInsertOffset(data []byte, offset int) int { + for i := offset; i < len(data); i++ { + switch data[i] { + case ' ', '\t': + continue + case ',': + return i + 1 + default: + return offset + } + } + return offset +} + +func findCommentNode(nodes []commentNode, path CommentPath) (commentNode, bool) { + for _, node := range nodes { + if len(node.path) != len(path) { + continue + } + if commentPathHasPrefix(node.path, path) { + return node, true + } + } + return commentNode{}, false +} + +func formatComment(comment Comment) []byte { + switch comment.Kind { + case CommentKindHash: + if comment.Placement == CommentPlacementTrailing { + return []byte(" #" + comment.Text + "\n") + } + return []byte("#" + comment.Text + "\n") + case CommentKindBlock: + if comment.Placement == CommentPlacementTrailing { + return []byte(" /*" + comment.Text + "*/") + } + return []byte("/*" + comment.Text + "*/\n") + default: + if comment.Placement == CommentPlacementTrailing { + return []byte(" //" + comment.Text + "\n") + } + return []byte("//" + comment.Text + "\n") + } +} diff --git a/common/json/internal/contextjson/context_test.go b/common/json/internal/contextjson/context_test.go index ab0abd603..81f7af5c5 100644 --- a/common/json/internal/contextjson/context_test.go +++ b/common/json/internal/contextjson/context_test.go @@ -328,6 +328,323 @@ func TestTrailingCommaDecoderDecode(t *testing.T) { } } +func TestUnmarshalAcceptsJSONComments(t *testing.T) { + t.Parallel() + var value map[string]int + if err := json.Unmarshal([]byte(`{ + // leading + "a": 1, + # hash + "b": 2, /* block */ + }`), &value); err != nil { + t.Fatal(err) + } + if value["a"] != 1 || value["b"] != 2 { + t.Fatalf("value = %#v", value) + } +} + +func TestDecoderAcceptsJSONComments(t *testing.T) { + t.Parallel() + decoder := json.NewDecoder(strings.NewReader(`// before + {"value": 1}`)) + var value map[string]int + if err := decoder.Decode(&value); err != nil { + t.Fatal(err) + } + if value["value"] != 1 { + t.Fatalf("value = %#v", value) + } +} + +func TestTokenAcceptsJSONComments(t *testing.T) { + t.Parallel() + decoder := json.NewDecoder(strings.NewReader(`[// before + 1]`)) + token, err := decoder.Token() + if err != nil { + t.Fatal(err) + } + if token != json.Delim('[') { + t.Fatalf("start token = %v", token) + } + token, err = decoder.Token() + if err != nil { + t.Fatal(err) + } + if token != float64(1) { + t.Fatalf("value token = %#v", token) + } + token, err = decoder.Token() + if err != nil { + t.Fatal(err) + } + if token != json.Delim(']') { + t.Fatalf("end token = %v", token) + } +} + +type commentContextValue struct { + Value int `json:"value"` + Raw string + CommentsSet *json.CommentSet +} + +func (v *commentContextValue) MarshalJSONContext(ctx context.Context) ([]byte, error) { + return json.MarshalContext(ctx, struct { + Value int `json:"value"` + }{Value: v.Value}) +} + +func (v *commentContextValue) UnmarshalJSONContext(ctx context.Context, content []byte) error { + v.Raw = string(content) + var wire struct { + Value int `json:"value"` + } + if err := json.UnmarshalContext(ctx, content, &wire); err != nil { + return err + } + v.Value = wire.Value + return nil +} + +func (v *commentContextValue) Comments() *json.CommentSet { + return v.CommentsSet +} + +func (v *commentContextValue) SetComments(comments *json.CommentSet) { + v.CommentsSet = comments +} + +func TestCommentUnmarshalerReceivesComments(t *testing.T) { + t.Parallel() + var value commentContextValue + if err := json.UnmarshalContext(context.Background(), []byte(`{ + // value leading + "value": 1 // value trailing + }`), &value); err != nil { + t.Fatal(err) + } + if strings.Contains(value.Raw, "//") { + t.Fatalf("raw content contains comments: %q", value.Raw) + } + if value.CommentsSet == nil || len(value.CommentsSet.Comments) != 2 { + t.Fatalf("comments = %#v", value.CommentsSet) + } + leading, trailing := value.CommentsSet.Comments[0], value.CommentsSet.Comments[1] + if leading.Placement != json.CommentPlacementLeading || len(leading.Path) != 1 || leading.Path[0].Key != "value" { + t.Fatalf("leading comment = %#v", leading) + } + if trailing.Placement != json.CommentPlacementTrailing || len(trailing.Path) != 1 || trailing.Path[0].Key != "value" { + t.Fatalf("trailing comment = %#v", trailing) + } +} + +func TestCommentMarshalerWritesComments(t *testing.T) { + t.Parallel() + value := &commentContextValue{ + Value: 1, + CommentsSet: &json.CommentSet{Comments: []json.Comment{ + { + Kind: json.CommentKindLine, + Placement: json.CommentPlacementLeading, + Path: json.CommentPath{{Kind: json.CommentPathKey, Key: "value"}}, + Text: " value leading", + }, + { + Kind: json.CommentKindBlock, + Placement: json.CommentPlacementTrailing, + Path: json.CommentPath{{Kind: json.CommentPathKey, Key: "value"}}, + Text: " value trailing ", + }, + }}, + } + content, err := json.MarshalContext(context.Background(), value) + if err != nil { + t.Fatal(err) + } + if !strings.Contains(string(content), "// value leading") || !strings.Contains(string(content), "/* value trailing */") { + t.Fatalf("content = %s", content) + } + var decoded map[string]int + if err := json.Unmarshal(content, &decoded); err != nil { + t.Fatal(err) + } + if decoded["value"] != 1 { + t.Fatalf("decoded = %#v", decoded) + } +} + +type commentFormatValue struct { + Listen string + Port int + Legacy bool + Route map[string]any + Items []int + CommentsSet *json.CommentSet +} + +func (v *commentFormatValue) MarshalJSONContext(ctx context.Context) ([]byte, error) { + type wire struct { + Listen string `json:"listen,omitempty"` + Port int `json:"port,omitempty"` + Legacy bool `json:"legacy,omitempty"` + Route *map[string]any `json:"route,omitempty"` + Items []int `json:"items,omitempty"` + } + var route *map[string]any + if v.Route != nil { + route = &v.Route + } + return json.MarshalContext(ctx, wire{ + Listen: v.Listen, + Port: v.Port, + Legacy: v.Legacy, + Route: route, + Items: v.Items, + }) +} + +func (v *commentFormatValue) UnmarshalJSONContext(ctx context.Context, content []byte) error { + type wire struct { + Listen string `json:"listen"` + Port int `json:"port"` + Legacy bool `json:"legacy"` + Route map[string]any `json:"route"` + Items []int `json:"items"` + } + var decoded wire + if err := json.UnmarshalContext(ctx, content, &decoded); err != nil { + return err + } + v.Listen = decoded.Listen + v.Port = decoded.Port + v.Legacy = decoded.Legacy + if decoded.Route != nil { + v.Route = decoded.Route + } + v.Items = decoded.Items + return nil +} + +func (v *commentFormatValue) Comments() *json.CommentSet { + return v.CommentsSet +} + +func (v *commentFormatValue) SetComments(comments *json.CommentSet) { + v.CommentsSet = comments +} + +func TestCommentMarshalIndentPreservesJSONCFormatting(t *testing.T) { + t.Parallel() + const input = `{ + // listen address + "listen": "127.0.0.1", + "port": 7890, // mixed port + # legacy option + "legacy": true, + /* + * block line 1 + * block line 2 + */ + "route": {} +}` + var value commentFormatValue + if err := json.UnmarshalContext(context.Background(), []byte(input), &value); err != nil { + t.Fatal(err) + } + content, err := json.MarshalIndent(&value, "", " ") + if err != nil { + t.Fatal(err) + } + if string(content) != input { + t.Fatalf("content mismatch\nwant:\n%s\n\ngot:\n%s", input, content) + } +} + +func TestCommentMarshalIndentFormatsArrayComments(t *testing.T) { + t.Parallel() + const input = `{ + "items": [ + // first + 1, + 2, // second + 3 + ] +}` + var value commentFormatValue + if err := json.UnmarshalContext(context.Background(), []byte(input), &value); err != nil { + t.Fatal(err) + } + content, err := json.MarshalIndent(&value, "", " ") + if err != nil { + t.Fatal(err) + } + if string(content) != input { + t.Fatalf("content mismatch\nwant:\n%s\n\ngot:\n%s", input, content) + } +} + +func TestIndentPreservesJSONComments(t *testing.T) { + t.Parallel() + const input = `{"a":1, // a +"b":[/* inner */2,3],"c":4, /* c */"d":5}` + const expected = `{ + "a": 1, // a + "b": [ + /* inner */ + 2, + 3 + ], + "c": 4, /* c */ + "d": 5 +}` + var out bytes.Buffer + if err := json.Indent(&out, []byte(input), "", " "); err != nil { + t.Fatal(err) + } + if out.String() != expected { + t.Fatalf("content mismatch\nwant:\n%s\n\ngot:\n%s", expected, out.String()) + } +} + +func TestEncoderSetIndentPreservesJSONComments(t *testing.T) { + t.Parallel() + value := &commentFormatValue{ + Listen: "127.0.0.1", + Port: 7890, + CommentsSet: &json.CommentSet{Comments: []json.Comment{ + { + Kind: json.CommentKindLine, + Placement: json.CommentPlacementLeading, + Path: json.CommentPath{{Kind: json.CommentPathKey, Key: "listen"}}, + Text: " listen address", + }, + { + Kind: json.CommentKindLine, + Placement: json.CommentPlacementTrailing, + Path: json.CommentPath{{Kind: json.CommentPathKey, Key: "port"}}, + Text: " mixed port", + }, + }}, + } + const expected = `{ + // listen address + "listen": "127.0.0.1", + "port": 7890 // mixed port +} +` + var out bytes.Buffer + encoder := json.NewEncoder(&out) + encoder.SetIndent("", " ") + if err := encoder.Encode(value); err != nil { + t.Fatal(err) + } + if out.String() != expected { + t.Fatalf("content mismatch\nwant:\n%s\n\ngot:\n%s", expected, out.String()) + } +} + func TestTrailingCommaDoesNotAllowMissingValues(t *testing.T) { t.Parallel() tests := []string{ diff --git a/common/json/internal/contextjson/decode.go b/common/json/internal/contextjson/decode.go index 14429d57c..e1d6c519a 100644 --- a/common/json/internal/contextjson/decode.go +++ b/common/json/internal/contextjson/decode.go @@ -103,17 +103,22 @@ func Unmarshal(data []byte, v any) error { } func UnmarshalContext(ctx context.Context, data []byte, v any) error { + data, comments, err := stripJSONComments(data) + if err != nil { + return err + } // Check for well-formedness. // Avoids filling out half a data structure // before discovering a JSON syntax error. var d decodeState d.ctx = ctx - err := checkValid(data, &d.scan) + err = checkValid(data, &d.scan) if err != nil { return err } d.init(data) + d.comments = comments return d.unmarshal(v) } @@ -225,6 +230,7 @@ type decodeState struct { useNumber bool disallowUnknownFields bool context []decodePathSegment + comments *CommentSet } // readIndex returns the position of the last byte read. @@ -440,7 +446,7 @@ func (d *decodeState) valueQuoted() any { // If it encounters an Unmarshaler, indirect stops and returns that. // If decodingNull is true, indirect stops at the first settable pointer so it // can be set to nil. -func indirect(v reflect.Value, decodingNull bool) (Unmarshaler, ContextUnmarshaler, encoding.TextUnmarshaler, reflect.Value) { +func indirect(v reflect.Value, decodingNull bool) (Unmarshaler, CommentUnmarshaler, ContextUnmarshaler, encoding.TextUnmarshaler, reflect.Value) { // Issue #24153 indicates that it is generally not a guaranteed property // that you may round-trip a reflect.Value by calling Value.Addr().Elem() // and expect the value to still be settable for values derived from @@ -494,14 +500,17 @@ func indirect(v reflect.Value, decodingNull bool) (Unmarshaler, ContextUnmarshal } if v.Type().NumMethod() > 0 && v.CanInterface() { if u, ok := v.Interface().(Unmarshaler); ok { - return u, nil, nil, reflect.Value{} + return u, nil, nil, nil, reflect.Value{} + } + if cu, ok := v.Interface().(CommentUnmarshaler); ok { + return nil, cu, nil, nil, reflect.Value{} } if cu, ok := v.Interface().(ContextUnmarshaler); ok { - return nil, cu, nil, reflect.Value{} + return nil, nil, cu, nil, reflect.Value{} } if !decodingNull { if u, ok := v.Interface().(encoding.TextUnmarshaler); ok { - return nil, nil, u, reflect.Value{} + return nil, nil, nil, u, reflect.Value{} } } } @@ -513,14 +522,14 @@ func indirect(v reflect.Value, decodingNull bool) (Unmarshaler, ContextUnmarshal v = v.Elem() } } - return nil, nil, nil, v + return nil, nil, nil, nil, v } // array consumes an array from d.data[d.off-1:], decoding into v. // The first byte of the array ('[') has been read already. func (d *decodeState) array(v reflect.Value) error { // Check for unmarshaler. - u, cu, ut, pv := indirect(v, false) + u, ccu, cu, ut, pv := indirect(v, false) if u != nil { start := d.readIndex() d.skip() @@ -529,6 +538,15 @@ func (d *decodeState) array(v reflect.Value) error { } return nil } + if ccu != nil { + start := d.readIndex() + d.skip() + ccu.SetComments(d.commentsForCurrentValue()) + if err := ccu.UnmarshalJSONContext(d.ctx, d.data[start:d.off]); err != nil { + d.saveError(err) + } + return nil + } if cu != nil { start := d.readIndex() d.skip() @@ -634,7 +652,7 @@ var textUnmarshalerType = reflect.TypeFor[encoding.TextUnmarshaler]() // The first byte ('{') of the object has been read already. func (d *decodeState) object(v reflect.Value) error { // Check for unmarshaler. - u, cu, ut, pv := indirect(v, false) + u, ccu, cu, ut, pv := indirect(v, false) if u != nil { start := d.readIndex() d.skip() @@ -643,6 +661,15 @@ func (d *decodeState) object(v reflect.Value) error { } return nil } + if ccu != nil { + start := d.readIndex() + d.skip() + ccu.SetComments(d.commentsForCurrentValue()) + if err := ccu.UnmarshalJSONContext(d.ctx, d.data[start:d.off]); err != nil { + d.saveError(err) + } + return nil + } if cu != nil { start := d.readIndex() d.skip() @@ -909,13 +936,20 @@ func (d *decodeState) literalStore(item []byte, v reflect.Value, fromQuoted bool return nil } isNull := item[0] == 'n' // null - u, cu, ut, pv := indirect(v, isNull) + u, ccu, cu, ut, pv := indirect(v, isNull) if u != nil { if err := u.UnmarshalJSON(item); err != nil { d.saveError(err) } return nil } + if ccu != nil { + ccu.SetComments(d.commentsForCurrentValue()) + if err := ccu.UnmarshalJSONContext(d.ctx, item); err != nil { + d.saveError(err) + } + return nil + } if cu != nil { if err := cu.UnmarshalJSONContext(d.ctx, item); err != nil { d.saveError(err) diff --git a/common/json/internal/contextjson/encode.go b/common/json/internal/contextjson/encode.go index 578676c1c..809bc1286 100644 --- a/common/json/internal/contextjson/encode.go +++ b/common/json/internal/contextjson/encode.go @@ -419,6 +419,7 @@ func typeEncoder(t reflect.Type) encoderFunc { var ( marshalerType = reflect.TypeFor[Marshaler]() + commentMarshalerType = reflect.TypeFor[CommentMarshaler]() contextMarshalerType = reflect.TypeFor[ContextMarshaler]() textMarshalerType = reflect.TypeFor[encoding.TextMarshaler]() ) @@ -436,6 +437,12 @@ func newTypeEncoder(t reflect.Type, allowAddr bool) encoderFunc { if t.Implements(marshalerType) { return marshalerEncoder } + if t.Kind() != reflect.Pointer && allowAddr && reflect.PointerTo(t).Implements(commentMarshalerType) { + return newCondAddrEncoder(addrCommentMarshalerEncoder, newTypeEncoder(t, false)) + } + if t.Implements(commentMarshalerType) { + return commentMarshalerEncoder + } if t.Kind() != reflect.Pointer && allowAddr && reflect.PointerTo(t).Implements(contextMarshalerType) { return newCondAddrEncoder(addrContextMarshalerEncoder, newTypeEncoder(t, false)) } @@ -546,6 +553,57 @@ func contextMarshalerEncoder(e *encodeState, v reflect.Value, opts encOpts) { } } +func commentMarshalerEncoder(e *encodeState, v reflect.Value, opts encOpts) { + if v.Kind() == reflect.Pointer && v.IsNil() { + e.WriteString("null") + return + } + m, ok := v.Interface().(CommentMarshaler) + if !ok { + e.WriteString("null") + return + } + b, err := m.MarshalJSONContext(e.ctx) + if err == nil { + var out []byte + out, err = appendCompact(out, b, opts.escapeHTML) + if err == nil { + out, err = insertJSONComments(out, m.Comments()) + } + if err == nil { + e.Grow(len(out)) + e.Buffer.Write(out) + } + } + if err != nil { + e.error(&MarshalerError{v.Type(), err, "MarshalJSONContext"}) + } +} + +func addrCommentMarshalerEncoder(e *encodeState, v reflect.Value, opts encOpts) { + va := v.Addr() + if va.IsNil() { + e.WriteString("null") + return + } + m := va.Interface().(CommentMarshaler) + b, err := m.MarshalJSONContext(e.ctx) + if err == nil { + var out []byte + out, err = appendCompact(out, b, opts.escapeHTML) + if err == nil { + out, err = insertJSONComments(out, m.Comments()) + } + if err == nil { + e.Grow(len(out)) + e.Buffer.Write(out) + } + } + if err != nil { + e.error(&MarshalerError{v.Type(), err, "MarshalJSONContext"}) + } +} + func addrContextMarshalerEncoder(e *encodeState, v reflect.Value, opts encOpts) { va := v.Addr() if va.IsNil() { @@ -933,7 +991,7 @@ func newSliceEncoder(t reflect.Type) encoderFunc { // Byte slices get special treatment; arrays don't. if t.Elem().Kind() == reflect.Uint8 { p := reflect.PointerTo(t.Elem()) - if !p.Implements(marshalerType) && !p.Implements(contextMarshalerType) && !p.Implements(textMarshalerType) { + if !p.Implements(marshalerType) && !p.Implements(commentMarshalerType) && !p.Implements(contextMarshalerType) && !p.Implements(textMarshalerType) { return encodeByteSlice } } diff --git a/common/json/internal/contextjson/indent.go b/common/json/internal/contextjson/indent.go index 01bfdf65e..c6a909867 100644 --- a/common/json/internal/contextjson/indent.go +++ b/common/json/internal/contextjson/indent.go @@ -87,8 +87,74 @@ func appendCompact(dst, src []byte, escape bool) ([]byte, error) { return dst, nil } -func appendNewline(dst []byte, prefix, indent string, depth int) []byte { - dst = append(dst, '\n') +type jsonCommentToken struct { + kind CommentKind + text []byte + end int + closed bool +} + +func readJSONComment(src []byte, start int) (jsonCommentToken, bool) { + if start >= len(src) { + return jsonCommentToken{}, false + } + switch src[start] { + case '#': + textStart := start + 1 + textEnd := textStart + for textEnd < len(src) && src[textEnd] != '\n' && src[textEnd] != '\r' { + textEnd++ + } + end := consumeLineEnd(src, textEnd) + return jsonCommentToken{kind: CommentKindHash, text: src[textStart:textEnd], end: end, closed: true}, true + case '/': + if start+1 >= len(src) { + return jsonCommentToken{}, false + } + switch src[start+1] { + case '/': + textStart := start + 2 + textEnd := textStart + for textEnd < len(src) && src[textEnd] != '\n' && src[textEnd] != '\r' { + textEnd++ + } + end := consumeLineEnd(src, textEnd) + return jsonCommentToken{kind: CommentKindLine, text: src[textStart:textEnd], end: end, closed: true}, true + case '*': + textStart := start + 2 + for i := textStart; i+1 < len(src); i++ { + if src[i] == '*' && src[i+1] == '/' { + return jsonCommentToken{kind: CommentKindBlock, text: src[textStart:i], end: i + 2, closed: true}, true + } + } + return jsonCommentToken{kind: CommentKindBlock, text: src[textStart:], end: len(src)}, true + } + } + return jsonCommentToken{}, false +} + +func consumeLineEnd(src []byte, i int) int { + if i >= len(src) { + return i + } + if src[i] == '\r' { + i++ + if i < len(src) && src[i] == '\n' { + i++ + } + return i + } + if src[i] == '\n' { + return i + 1 + } + return i +} + +func appendLineBreak(dst []byte) []byte { + return append(dst, '\n') +} + +func appendIndentPrefix(dst []byte, prefix, indent string, depth int) []byte { dst = append(dst, prefix...) for i := 0; i < depth; i++ { dst = append(dst, indent...) @@ -96,6 +162,187 @@ func appendNewline(dst []byte, prefix, indent string, depth int) []byte { return dst } +func nextCommentInline(src []byte, i int) bool { + for i < len(src) { + switch src[i] { + case ' ', '\t': + i++ + continue + case '\n', '\r': + return false + } + break + } + comment, ok := readJSONComment(src, i) + if !ok || !comment.closed { + return false + } + return comment.kind != CommentKindBlock || !containsLineBreak(comment.text) +} + +func containsLineBreak(value []byte) bool { + for _, c := range value { + if c == '\n' || c == '\r' { + return true + } + } + return false +} + +func appendJSONComment(dst []byte, comment jsonCommentToken, prefix, indent string, depth int, atLineStart, lineHasContent bool) ([]byte, bool, bool) { + if comment.kind == CommentKindBlock { + return appendBlockJSONComment(dst, comment.text, prefix, indent, depth, atLineStart, lineHasContent) + } + if lineHasContent { + dst = appendSpaceBeforeInlineComment(dst) + } else if atLineStart { + dst = appendIndentPrefix(dst, prefix, indent, depth) + atLineStart = false + } + dst = appendCommentMarker(dst, comment.kind) + dst = append(dst, comment.text...) + dst = appendLineBreak(dst) + return dst, true, false +} + +func appendBlockJSONComment(dst []byte, text []byte, prefix, indent string, depth int, atLineStart, lineHasContent bool) ([]byte, bool, bool) { + if !containsLineBreak(text) { + leading := !lineHasContent + inlineAfterComma := lineHasContent && lastNonSpaceByte(dst) == ',' + if lineHasContent { + dst = appendSpaceBeforeInlineComment(dst) + } else if atLineStart { + dst = appendIndentPrefix(dst, prefix, indent, depth) + atLineStart = false + } + dst = append(dst, '/', '*') + dst = append(dst, text...) + dst = append(dst, '*', '/') + if leading || inlineAfterComma { + dst = appendLineBreak(dst) + return dst, true, false + } + return dst, false, true + } + if lineHasContent { + dst = appendLineBreak(dst) + atLineStart = true + lineHasContent = false + } + if atLineStart { + dst = appendIndentPrefix(dst, prefix, indent, depth) + } + dst = appendMultilineBlockComment(dst, text, prefix, indent, depth) + dst = appendLineBreak(dst) + return dst, true, false +} + +func appendSpaceBeforeInlineComment(dst []byte) []byte { + if len(dst) == 0 { + return dst + } + switch dst[len(dst)-1] { + case ' ', '\t', '\n', '\r': + return dst + default: + return append(dst, ' ') + } +} + +func lastNonSpaceByte(value []byte) byte { + for i := len(value) - 1; i >= 0; i-- { + switch value[i] { + case ' ', '\t', '\n', '\r': + continue + default: + return value[i] + } + } + return 0 +} + +func appendCommentMarker(dst []byte, kind CommentKind) []byte { + switch kind { + case CommentKindHash: + return append(dst, '#') + case CommentKindBlock: + return append(dst, '/', '*') + default: + return append(dst, '/', '/') + } +} + +func appendMultilineBlockComment(dst []byte, text []byte, prefix, indent string, depth int) []byte { + lines := splitCommentLines(text) + trim := blockCommentTrim(lines) + dst = append(dst, '/', '*') + for i, line := range lines { + if i > 0 { + dst = appendLineBreak(dst) + dst = appendIndentPrefix(dst, prefix, indent, depth) + line = trimBlockCommentLine(line, trim) + } + dst = append(dst, line...) + } + dst = append(dst, '*', '/') + return dst +} + +func splitCommentLines(text []byte) [][]byte { + lines := make([][]byte, 0, bytes.Count(text, []byte{'\n'})+1) + start := 0 + for i := 0; i < len(text); i++ { + if text[i] != '\n' && text[i] != '\r' { + continue + } + lines = append(lines, text[start:i]) + if text[i] == '\r' && i+1 < len(text) && text[i+1] == '\n' { + i++ + } + start = i + 1 + } + lines = append(lines, text[start:]) + return lines +} + +func blockCommentTrim(lines [][]byte) int { + minIndent := -1 + for _, line := range lines[1:] { + if len(line) == 0 { + continue + } + indent := leadingCommentIndent(line) + if indent == len(line) { + continue + } + if minIndent < 0 || indent < minIndent { + minIndent = indent + } + } + if minIndent <= 1 { + return 0 + } + return minIndent - 1 +} + +func leadingCommentIndent(line []byte) int { + i := 0 + for i < len(line) && (line[i] == ' ' || line[i] == '\t') { + i++ + } + return i +} + +func trimBlockCommentLine(line []byte, trim int) []byte { + if trim == 0 { + return line + } + if indent := leadingCommentIndent(line); indent < trim { + return line[indent:] + } + return line[trim:] +} + // indentGrowthFactor specifies the growth factor of indenting JSON input. // Empirically, the growth factor was measured to be between 1.4x to 1.8x // for some set of compacted JSON with the indent being a single tab. @@ -129,10 +376,35 @@ func appendIndent(dst, src []byte, prefix, indent string) ([]byte, error) { defer freeScanner(scan) needIndent := false depth := 0 - for _, c := range src { + atLineStart := true + lineHasContent := false +Input: + for i := 0; i < len(src); i++ { + c := src[i] scan.bytes++ v := scan.step(scan, c) if v == scanSkipSpace { + if comment, ok := readJSONComment(src, i); ok { + for j := i + 1; j < comment.end; j++ { + scan.bytes++ + if scan.step(scan, src[j]) == scanError { + break Input + } + } + if !comment.closed { + i = comment.end - 1 + continue + } + if needIndent { + dst = appendLineBreak(dst) + atLineStart = true + lineHasContent = false + needIndent = false + } + dst, atLineStart, lineHasContent = appendJSONComment(dst, comment, prefix, indent, depth, atLineStart, lineHasContent) + i = comment.end - 1 + continue + } continue } if v == scanError { @@ -140,39 +412,70 @@ func appendIndent(dst, src []byte, prefix, indent string) ([]byte, error) { } if needIndent && v != scanEndObject && v != scanEndArray { needIndent = false - depth++ - dst = appendNewline(dst, prefix, indent, depth) + dst = appendLineBreak(dst) + atLineStart = true + lineHasContent = false } // Emit semantically uninteresting bytes // (in particular, punctuation in strings) unmodified. if v == scanContinue { + if atLineStart { + dst = appendIndentPrefix(dst, prefix, indent, depth) + atLineStart = false + } dst = append(dst, c) + lineHasContent = true continue } // Add spacing around real punctuation. switch c { case '{', '[': + if atLineStart { + dst = appendIndentPrefix(dst, prefix, indent, depth) + atLineStart = false + } // delay indent so that empty object and array are formatted as {} and []. needIndent = true dst = append(dst, c) + lineHasContent = true + depth++ case ',': dst = append(dst, c) - dst = appendNewline(dst, prefix, indent, depth) + lineHasContent = true + if !nextCommentInline(src, i+1) { + dst = appendLineBreak(dst) + atLineStart = true + lineHasContent = false + } case ':': dst = append(dst, c, ' ') + lineHasContent = true + atLineStart = false case '}', ']': + depth-- if needIndent { // suppress indent in empty object/array needIndent = false - } else { - depth-- - dst = appendNewline(dst, prefix, indent, depth) + } else if lineHasContent { + dst = appendLineBreak(dst) + atLineStart = true + lineHasContent = false + } + if atLineStart { + dst = appendIndentPrefix(dst, prefix, indent, depth) + atLineStart = false } dst = append(dst, c) + lineHasContent = true default: + if atLineStart { + dst = appendIndentPrefix(dst, prefix, indent, depth) + atLineStart = false + } dst = append(dst, c) + lineHasContent = true } } if scan.eof() == scanError { diff --git a/common/json/internal/contextjson/scanner.go b/common/json/internal/contextjson/scanner.go index 2b2600a3a..be451c013 100644 --- a/common/json/internal/contextjson/scanner.go +++ b/common/json/internal/contextjson/scanner.go @@ -82,6 +82,9 @@ type scanner struct { // total bytes consumed, updated by decoder.Decode (and deliberately // not set to zero by scan.reset) bytes int64 + + commentStep func(*scanner, byte) int + commentMode int } var scannerPool = sync.Pool{ @@ -152,6 +155,8 @@ func (s *scanner) reset() { s.parseState = s.parseState[0:0] s.err = nil s.endTop = false + s.commentStep = nil + s.commentMode = 0 } // eof tells the scanner that the end of input has been reached. @@ -160,6 +165,17 @@ func (s *scanner) eof() int { if s.err != nil { return scanError } + switch s.commentMode { + case scannerLineComment: + s.step = s.commentStep + s.commentStep = nil + s.commentMode = 0 + case scannerBlockComment: + if s.err == nil { + s.err = &SyntaxError{"unexpected end of JSON comment", s.bytes} + } + return scanError + } if s.endTop { return scanEnd } @@ -200,11 +216,79 @@ func isSpace(c byte) bool { return c <= ' ' && (c == ' ' || c == '\t' || c == '\r' || c == '\n') } +const ( + scannerNoComment = iota + scannerLineComment + scannerBlockComment +) + +func stateBeginComment(s *scanner, c byte, next func(*scanner, byte) int) int { + switch c { + case '#': + s.commentStep = next + s.commentMode = scannerLineComment + s.step = stateInLineComment + return scanSkipSpace + case '/': + s.commentStep = next + s.step = stateMaybeComment + return scanSkipSpace + } + return -1 +} + +func stateMaybeComment(s *scanner, c byte) int { + switch c { + case '/': + s.commentMode = scannerLineComment + s.step = stateInLineComment + return scanSkipSpace + case '*': + s.commentMode = scannerBlockComment + s.step = stateInBlockComment + return scanSkipSpace + default: + return s.error('/', "looking for beginning of comment") + } +} + +func stateInLineComment(s *scanner, c byte) int { + if c == '\n' || c == '\r' { + s.step = s.commentStep + s.commentStep = nil + s.commentMode = scannerNoComment + } + return scanSkipSpace +} + +func stateInBlockComment(s *scanner, c byte) int { + if c == '*' { + s.step = stateInBlockCommentStar + } + return scanSkipSpace +} + +func stateInBlockCommentStar(s *scanner, c byte) int { + switch c { + case '/': + s.step = s.commentStep + s.commentStep = nil + s.commentMode = scannerNoComment + case '*': + default: + s.step = stateInBlockComment + } + return scanSkipSpace +} + // stateBeginValueOrEmpty is the state after reading `[`. func stateBeginValueOrEmpty(s *scanner, c byte) int { if isSpace(c) { return scanSkipSpace } + if op := stateBeginComment(s, c, stateBeginValueOrEmpty); op >= 0 { + return op + } if c == ']' { return stateEndValue(s, c) } @@ -216,6 +300,9 @@ func stateBeginValue(s *scanner, c byte) int { if isSpace(c) { return scanSkipSpace } + if op := stateBeginComment(s, c, stateBeginValue); op >= 0 { + return op + } switch c { case '{': s.step = stateBeginStringOrEmpty @@ -254,6 +341,9 @@ func stateBeginStringOrEmpty(s *scanner, c byte) int { if isSpace(c) { return scanSkipSpace } + if op := stateBeginComment(s, c, stateBeginStringOrEmpty); op >= 0 { + return op + } if c == '}' { n := len(s.parseState) s.parseState[n-1] = parseObjectValue @@ -267,6 +357,9 @@ func stateBeginString(s *scanner, c byte) int { if isSpace(c) { return scanSkipSpace } + if op := stateBeginComment(s, c, stateBeginString); op >= 0 { + return op + } if c == '"' { s.step = stateInString return scanBeginLiteral @@ -288,6 +381,9 @@ func stateEndValue(s *scanner, c byte) int { s.step = stateEndValue return scanSkipSpace } + if op := stateBeginComment(s, c, stateEndValue); op >= 0 { + return op + } ps := s.parseState[n-1] switch ps { case parseObjectKey: @@ -326,6 +422,12 @@ func stateEndValue(s *scanner, c byte) int { // such as after reading `{}` or `[1,2,3]`. // Only space characters should be seen now. func stateEndTop(s *scanner, c byte) int { + if isSpace(c) { + return scanEnd + } + if op := stateBeginComment(s, c, stateEndTop); op >= 0 { + return op + } if !isSpace(c) { // Complain about non-space byte on next call. s.error(c, "after top-level value") diff --git a/common/json/internal/contextjson/stream.go b/common/json/internal/contextjson/stream.go index 59a299f69..793b31c5f 100644 --- a/common/json/internal/contextjson/stream.go +++ b/common/json/internal/contextjson/stream.go @@ -69,7 +69,12 @@ func (dec *Decoder) Decode(v any) error { if err != nil { return err } - dec.d.init(dec.buf[dec.scanp : dec.scanp+n]) + data, comments, err := stripJSONComments(dec.buf[dec.scanp : dec.scanp+n]) + if err != nil { + return err + } + dec.d.init(data) + dec.d.comments = comments dec.scanp += n // Don't save err from unmarshal into dec.err: @@ -516,11 +521,42 @@ func (dec *Decoder) More() bool { func (dec *Decoder) peek() (byte, error) { var err error for { - for i := dec.scanp; i < len(dec.buf); i++ { + Buffer: + for i := dec.scanp; i < len(dec.buf); { c := dec.buf[i] if isSpace(c) { + i++ continue } + if c == '#' { + next, ok := dec.skipLineCommentInBuffer(i + 1) + if !ok { + break + } + i = next + continue + } + if c == '/' { + if i+1 >= len(dec.buf) { + break Buffer + } + switch dec.buf[i+1] { + case '/': + next, ok := dec.skipLineCommentInBuffer(i + 2) + if !ok { + break Buffer + } + i = next + continue + case '*': + next, ok := dec.skipBlockCommentInBuffer(i + 2) + if !ok { + break Buffer + } + i = next + continue + } + } dec.scanp = i return c, nil } @@ -551,6 +587,26 @@ func (dec *Decoder) peekNoRefill() (byte, error) { } } +func (dec *Decoder) skipLineCommentInBuffer(i int) (int, bool) { + for i < len(dec.buf) { + if dec.buf[i] == '\n' || dec.buf[i] == '\r' { + return i + 1, true + } + i++ + } + return i, false +} + +func (dec *Decoder) skipBlockCommentInBuffer(i int) (int, bool) { + for i+1 < len(dec.buf) { + if dec.buf[i] == '*' && dec.buf[i+1] == '/' { + return i + 2, true + } + i++ + } + return i, false +} + // InputOffset returns the input stream byte offset of the current decoder position. // The offset gives the location of the end of the most recently returned token // and the beginning of the next token. diff --git a/common/json/internal/contextjson/unmarshal.go b/common/json/internal/contextjson/unmarshal.go index 04c13cbe9..dff33a9fa 100644 --- a/common/json/internal/contextjson/unmarshal.go +++ b/common/json/internal/contextjson/unmarshal.go @@ -3,24 +3,34 @@ package json import "context" func UnmarshalDisallowUnknownFields(data []byte, v any) error { + data, comments, err := stripJSONComments(data) + if err != nil { + return err + } var d decodeState d.disallowUnknownFields = true - err := checkValid(data, &d.scan) + err = checkValid(data, &d.scan) if err != nil { return err } d.init(data) + d.comments = comments return d.unmarshal(v) } func UnmarshalContextDisallowUnknownFields(ctx context.Context, data []byte, v any) error { + data, comments, err := stripJSONComments(data) + if err != nil { + return err + } var d decodeState d.ctx = ctx d.disallowUnknownFields = true - err := checkValid(data, &d.scan) + err = checkValid(data, &d.scan) if err != nil { return err } d.init(data) + d.comments = comments return d.unmarshal(v) } diff --git a/common/json/std.go b/common/json/std.go deleted file mode 100644 index adedde4d8..000000000 --- a/common/json/std.go +++ /dev/null @@ -1,21 +0,0 @@ -//go:build without_contextjson - -package json - -import "encoding/json" - -var ( - Marshal = json.Marshal - Unmarshal = json.Unmarshal - NewEncoder = json.NewEncoder - NewDecoder = json.NewDecoder -) - -type ( - Encoder = json.Encoder - Decoder = json.Decoder - Token = json.Token - Delim = json.Delim - SyntaxError = json.SyntaxError - RawMessage = json.RawMessage -) diff --git a/common/json/unmarshal.go b/common/json/unmarshal.go index 94a2d7649..d1d0837d5 100644 --- a/common/json/unmarshal.go +++ b/common/json/unmarshal.go @@ -1,7 +1,6 @@ package json import ( - "bytes" "context" "errors" "strings" @@ -15,9 +14,8 @@ func UnmarshalExtended[T any](content []byte) (T, error) { } func UnmarshalExtendedContext[T any](ctx context.Context, content []byte) (T, error) { - decoder := NewDecoderContext(ctx, NewCommentFilter(bytes.NewReader(content))) var value T - err := decoder.Decode(&value) + err := UnmarshalContext(ctx, content, &value) if err == nil { return value, err } diff --git a/common/json/unmarshal_context.go b/common/json/unmarshal_context.go index 3b9b7f774..821d4684a 100644 --- a/common/json/unmarshal_context.go +++ b/common/json/unmarshal_context.go @@ -1,5 +1,3 @@ -//go:build !without_contextjson - package json import ( diff --git a/common/json/unmarshal_std.go b/common/json/unmarshal_std.go deleted file mode 100644 index bcbae4e4a..000000000 --- a/common/json/unmarshal_std.go +++ /dev/null @@ -1,13 +0,0 @@ -//go:build without_contextjson - -package json - -import ( - "bytes" -) - -func UnmarshalDisallowUnknownFields(content []byte, value any) error { - decoder := NewDecoder(bytes.NewReader(content)) - decoder.DisallowUnknownFields() - return decoder.Decode(value) -} From e8427b7b7ce079213bcb3bafb47d5f769ef41546 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E4=B8=96=E7=95=8C?= Date: Tue, 28 Apr 2026 07:09:26 +0800 Subject: [PATCH 5/7] Add package cleanup --- common/cleanup/cleanup.go | 94 +++++++++++++++++++ common/cleanup/cleanup_test.go | 166 +++++++++++++++++++++++++++++++++ 2 files changed, 260 insertions(+) create mode 100644 common/cleanup/cleanup.go create mode 100644 common/cleanup/cleanup_test.go diff --git a/common/cleanup/cleanup.go b/common/cleanup/cleanup.go new file mode 100644 index 000000000..294efab6f --- /dev/null +++ b/common/cleanup/cleanup.go @@ -0,0 +1,94 @@ +package cleanup + +import ( + "runtime" + "sync" + _ "unsafe" +) + +func init() { + registerPoolCleanup(myCleanup) +} + +//go:linkname registerPoolCleanup sync.runtime_registerPoolCleanup +func registerPoolCleanup(cleanup func()) + +//go:linkname poolCleanup sync.poolCleanup +func poolCleanup() + +var unsafeCleanupFuncs []func() + +func myCleanup() { + poolCleanup() + for _, cleanupFunc := range unsafeCleanupFuncs { + cleanupFunc() + } +} + +// AddUnsafe must be called only in init {} +// called in STW, must not allocate or hold a lock +func AddUnsafe(cleanup func()) { + unsafeCleanupFuncs = append(unsafeCleanupFuncs, cleanup) +} + +type Cleaner struct { + access sync.Mutex + state *cleanupState +} + +func Add(cleanup func()) *Cleaner { + state := &cleanupState{ + cleanupFunc: cleanup, + } + newObject(state) + return &Cleaner{ + state: state, + } +} + +func (c *Cleaner) Close() { + c.access.Lock() + defer c.access.Unlock() + if c.state == nil { + return + } + c.state.close() + c.state = nil +} + +type cleanupState struct { + access sync.Mutex + cleanupFunc func() + closed bool +} + +func (s *cleanupState) close() { + s.access.Lock() + defer s.access.Unlock() + s.closed = true +} + +type Object struct { + state *cleanupState +} + +func newObject(state *cleanupState) { + object := &Object{ + state: state, + } + runtime.SetFinalizer(object, (*Object).cleanup) +} + +func (o *Object) cleanup() { + state := o.state + state.access.Lock() + if state.closed { + state.access.Unlock() + return + } + + newObject(state) + cleanupFunc := state.cleanupFunc + state.access.Unlock() + cleanupFunc() +} diff --git a/common/cleanup/cleanup_test.go b/common/cleanup/cleanup_test.go new file mode 100644 index 000000000..38111da6d --- /dev/null +++ b/common/cleanup/cleanup_test.go @@ -0,0 +1,166 @@ +//nolint:paralleltest +package cleanup + +import ( + "runtime" + "runtime/debug" + "sync" + "sync/atomic" + "testing" + "time" + + "github.com/stretchr/testify/require" +) + +func TestCleanup(t *testing.T) { + var didCleanup atomic.Int32 + Add(func() { + didCleanup.Add(1) + }) + runCleanup(t, func() bool { + return didCleanup.Load() >= 1 + }) + runCleanup(t, func() bool { + return didCleanup.Load() >= 2 + }) +} + +func TestCleanupUnsafe(t *testing.T) { + var didCleanup atomic.Int32 + AddUnsafe(func() { + didCleanup.Add(1) + }) + runCleanup(t, func() bool { + return didCleanup.Load() >= 1 + }) + runCleanup(t, func() bool { + return didCleanup.Load() >= 2 + }) +} + +func TestCleanup1(t *testing.T) { + var didCleanup atomic.Bool + var didReset atomic.Bool + m := newMap(func() { + didCleanup.Store(true) + }) + Add(func() { + *m = sync.Map{} + didReset.Store(true) + }) + runtime.KeepAlive(&m) + runCleanup(t, func() bool { + return didReset.Load() + }) + require.False(t, didCleanup.Load()) + runCleanup(t, func() bool { + return didCleanup.Load() + }) +} + +func TestCleanup1Unsafe(t *testing.T) { + var didCleanup atomic.Bool + m := newMap(func() { + didCleanup.Store(true) + }) + AddUnsafe(func() { + *m = sync.Map{} + }) + runtime.KeepAlive(&m) + runCleanup(t, func() bool { + return didCleanup.Load() + }) +} + +type myObj struct { + _ string +} + +func newMap(cleanup func()) *sync.Map { + obj := &myObj{} + runtime.SetFinalizer(obj, func(*myObj) { + cleanup() + }) + var m sync.Map + m.Store("test", obj) + return &m +} + +func TestSafeCleanup(t *testing.T) { + var didCleanup atomic.Int32 + safeObject := Add(func() { + didCleanup.Add(1) + }) + runCleanup(t, func() bool { + return didCleanup.Load() >= 1 + }) + safeObject.Close() + didCleanupAfterClose := didCleanup.Load() + debug.FreeOSMemory() + require.Never(t, func() bool { + return didCleanup.Load() != didCleanupAfterClose + }, cleanupWait, cleanupTick) +} + +func TestCloseSkipsQueuedCleanup(t *testing.T) { + blockCleanup := make(chan struct{}) + cleanupBlocked := make(chan struct{}) + var blockOnce sync.Once + blocker := Add(func() { + blockOnce.Do(func() { + close(cleanupBlocked) + <-blockCleanup + }) + }) + runCleanup(t, func() bool { + select { + case <-cleanupBlocked: + return true + default: + return false + } + }) + + var didCleanup atomic.Bool + queued := Add(func() { + didCleanup.Store(true) + }) + debug.FreeOSMemory() + queued.Close() + + close(blockCleanup) + blocker.Close() + require.Never(t, didCleanup.Load, cleanupWait, cleanupTick) +} + +func TestCleanupCanCloseItself(t *testing.T) { + ready := make(chan struct{}) + done := make(chan struct{}) + var cleaner *Cleaner + cleaner = Add(func() { + <-ready + cleaner.Close() + close(done) + }) + close(ready) + + runCleanup(t, func() bool { + select { + case <-done: + return true + default: + return false + } + }) +} + +const ( + cleanupWait = time.Second + cleanupTick = time.Millisecond +) + +func runCleanup(t *testing.T, condition func() bool) { + t.Helper() + debug.FreeOSMemory() + require.Eventually(t, condition, cleanupWait, cleanupTick) +} From 2bc976d03e39099a21823978288f161c8af4f1f6 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E4=B8=96=E7=95=8C?= Date: Tue, 28 Apr 2026 16:41:02 +0800 Subject: [PATCH 6/7] contextjson: Fix crash --- .../json/internal/contextjson/context_test.go | 96 +++++++++++++++++++ common/json/internal/contextjson/stream.go | 62 +++++------- 2 files changed, 121 insertions(+), 37 deletions(-) diff --git a/common/json/internal/contextjson/context_test.go b/common/json/internal/contextjson/context_test.go index 81f7af5c5..dc6e83b19 100644 --- a/common/json/internal/contextjson/context_test.go +++ b/common/json/internal/contextjson/context_test.go @@ -292,6 +292,102 @@ func TestTrailingCommaToken(t *testing.T) { } } +func TestTokenMoreAfterCommaAcrossRefill(t *testing.T) { + t.Parallel() + // Positions the comma at decoder buffer offset 500 and the next value + // beyond the initial 512-byte read. + value := strings.Repeat("a", 497) + decoder := json.NewDecoder(strings.NewReader(`["` + value + `",` + strings.Repeat(" ", 11) + `2]`)) + token, err := decoder.Token() + if err != nil { + t.Fatal(err) + } + if token != json.Delim('[') { + t.Fatalf("start token = %v", token) + } + if !decoder.More() { + t.Fatal("expected first array value") + } + token, err = decoder.Token() + if err != nil { + t.Fatal(err) + } + if token != value { + t.Fatalf("first value = %#v", token) + } + if !decoder.More() { + t.Fatal("expected second array value") + } + token, err = decoder.Token() + if err != nil { + t.Fatal(err) + } + if token != float64(2) { + t.Fatalf("second value = %#v", token) + } + if decoder.More() { + t.Fatal("expected end of array") + } + token, err = decoder.Token() + if err != nil { + t.Fatal(err) + } + if token != json.Delim(']') { + t.Fatalf("end token = %v", token) + } +} + +func TestTrailingCommaTokenAcrossRefill(t *testing.T) { + t.Parallel() + value := strings.Repeat("a", 497) + tests := []struct { + name string + suffix string + }{ + { + name: "spaces", + suffix: strings.Repeat(" ", 11) + `]`, + }, + { + name: "line comment", + suffix: strings.Repeat(" ", 11) + "// trailing\n]", + }, + } + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + t.Parallel() + decoder := json.NewDecoder(strings.NewReader(`["` + value + `",` + test.suffix)) + token, err := decoder.Token() + if err != nil { + t.Fatal(err) + } + if token != json.Delim('[') { + t.Fatalf("start token = %v", token) + } + if !decoder.More() { + t.Fatal("expected array value") + } + token, err = decoder.Token() + if err != nil { + t.Fatal(err) + } + if token != value { + t.Fatalf("value = %#v", token) + } + if decoder.More() { + t.Fatal("expected trailing comma to end array") + } + token, err = decoder.Token() + if err != nil { + t.Fatal(err) + } + if token != json.Delim(']') { + t.Fatalf("end token = %v", token) + } + }) + } +} + func TestTrailingComma(t *testing.T) { t.Parallel() var array []int diff --git a/common/json/internal/contextjson/stream.go b/common/json/internal/contextjson/stream.go index 793b31c5f..bd6da2cf6 100644 --- a/common/json/internal/contextjson/stream.go +++ b/common/json/internal/contextjson/stream.go @@ -504,10 +504,7 @@ func (dec *Decoder) More() bool { return false } if c == ',' { - scanp := dec.scanp - dec.scanp++ - c, err = dec.peekNoRefill() - dec.scanp = scanp + c, _, err = dec.peekFrom(dec.scanp + 1) if err != nil { return false } @@ -519,71 +516,62 @@ func (dec *Decoder) More() bool { } func (dec *Decoder) peek() (byte, error) { + c, scanp, err := dec.peekFrom(dec.scanp) + if err != nil { + return 0, err + } + dec.scanp = scanp + return c, nil +} + +func (dec *Decoder) peekFrom(scanp int) (byte, int, error) { var err error for { Buffer: - for i := dec.scanp; i < len(dec.buf); { - c := dec.buf[i] + for scanp < len(dec.buf) { + c := dec.buf[scanp] if isSpace(c) { - i++ + scanp++ continue } if c == '#' { - next, ok := dec.skipLineCommentInBuffer(i + 1) + next, ok := dec.skipLineCommentInBuffer(scanp + 1) if !ok { break } - i = next + scanp = next continue } if c == '/' { - if i+1 >= len(dec.buf) { + if scanp+1 >= len(dec.buf) { break Buffer } - switch dec.buf[i+1] { + switch dec.buf[scanp+1] { case '/': - next, ok := dec.skipLineCommentInBuffer(i + 2) + next, ok := dec.skipLineCommentInBuffer(scanp + 2) if !ok { break Buffer } - i = next + scanp = next continue case '*': - next, ok := dec.skipBlockCommentInBuffer(i + 2) + next, ok := dec.skipBlockCommentInBuffer(scanp + 2) if !ok { break Buffer } - i = next + scanp = next continue } } - dec.scanp = i - return c, nil + return c, scanp, nil } // buffer has been scanned, now report any error if err != nil { - return 0, err - } - err = dec.refill() - } -} - -func (dec *Decoder) peekNoRefill() (byte, error) { - var err error - for { - for i := dec.scanp; i < len(dec.buf); i++ { - c := dec.buf[i] - if isSpace(c) { - continue - } - dec.scanp = i - return c, nil - } - // buffer has been scanned, now report any error - if err != nil { - return 0, err + return 0, 0, err } + oldScanp := dec.scanp err = dec.refill() + scanp -= oldScanp } } From bea407e4972bdf72d08ce97c93989d870cf38ec0 Mon Sep 17 00:00:00 2001 From: Ken Date: Wed, 13 May 2026 15:20:10 +0800 Subject: [PATCH 7/7] http: don't keep alive close-delimited proxy responses --- protocol/http/handshake.go | 33 ++++++++++++++- protocol/http/handshake_test.go | 71 +++++++++++++++++++++++++++++++++ 2 files changed, 102 insertions(+), 2 deletions(-) create mode 100644 protocol/http/handshake_test.go diff --git a/protocol/http/handshake.go b/protocol/http/handshake.go index 4320061a4..113bb74b6 100644 --- a/protocol/http/handshake.go +++ b/protocol/http/handshake.go @@ -147,7 +147,7 @@ func handleHTTPConnection( conn net.Conn, request *http.Request, source M.Socksaddr, ) error { - keepAlive := !(request.ProtoMajor == 1 && request.ProtoMinor == 0) && strings.TrimSpace(strings.ToLower(request.Header.Get("Proxy-Connection"))) == "keep-alive" + keepAlive := requestWantsHTTPProxyKeepAlive(request) request.RequestURI = "" removeHopByHopHeaders(request.Header) @@ -189,8 +189,9 @@ func handleHTTPConnection( return E.Errors(innerErr.Load(), err, responseWith(request, http.StatusBadGateway).Write(conn)) } - removeHopByHopHeaders(response.Header) + keepAlive = keepAlive && canKeepAliveHTTPProxyResponse(request, response) + removeHopByHopHeaders(response.Header) if keepAlive { response.Header.Set("Proxy-Connection", "keep-alive") response.Header.Set("Connection", "keep-alive") @@ -212,6 +213,34 @@ func handleHTTPConnection( return nil } +func requestWantsHTTPProxyKeepAlive(request *http.Request) bool { + return !(request.ProtoMajor == 1 && request.ProtoMinor == 0) && + strings.TrimSpace(strings.ToLower(request.Header.Get("Proxy-Connection"))) == "keep-alive" +} + +func canKeepAliveHTTPProxyResponse(request *http.Request, response *http.Response) bool { + if response.Close { + return false + } + if responseHasCloseDelimitedBody(request, response) { + return false + } + return true +} + +func responseHasCloseDelimitedBody(request *http.Request, response *http.Response) bool { + if request.Method == http.MethodHead { + return false + } + if response.StatusCode >= 100 && response.StatusCode <= 199 { + return false + } + if response.StatusCode == http.StatusNoContent || response.StatusCode == http.StatusNotModified { + return false + } + return response.ContentLength < 0 && len(response.TransferEncoding) == 0 +} + func removeHopByHopHeaders(header http.Header) { // Strip hop-by-hop header based on RFC: // http://www.w3.org/Protocols/rfc2616/rfc2616-sec13.html#sec13.5.1 diff --git a/protocol/http/handshake_test.go b/protocol/http/handshake_test.go new file mode 100644 index 000000000..9c03572d2 --- /dev/null +++ b/protocol/http/handshake_test.go @@ -0,0 +1,71 @@ +package http + +import ( + std_http "net/http" + "testing" +) + +func TestCanKeepAliveHTTPProxyResponseRejectsCloseDelimitedHTTP10Body(t *testing.T) { + request := &std_http.Request{ + Method: std_http.MethodGet, + ProtoMajor: 1, + ProtoMinor: 1, + Header: std_http.Header{"Proxy-Connection": []string{"keep-alive"}}, + } + + response := &std_http.Response{ + StatusCode: std_http.StatusUnauthorized, + ProtoMajor: 1, + ProtoMinor: 0, + Header: std_http.Header{"WWW-Authenticate": []string{`Basic realm="K2P"`}}, + ContentLength: -1, + Close: true, + } + + if !requestWantsHTTPProxyKeepAlive(request) { + t.Fatal("request should ask for HTTP proxy keep-alive") + } + if canKeepAliveHTTPProxyResponse(request, response) { + t.Fatal("HTTP proxy must not keep alive a close-delimited HTTP/1.0 response") + } +} + +func TestCanKeepAliveHTTPProxyResponseAllowsLengthDelimitedBody(t *testing.T) { + request := &std_http.Request{ + Method: std_http.MethodGet, + ProtoMajor: 1, + ProtoMinor: 1, + Header: std_http.Header{"Proxy-Connection": []string{"keep-alive"}}, + } + + response := &std_http.Response{ + StatusCode: std_http.StatusOK, + ProtoMajor: 1, + ProtoMinor: 1, + ContentLength: 2, + Close: false, + } + + if !canKeepAliveHTTPProxyResponse(request, response) { + t.Fatal("HTTP proxy should keep alive a length-delimited reusable response") + } +} + +func TestResponseHasCloseDelimitedBodyIgnoresNoBodyResponses(t *testing.T) { + request := &std_http.Request{Method: std_http.MethodGet} + + for _, statusCode := range []int{ + std_http.StatusContinue, + std_http.StatusSwitchingProtocols, + std_http.StatusNoContent, + std_http.StatusNotModified, + } { + response := &std_http.Response{ + StatusCode: statusCode, + ContentLength: -1, + } + if responseHasCloseDelimitedBody(request, response) { + t.Fatalf("status %d must not be treated as close-delimited body", statusCode) + } + } +}