diff --git a/ingress/ingress.go b/ingress/ingress.go index a325271a7e7..f00ed231a7f 100644 --- a/ingress/ingress.go +++ b/ingress/ingress.go @@ -255,6 +255,10 @@ func validateIngress(ingress []config.UnvalidatedIngressRule, defaults OriginReq } else if prefix := "unix+tls:"; strings.HasPrefix(r.Service, prefix) { path := strings.TrimPrefix(r.Service, prefix) service = &unixSocketPath{path: path, scheme: "https"} + } else if prefix := "unix+tcp:"; strings.HasPrefix(r.Service, prefix) { + // Stream raw bytes (e.g. SSH, RDP protocol) directly into a unix socket without HTTP wrapping + path := strings.TrimPrefix(r.Service, prefix) + service = &unixSocketTCPService{path: path} } else if prefix := "http_status:"; strings.HasPrefix(r.Service, prefix) { statusCode, err := strconv.Atoi(strings.TrimPrefix(r.Service, prefix)) if err != nil { diff --git a/ingress/ingress_test.go b/ingress/ingress_test.go index 109cb3530e2..8dfd5416609 100644 --- a/ingress/ingress_test.go +++ b/ingress/ingress_test.go @@ -43,6 +43,19 @@ ingress: require.Equal(t, "https", s.scheme) } +func TestParseUnixSocketTCP(t *testing.T) { + rawYAML := ` +ingress: +- service: unix+tcp:/run/sshd.sock +` + ing, err := ParseIngress(MustReadIngress(rawYAML)) + require.NoError(t, err) + s, ok := ing.Rules[0].Service.(*unixSocketTCPService) + require.True(t, ok) + require.Equal(t, "/run/sshd.sock", s.path) + require.Equal(t, "unix+tcp:/run/sshd.sock", s.String()) +} + func TestParseIngressNilConfig(t *testing.T) { _, err := ParseIngress(nil) require.Error(t, err) @@ -322,6 +335,19 @@ ingress: }, }, }, + { + name: "Unix+TCP service", + args: args{rawYAML: ` +ingress: +- service: unix+tcp:/run/sshd.sock +`}, + want: []Rule{ + { + Service: &unixSocketTCPService{path: "/run/sshd.sock"}, + Config: defaultConfig, + }, + }, + }, { name: "RDP services", args: args{rawYAML: ` diff --git a/ingress/origin_proxy.go b/ingress/origin_proxy.go index 7371eac92ec..10f33156292 100644 --- a/ingress/origin_proxy.go +++ b/ingress/origin_proxy.go @@ -119,3 +119,14 @@ func (o *tcpOverWSService) EstablishConnection(ctx context.Context, dest string, func (o *socksProxyOverWSService) EstablishConnection(_ context.Context, _ string, _ *zerolog.Logger) (OriginConnection, error) { return o.conn, nil } + +func (o *unixSocketTCPService) EstablishConnection(ctx context.Context, _ string, _ *zerolog.Logger) (OriginConnection, error) { + conn, err := o.dialer.DialContext(ctx, "unix", o.path) + if err != nil { + return nil, err + } + return &tcpOverWSConnection{ + conn: conn, + streamHandler: o.streamHandler, + }, nil +} diff --git a/ingress/origin_proxy_test.go b/ingress/origin_proxy_test.go index 7a6170a2a68..aab4697e571 100644 --- a/ingress/origin_proxy_test.go +++ b/ingress/origin_proxy_test.go @@ -8,6 +8,7 @@ import ( "net/http" "net/http/httptest" "net/url" + "os" "testing" "github.com/stretchr/testify/assert" @@ -186,6 +187,35 @@ func TestHTTPServiceUsesIngressRuleScheme(t *testing.T) { } } +func TestUnixSocketTCPServiceEstablishConnection(t *testing.T) { + dir, err := os.MkdirTemp("/tmp", "cf-test-") + require.NoError(t, err) + defer os.RemoveAll(dir) + + socketPath := dir + "/sshd.sock" + originListener, err := net.Listen("unix", socketPath) + require.NoError(t, err) + + listenerClosed := make(chan struct{}) + tcpListenRoutine(originListener, listenerClosed) + + svc := &unixSocketTCPService{path: socketPath} + require.NoError(t, svc.start(TestLogger, make(chan struct{}), OriginRequestConfig{})) + + // Successful connection to the unix socket + conn, err := svc.EstablishConnection(context.Background(), "", TestLogger) + require.NoError(t, err) + require.NotNil(t, conn) + conn.Close() + + // Close the listener and verify that new connections fail + originListener.Close() + <-listenerClosed + + _, err = svc.EstablishConnection(context.Background(), "", TestLogger) + require.Error(t, err) +} + func tcpListenRoutine(listener net.Listener, closeChan chan struct{}) { go func() { for { diff --git a/ingress/origin_service.go b/ingress/origin_service.go index e13204c5789..0a5a2be2380 100644 --- a/ingress/origin_service.go +++ b/ingress/origin_service.go @@ -46,6 +46,14 @@ type unixSocketPath struct { transport *http.Transport } +// unixSocketTCPService is an OriginService that streams raw bytes (e.g. SSH, RDP) directly into a +// unix socket, bypassing HTTP entirely. It is the unix-socket analogue of tcpOverWSService. +type unixSocketTCPService struct { + path string + streamHandler streamHandlerFunc + dialer net.Dialer +} + func (o *unixSocketPath) String() string { scheme := "" if o.scheme == "https" { @@ -67,6 +75,21 @@ func (o unixSocketPath) MarshalJSON() ([]byte, error) { return json.Marshal(o.String()) } +func (o *unixSocketTCPService) String() string { + return "unix+tcp:" + o.path +} + +func (o *unixSocketTCPService) start(_ *zerolog.Logger, _ <-chan struct{}, cfg OriginRequestConfig) error { + o.streamHandler = DefaultStreamHandler + o.dialer.Timeout = cfg.ConnectTimeout.Duration + o.dialer.KeepAlive = cfg.TCPKeepAlive.Duration + return nil +} + +func (o unixSocketTCPService) MarshalJSON() ([]byte, error) { + return json.Marshal(o.String()) +} + type httpService struct { url *url.URL hostHeader string