66 "errors"
77 "io"
88 "net/netip"
9+ "os"
910 "syscall"
1011
1112 "github.com/sagernet/sing/common/buf"
@@ -25,24 +26,21 @@ func copyWaitWithPool(originDestination io.Writer, destination N.ExtendedWriter,
2526 bufferSize = buf .BufferSize
2627 }
2728 var (
28- buffer * buf.Buffer
29- readBuffer * buf.Buffer
29+ buffer * buf.Buffer
30+ readBuffer * buf.Buffer
31+ notFirstTime bool
3032 )
31- newBuffer := func () * buf.Buffer {
32- if buffer != nil {
33- buffer .Release ()
34- }
33+ source .InitializeReadWaiter (func () * buf.Buffer {
3534 buffer = buf .NewSize (bufferSize )
3635 readBufferRaw := buffer .Slice ()
3736 readBuffer = buf .With (readBufferRaw [:len (readBufferRaw )- rearHeadroom ])
3837 readBuffer .Resize (frontHeadroom , 0 )
3938 return readBuffer
40- }
41- var notFirstTime bool
39+ })
40+ defer source . InitializeReadWaiter ( nil )
4241 for {
43- err = source .WaitReadBuffer (newBuffer )
42+ err = source .WaitReadBuffer ()
4443 if err != nil {
45- buffer .Release ()
4644 if errors .Is (err , io .EOF ) {
4745 err = nil
4846 return
@@ -56,9 +54,7 @@ func copyWaitWithPool(originDestination io.Writer, destination N.ExtendedWriter,
5654 buffer .Resize (readBuffer .Start (), dataLen )
5755 err = destination .WriteBuffer (buffer )
5856 if err != nil {
59- if buffer != nil {
60- buffer .Release ()
61- }
57+ buffer .Release ()
6258 return
6359 }
6460 n += int64 (dataLen )
@@ -83,25 +79,22 @@ func copyPacketWaitWithPool(destinationConn N.PacketWriter, source N.PacketReadW
8379 bufferSize = buf .UDPBufferSize
8480 }
8581 var (
86- buffer * buf.Buffer
87- readBuffer * buf.Buffer
82+ buffer * buf.Buffer
83+ readBuffer * buf.Buffer
84+ destination M.Socksaddr
85+ notFirstTime bool
8886 )
89- newBuffer := func () * buf.Buffer {
90- if buffer != nil {
91- buffer .Release ()
92- }
87+ source .InitializeReadWaiter (func () * buf.Buffer {
9388 buffer = buf .NewSize (bufferSize )
9489 readBufferRaw := buffer .Slice ()
9590 readBuffer = buf .With (readBufferRaw [:len (readBufferRaw )- rearHeadroom ])
9691 readBuffer .Resize (frontHeadroom , 0 )
9792 return readBuffer
98- }
99- var destination M.Socksaddr
100- var notFirstTime bool
93+ })
94+ defer source .InitializeReadWaiter (nil )
10195 for {
102- destination , err = source .WaitReadPacket (newBuffer )
96+ destination , err = source .WaitReadPacket ()
10397 if err != nil {
104- buffer .Release ()
10598 if ! notFirstTime {
10699 err = N .HandshakeFailure (destinationConn , err )
107100 }
@@ -113,8 +106,6 @@ func copyPacketWaitWithPool(destinationConn N.PacketWriter, source N.PacketReadW
113106 if err != nil {
114107 buffer .Release ()
115108 return
116- } else {
117- buffer = nil
118109 }
119110 n += int64 (dataLen )
120111 for _ , counter := range readCounters {
@@ -127,6 +118,8 @@ func copyPacketWaitWithPool(destinationConn N.PacketWriter, source N.PacketReadW
127118 }
128119}
129120
121+ var _ N.ReadWaiter = (* syscallReadWaiter )(nil )
122+
130123type syscallReadWaiter struct {
131124 rawConn syscall.RawConn
132125 readErr error
@@ -143,8 +136,11 @@ func createSyscallReadWaiter(reader any) (*syscallReadWaiter, bool) {
143136 return nil , false
144137}
145138
146- func (w * syscallReadWaiter ) WaitReadBuffer (newBuffer func () * buf.Buffer ) error {
147- if w .readFunc == nil {
139+ func (w * syscallReadWaiter ) InitializeReadWaiter (newBuffer func () * buf.Buffer ) {
140+ w .readErr = nil
141+ if newBuffer == nil {
142+ w .readFunc = nil
143+ } else {
148144 w .readFunc = func (fd uintptr ) (done bool ) {
149145 buffer := newBuffer ()
150146 var readN int
@@ -164,16 +160,27 @@ func (w *syscallReadWaiter) WaitReadBuffer(newBuffer func() *buf.Buffer) error {
164160 return true
165161 }
166162 }
163+ }
164+
165+ func (w * syscallReadWaiter ) WaitReadBuffer () error {
166+ if w .readFunc == nil {
167+ return os .ErrInvalid
168+ }
167169 err := w .rawConn .Read (w .readFunc )
168170 if err != nil {
169171 return err
170172 }
171173 if w .readErr != nil {
174+ if w .readErr == io .EOF {
175+ return io .EOF
176+ }
172177 return E .Cause (w .readErr , "raw read" )
173178 }
174179 return nil
175180}
176181
182+ var _ N.PacketReadWaiter = (* syscallPacketReadWaiter )(nil )
183+
177184type syscallPacketReadWaiter struct {
178185 rawConn syscall.RawConn
179186 readErr error
@@ -191,8 +198,12 @@ func createSyscallPacketReadWaiter(reader any) (*syscallPacketReadWaiter, bool)
191198 return nil , false
192199}
193200
194- func (w * syscallPacketReadWaiter ) WaitReadPacket (newBuffer func () * buf.Buffer ) (destination M.Socksaddr , err error ) {
195- if w .readFunc == nil {
201+ func (w * syscallPacketReadWaiter ) InitializeReadWaiter (newBuffer func () * buf.Buffer ) {
202+ w .readErr = nil
203+ w .readFrom = M.Socksaddr {}
204+ if newBuffer == nil {
205+ w .readFunc = nil
206+ } else {
196207 w .readFunc = func (fd uintptr ) (done bool ) {
197208 buffer := newBuffer ()
198209 var readN int
@@ -221,6 +232,12 @@ func (w *syscallPacketReadWaiter) WaitReadPacket(newBuffer func() *buf.Buffer) (
221232 return true
222233 }
223234 }
235+ }
236+
237+ func (w * syscallPacketReadWaiter ) WaitReadPacket () (destination M.Socksaddr , err error ) {
238+ if w .readFunc == nil {
239+ return M.Socksaddr {}, os .ErrInvalid
240+ }
224241 err = w .rawConn .Read (w .readFunc )
225242 if err != nil {
226243 return
0 commit comments