diff --git a/cfapi/base_client.go b/cfapi/base_client.go index 05e32a83edb..7663fc98a68 100644 --- a/cfapi/base_client.go +++ b/cfapi/base_client.go @@ -40,6 +40,7 @@ type baseEndpoints struct { zoneLevel url.URL accountRoutes url.URL accountVnets url.URL + zones url.URL } var _ Client = (*RESTClient)(nil) @@ -60,7 +61,11 @@ func NewRESTClient(baseURL, accountTag, zoneTag, authToken, userAgent string, lo } zoneLevelEndpoint, err := url.Parse(fmt.Sprintf("%s/zones/%s/tunnels", baseURL, zoneTag)) if err != nil { - return nil, errors.Wrap(err, "failed to create account level endpoint") + return nil, errors.Wrap(err, "failed to create zone level endpoint") + } + zonesEndpoint, err := url.Parse(fmt.Sprintf("%s/zones", baseURL)) + if err != nil { + return nil, errors.Wrap(err, "failed to create zones endpoint") } httpTransport := http.Transport{ TLSHandshakeTimeout: defaultTimeout, @@ -73,6 +78,7 @@ func NewRESTClient(baseURL, accountTag, zoneTag, authToken, userAgent string, lo zoneLevel: *zoneLevelEndpoint, accountRoutes: *accountRoutesEndpoint, accountVnets: *accountVnetsEndpoint, + zones: *zonesEndpoint, }, authToken: authToken, userAgent: userAgent, @@ -241,3 +247,16 @@ func (r *RESTClient) statusCodeToError(op string, resp *http.Response) error { return errors.Errorf("API call to %s failed with status %d: %s", op, resp.StatusCode, http.StatusText(resp.StatusCode)) } + +func (r *RESTClient) ListZones() ([]*Zone, error) { + endpoint := r.baseEndpoints.zones + return fetchExhaustively[Zone](func(page int) (*http.Response, error) { + reqURL := endpoint + query := reqURL.Query() + query.Set("page", fmt.Sprintf("%d", page)) + query.Set("per_page", "50") + // Required to get basic zone info instead of just IDs + reqURL.RawQuery = query.Encode() + return r.sendRequest("GET", reqURL, nil) + }) +} diff --git a/cfapi/client.go b/cfapi/client.go index 1d08cc50cdb..55c914f5ba9 100644 --- a/cfapi/client.go +++ b/cfapi/client.go @@ -17,6 +17,7 @@ type TunnelClient interface { type HostnameClient interface { RouteTunnel(tunnelID uuid.UUID, route HostnameRoute) (HostnameRouteResult, error) + ListZones() ([]*Zone, error) } type IPRouteClient interface { @@ -39,3 +40,8 @@ type Client interface { IPRouteClient VnetClient } + +type Zone struct { + ID string `json:"id"` + Name string `json:"name"` +} diff --git a/cfapi/hostname.go b/cfapi/hostname.go index b8ca8bd441c..115bf613b59 100644 --- a/cfapi/hostname.go +++ b/cfapi/hostname.go @@ -5,7 +5,9 @@ import ( "fmt" "io" "net/http" + "net/url" "path" + "strings" "github.com/google/uuid" "github.com/pkg/errors" @@ -25,6 +27,7 @@ type HostnameRoute interface { RecordType() string UnmarshalResult(body io.Reader) (HostnameRouteResult, error) String() string + Hostname() string } type HostnameRouteResult interface { @@ -78,6 +81,10 @@ func (dr *DNSRoute) String() string { return fmt.Sprintf("%s %s", dr.RecordType(), dr.userHostname) } +func (dr *DNSRoute) Hostname() string { + return dr.userHostname +} + func (res *DNSRouteResult) SuccessSummary() string { var msgFmt string switch res.CName { @@ -139,6 +146,10 @@ func (lb *LBRoute) String() string { return fmt.Sprintf("%s %s %s", lb.RecordType(), lb.lbName, lb.lbPool) } +func (lr *LBRoute) Hostname() string { + return lr.lbName +} + func (lr *LBRoute) UnmarshalResult(body io.Reader) (HostnameRouteResult, error) { var result LBRouteResult err := parseResponse(body, &result) @@ -176,7 +187,35 @@ func (res *LBRouteResult) SuccessSummary() string { } func (r *RESTClient) RouteTunnel(tunnelID uuid.UUID, route HostnameRoute) (HostnameRouteResult, error) { + // First, try to find the correct zone by fetching all zones and matching the hostname + zoneID := "" + zones, err := r.ListZones() + if err == nil { + longestMatch := "" + for _, zone := range zones { + // A hostname should end with the zone name EXACTLY or be a subdomain of it. + // e.g. "app.staging.example.com" ends with ".example.com" and "example.com" == "example.com" + if route.Hostname() == zone.Name || strings.HasSuffix(route.Hostname(), "."+zone.Name) { + // We want the most specific zone if there are multiple matches + // e.g. "staging.example.com" zone vs "example.com" zone + if len(zone.Name) > len(longestMatch) { + longestMatch = zone.Name + zoneID = zone.ID + } + } + } + } + endpoint := r.baseEndpoints.zoneLevel + if zoneID != "" { + // Construct dynamic endpoint using the correct zone ID instead of the default one + baseURL := strings.TrimSuffix(r.baseEndpoints.zones.String(), "/zones") + zoneEndpoint, err := url.Parse(fmt.Sprintf("%s/zones/%s/tunnels", baseURL, zoneID)) + if err == nil { + endpoint = *zoneEndpoint + } + } + endpoint.Path = path.Join(endpoint.Path, fmt.Sprintf("%v/routes", tunnelID)) resp, err := r.sendRequest("PUT", endpoint, route) if err != nil { diff --git a/cfapi/hostname_test.go b/cfapi/hostname_test.go index 5100465a39b..86e4265f6c5 100644 --- a/cfapi/hostname_test.go +++ b/cfapi/hostname_test.go @@ -1,9 +1,14 @@ package cfapi import ( + "io" + "net/http" + "net/http/httptest" "strings" "testing" + "github.com/google/uuid" + "github.com/rs/zerolog" "github.com/stretchr/testify/assert" ) @@ -97,3 +102,73 @@ func TestLBRouteResultSuccessSummary(t *testing.T) { assert.Equal(t, tt.expected, actual, "case %d", i+1) } } + +func TestRouteTunnel_ZoneResolution(t *testing.T) { + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.URL.Path == "/zones" { + // A sample JSON response matching the Cloudflare api format, ensuring we mimic what the real API sends. + io.WriteString(w, `{ + "success": true, + "errors": [], + "messages": [], + "result": [ + { + "id": "zone-1", + "name": "example.com", + "status": "active", + "paused": false + }, + { + "id": "zone-2", + "name": "example.co.uk", + "status": "active", + "paused": false + } + ], + "result_info": { + "page": 1, + "per_page": 50, + "total_pages": 1, + "count": 2, + "total_count": 2 + } + }`) + return + } + if r.URL.Path == "/zones/zone-1/tunnels/11111111-2222-3333-4444-555555555555/routes" { + io.WriteString(w, `{"success":true,"result":{"cname":"new","name":"app.example.com"}}`) + return + } + + // Fallback path when zone does NOT match. It uses "default-zone-from-login" as specified in NewRESTClient arguments. + if r.URL.Path == "/zones/default-zone-from-login/tunnels/11111111-2222-3333-4444-555555555555/routes" { + io.WriteString(w, `{"success":true,"result":{"cname":"new","name":"fallback.otherdomain.com"}}`) + return + } + + w.WriteHeader(http.StatusNotFound) + })) + defer ts.Close() + + logger := zerolog.Nop() + client, err := NewRESTClient(ts.URL, "account", "default-zone-from-login", "token", "agent", &logger) + assert.NoError(t, err) + + tunnelID, _ := uuid.Parse("11111111-2222-3333-4444-555555555555") + + t.Run("Success match", func(t *testing.T) { + route := NewDNSRoute("app.example.com", false) + res, err := client.RouteTunnel(tunnelID, route) + assert.NoError(t, err) + assert.NotNil(t, res) + assert.Equal(t, "Added CNAME app.example.com which will route to this tunnel", res.SuccessSummary()) + }) + + t.Run("Fallback to default zone when no match", func(t *testing.T) { + route := NewDNSRoute("fallback.otherdomain.com", false) + res, err := client.RouteTunnel(tunnelID, route) + assert.NoError(t, err) + assert.NotNil(t, res) + assert.Equal(t, "Added CNAME fallback.otherdomain.com which will route to this tunnel", res.SuccessSummary()) + }) +}