diff --git a/.gitignore b/.gitignore index a5c8936..44e07de 100644 --- a/.gitignore +++ b/.gitignore @@ -1,2 +1,3 @@ .vscode -.env \ No newline at end of file +*.env +.private \ No newline at end of file diff --git a/client.go b/client.go index 1d4869c..d98a5ac 100644 --- a/client.go +++ b/client.go @@ -2,7 +2,6 @@ package loopia import ( "context" - "errors" "fmt" "strings" "sync" @@ -10,6 +9,7 @@ import ( "github.com/kolo/xmlrpc" "github.com/libdns/libdns" + "github.com/libdns/loopia/internal/cache" ) const ( @@ -19,12 +19,17 @@ const ( type client struct { rpc *xmlrpc.Client mutex sync.Mutex + cache cache.Cache // Cache for storing zone records } type libdnsKey string var libdnsKeyTrace libdnsKey = "libdns.loopia.trace" +func params(args ...any) []any { + return args +} + func writeTrace(ctx context.Context, trace string) context.Context { if ctx == nil { ctx = context.Background() @@ -95,32 +100,55 @@ func (p *Provider) getRPC() *xmlrpc.Client { return p.rpc } -func (p *Provider) call(serviceMethod string, args []interface{}, reply interface{}) error { +func (p *Provider) call(serviceMethod string, args []interface{}, reply interface{}) (err error) { params := []interface{}{ p.Username, p.Password, } + if p.logging { + t := time.Now() + defer func() { + Log().Debugw("called rpc", "method", serviceMethod, "params", args, "error", err, "duration", time.Since(t).String()) + }() + } if p.Customer != "" { params = append(params, p.Customer) } params = append(params, args...) - err := p.getRPC().Call( + hashKey := fmt.Sprintf("%s:%v", serviceMethod, args) + + if p.cache != nil && serviceMethod == "getSubdomains" { + if p.cache.Get(hashKey, reply) { + if p.logging { + Log().Debugw("cache hit", "method", serviceMethod, "params", args) + } + return nil + } + } + + err = p.getRPC().Call( serviceMethod, params, reply, ) - if p.logging { - Log().Debugw("called rpc", "method", serviceMethod, "params", args, "error", err) + if err == nil && serviceMethod == "getSubdomains" { + if p.cache == nil { + p.cache = cache.NewInMemoryCache() + } + if cacheErr := p.cache.Set(hashKey, reply, 5*time.Second); cacheErr != nil && p.logging { + Log().Warnw("cache set failed", "method", serviceMethod, "params", args, "error", cacheErr) + } } return err } +// getLoopiaRecords retrieves the records for a given zone and name from Loopia. It populates the provided records slice with the results. func (p *Provider) getLoopiaRecords(ctx context.Context, zone, name string, records *[]loopiaRecord) error { if !validZone(zone) { return fmt.Errorf("invalid zone '%s'", zone) } if name == "" { - return fmt.Errorf("invalide name '%s'", name) + return fmt.Errorf("invalid name '%s'", name) } if p.logging { Log().Debugw("getLoopiaRecords", "zone", zone, "name", name, "trace", getTrace(ctx)) @@ -165,14 +193,23 @@ func (p *Provider) getRecords(ctx context.Context, zone, name string) ([]libdns. for _, r := range records { rr, err := r.libdnsRecord(name) if err != nil { - return nil, fmt.Errorf("unexpected error converting record: %w", err) + return nil, fmt.Errorf("unexpected error in client converting record for getRecords: %w", err) } result = append(result, rr) } return result, nil } +func (p *Provider) clearCache() { + if p.cache != nil { + if err := p.cache.Clear(); err != nil && p.logging { + Log().Debugw("cache clear failed", "method", "clearCache", "error", err) + } + } +} + func (p *Provider) addRecord(ctx context.Context, zone string, record libdns.Record, withSubdomain bool) (out libdns.Record, id int64, err error) { + defer p.clearCache() // Clear cache after adding a record to ensure consistency for subsequent reads if p.logging { Log().Debugw("addRecord", "zone", zone, @@ -184,7 +221,7 @@ func (p *Provider) addRecord(ctx context.Context, zone string, record libdns.Rec name := record.RR().Name loopiaToAdd, err := toLoopiaRecord(record, 0) if err != nil { - return nil, 0, fmt.Errorf("unexpected error converting record: %w", err) + return nil, 0, fmt.Errorf("unexpected error converting record for addRecord: %w", err) } if withSubdomain { var response string @@ -209,7 +246,7 @@ func (p *Provider) addRecord(ctx context.Context, zone string, record libdns.Rec for _, r := range records { out, err = r.libdnsRecord(name) if err != nil { - return nil, 0, fmt.Errorf("unexpected error converting record: %w", err) + return nil, 0, fmt.Errorf("unexpected error converting record for addRecord: %w", err) } if libdnsRecordEqual(record, out) { return out, r.ID, nil @@ -219,10 +256,6 @@ func (p *Provider) addRecord(ctx context.Context, zone string, record libdns.Rec return nil, 0, fmt.Errorf("unable to retreive new record to get it's ID") } -func params(args ...interface{}) []interface{} { - return args -} - func (p *Provider) getZoneRecords(ctx context.Context, zone string) ([]libdns.Record, error) { if p.logging { Log().Debugw("getZoneRecords", "zone", zone) @@ -254,6 +287,7 @@ myloop: } func (p *Provider) addDNSEntries(ctx context.Context, zone string, records []libdns.Record) ([]libdns.Record, error) { + defer p.clearCache() // Clear cache after adding a record to ensure consistency for subsequent reads if p.logging { Log().Debugw("addDNSEntries", "zone", zone, @@ -327,16 +361,94 @@ OUTER: // records in the output zone with that (name, type) pair are those that were // provided in the input. func (p *Provider) setRecords(ctx context.Context, zone string, records []libdns.Record) ([]libdns.Record, error) { + defer p.clearCache() // Clear cache after adding a record to ensure consistency for subsequent reads ctx = addTrace(ctx, "setRecords") + var err error + type nameTypeEntry struct { + name string + rtype string + wanted []libdns.Record + existing []libdns.Record + } + + dnsNameRecords := map[string][]libdns.Record{} + entries := map[string]nameTypeEntry{} + foundOrAdded := []libdns.Record{} for _, r := range records { n, z := loopify(r.RR().Name, zone) - existing := []loopiaRecord{} - err := p.getLoopiaRecords(ctx, z, n, &existing) - if err != nil { - return nil, fmt.Errorf("unexpected error getting zone records: %w", err) + nameTypeKey := fmt.Sprintf("%s:%s", n, r.RR().Type) + entry, hasName := entries[nameTypeKey] + if !hasName { + entry = nameTypeEntry{ + name: n, + rtype: r.RR().Type, + wanted: []libdns.Record{}, + existing: []libdns.Record{}, + } + + existing, ok := dnsNameRecords[n] + // Fetch existing records for the name if not already fetched + if !ok { + existing, err = p.getRecords(ctx, z, n) + if err != nil { + return nil, fmt.Errorf("unexpected error getting existing records for setRecords: %w", err) + } + dnsNameRecords[n] = existing + + } + for _, er := range existing { + if er.RR().Type == r.RR().Type { + entry.existing = append(entry.existing, er) + } + } + entry.wanted = append(entry.wanted, r) + entries[nameTypeKey] = entry + } + } + for _, entry := range entries { + // Delete records that are in existing but not in wanted + for _, er := range entry.existing { + found := false + for _, wr := range entry.wanted { + if libdnsEqual(er, wr) { + found = true + break + } + } + if !found { + if p.logging { + Log().Debugw("deleting record", "zone", zone, "name", entry.name, "type", entry.rtype, "record", er) + } + _, err := p.deleteRecords(ctx, zone, []libdns.Record{er}) + if err != nil { + return nil, fmt.Errorf("unexpected error deleting record for setRecords: %w", err) + } + } + } + + // Add records that are in wanted but not in existing + for _, wr := range entry.wanted { + found := false + for _, er := range entry.existing { + if libdnsEqual(er, wr) { + found = true + foundOrAdded = append(foundOrAdded, er) + break + } + } + if !found { + if p.logging { + Log().Debugw("adding record", "zone", zone, "name", entry.name, "type", entry.rtype, "record", wr) + } + added, err := p.addDNSEntries(ctx, zone, []libdns.Record{wr}) + if err != nil { + return nil, fmt.Errorf("unexpected error adding record for setRecords: %w", err) + } + foundOrAdded = append(foundOrAdded, added...) + } } } - return nil, errors.New("not implemented") + return foundOrAdded, nil } func (p *Provider) updateZoneRecord(ctx context.Context, zone string, record libdns.Record, id int64) (*loopiaRecord, error) { @@ -346,7 +458,7 @@ func (p *Provider) updateZoneRecord(ctx context.Context, zone string, record lib if id == 0 { return nil, fmt.Errorf("invalid ID") } - + defer p.clearCache() // Clear cache after adding a record to ensure consistency for subsequent reads zone = cleanZone(zone) updated := mustToLoopiaRecord(record, id) @@ -364,6 +476,7 @@ func (p *Provider) updateZoneRecord(ctx context.Context, zone string, record lib } func (p *Provider) deleteRecords(ctx context.Context, zone string, records []libdns.Record) ([]libdns.Record, error) { + defer p.clearCache() // Clear cache after adding a record to ensure consistency for subsequent reads if p.logging { Log().Debugw("deleteRecords", "zone", zone, "records", len(records), "trace", getTrace(ctx)) } @@ -453,6 +566,7 @@ func (p *Provider) removeDNSEntry(ctx context.Context, zone, name string, id int if id == 0 { return fmt.Errorf("invalid ID") } + defer p.clearCache() // Clear cache after adding a record to ensure consistency for subsequent reads ctx = addTrace(ctx, "removeDNSEntry") zone = cleanZone(zone) var response string diff --git a/internal/cache/cache.go b/internal/cache/cache.go new file mode 100644 index 0000000..ab728ac --- /dev/null +++ b/internal/cache/cache.go @@ -0,0 +1,97 @@ +package cache + +import ( + "encoding/json" + "sync" + "time" +) + +type Cache interface { + // Get retrieves a value from the cache. It returns the value and a boolean indicating whether the key was found. + Get(key string, value any) (ok bool) + Set(key string, value any, ttl time.Duration) error + // Clear removes all cache entries, regardless of their expiry status. This is useful for testing or when you want to clear the cache manually. + Clear() error +} + +type noopCache struct{} + +func (c *noopCache) Get(key string, value any) (ok bool) { + return false +} + +func (c *noopCache) Set(key string, value any, ttl time.Duration) error { + // no-op + return nil +} + +func (c *noopCache) Clear() error { + // no-op + return nil +} + +func NewNoopCache() Cache { + return &noopCache{} +} + +func NewInMemoryCache() Cache { + return &inMemoryCache{ + store: make(map[string]cacheEntry), + } +} + +type cacheEntry struct { + value json.RawMessage + expiry time.Time +} + +type inMemoryCache struct { + mu sync.RWMutex + store map[string]cacheEntry +} + +func (c *inMemoryCache) Get(key string, value any) (ok bool) { + c.mu.RLock() + entry, found := c.store[key] + c.mu.RUnlock() + if !found || time.Now().After(entry.expiry) { + if found { + c.mu.Lock() + entry, found = c.store[key] + if found && time.Now().After(entry.expiry) { + delete(c.store, key) + } + c.mu.Unlock() + } + return false + } + if value != nil { + // We need to copy the cached value to the provided pointer. + // This is a bit hacky, but it works for our use case. + if err := json.Unmarshal(entry.value, value); err != nil { + return false + } + } + return true +} + +func (c *inMemoryCache) Set(key string, value any, ttl time.Duration) error { + data, err := json.Marshal(value) + if err != nil { + return err + } + c.mu.Lock() + c.store[key] = cacheEntry{ + value: data, + expiry: time.Now().Add(ttl), + } + c.mu.Unlock() + return nil +} + +func (c *inMemoryCache) Clear() error { + c.mu.Lock() + c.store = make(map[string]cacheEntry) + c.mu.Unlock() + return nil +} diff --git a/internal/cache/cache_test.go b/internal/cache/cache_test.go new file mode 100644 index 0000000..4cb0951 --- /dev/null +++ b/internal/cache/cache_test.go @@ -0,0 +1,109 @@ +package cache + +import ( + "sync" + "testing" + "time" +) + +type cacheTestValue struct { + Name string `json:"name"` + ID int `json:"id"` +} + +func TestNoopCache(t *testing.T) { + c := NewNoopCache() + + if err := c.Set("k", cacheTestValue{Name: "x", ID: 1}, time.Second); err != nil { + t.Fatalf("Set() returned error: %v", err) + } + + var out cacheTestValue + if ok := c.Get("k", &out); ok { + t.Fatalf("Get() = true, want false") + } +} + +func TestInMemoryCache_SetGet(t *testing.T) { + c := NewInMemoryCache() + want := cacheTestValue{Name: "alice", ID: 42} + + if err := c.Set("user:1", want, time.Minute); err != nil { + t.Fatalf("Set() returned error: %v", err) + } + + var got cacheTestValue + if ok := c.Get("user:1", &got); !ok { + t.Fatalf("Get() = false, want true") + } + if got != want { + t.Fatalf("Get() value = %+v, want %+v", got, want) + } +} + +func TestInMemoryCache_Expires(t *testing.T) { + c := NewInMemoryCache() + + if err := c.Set("ephemeral", cacheTestValue{Name: "tmp", ID: 1}, 10*time.Millisecond); err != nil { + t.Fatalf("Set() returned error: %v", err) + } + + time.Sleep(25 * time.Millisecond) + + var got cacheTestValue + if ok := c.Get("ephemeral", &got); ok { + t.Fatalf("Get() = true for expired key, want false") + } +} + +func TestInMemoryCache_GetTypeMismatchReturnsFalse(t *testing.T) { + c := NewInMemoryCache() + + if err := c.Set("value", cacheTestValue{Name: "alice", ID: 10}, time.Minute); err != nil { + t.Fatalf("Set() returned error: %v", err) + } + + var wrongType int + if ok := c.Get("value", &wrongType); ok { + t.Fatalf("Get() = true with unmarshal type mismatch, want false") + } +} + +func TestInMemoryCache_ConcurrentAccess(t *testing.T) { + c := NewInMemoryCache() + + const workers = 24 + const iterations = 250 + + var wg sync.WaitGroup + for w := 0; w < workers; w++ { + w := w + wg.Add(1) + go func() { + defer wg.Done() + for i := 0; i < iterations; i++ { + key := "key:shared" + if i%3 == 0 { + key = "key:alt" + } + + val := cacheTestValue{Name: "worker", ID: w*iterations + i} + if err := c.Set(key, val, time.Second); err != nil { + t.Errorf("Set() returned error: %v", err) + return + } + + var out cacheTestValue + _ = c.Get(key, &out) + } + }() + } + + wg.Wait() + + // Ensure cache still responds after concurrent load. + var got cacheTestValue + if ok := c.Get("key:shared", &got); !ok { + t.Fatalf("Get() after concurrent access = false, want true") + } +} diff --git a/loopify.go b/loopify.go index 8b0715c..9b9723c 100644 --- a/loopify.go +++ b/loopify.go @@ -5,7 +5,7 @@ import ( "strings" ) -// loopia does not have support for propper subdomains so +// loopify modifies the components because loopia does not have support for propper subdomains so // we need so that zone only contains . func loopify(name, zone string) (string, string) { components := strings.Split(zone, ".") diff --git a/models.go b/models.go index dbfa9ff..55f37be 100644 --- a/models.go +++ b/models.go @@ -1,6 +1,7 @@ package loopia import ( + "fmt" "strings" "time" @@ -15,13 +16,30 @@ type loopiaRecord struct { Priority int `xmlrpc:"priority"` } +func (r *loopiaRecord) MarshalXMLRPC() (string, error) { + return fmt.Sprintf("record_id%dttl%dtype%srdata%spriority%d", r.ID, r.TTL, r.Type, r.RData, r.Priority), nil +} + func (r *loopiaRecord) libdnsRecord(subDomain string) (libdns.Record, error) { - return libdns.RR{ - Name: subDomain, - Type: r.Type, - Data: strings.Trim(r.RData, "\""), - TTL: time.Duration(r.TTL) * time.Second, - }.Parse() + switch r.Type { + case "MX": + return libdns.RR{ + Name: subDomain, + Type: r.Type, + Data: fmt.Sprintf("%d %s", r.Priority, strings.Trim(r.RData, "\"")), + TTL: time.Duration(r.TTL) * time.Second, + }.Parse() + default: + if r.Priority != 0 { + fmt.Printf("unsupported record type with priority: %s", r.Type) + } + return libdns.RR{ + Name: subDomain, + Type: r.Type, + Data: strings.Trim(r.RData, "\""), + TTL: time.Duration(r.TTL) * time.Second, + }.Parse() + } } func (r *loopiaRecord) mustLibdnsRecord(subDomain string) libdns.Record { @@ -67,3 +85,23 @@ func libdnsEqualLoopia(r1 libdns.Record, r2 loopiaRecord) bool { } return libdnsRecordEqual(r1, r2libdns) } + +// libdnsEqual compares two libdns records for equality, including TTL values. +// It returns true if the records are equal, and false otherwise. +func libdnsEqual(a, b libdns.Record) bool { + ra := a.RR() + rb := b.RR() + if ra.Name != rb.Name { + return false + } + if ra.Type != rb.Type { + return false + } + if ra.Data != rb.Data { + return false + } + if ra.TTL != rb.TTL { + return false + } + return true +} diff --git a/provider_test.go b/provider_test.go index 510f2a7..d2baceb 100644 --- a/provider_test.go +++ b/provider_test.go @@ -1,6 +1,6 @@ // Package libdns-loopia implements a DNS record management client compatible // with the libdns interfaces for Loopia. -//go:build integration +// ----asdfasdfgo:build integration package loopia @@ -12,6 +12,7 @@ import ( "time" "github.com/libdns/libdns" + "github.com/stretchr/testify/assert" ) func getRecords() []libdns.Record { @@ -125,12 +126,16 @@ func TestProvider_SetRecords(t *testing.T) { want []libdns.Record wantErr bool }{ - {"nil records", tc.getProvider(), args{context.TODO(), "test.local", nil}, nil, true}, - {"empty records", tc.getProvider(), args{context.TODO(), "test.local", []libdns.Record{}}, nil, true}, - {"invalid record", tc.getProvider(), args{context.TODO(), "test.local", []libdns.Record{libdns.Address{Name: "www"}}}, nil, true}, - {"invalid ID", tc.getProvider(), args{context.TODO(), "test.local", []libdns.Record{libdns.Address{Name: "www", IP: netip.MustParseAddr("127.0.0.1"), TTL: 5 * time.Minute}}}, nil, true}, - {"valid record", tc.getProvider(), args{context.TODO(), "test.local", []libdns.Record{libdns.Address{Name: "www", IP: netip.MustParseAddr("127.0.0.1"), TTL: 5 * time.Minute}}}, - []libdns.Record{libdns.Address{Name: "www", IP: netip.MustParseAddr("127.0.0.1"), TTL: 5 * time.Minute}}, false}, + // {"nil records", tc.getProvider(), args{context.TODO(), "test.local", nil}, []libdns.Record{}, false}, + // {"empty records", tc.getProvider(), args{context.TODO(), "test.local", []libdns.Record{}}, []libdns.Record{}, false}, + // {"invalid record", tc.getProvider(), args{context.TODO(), "test.local", []libdns.Record{libdns.Address{Name: "www"}}}, nil, true}, + // {"invalid ID", tc.getProvider(), args{context.TODO(), "test.local", []libdns.Record{libdns.Address{Name: "www", IP: netip.MustParseAddr("127.0.0.1"), TTL: 5 * time.Minute}}}, nil, true}, + {"valid record", tc.getProvider(), args{context.TODO(), "test.local", []libdns.Record{ + libdns.Address{Name: "www", IP: netip.MustParseAddr("127.0.0.1"), TTL: 5 * time.Minute}}, + }, + []libdns.Record{libdns.Address{Name: "www", IP: netip.MustParseAddr("127.0.0.1"), TTL: 5 * time.Minute}}, + false, + }, // TODO: Add test cases. } for _, tt := range tests { @@ -141,9 +146,11 @@ func TestProvider_SetRecords(t *testing.T) { t.Errorf("Provider.SetRecords() error = %v, wantErr %v", err, tt.wantErr) return } - if !reflect.DeepEqual(got, tt.want) { - t.Errorf("Provider.SetRecords() = %v, want %v", got, tt.want) - } + + assert.Equal(t, tt.want, got) + // if !reflect.DeepEqual(got, tt.want) { + // t.Errorf("Provider.SetRecords() = %v, want %v", got, tt.want) + // } }) } } diff --git a/server_test.go b/server_test.go index e371def..bf36aaa 100644 --- a/server_test.go +++ b/server_test.go @@ -3,10 +3,10 @@ package loopia import ( "fmt" "io" - "io/ioutil" "net/http" "net/http/httptest" "os" + "sync" "testing" "github.com/kolo/xmlrpc" @@ -14,28 +14,25 @@ import ( "github.com/subchen/go-xmldom" ) -var ( - handlers map[string]methodHandler -) - type methodHandler func(t *testing.T, w http.ResponseWriter, params []string) -func init() { - handlers = make(map[string]methodHandler) - handlers["getZoneRecords"] = getZoneRecordsHandler - handlers["getSubdomains"] = getSubdomainsHandler - handlers["addSubdomain"] = addSubdomainHandler - handlers["addZoneRecord"] = addZoneRecordHandler - handlers["updateZoneRecord"] = updateZoneRecordHandler - handlers["removeZoneRecord"] = returnOkHandler - handlers["removeSubdomain"] = returnOkHandler +type mochRecords map[string][]loopiaRecord + +func (m mochRecords) getSubdomains() []string { + subdomains := []string{} + for subdomain := range m { + subdomains = append(subdomains, subdomain) + } + return subdomains } type testContext struct { - mux *http.ServeMux - - rpc *xmlrpc.Client - server *httptest.Server + mux *http.ServeMux + rpc *xmlrpc.Client + server *httptest.Server + records mochRecords + id int + mu sync.Mutex } func (tc *testContext) getProvider() *Provider { @@ -45,11 +42,39 @@ func (tc *testContext) getProvider() *Provider { } func setupTest(t *testing.T) *testContext { - tc := &testContext{} + tc := &testContext{ + records: map[string][]loopiaRecord{ + "*": { + loopiaRecord{ID: 14096733, TTL: 300, Type: "A", RData: "*"}, + }, + "@": { + loopiaRecord{ID: 14096730, TTL: 300, Type: "NS", RData: "ns1.test.local."}, + loopiaRecord{ID: 14096731, TTL: 600, Type: "NS", RData: "ns2.test.local."}, + }, + "www": { + loopiaRecord{ID: 14096733, TTL: 300, Type: "A", RData: "www"}, + }, + "cdn": { + loopiaRecord{ID: 14096734, TTL: 300, Type: "A", RData: "cdn"}, + }, + "_challenge.test": { + loopiaRecord{ID: 14096735, TTL: 0, Type: "TXT", RData: "foo"}, + }, + }, + } + handlers := make(map[string]methodHandler) + handlers["getZoneRecords"] = tc.getZoneRecordsHandler + handlers["getSubdomains"] = tc.getSubdomainsHandler + handlers["addSubdomain"] = tc.addSubdomainHandler + handlers["addZoneRecord"] = tc.addZoneRecordHandler + handlers["updateZoneRecord"] = tc.updateZoneRecordHandler + handlers["removeZoneRecord"] = tc.returnOkHandler + handlers["removeSubdomain"] = tc.returnOkHandler + tc.mux = http.NewServeMux() tc.server = httptest.NewServer(tc.mux) tc.rpc, _ = xmlrpc.NewClient(tc.server.URL, nil) - tc.mux.HandleFunc("/", apiHandler(t)) + tc.mux.HandleFunc("/", apiHandler(t, handlers)) return tc } @@ -59,7 +84,7 @@ func teardownTest(tc *testContext) { } } -func apiHandler(t *testing.T) func(w http.ResponseWriter, r *http.Request) { +func apiHandler(t *testing.T, handlers map[string]methodHandler) func(w http.ResponseWriter, r *http.Request) { return func(w http.ResponseWriter, r *http.Request) { assert.Equal(t, r.Method, "POST") @@ -89,12 +114,14 @@ func apiHandler(t *testing.T) func(w http.ResponseWriter, r *http.Request) { } } -func getSubdomainsHandler(t *testing.T, w http.ResponseWriter, params []string) { +func (tc *testContext) getSubdomainsHandler(t *testing.T, w http.ResponseWriter, params []string) { + // get subdomains from records map keys + byteArray, _ := os.ReadFile("testdata/subdomains.xml") fmt.Fprint(w, string(byteArray[:])) } -func getZoneRecordsHandler(t *testing.T, w http.ResponseWriter, params []string) { +func (tc *testContext) getZoneRecordsHandler(t *testing.T, w http.ResponseWriter, params []string) { recname := params[len(params)-1] //last parameter if recname == "*" { @@ -110,7 +137,7 @@ func getZoneRecordsHandler(t *testing.T, w http.ResponseWriter, params []string) fmt.Fprint(w, string(byteArray[:])) } -func addSubdomainHandler(t *testing.T, w http.ResponseWriter, params []string) { +func (tc *testContext) addSubdomainHandler(t *testing.T, w http.ResponseWriter, params []string) { fmt.Printf(" > addSubdomainHandler(%s, %s)\n", params[3], params[4]) assert.Len(t, params, 5) lastp := params[len(params)-1] @@ -119,17 +146,17 @@ func addSubdomainHandler(t *testing.T, w http.ResponseWriter, params []string) { fmt.Fprint(w, string(byteArray[:])) } -func addZoneRecordHandler(t *testing.T, w http.ResponseWriter, params []string) { +func (tc *testContext) addZoneRecordHandler(t *testing.T, w http.ResponseWriter, params []string) { byteArray, _ := os.ReadFile("testdata/ok.xml") fmt.Fprint(w, string(byteArray[:])) } -func updateZoneRecordHandler(t *testing.T, w http.ResponseWriter, params []string) { - byteArray, _ := ioutil.ReadFile("testdata/ok.xml") +func (tc *testContext) updateZoneRecordHandler(t *testing.T, w http.ResponseWriter, params []string) { + byteArray, _ := os.ReadFile("testdata/ok.xml") fmt.Fprint(w, string(byteArray[:])) } -func returnOkHandler(t *testing.T, w http.ResponseWriter, params []string) { +func (tc *testContext) returnOkHandler(t *testing.T, w http.ResponseWriter, params []string) { byteArray, _ := os.ReadFile("testdata/ok.xml") fmt.Fprint(w, string(byteArray[:])) }