Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 24 additions & 2 deletions common/bufio/bind.go
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,14 @@ func (c *bindPacketConn) CreateReadWaiter() (N.ReadWaiter, bool) {
return &bindPacketReadWaiter{readWaiter}, true
}

func (c *bindPacketConn) CreatePacketBatchReadWaiter() (N.PacketBatchReadWaiter, bool) {
return CreatePacketBatchReadWaiter(c.NetPacketConn)
}

func (c *bindPacketConn) CreatePacketBatchWriter() (N.PacketBatchWriter, bool) {
return CreatePacketBatchWriter(c.NetPacketConn)
}

func (c *bindPacketConn) RemoteAddr() net.Addr {
return c.addr
}
Expand All @@ -51,8 +59,10 @@ func (c *bindPacketConn) Upstream() any {
}

var (
_ N.NetPacketConn = (*UnbindPacketConn)(nil)
_ N.PacketReadWaitCreator = (*UnbindPacketConn)(nil)
_ N.NetPacketConn = (*UnbindPacketConn)(nil)
_ N.PacketReadWaitCreator = (*UnbindPacketConn)(nil)
_ N.ConnectedPacketBatchReadWaitCreator = (*UnbindPacketConn)(nil)
_ N.ConnectedPacketBatchWriteCreator = (*UnbindPacketConn)(nil)
)

type UnbindPacketConn struct {
Expand Down Expand Up @@ -107,6 +117,14 @@ func (c *UnbindPacketConn) CreateReadWaiter() (N.PacketReadWaiter, bool) {
return &unbindPacketReadWaiter{readWaiter, c.addr}, true
}

func (c *UnbindPacketConn) CreateConnectedPacketBatchReadWaiter() (N.ConnectedPacketBatchReadWaiter, bool) {
return createSyscallConnectedPacketBatchReadWaiter(c.ExtendedConn, c.addr)
}

func (c *UnbindPacketConn) CreateConnectedPacketBatchWriter() (N.ConnectedPacketBatchWriter, bool) {
return createSyscallConnectedPacketBatchWriter(c.ExtendedConn)
}

func (c *UnbindPacketConn) Upstream() any {
return c.ExtendedConn
}
Expand Down Expand Up @@ -163,3 +181,7 @@ func (c *serverPacketConn) CreateReadWaiter() (N.ReadWaiter, bool) {
}
return &serverPacketReadWaiter{c, readWaiter}, true
}

