Skip to content

Commit bc5a1f8

Browse files
committed
Add mixed stack
1 parent aa8760b commit bc5a1f8

10 files changed

Lines changed: 275 additions & 3 deletions

monitor_linux_default.go

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ package tun
44

55
import (
66
"github.com/sagernet/netlink"
7+
78
"golang.org/x/sys/unix"
89
)
910

monitor_other.go

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,9 @@
33
package tun
44

55
import (
6-
"github.com/sagernet/sing/common/logger"
76
"os"
7+
8+
"github.com/sagernet/sing/common/logger"
89
)
910

1011
func NewNetworkUpdateMonitor(logger logger.Logger) (NetworkUpdateMonitor, error) {

stack.go

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,9 +35,15 @@ func NewStack(
3535
) (Stack, error) {
3636
switch stack {
3737
case "":
38-
return NewSystem(options)
38+
if WithGVisor {
39+
return NewMixed(options)
40+
} else {
41+
return NewSystem(options)
42+
}
3943
case "gvisor":
4044
return NewGVisor(options)
45+
case "mixed":
46+
return NewMixed(options)
4147
case "system":
4248
return NewSystem(options)
4349
case "lwip":

stack_gvisor_stub.go

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,3 +13,9 @@ func NewGVisor(
1313
) (Stack, error) {
1414
return nil, ErrGVisorNotIncluded
1515
}
16+
17+
func NewMixed(
18+
options StackOptions,
19+
) (Stack, error) {
20+
return nil, ErrGVisorNotIncluded
21+
}

stack_gvisor_udp.go

Lines changed: 23 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ import (
88
"math"
99
"net/netip"
1010
"os"
11+
"sync"
1112
"syscall"
1213

1314
"github.com/sagernet/gvisor/pkg/buffer"
@@ -74,19 +75,34 @@ func (f *UDPForwarder) newUDPConn(natConn N.PacketConn) N.PacketWriter {
7475
source: f.cacheID.RemoteAddress,
7576
sourcePort: f.cacheID.RemotePort,
7677
sourceNetwork: f.cacheProto,
78+
packet: f.cachePacket.IncRef(),
7779
}
7880
}
7981

8082
type UDPBackWriter struct {
83+
access sync.Mutex
8184
stack *stack.Stack
8285
source tcpip.Address
8386
sourcePort uint16
8487
sourceNetwork tcpip.NetworkProtocolNumber
8588
packet stack.PacketBufferPtr
8689
}
8790

91+
func (w *UDPBackWriter) Close() error {
92+
w.access.Lock()
93+
defer w.access.Unlock()
94+
if w.packet == nil {
95+
return os.ErrClosed
96+
}
97+
w.packet.DecRef()
98+
w.packet = nil
99+
return nil
100+
}
101+
88102
func (w *UDPBackWriter) WritePacket(packetBuffer *buf.Buffer, destination M.Socksaddr) error {
89-
if destination.IsIPv4() && w.sourceNetwork == header.IPv6ProtocolNumber {
103+
if !destination.IsIP() {
104+
return E.Cause(os.ErrInvalid, "invalid destination")
105+
} else if destination.IsIPv4() && w.sourceNetwork == header.IPv6ProtocolNumber {
90106
destination = M.SocksaddrFrom(netip.AddrFrom16(destination.Addr.As16()), destination.Port)
91107
} else if destination.IsIPv6() && (w.sourceNetwork == header.IPv4AddressSizeBits) {
92108
return E.New("send IPv6 packet to IPv4 connection")
@@ -165,6 +181,7 @@ type gRequest struct {
165181

166182
type gUDPConn struct {
167183
*gonet.UDPConn
184+
access sync.Mutex
168185
stack *stack.Stack
169186
packet stack.PacketBufferPtr
170187
}
@@ -188,6 +205,11 @@ func (c *gUDPConn) Write(b []byte) (n int, err error) {
188205
}
189206

190207
func (c *gUDPConn) Close() error {
208+
c.access.Lock()
209+
defer c.access.Unlock()
210+
if c.packet == nil {
211+
return os.ErrClosed
212+
}
191213
c.packet.DecRef()
192214
c.packet = nil
193215
return c.UDPConn.Close()

stack_mixed.go

Lines changed: 202 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,202 @@
1+
//go:build with_gvisor
2+
3+
package tun
4+
5+
import (
6+
"time"
7+
"unsafe"
8+
9+
"github.com/sagernet/gvisor/pkg/buffer"
10+
"github.com/sagernet/gvisor/pkg/tcpip/adapters/gonet"
11+
"github.com/sagernet/gvisor/pkg/tcpip/header"
12+
"github.com/sagernet/gvisor/pkg/tcpip/link/channel"
13+
"github.com/sagernet/gvisor/pkg/tcpip/stack"
14+
"github.com/sagernet/gvisor/pkg/tcpip/transport/udp"
15+
"github.com/sagernet/gvisor/pkg/waiter"
16+
"github.com/sagernet/sing-tun/internal/clashtcpip"
17+
"github.com/sagernet/sing/common"
18+
"github.com/sagernet/sing/common/bufio"
19+
"github.com/sagernet/sing/common/canceler"
20+
E "github.com/sagernet/sing/common/exceptions"
21+
M "github.com/sagernet/sing/common/metadata"
22+
N "github.com/sagernet/sing/common/network"
23+
)
24+
25+
type Mixed struct {
26+
*System
27+
writer N.VectorisedWriter
28+
endpointIndependentNat bool
29+
stack *stack.Stack
30+
endpoint *channel.Endpoint
31+
}
32+
33+
func NewMixed(
34+
options StackOptions,
35+
) (Stack, error) {
36+
system, err := NewSystem(options)
37+
if err != nil {
38+
return nil, err
39+
}
40+
return &Mixed{
41+
System: system.(*System),
42+
writer: options.Tun.CreateVectorisedWriter(),
43+
endpointIndependentNat: options.EndpointIndependentNat,
44+
}, nil
45+
}
46+
47+
func (m *Mixed) Start() error {
48+
err := m.System.start()
49+
if err != nil {
50+
return err
51+
}
52+
endpoint := channel.New(1024, m.mtu, "")
53+
ipStack, err := newGVisorStack(endpoint)
54+
if err != nil {
55+
return err
56+
}
57+
if !m.endpointIndependentNat {
58+
udpForwarder := udp.NewForwarder(ipStack, func(request *udp.ForwarderRequest) {
59+
var wq waiter.Queue
60+
endpoint, err := request.CreateEndpoint(&wq)
61+
if err != nil {
62+
return
63+
}
64+
udpConn := gonet.NewUDPConn(ipStack, &wq, endpoint)
65+
lAddr := udpConn.RemoteAddr()
66+
rAddr := udpConn.LocalAddr()
67+
if lAddr == nil || rAddr == nil {
68+
endpoint.Abort()
69+
return
70+
}
71+
gConn := &gUDPConn{udpConn, ipStack, (*gRequest)(unsafe.Pointer(request)).pkt.IncRef()}
72+
go func() {
73+
var metadata M.Metadata
74+
metadata.Source = M.SocksaddrFromNet(lAddr)
75+
metadata.Destination = M.SocksaddrFromNet(rAddr)
76+
ctx, conn := canceler.NewPacketConn(m.ctx, bufio.NewPacketConn(&bufio.UnbindPacketConn{ExtendedConn: bufio.NewExtendedConn(gConn), Addr: M.SocksaddrFromNet(rAddr)}), time.Duration(m.udpTimeout)*time.Second)
77+
hErr := m.handler.NewPacketConnection(ctx, conn, metadata)
78+
if hErr != nil {
79+
endpoint.Abort()
80+
}
81+
}()
82+
})
83+
ipStack.SetTransportProtocolHandler(udp.ProtocolNumber, udpForwarder.HandlePacket)
84+
} else {
85+
ipStack.SetTransportProtocolHandler(udp.ProtocolNumber, NewUDPForwarder(m.ctx, ipStack, m.handler, m.udpTimeout).HandlePacket)
86+
}
87+
m.stack = ipStack
88+
m.endpoint = endpoint
89+
go m.tunLoop()
90+
go m.packetLoop()
91+
return nil
92+
}
93+
94+
func (m *Mixed) tunLoop() {
95+
if winTun, isWinTun := m.tun.(WinTun); isWinTun {
96+
m.wintunLoop(winTun)
97+
return
98+
}
99+
packetBuffer := make([]byte, m.mtu+PacketOffset)
100+
for {
101+
n, err := m.tun.Read(packetBuffer)
102+
if err != nil {
103+
return
104+
}
105+
if n < clashtcpip.IPv4PacketMinLength {
106+
continue
107+
}
108+
packet := packetBuffer[PacketOffset:n]
109+
switch ipVersion := packet[0] >> 4; ipVersion {
110+
case 4:
111+
err = m.processIPv4(packet)
112+
case 6:
113+
err = m.processIPv6(packet)
114+
default:
115+
err = E.New("ip: unknown version: ", ipVersion)
116+
}
117+
if err != nil {
118+
m.logger.Trace(err)
119+
}
120+
}
121+
}
122+
123+
func (m *Mixed) wintunLoop(winTun WinTun) {
124+
for {
125+
packet, release, err := winTun.ReadPacket()
126+
if err != nil {
127+
return
128+
}
129+
if len(packet) < clashtcpip.IPv4PacketMinLength {
130+
release()
131+
continue
132+
}
133+
switch ipVersion := packet[0] >> 4; ipVersion {
134+
case 4:
135+
err = m.processIPv4(packet)
136+
case 6:
137+
err = m.processIPv6(packet)
138+
default:
139+
err = E.New("ip: unknown version: ", ipVersion)
140+
}
141+
if err != nil {
142+
m.logger.Trace(err)
143+
}
144+
release()
145+
}
146+
}
147+
148+
func (m *Mixed) processIPv4(packet clashtcpip.IPv4Packet) error {
149+
switch packet.Protocol() {
150+
case clashtcpip.TCP:
151+
return m.processIPv4TCP(packet, packet.Payload())
152+
case clashtcpip.UDP:
153+
pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
154+
Payload: buffer.MakeWithData(packet),
155+
})
156+
m.endpoint.InjectInbound(header.IPv4ProtocolNumber, pkt)
157+
pkt.DecRef()
158+
return nil
159+
case clashtcpip.ICMP:
160+
return m.processIPv4ICMP(packet, packet.Payload())
161+
default:
162+
return common.Error(m.tun.Write(packet))
163+
}
164+
}
165+
166+
func (m *Mixed) processIPv6(packet clashtcpip.IPv6Packet) error {
167+
switch packet.Protocol() {
168+
case clashtcpip.TCP:
169+
return m.processIPv6TCP(packet, packet.Payload())
170+
case clashtcpip.UDP:
171+
pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
172+
Payload: buffer.MakeWithData(packet),
173+
})
174+
m.endpoint.InjectInbound(header.IPv6ProtocolNumber, pkt)
175+
pkt.DecRef()
176+
return nil
177+
case clashtcpip.ICMPv6:
178+
return m.processIPv6ICMP(packet, packet.Payload())
179+
default:
180+
return common.Error(m.tun.Write(packet))
181+
}
182+
}
183+
184+
func (m *Mixed) packetLoop() {
185+
for {
186+
packet := m.endpoint.ReadContext(m.ctx)
187+
if packet == nil {
188+
break
189+
}
190+
bufio.WriteVectorised(m.writer, packet.AsSlices())
191+
packet.DecRef()
192+
}
193+
}
194+
195+
func (m *Mixed) Close() error {
196+
m.endpoint.Attach(nil)
197+
m.stack.Close()
198+
for _, endpoint := range m.stack.CleanupEndpoints() {
199+
endpoint.Abort()
200+
}
201+
return m.System.Close()
202+
}

tun.go

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ type Handler interface {
2323

2424
type Tun interface {
2525
io.ReadWriter
26+
CreateVectorisedWriter() N.VectorisedWriter
2627
Close() error
2728
}
2829

tun_darwin.go

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ import (
1010
"unsafe"
1111

1212
"github.com/sagernet/sing/common"
13+
"github.com/sagernet/sing/common/buf"
1314
"github.com/sagernet/sing/common/bufio"
1415
E "github.com/sagernet/sing/common/exceptions"
1516
N "github.com/sagernet/sing/common/network"
@@ -101,6 +102,20 @@ func (t *NativeTun) Write(p []byte) (n int, err error) {
101102
return
102103
}
103104

105+
func (t *NativeTun) CreateVectorisedWriter() N.VectorisedWriter {
106+
return t
107+
}
108+
109+
func (t *NativeTun) WriteVectorised(buffers []*buf.Buffer) error {
110+
var packetHeader []byte
111+
if buffers[0].Byte(0)>>4 == 4 {
112+
packetHeader = packetHeader4[:]
113+
} else {
114+
packetHeader = packetHeader6[:]
115+
}
116+
return t.tunWriter.WriteVectorised(append([]*buf.Buffer{buf.As(packetHeader)}, buffers...))
117+
}
118+
104119
func (t *NativeTun) Close() error {
105120
flushDNSCache()
106121
return t.tunFile.Close()

tun_linux.go

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,9 @@ import (
1212

1313
"github.com/sagernet/netlink"
1414
"github.com/sagernet/sing/common"
15+
"github.com/sagernet/sing/common/bufio"
1516
E "github.com/sagernet/sing/common/exceptions"
17+
N "github.com/sagernet/sing/common/network"
1618
"github.com/sagernet/sing/common/rw"
1719
"github.com/sagernet/sing/common/shell"
1820
"github.com/sagernet/sing/common/x/list"
@@ -68,6 +70,10 @@ func (t *NativeTun) Write(p []byte) (n int, err error) {
6870
return t.tunFile.Write(p)
6971
}
7072

73+
func (t *NativeTun) CreateVectorisedWriter() N.VectorisedWriter {
74+
return bufio.NewVectorisedWriter(t.tunFile)
75+
}
76+
7177
var controlPath string
7278

7379
func init() {

tun_windows.go

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,10 @@ import (
1616
"github.com/sagernet/sing-tun/internal/winipcfg"
1717
"github.com/sagernet/sing-tun/internal/winsys"
1818
"github.com/sagernet/sing-tun/internal/wintun"
19+
"github.com/sagernet/sing/common"
20+
"github.com/sagernet/sing/common/buf"
1921
E "github.com/sagernet/sing/common/exceptions"
22+
N "github.com/sagernet/sing/common/network"
2023
"github.com/sagernet/sing/common/windnsapi"
2124

2225
"golang.org/x/sys/windows"
@@ -467,6 +470,15 @@ func (t *NativeTun) write(packetElementList [][]byte) (n int, err error) {
467470
return 0, fmt.Errorf("write failed: %w", err)
468471
}
469472

473+
func (t *NativeTun) CreateVectorisedWriter() N.VectorisedWriter {
474+
return t
475+
}
476+
477+
func (t *NativeTun) WriteVectorised(buffers []*buf.Buffer) error {
478+
defer buf.ReleaseMulti(buffers)
479+
return common.Error(t.write(buf.ToSliceMulti(buffers)))
480+
}
481+
470482
func (t *NativeTun) Close() error {
471483
var err error
472484
t.closeOnce.Do(func() {

0 commit comments

Comments
 (0)