Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 20 additions & 1 deletion cfapi/base_client.go
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ type baseEndpoints struct {
zoneLevel url.URL
accountRoutes url.URL
accountVnets url.URL
zones url.URL
}

var _ Client = (*RESTClient)(nil)
Expand All @@ -60,7 +61,11 @@ func NewRESTClient(baseURL, accountTag, zoneTag, authToken, userAgent string, lo
}
zoneLevelEndpoint, err := url.Parse(fmt.Sprintf("%s/zones/%s/tunnels", baseURL, zoneTag))
if err != nil {
return nil, errors.Wrap(err, "failed to create account level endpoint")
return nil, errors.Wrap(err, "failed to create zone level endpoint")
}
zonesEndpoint, err := url.Parse(fmt.Sprintf("%s/zones", baseURL))
if err != nil {
return nil, errors.Wrap(err, "failed to create zones endpoint")
}
httpTransport := http.Transport{
TLSHandshakeTimeout: defaultTimeout,
Expand All @@ -73,6 +78,7 @@ func NewRESTClient(baseURL, accountTag, zoneTag, authToken, userAgent string, lo
zoneLevel: *zoneLevelEndpoint,
accountRoutes: *accountRoutesEndpoint,
accountVnets: *accountVnetsEndpoint,
zones: *zonesEndpoint,
},
authToken: authToken,
userAgent: userAgent,
Expand Down Expand Up @@ -241,3 +247,16 @@ func (r *RESTClient) statusCodeToError(op string, resp *http.Response) error {
return errors.Errorf("API call to %s failed with status %d: %s", op,
resp.StatusCode, http.StatusText(resp.StatusCode))
}

func (r *RESTClient) ListZones() ([]*Zone, error) {
endpoint := r.baseEndpoints.zones
return fetchExhaustively[Zone](func(page int) (*http.Response, error) {
reqURL := endpoint
query := reqURL.Query()
query.Set("page", fmt.Sprintf("%d", page))
query.Set("per_page", "50")
// Required to get basic zone info instead of just IDs
reqURL.RawQuery = query.Encode()
return r.sendRequest("GET", reqURL, nil)
})
}
6 changes: 6 additions & 0 deletions cfapi/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ type TunnelClient interface {

type HostnameClient interface {
RouteTunnel(tunnelID uuid.UUID, route HostnameRoute) (HostnameRouteResult, error)
ListZones() ([]*Zone, error)
}

type IPRouteClient interface {
Expand All @@ -39,3 +40,8 @@ type Client interface {
IPRouteClient
VnetClient
}

type Zone struct {
ID string `json:"id"`
Name string `json:"name"`
}
39 changes: 39 additions & 0 deletions cfapi/hostname.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,9 @@ import (
"fmt"
"io"
"net/http"
"net/url"
"path"
"strings"

"github.com/google/uuid"
"github.com/pkg/errors"
Expand All @@ -25,6 +27,7 @@ type HostnameRoute interface {
RecordType() string
UnmarshalResult(body io.Reader) (HostnameRouteResult, error)
String() string
Hostname() string
}

type HostnameRouteResult interface {
Expand Down Expand Up @@ -78,6 +81,10 @@ func (dr *DNSRoute) String() string {
return fmt.Sprintf("%s %s", dr.RecordType(), dr.userHostname)
}

func (dr *DNSRoute) Hostname() string {
return dr.userHostname
}

func (res *DNSRouteResult) SuccessSummary() string {
var msgFmt string
switch res.CName {
Expand Down Expand Up @@ -139,6 +146,10 @@ func (lb *LBRoute) String() string {
return fmt.Sprintf("%s %s %s", lb.RecordType(), lb.lbName, lb.lbPool)
}

func (lr *LBRoute) Hostname() string {
return lr.lbName
}

func (lr *LBRoute) UnmarshalResult(body io.Reader) (HostnameRouteResult, error) {
var result LBRouteResult
err := parseResponse(body, &result)
Expand Down Expand Up @@ -176,7 +187,35 @@ func (res *LBRouteResult) SuccessSummary() string {
}

func (r *RESTClient) RouteTunnel(tunnelID uuid.UUID, route HostnameRoute) (HostnameRouteResult, error) {
// First, try to find the correct zone by fetching all zones and matching the hostname
zoneID := ""
zones, err := r.ListZones()
if err == nil {
longestMatch := ""
for _, zone := range zones {
// A hostname should end with the zone name EXACTLY or be a subdomain of it.
// e.g. "app.staging.example.com" ends with ".example.com" and "example.com" == "example.com"
if route.Hostname() == zone.Name || strings.HasSuffix(route.Hostname(), "."+zone.Name) {
// We want the most specific zone if there are multiple matches
// e.g. "staging.example.com" zone vs "example.com" zone
if len(zone.Name) > len(longestMatch) {
longestMatch = zone.Name
zoneID = zone.ID
}
}
}
}

endpoint := r.baseEndpoints.zoneLevel
if zoneID != "" {
// Construct dynamic endpoint using the correct zone ID instead of the default one
baseURL := strings.TrimSuffix(r.baseEndpoints.zones.String(), "/zones")
zoneEndpoint, err := url.Parse(fmt.Sprintf("%s/zones/%s/tunnels", baseURL, zoneID))
if err == nil {
endpoint = *zoneEndpoint
}
}

endpoint.Path = path.Join(endpoint.Path, fmt.Sprintf("%v/routes", tunnelID))
resp, err := r.sendRequest("PUT", endpoint, route)
if err != nil {
Expand Down
75 changes: 75 additions & 0 deletions cfapi/hostname_test.go
Original file line number Diff line number Diff line change
@@ -1,9 +1,14 @@
package cfapi

import (
"io"
"net/http"
"net/http/httptest"
"strings"
"testing"

"github.com/google/uuid"
"github.com/rs/zerolog"
"github.com/stretchr/testify/assert"
)

Expand Down Expand Up @@ -97,3 +102,73 @@ func TestLBRouteResultSuccessSummary(t *testing.T) {
assert.Equal(t, tt.expected, actual, "case %d", i+1)
}
}

func TestRouteTunnel_ZoneResolution(t *testing.T) {
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.URL.Path == "/zones" {
// A sample JSON response matching the Cloudflare api format, ensuring we mimic what the real API sends.
io.WriteString(w, `{
"success": true,
"errors": [],
"messages": [],
"result": [
{
"id": "zone-1",
"name": "example.com",
"status": "active",
"paused": false
},
{
"id": "zone-2",
"name": "example.co.uk",
"status": "active",
"paused": false
}
],
"result_info": {
"page": 1,
"per_page": 50,
"total_pages": 1,
"count": 2,
"total_count": 2
}
}`)
return
}
if r.URL.Path == "/zones/zone-1/tunnels/11111111-2222-3333-4444-555555555555/routes" {
io.WriteString(w, `{"success":true,"result":{"cname":"new","name":"app.example.com"}}`)
return
}

// Fallback path when zone does NOT match. It uses "default-zone-from-login" as specified in NewRESTClient arguments.
if r.URL.Path == "/zones/default-zone-from-login/tunnels/11111111-2222-3333-4444-555555555555/routes" {
io.WriteString(w, `{"success":true,"result":{"cname":"new","name":"fallback.otherdomain.com"}}`)
return
}

w.WriteHeader(http.StatusNotFound)
}))
defer ts.Close()

logger := zerolog.Nop()
client, err := NewRESTClient(ts.URL, "account", "default-zone-from-login", "token", "agent", &logger)
assert.NoError(t, err)

tunnelID, _ := uuid.Parse("11111111-2222-3333-4444-555555555555")

t.Run("Success match", func(t *testing.T) {
route := NewDNSRoute("app.example.com", false)
res, err := client.RouteTunnel(tunnelID, route)
assert.NoError(t, err)
assert.NotNil(t, res)
assert.Equal(t, "Added CNAME app.example.com which will route to this tunnel", res.SuccessSummary())
})

t.Run("Fallback to default zone when no match", func(t *testing.T) {
route := NewDNSRoute("fallback.otherdomain.com", false)
res, err := client.RouteTunnel(tunnelID, route)
assert.NoError(t, err)
assert.NotNil(t, res)
assert.Equal(t, "Added CNAME fallback.otherdomain.com which will route to this tunnel", res.SuccessSummary())
})
}