func (c *serverPacketConn) CreatePacketBatchReadWaiter() (N.PacketBatchReadWaiter, bool) {
return CreatePacketBatchReadWaiter(c.NetPacketConn)
}
195 changes: 175 additions & 20 deletions common/bufio/copy.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ import (
const (
DefaultIncreaseBufferAfter = 512 * 1000
DefaultBatchSize = 8
DefaultPacketReadBatchSize = 64
)

func Copy(destination io.Writer, source io.Reader) (n int64, err error) {
Expand Down Expand Up @@ -259,53 +260,201 @@ func CopyConn(ctx context.Context, source net.Conn, destination net.Conn) error

func CopyPacket(destinationConn N.PacketWriter, source N.PacketReader) (n int64, err error) {
var readCounters, writeCounters []N.CountFunc
var cachedPackets []*N.PacketBuffer
originSource := source
for {
source, readCounters = N.UnwrapCountPacketReader(source, readCounters)
destinationConn, writeCounters = N.UnwrapCountPacketWriter(destinationConn, writeCounters)
if cachedReader, isCached := source.(N.CachedPacketReader); isCached {
packet := cachedReader.ReadCachedPacket()
if packet != nil {
cachedPackets = append(cachedPackets, packet)
var cachedN int64
cachedN, err = writePacketWithPool(originSource, destinationConn, []*N.PacketBuffer{packet}, readCounters, writeCounters, n > 0)
n += cachedN
if err != nil {
return
}
continue
}
}
break
}
if cachedPackets != nil {
n, err = WritePacketWithPool(originSource, destinationConn, cachedPackets, readCounters, writeCounters)
if err != nil {
return
}
}
copeN, err := CopyPacketWithCounters(destinationConn, source, originSource, readCounters, writeCounters)
copeN, err := copyPacketWithCounters(destinationConn, source, originSource, readCounters, writeCounters, n > 0)
n += copeN
return
}

type packetCopySession struct {
destination N.PacketWriter
source N.PacketReader
originSource N.PacketReader
readCounters []N.CountFunc
writeCounters []N.CountFunc
handshake N.HandshakeState
upgradable bool
}

func newPacketCopySession(destination N.PacketWriter, source N.PacketReader, originSource N.PacketReader, readCounters []N.CountFunc, writeCounters []N.CountFunc) *packetCopySession {
session := &packetCopySession{
destination: destination,
source: source,
originSource: originSource,
readCounters: readCounters,
writeCounters: writeCounters,
}
session.ResetHandshake()
return session
}

func (s *packetCopySession) ResetHandshake() {
s.handshake = N.NewPacketHandshakeState(s.source, s.destination)
s.upgradable = s.handshake.Upgradable()
}

func (s *packetCopySession) Transfer(n int64) error {
for _, counter := range s.readCounters {
counter(n)
}
for _, counter := range s.writeCounters {
counter(n)
}
if s.upgradable {
return s.handshake.Check()
}
return nil
}

func (s *packetCopySession) TransferBatch(dataLens []int) error {
for _, dataLen := range dataLens {
n := int64(dataLen)
for _, counter := range s.readCounters {
counter(n)
}
for _, counter := range s.writeCounters {
counter(n)
}
}
if s.upgradable {
return s.handshake.Check()
}
return nil
}

func CopyPacketWithCounters(destinationConn N.PacketWriter, source N.PacketReader, originSource N.PacketReader, readCounters []N.CountFunc, writeCounters []N.CountFunc) (n int64, err error) {
return copyPacketWithCounters(destinationConn, source, originSource, readCounters, writeCounters, false)
}

func copyPacketWithCounters(destinationConn N.PacketWriter, source N.PacketReader, originSource N.PacketReader, readCounters []N.CountFunc, writeCounters []N.CountFunc, notFirstTime bool) (n int64, err error) {
sourceReader := source
destinationWriter := destinationConn
session := newPacketCopySession(destinationWriter, sourceReader, originSource, readCounters, writeCounters)
refreshUnwrap := func() {
sourceReader, readCounters = N.UnwrapCountPacketReader(sourceReader, readCounters)
destinationWriter, writeCounters = N.UnwrapCountPacketWriter(destinationWriter, writeCounters)
session.source = sourceReader
session.destination = destinationWriter
session.readCounters = readCounters
session.writeCounters = writeCounters
session.ResetHandshake()
}
for {
var copyN int64
copyN, err = copyPacketWithCountersOnce(session, notFirstTime || n > 0)
n += copyN
if errors.Is(err, N.ErrHandshakeCompleted) {
refreshUnwrap()
continue
}
return
}
}

func copyPacketWithCountersOnce(session *packetCopySession, notFirstTime bool) (n int64, err error) {
var (
handled bool
copeN int64
)
source := session.source
destinationConn := session.destination
batchReadWaiter, isBatchReadWaiter := CreatePacketBatchReadWaiter(source)
if isBatchReadWaiter {
batchWriter, isBatchWriter := CreatePacketBatchWriter(destinationConn)
if isBatchWriter {
readWaitOptions := N.NewReadWaitOptions(source, destinationConn)
readWaitOptions.BatchSize = DefaultPacketReadBatchSize
needCopy := batchReadWaiter.InitializeReadWaiter(readWaitOptions)
if !needCopy || common.LowMemory {
handled, copeN, err = copyPacketBatchWaitWithPool(session, batchWriter, batchReadWaiter, notFirstTime)
if handled {
n += copeN
return
}
}
}
connectedBatchWriter, isConnectedBatchWriter := CreateConnectedPacketBatchWriter(destinationConn)
if isConnectedBatchWriter {
readWaitOptions := N.NewReadWaitOptions(source, destinationConn)
readWaitOptions.BatchSize = DefaultPacketReadBatchSize
needCopy := batchReadWaiter.InitializeReadWaiter(readWaitOptions)
if !needCopy || common.LowMemory {
handled, copeN, err = copyPacketBatchToConnectedWaitWithPool(session, connectedBatchWriter, batchReadWaiter, notFirstTime)
if handled {
n += copeN
return
}
}
}
}
connectedBatchReadWaiter, isConnectedBatchReadWaiter := CreateConnectedPacketBatchReadWaiter(source)
if isConnectedBatchReadWaiter {
batchWriter, isBatchWriter := CreatePacketBatchWriter(destinationConn)
if isBatchWriter {
readWaitOptions := N.NewReadWaitOptions(source, destinationConn)
readWaitOptions.BatchSize = DefaultPacketReadBatchSize
needCopy := connectedBatchReadWaiter.InitializeReadWaiter(readWaitOptions)
if !needCopy || common.LowMemory {
handled, copeN, err = copyConnectedPacketBatchWaitWithPool(session, batchWriter, connectedBatchReadWaiter, notFirstTime)
if handled {
n += copeN
return
}
}
}
connectedBatchWriter, isConnectedBatchWriter := CreateConnectedPacketBatchWriter(destinationConn)
if isConnectedBatchWriter {
readWaitOptions := N.NewReadWaitOptions(source, destinationConn)
readWaitOptions.BatchSize = DefaultPacketReadBatchSize
needCopy := connectedBatchReadWaiter.InitializeReadWaiter(readWaitOptions)
if !needCopy || common.LowMemory {
handled, copeN, err = copyConnectedPacketBatchToConnectedWaitWithPool(session, connectedBatchWriter, connectedBatchReadWaiter, notFirstTime)
if handled {
n += copeN
return
}
}
}
}
readWaiter, isReadWaiter := CreatePacketReadWaiter(source)
if isReadWaiter {
needCopy := readWaiter.InitializeReadWaiter(N.NewReadWaitOptions(source, destinationConn))
if !needCopy || common.LowMemory {
handled, copeN, err = copyPacketWaitWithPool(originSource, destinationConn, readWaiter, readCounters, writeCounters, n > 0)
handled, copeN, err = copyPacketWaitWithPool(session, destinationConn, readWaiter, notFirstTime)
if handled {
n += copeN
return
}
}
}
copeN, err = CopyPacketWithPool(originSource, destinationConn, source, readCounters, writeCounters, n > 0)
copeN, err = copyPacketWithPool(session, destinationConn, source, notFirstTime)
n += copeN
return
}

func CopyPacketWithPool(originSource N.PacketReader, destination N.PacketWriter, source N.PacketReader, readCounters []N.CountFunc, writeCounters []N.CountFunc, notFirstTime bool) (n int64, err error) {
session := newPacketCopySession(destination, source, originSource, readCounters, writeCounters)
return copyPacketWithPool(session, destination, source, notFirstTime)
}

func copyPacketWithPool(session *packetCopySession, destination N.PacketWriter, source N.PacketReader, notFirstTime bool) (n int64, err error) {
options := N.NewReadWaitOptions(source, destination)
var destinationAddress M.Socksaddr
for {
Expand All @@ -321,24 +470,27 @@ func CopyPacketWithPool(originSource N.PacketReader, destination N.PacketWriter,
if err != nil {
buffer.Leak()
if !notFirstTime {
err = N.ReportHandshakeFailure(originSource, err)
handshakeErr := N.ReportHandshakeFailure(session.originSource, err)
if handshakeErr != nil {
err = handshakeErr
}
}
return
}
for _, counter := range readCounters {
counter(int64(dataLen))
}
for _, counter := range writeCounters {
counter(int64(dataLen))
}
n += int64(dataLen)
if err = session.Transfer(int64(dataLen)); err != nil {
return
}
notFirstTime = true
}
}

func WritePacketWithPool(originSource N.PacketReader, destination N.PacketWriter, packetBuffers []*N.PacketBuffer, readCounters []N.CountFunc, writeCounters []N.CountFunc) (n int64, err error) {
return writePacketWithPool(originSource, destination, packetBuffers, readCounters, writeCounters, false)
}

func writePacketWithPool(originSource N.PacketReader, destination N.PacketWriter, packetBuffers []*N.PacketBuffer, readCounters []N.CountFunc, writeCounters []N.CountFunc, notFirstTime bool) (n int64, err error) {
options := N.NewReadWaitOptions(nil, destination)
var notFirstTime bool
for _, packetBuffer := range packetBuffers {
buffer := options.Copy(packetBuffer.Buffer)
dataLen := buffer.Len()
Expand All @@ -347,7 +499,10 @@ func WritePacketWithPool(originSource N.PacketReader, destination N.PacketWriter
if err != nil {
buffer.Leak()
if !notFirstTime {
err = N.ReportHandshakeFailure(originSource, err)
handshakeErr := N.ReportHandshakeFailure(originSource, err)
if handshakeErr != nil {
err = handshakeErr
}
}
return
}
Expand Down
Loading