From ef095a3ac8e586ceffaf27bab920b9b289dadb50 Mon Sep 17 00:00:00 2001 From: "@StringerBell69" Date: Sat, 14 Mar 2026 23:25:43 +0100 Subject: [PATCH] Fix cloudflared tunnel route dns zone resolution bug When users have multiple domains, tunnel route dns would incorrectly use the default zoneID from the login certificate, creating an invalid CNAME record (e.g. app.domain1.com.domain2.com). This fix introduces ListZones in cfapi to fetch all valid zones for the account and explicitly checks if the provided hostname exactly matches or is a subdomain of a discovered zone, preventing this behavior and dynamically adjusting the endpoint to the correct Zone ID. --- cfapi/base_client.go | 21 +++++++++++- cfapi/client.go | 6 ++++ cfapi/hostname.go | 39 ++++++++++++++++++++++ cfapi/hostname_test.go | 75 ++++++++++++++++++++++++++++++++++++++++++ 4 files changed, 140 insertions(+), 1 deletion(-) diff --git a/cfapi/base_client.go b/cfapi/base_client.go index 05e32a83edb..7663fc98a68 100644 --- a/cfapi/base_client.go +++ b/cfapi/base_client.go @@ -40,6 +40,7 @@ type baseEndpoints struct { zoneLevel url.URL accountRoutes url.URL accountVnets url.URL + zones url.URL } var _ Client = (*RESTClient)(nil) @@ -60,7 +61,11 @@ func NewRESTClient(baseURL, accountTag, zoneTag, authToken, userAgent string, lo } zoneLevelEndpoint, err := url.Parse(fmt.Sprintf("%s/zones/%s/tunnels", baseURL, zoneTag)) if err != nil { - return nil, errors.Wrap(err, "failed to create account level endpoint") + return nil, errors.Wrap(err, "failed to create zone level endpoint") + } + zonesEndpoint, err := url.Parse(fmt.Sprintf("%s/zones", baseURL)) + if err != nil { + return nil, errors.Wrap(err, "failed to create zones endpoint") } httpTransport := http.Transport{ TLSHandshakeTimeout: defaultTimeout, @@ -73,6 +78,7 @@ func NewRESTClient(baseURL, accountTag, zoneTag, authToken, userAgent string, lo zoneLevel: *zoneLevelEndpoint, accountRoutes: *accountRoutesEndpoint, accountVnets: *accountVnetsEndpoint, + zones: *zonesEndpoint, }, authToken: authToken, userAgent: userAgent, @@ -241,3 +247,16 @@ func (r *RESTClient) statusCodeToError(op string, resp *http.Response) error { return errors.Errorf("API call to %s failed with status %d: %s", op, resp.StatusCode, http.StatusText(resp.StatusCode)) } + +func (r *RESTClient) ListZones() ([]*Zone, error) { + endpoint := r.baseEndpoints.zones + return fetchExhaustively[Zone](func(page int) (*http.Response, error) { + reqURL := endpoint + query := reqURL.Query() + query.Set("page", fmt.Sprintf("%d", page)) + query.Set("per_page", "50") + // Required to get basic zone info instead of just IDs + reqURL.RawQuery = query.Encode() + return r.sendRequest("GET", reqURL, nil) + }) +} diff --git a/cfapi/client.go b/cfapi/client.go index 1d08cc50cdb..55c914f5ba9 100644 --- a/cfapi/client.go +++ b/cfapi/client.go @@ -17,6 +17,7 @@ type TunnelClient interface { type HostnameClient interface { RouteTunnel(tunnelID uuid.UUID, route HostnameRoute) (HostnameRouteResult, error) + ListZones() ([]*Zone, error) } type IPRouteClient interface { @@ -39,3 +40,8 @@ type Client interface { IPRouteClient VnetClient } + +type Zone struct { + ID string `json:"id"` + Name string `json:"name"` +} diff --git a/cfapi/hostname.go b/cfapi/hostname.go index b8ca8bd441c..115bf613b59 100644 --- a/cfapi/hostname.go +++ b/cfapi/hostname.go @@ -5,7 +5,9 @@ import ( "fmt" "io" "net/http" + "net/url" "path" + "strings" "github.com/google/uuid" "github.com/pkg/errors" @@ -25,6 +27,7 @@ type HostnameRoute interface { RecordType() string UnmarshalResult(body io.Reader) (HostnameRouteResult, error) String() string + Hostname() string } type HostnameRouteResult interface { @@ -78,6 +81,10 @@ func (dr *DNSRoute) String() string { return fmt.Sprintf("%s %s", dr.RecordType(), dr.userHostname) } +func (dr *DNSRoute) Hostname() string { + return dr.userHostname +} + func (res *DNSRouteResult) SuccessSummary() string { var msgFmt string switch res.CName { @@ -139,6 +146,10 @@ func (lb *LBRoute) String() string { return fmt.Sprintf("%s %s %s", lb.RecordType(), lb.lbName, lb.lbPool) } +func (lr *LBRoute) Hostname() string { + return lr.lbName +} + func (lr *LBRoute) UnmarshalResult(body io.Reader) (HostnameRouteResult, error) { var result LBRouteResult err := parseResponse(body, &result) @@ -176,7 +187,35 @@ func (res *LBRouteResult) SuccessSummary() string { } func (r *RESTClient) RouteTunnel(tunnelID uuid.UUID, route HostnameRoute) (HostnameRouteResult, error) { + // First, try to find the correct zone by fetching all zones and matching the hostname + zoneID := "" + zones, err := r.ListZones() + if err == nil { + longestMatch := "" + for _, zone := range zones { + // A hostname should end with the zone name EXACTLY or be a subdomain of it. + // e.g. "app.staging.example.com" ends with ".example.com" and "example.com" == "example.com" + if route.Hostname() == zone.Name || strings.HasSuffix(route.Hostname(), "."+zone.Name) { + // We want the most specific zone if there are multiple matches + // e.g. "staging.example.com" zone vs "example.com" zone + if len(zone.Name) > len(longestMatch) { + longestMatch = zone.Name + zoneID = zone.ID + } + } + } + } + endpoint := r.baseEndpoints.zoneLevel + if zoneID != "" { + // Construct dynamic endpoint using the correct zone ID instead of the default one + baseURL := strings.TrimSuffix(r.baseEndpoints.zones.String(), "/zones") + zoneEndpoint, err := url.Parse(fmt.Sprintf("%s/zones/%s/tunnels", baseURL, zoneID)) + if err == nil { + endpoint = *zoneEndpoint + } + } + endpoint.Path = path.Join(endpoint.Path, fmt.Sprintf("%v/routes", tunnelID)) resp, err := r.sendRequest("PUT", endpoint, route) if err != nil { diff --git a/cfapi/hostname_test.go b/cfapi/hostname_test.go index 5100465a39b..86e4265f6c5 100644 --- a/cfapi/hostname_test.go +++ b/cfapi/hostname_test.go @@ -1,9 +1,14 @@ package cfapi import ( + "io" + "net/http" + "net/http/httptest" "strings" "testing" + "github.com/google/uuid" + "github.com/rs/zerolog" "github.com/stretchr/testify/assert" ) @@ -97,3 +102,73 @@ func TestLBRouteResultSuccessSummary(t *testing.T) { assert.Equal(t, tt.expected, actual, "case %d", i+1) } } + +func TestRouteTunnel_ZoneResolution(t *testing.T) { + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.URL.Path == "/zones" { + // A sample JSON response matching the Cloudflare api format, ensuring we mimic what the real API sends. + io.WriteString(w, `{ + "success": true, + "errors": [], + "messages": [], + "result": [ + { + "id": "zone-1", + "name": "example.com", + "status": "active", + "paused": false + }, + { + "id": "zone-2", + "name": "example.co.uk", + "status": "active", + "paused": false + } + ], + "result_info": { + "page": 1, + "per_page": 50, + "total_pages": 1, + "count": 2, + "total_count": 2 + } + }`) + return + } + if r.URL.Path == "/zones/zone-1/tunnels/11111111-2222-3333-4444-555555555555/routes" { + io.WriteString(w, `{"success":true,"result":{"cname":"new","name":"app.example.com"}}`) + return + } + + // Fallback path when zone does NOT match. It uses "default-zone-from-login" as specified in NewRESTClient arguments. + if r.URL.Path == "/zones/default-zone-from-login/tunnels/11111111-2222-3333-4444-555555555555/routes" { + io.WriteString(w, `{"success":true,"result":{"cname":"new","name":"fallback.otherdomain.com"}}`) + return + } + + w.WriteHeader(http.StatusNotFound) + })) + defer ts.Close() + + logger := zerolog.Nop() + client, err := NewRESTClient(ts.URL, "account", "default-zone-from-login", "token", "agent", &logger) + assert.NoError(t, err) + + tunnelID, _ := uuid.Parse("11111111-2222-3333-4444-555555555555") + + t.Run("Success match", func(t *testing.T) { + route := NewDNSRoute("app.example.com", false) + res, err := client.RouteTunnel(tunnelID, route) + assert.NoError(t, err) + assert.NotNil(t, res) + assert.Equal(t, "Added CNAME app.example.com which will route to this tunnel", res.SuccessSummary()) + }) + + t.Run("Fallback to default zone when no match", func(t *testing.T) { + route := NewDNSRoute("fallback.otherdomain.com", false) + res, err := client.RouteTunnel(tunnelID, route) + assert.NoError(t, err) + assert.NotNil(t, res) + assert.Equal(t, "Added CNAME fallback.otherdomain.com which will route to this tunnel", res.SuccessSummary()) + }) +}