Skip to content

Commit 284cb5c

Browse files
dyhkwongnekohasekai
authored andcommitted
Fix socks5 packet conn
1 parent 8fb1634 commit 284cb5c

5 files changed

Lines changed: 47 additions & 59 deletions

File tree

common/bufio/nat.go

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,10 @@ func (c *unidirectionalNATPacketConn) UpdateDestination(destinationAddress netip
6363
c.destination = M.SocksaddrFrom(destinationAddress, c.destination.Port)
6464
}
6565

66+
func (c *unidirectionalNATPacketConn) RemoteAddr() net.Addr {
67+
return c.destination.UDPAddr()
68+
}
69+
6670
func (c *unidirectionalNATPacketConn) Upstream() any {
6771
return c.NetPacketConn
6872
}
@@ -136,6 +140,10 @@ func (c *bidirectionalNATPacketConn) Upstream() any {
136140
return c.NetPacketConn
137141
}
138142

143+
func (c *bidirectionalNATPacketConn) RemoteAddr() net.Addr {
144+
return c.destination.UDPAddr()
145+
}
146+
139147
func socksaddrWithoutPort(destination M.Socksaddr) M.Socksaddr {
140148
destination.Port = 0
141149
return destination

protocol/socks/client.go

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@ import (
77
"os"
88
"strings"
99

10-
"github.com/sagernet/sing/common/bufio"
1110
E "github.com/sagernet/sing/common/exceptions"
1211
M "github.com/sagernet/sing/common/metadata"
1312
N "github.com/sagernet/sing/common/network"
@@ -148,7 +147,7 @@ func (c *Client) DialContext(ctx context.Context, network string, address M.Sock
148147
tcpConn.Close()
149148
return nil, err
150149
}
151-
return NewAssociatePacketConn(bufio.NewUnbindPacketConn(udpConn), address, tcpConn), nil
150+
return NewAssociatePacketConn(udpConn, address, tcpConn), nil
152151
}
153152
return nil, os.ErrInvalid
154153
}

protocol/socks/packet.go

