Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions ingress/ingress.go
Original file line number Diff line number Diff line change
Expand Up @@ -255,6 +255,10 @@ func validateIngress(ingress []config.UnvalidatedIngressRule, defaults OriginReq
} else if prefix := "unix+tls:"; strings.HasPrefix(r.Service, prefix) {
path := strings.TrimPrefix(r.Service, prefix)
service = &unixSocketPath{path: path, scheme: "https"}
} else if prefix := "unix+tcp:"; strings.HasPrefix(r.Service, prefix) {
// Stream raw bytes (e.g. SSH, RDP protocol) directly into a unix socket without HTTP wrapping
path := strings.TrimPrefix(r.Service, prefix)
service = &unixSocketTCPService{path: path}
} else if prefix := "http_status:"; strings.HasPrefix(r.Service, prefix) {
statusCode, err := strconv.Atoi(strings.TrimPrefix(r.Service, prefix))
if err != nil {
Expand Down
26 changes: 26 additions & 0 deletions ingress/ingress_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,19 @@ ingress:
require.Equal(t, "https", s.scheme)
}

func TestParseUnixSocketTCP(t *testing.T) {
rawYAML := `
ingress:
- service: unix+tcp:/run/sshd.sock
`
ing, err := ParseIngress(MustReadIngress(rawYAML))
require.NoError(t, err)
s, ok := ing.Rules[0].Service.(*unixSocketTCPService)
require.True(t, ok)
require.Equal(t, "/run/sshd.sock", s.path)
require.Equal(t, "unix+tcp:/run/sshd.sock", s.String())
}

func TestParseIngressNilConfig(t *testing.T) {
_, err := ParseIngress(nil)
require.Error(t, err)
Expand Down Expand Up @@ -322,6 +335,19 @@ ingress:
},
},
},
{
name: "Unix+TCP service",
args: args{rawYAML: `
ingress:
- service: unix+tcp:/run/sshd.sock
`},
want: []Rule{
{
Service: &unixSocketTCPService{path: "/run/sshd.sock"},
Config: defaultConfig,
},
},
},
{
name: "RDP services",
args: args{rawYAML: `
Expand Down
11 changes: 11 additions & 0 deletions ingress/origin_proxy.go
Original file line number Diff line number Diff line change
Expand Up @@ -119,3 +119,14 @@ func (o *tcpOverWSService) EstablishConnection(ctx context.Context, dest string,
func (o *socksProxyOverWSService) EstablishConnection(_ context.Context, _ string, _ *zerolog.Logger) (OriginConnection, error) {
return o.conn, nil
}

func (o *unixSocketTCPService) EstablishConnection(ctx context.Context, _ string, _ *zerolog.Logger) (OriginConnection, error) {
conn, err := o.dialer.DialContext(ctx, "unix", o.path)
if err != nil {
return nil, err
}
return &tcpOverWSConnection{
conn: conn,
streamHandler: o.streamHandler,
}, nil
}
30 changes: 30 additions & 0 deletions ingress/origin_proxy_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import (
"net/http"
"net/http/httptest"
"net/url"
"os"
"testing"

"github.com/stretchr/testify/assert"
Expand Down Expand Up @@ -186,6 +187,35 @@ func TestHTTPServiceUsesIngressRuleScheme(t *testing.T) {
}
}

func TestUnixSocketTCPServiceEstablishConnection(t *testing.T) {
dir, err := os.MkdirTemp("/tmp", "cf-test-")
require.NoError(t, err)
defer os.RemoveAll(dir)

socketPath := dir + "/sshd.sock"
originListener, err := net.Listen("unix", socketPath)
require.NoError(t, err)

listenerClosed := make(chan struct{})
tcpListenRoutine(originListener, listenerClosed)

svc := &unixSocketTCPService{path: socketPath}
require.NoError(t, svc.start(TestLogger, make(chan struct{}), OriginRequestConfig{}))

// Successful connection to the unix socket
conn, err := svc.EstablishConnection(context.Background(), "", TestLogger)
require.NoError(t, err)
require.NotNil(t, conn)
conn.Close()

// Close the listener and verify that new connections fail
originListener.Close()
<-listenerClosed

_, err = svc.EstablishConnection(context.Background(), "", TestLogger)
require.Error(t, err)
}

func tcpListenRoutine(listener net.Listener, closeChan chan struct{}) {
go func() {
for {
Expand Down
23 changes: 23 additions & 0 deletions ingress/origin_service.go
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,14 @@ type unixSocketPath struct {
transport *http.Transport
}

// unixSocketTCPService is an OriginService that streams raw bytes (e.g. SSH, RDP) directly into a
// unix socket, bypassing HTTP entirely. It is the unix-socket analogue of tcpOverWSService.
type unixSocketTCPService struct {
path string
streamHandler streamHandlerFunc
dialer net.Dialer
}

func (o *unixSocketPath) String() string {
scheme := ""
if o.scheme == "https" {
Expand All @@ -67,6 +75,21 @@ func (o unixSocketPath) MarshalJSON() ([]byte, error) {
return json.Marshal(o.String())
}

func (o *unixSocketTCPService) String() string {
return "unix+tcp:" + o.path
}

func (o *unixSocketTCPService) start(_ *zerolog.Logger, _ <-chan struct{}, cfg OriginRequestConfig) error {
o.streamHandler = DefaultStreamHandler
o.dialer.Timeout = cfg.ConnectTimeout.Duration
o.dialer.KeepAlive = cfg.TCPKeepAlive.Duration
return nil
}

func (o unixSocketTCPService) MarshalJSON() ([]byte, error) {
return json.Marshal(o.String())
}

type httpService struct {
url *url.URL
hostHeader string
Expand Down