Skip to content

Commit 9be7806

Browse files
committed
Improve read waiter interface
1 parent ab3e469 commit 9be7806

2 files changed

Lines changed: 51 additions & 32 deletions

File tree

common/bufio/copy_direct_posix.go

Lines changed: 47 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ import (
66
"errors"
77
"io"
88
"net/netip"
9+
"os"
910
"syscall"
1011

1112
"github.com/sagernet/sing/common/buf"
@@ -25,24 +26,21 @@ func copyWaitWithPool(originDestination io.Writer, destination N.ExtendedWriter,
2526
bufferSize = buf.BufferSize
2627
}
2728
var (
28-
buffer *buf.Buffer
29-
readBuffer *buf.Buffer
29+
buffer *buf.Buffer
30+
readBuffer *buf.Buffer
31+
notFirstTime bool
3032
)
31-
newBuffer := func() *buf.Buffer {
32-
if buffer != nil {
33-
buffer.Release()
34-
}
33+
source.InitializeReadWaiter(func() *buf.Buffer {
3534
buffer = buf.NewSize(bufferSize)
3635
readBufferRaw := buffer.Slice()
3736
readBuffer = buf.With(readBufferRaw[:len(readBufferRaw)-rearHeadroom])
3837
readBuffer.Resize(frontHeadroom, 0)
3938
return readBuffer
40-
}
41-
var notFirstTime bool
39+
})
40+
defer source.InitializeReadWaiter(nil)
4241
for {
43-
err = source.WaitReadBuffer(newBuffer)
42+
err = source.WaitReadBuffer()
4443
if err != nil {
45-
buffer.Release()
4644
if errors.Is(err, io.EOF) {
4745
err = nil
4846
return
@@ -56,9 +54,7 @@ func copyWaitWithPool(originDestination io.Writer, destination N.ExtendedWriter,
5654
buffer.Resize(readBuffer.Start(), dataLen)
5755
err = destination.WriteBuffer(buffer)
5856
if err != nil {
59-
if buffer != nil {
60-
buffer.Release()
61-
}
57+
buffer.Release()
6258
return
6359
}
6460
n += int64(dataLen)
@@ -83,25 +79,22 @@ func copyPacketWaitWithPool(destinationConn N.PacketWriter, source N.PacketReadW
8379
bufferSize = buf.UDPBufferSize
8480
}
8581
var (
86-
buffer *buf.Buffer
87-
readBuffer *buf.Buffer
82+
buffer *buf.Buffer
83+
readBuffer *buf.Buffer
84+
destination M.Socksaddr
85+
notFirstTime bool
8886
)
89-
newBuffer := func() *buf.Buffer {
90-
if buffer != nil {
91-
buffer.Release()
92-
}
87+
source.InitializeReadWaiter(func() *buf.Buffer {
9388
buffer = buf.NewSize(bufferSize)
9489
readBufferRaw := buffer.Slice()
9590
readBuffer = buf.With(readBufferRaw[:len(readBufferRaw)-rearHeadroom])
9691
readBuffer.Resize(frontHeadroom, 0)
9792
return readBuffer
98-
}
99-
var destination M.Socksaddr
100-
var notFirstTime bool
93+
})
94+
defer source.InitializeReadWaiter(nil)
10195
for {
102-
destination, err = source.WaitReadPacket(newBuffer)
96+
destination, err = source.WaitReadPacket()
10397
if err != nil {
104-
buffer.Release()
10598
if !notFirstTime {
10699
err = N.HandshakeFailure(destinationConn, err)
107100
}
@@ -113,8 +106,6 @@ func copyPacketWaitWithPool(destinationConn N.PacketWriter, source N.PacketReadW
113106
if err != nil {
114107
buffer.Release()
115108
return
116-
} else {
117-
buffer = nil
118109
}
119110
n += int64(dataLen)
120111
for _, counter := range readCounters {
@@ -127,6 +118,8 @@ func copyPacketWaitWithPool(destinationConn N.PacketWriter, source N.PacketReadW
127118
}
128119
}
129120

121+
var _ N.ReadWaiter = (*syscallReadWaiter)(nil)
122+
130123
type syscallReadWaiter struct {
131124
rawConn syscall.RawConn
132125
readErr error
@@ -143,8 +136,11 @@ func createSyscallReadWaiter(reader any) (*syscallReadWaiter, bool) {
143136
return nil, false
144137
}
145138

146-
func (w *syscallReadWaiter) WaitReadBuffer(newBuffer func() *buf.Buffer) error {
147-
if w.readFunc == nil {
139+
func (w *syscallReadWaiter) InitializeReadWaiter(newBuffer func() *buf.Buffer) {
140+
w.readErr = nil
141+
if newBuffer == nil {
142+
w.readFunc = nil
143+
} else {
148144
w.readFunc = func(fd uintptr) (done bool) {
149145
buffer := newBuffer()
150146
var readN int
@@ -164,16 +160,27 @@ func (w *syscallReadWaiter) WaitReadBuffer(newBuffer func() *buf.Buffer) error {
164160
return true
165161
}
166162
}
163+
}
164+
165+
func (w *syscallReadWaiter) WaitReadBuffer() error {
166+
if w.readFunc == nil {
167+
return os.ErrInvalid
168+
}
167169
err := w.rawConn.Read(w.readFunc)
168170
if err != nil {
169171
return err
170172
}
171173
if w.readErr != nil {
174+
if w.readErr == io.EOF {
175+
return io.EOF
176+
}
172177
return E.Cause(w.readErr, "raw read")
173178
}
174179
return nil
175180
}
176181

182+
var _ N.PacketReadWaiter = (*syscallPacketReadWaiter)(nil)
183+
177184
type syscallPacketReadWaiter struct {
178185
rawConn syscall.RawConn
179186
readErr error
@@ -191,8 +198,12 @@ func createSyscallPacketReadWaiter(reader any) (*syscallPacketReadWaiter, bool)
191198
return nil, false
192199
}
193200

194-
func (w *syscallPacketReadWaiter) WaitReadPacket(newBuffer func() *buf.Buffer) (destination M.Socksaddr, err error) {
195-
if w.readFunc == nil {
201+
func (w *syscallPacketReadWaiter) InitializeReadWaiter(newBuffer func() *buf.Buffer) {
202+
w.readErr = nil
203+
w.readFrom = M.Socksaddr{}
204+
if newBuffer == nil {
205+
w.readFunc = nil
206+
} else {
196207
w.readFunc = func(fd uintptr) (done bool) {
197208
buffer := newBuffer()
198209
var readN int
@@ -221,6 +232,12 @@ func (w *syscallPacketReadWaiter) WaitReadPacket(newBuffer func() *buf.Buffer) (
221232
return true
222233
}
223234
}
235+
}
236+
237+
func (w *syscallPacketReadWaiter) WaitReadPacket() (destination M.Socksaddr, err error) {
238+
if w.readFunc == nil {
239+
return M.Socksaddr{}, os.ErrInvalid
240+
}
224241
err = w.rawConn.Read(w.readFunc)
225242
if err != nil {
226243
return

common/network/direct.go

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,15 +6,17 @@ import (
66
)
77

88
type ReadWaiter interface {
9-
WaitReadBuffer(newBuffer func() *buf.Buffer) error
9+
InitializeReadWaiter(newBuffer func() *buf.Buffer)
10+
WaitReadBuffer() error
1011
}
1112

1213
type ReadWaitCreator interface {
1314
CreateReadWaiter() (ReadWaiter, bool)
1415
}
1516

1617
type PacketReadWaiter interface {
17-
WaitReadPacket(newBuffer func() *buf.Buffer) (destination M.Socksaddr, err error)
18+
InitializeReadWaiter(newBuffer func() *buf.Buffer)
19+
WaitReadPacket() (destination M.Socksaddr, err error)
1820
}
1921

2022
type PacketReadWaitCreator interface {

0 commit comments

Comments
 (0)