Lines changed: 29 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -21,54 +21,41 @@ import (
2121
var ErrInvalidPacket = E.New("socks5: invalid packet")
2222

2323
type AssociatePacketConn struct {
24-
N.NetPacketConn
24+
N.AbstractConn
25+
conn N.ExtendedConn
2526
remoteAddr M.Socksaddr
2627
underlying net.Conn
2728
}
2829

29-
func NewAssociatePacketConn(conn net.PacketConn, remoteAddr M.Socksaddr, underlying net.Conn) *AssociatePacketConn {
30+
func NewAssociatePacketConn(conn net.Conn, remoteAddr M.Socksaddr, underlying net.Conn) *AssociatePacketConn {
3031
return &AssociatePacketConn{
31-
NetPacketConn: bufio.NewPacketConn(conn),
32-
remoteAddr: remoteAddr,
33-
underlying: underlying,
32+
AbstractConn: conn,
33+
conn: bufio.NewExtendedConn(conn),
34+
remoteAddr: remoteAddr,
35+
underlying: underlying,
3436
}
3537
}
3638

37-
// Deprecated: NewAssociatePacketConn(bufio.NewUnbindPacketConn(conn), remoteAddr, underlying) instead.
38-
func NewAssociateConn(conn net.Conn, remoteAddr M.Socksaddr, underlying net.Conn) *AssociatePacketConn {
39-
return &AssociatePacketConn{
40-
NetPacketConn: bufio.NewUnbindPacketConn(conn),
41-
remoteAddr: remoteAddr,
42-
underlying: underlying,
43-
}
44-
}
45-
46-
func (c *AssociatePacketConn) RemoteAddr() net.Addr {
47-
return c.remoteAddr.UDPAddr()
48-
}
49-
50-
//warn:unsafe
5139
func (c *AssociatePacketConn) ReadFrom(p []byte) (n int, addr net.Addr, err error) {
52-
n, addr, err = c.NetPacketConn.ReadFrom(p)
40+
n, err = c.conn.Read(p)
5341
if err != nil {
5442
return
5543
}
5644
if n < 3 {
5745
return 0, nil, ErrInvalidPacket
5846
}
59-
c.remoteAddr = M.SocksaddrFromNet(addr)
6047
reader := bytes.NewReader(p[3:n])
6148
destination, err := M.SocksaddrSerializer.ReadAddrPort(reader)
6249
if err != nil {
6350
return
6451
}
52+
c.remoteAddr = destination
6553
addr = destination.UDPAddr()
6654
index := 3 + int(reader.Size()) - reader.Len()
6755
n = copy(p, p[index:n])
6856
return
6957
}
7058

71-
//warn:unsafe
7259
func (c *AssociatePacketConn) WriteTo(p []byte, addr net.Addr) (n int, err error) {
7360
destination := M.SocksaddrFromNet(addr)
7461
buffer := buf.NewSize(3 + M.SocksaddrSerializer.AddrPortLen(destination) + len(p))
@@ -82,32 +69,23 @@ func (c *AssociatePacketConn) WriteTo(p []byte, addr net.Addr) (n int, err error
8269
if err != nil {
8370
return
8471
}
85-
return bufio.WritePacketBuffer(c.NetPacketConn, buffer, c.remoteAddr)
86-
}
87-
88-
func (c *AssociatePacketConn) Read(b []byte) (n int, err error) {
89-
n, _, err = c.ReadFrom(b)
90-
return
91-
}
92-
93-
func (c *AssociatePacketConn) Write(b []byte) (n int, err error) {
94-
return c.WriteTo(b, c.remoteAddr)
72+
return c.conn.Write(buffer.Bytes())
9573
}
9674

9775
func (c *AssociatePacketConn) ReadPacket(buffer *buf.Buffer) (destination M.Socksaddr, err error) {
98-
destination, err = c.NetPacketConn.ReadPacket(buffer)
76+
err = c.conn.ReadBuffer(buffer)
9977
if err != nil {
100-
return M.Socksaddr{}, err
78+
return
10179
}
10280
if buffer.Len() < 3 {
10381
return M.Socksaddr{}, ErrInvalidPacket
10482
}
105-
c.remoteAddr = destination
10683
buffer.Advance(3)
10784
destination, err = M.SocksaddrSerializer.ReadAddrPort(buffer)
10885
if err != nil {
10986
return
11087
}
88+
c.remoteAddr = destination
11189
return destination.Unwrap(), nil
11290
}
11391

@@ -118,11 +96,24 @@ func (c *AssociatePacketConn) WritePacket(buffer *buf.Buffer, destination M.Sock
11896
if err != nil {
11997
return err
12098
}
121-
return common.Error(bufio.WritePacketBuffer(c.NetPacketConn, buffer, c.remoteAddr))
99+
return c.conn.WriteBuffer(buffer)
100+
}
101+
102+
func (c *AssociatePacketConn) Read(b []byte) (n int, err error) {
103+
n, _, err = c.ReadFrom(b)
104+
return
105+
}
106+
107+
func (c *AssociatePacketConn) Write(b []byte) (n int, err error) {
108+
return c.WriteTo(b, c.remoteAddr)
109+
}
110+
111+
func (c *AssociatePacketConn) RemoteAddr() net.Addr {
112+
return c.remoteAddr.UDPAddr()
122113
}
123114

124115
func (c *AssociatePacketConn) Upstream() any {
125-
return c.NetPacketConn
116+
return c.conn
126117
}
127118

128119
func (c *AssociatePacketConn) FrontHeadroom() int {
@@ -131,7 +122,7 @@ func (c *AssociatePacketConn) FrontHeadroom() int {
131122

132123
func (c *AssociatePacketConn) Close() error {
133124
return common.Close(
134-
c.NetPacketConn,
125+
c.conn,
135126
c.underlying,
136127
)
137128
}

protocol/socks/packet_vectorised.go

Lines changed: 4 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -17,23 +17,13 @@ type VectorisedAssociatePacketConn struct {
1717
N.VectorisedPacketWriter
1818
}
1919

20-
func NewVectorisedAssociatePacketConn(conn net.PacketConn, writer N.VectorisedPacketWriter, remoteAddr M.Socksaddr, underlying net.Conn) *VectorisedAssociatePacketConn {
21-
return &VectorisedAssociatePacketConn{
22-
AssociatePacketConn{
23-
NetPacketConn: bufio.NewPacketConn(conn),
24-
remoteAddr: remoteAddr,
25-
underlying: underlying,
26-
},
27-
writer,
28-
}
29-
}
30-
3120
func NewVectorisedAssociateConn(conn net.Conn, writer N.VectorisedWriter, remoteAddr M.Socksaddr, underlying net.Conn) *VectorisedAssociatePacketConn {
3221
return &VectorisedAssociatePacketConn{
3322
AssociatePacketConn{
34-
NetPacketConn: bufio.NewUnbindPacketConn(conn),
35-
remoteAddr: remoteAddr,
36-
underlying: underlying,
23+
AbstractConn: conn,
24+
conn: bufio.NewExtendedConn(conn),
25+
remoteAddr: remoteAddr,
26+
underlying: underlying,
3727
},
3828
&bufio.UnbindVectorisedPacketWriter{VectorisedWriter: writer},
3929
}

protocol/socks/packet_wait.go

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ import (
1010
var _ N.PacketReadWaitCreator = (*AssociatePacketConn)(nil)
1111

1212
func (c *AssociatePacketConn) CreateReadWaiter() (N.PacketReadWaiter, bool) {
13-
readWaiter, isReadWaiter := bufio.CreatePacketReadWaiter(c.NetPacketConn)
13+
readWaiter, isReadWaiter := bufio.CreateReadWaiter(c.conn)
1414
if !isReadWaiter {
1515
return nil, false
1616
}
@@ -21,28 +21,28 @@ var _ N.PacketReadWaiter = (*AssociatePacketReadWaiter)(nil)
2121

2222
type AssociatePacketReadWaiter struct {
2323
conn *AssociatePacketConn
24-
readWaiter N.PacketReadWaiter
24+
readWaiter N.ReadWaiter
2525
}
2626

2727
func (w *AssociatePacketReadWaiter) InitializeReadWaiter(options N.ReadWaitOptions) (needCopy bool) {
2828
return w.readWaiter.InitializeReadWaiter(options)
2929
}
3030

3131
func (w *AssociatePacketReadWaiter) WaitReadPacket() (buffer *buf.Buffer, destination M.Socksaddr, err error) {
32-
buffer, destination, err = w.readWaiter.WaitReadPacket()
32+
buffer, err = w.readWaiter.WaitReadBuffer()
3333
if err != nil {
3434
return
3535
}
3636
if buffer.Len() < 3 {
3737
buffer.Release()
3838
return nil, M.Socksaddr{}, ErrInvalidPacket
3939
}
40-
w.conn.remoteAddr = destination
4140
buffer.Advance(3)
4241
destination, err = M.SocksaddrSerializer.ReadAddrPort(buffer)
4342
if err != nil {
4443
buffer.Release()
45-
return nil, M.Socksaddr{}, err
44+
return
4645
}
46+
w.conn.remoteAddr = destination
4747
return
4848
}

0 commit comments

Comments
 (0)