Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
.vscode
.env
*.env
.private
152 changes: 133 additions & 19 deletions client.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,14 @@ package loopia

import (
"context"
"errors"
"fmt"
"strings"
"sync"
"time"

"github.com/kolo/xmlrpc"
"github.com/libdns/libdns"
"github.com/libdns/loopia/internal/cache"
)

const (
Expand All @@ -19,12 +19,17 @@ const (
type client struct {
rpc *xmlrpc.Client
mutex sync.Mutex
cache cache.Cache // Cache for storing zone records
}

type libdnsKey string

var libdnsKeyTrace libdnsKey = "libdns.loopia.trace"

func params(args ...any) []any {
return args
}

func writeTrace(ctx context.Context, trace string) context.Context {
if ctx == nil {
ctx = context.Background()
Expand Down Expand Up @@ -95,32 +100,55 @@ func (p *Provider) getRPC() *xmlrpc.Client {
return p.rpc
}

func (p *Provider) call(serviceMethod string, args []interface{}, reply interface{}) error {
func (p *Provider) call(serviceMethod string, args []interface{}, reply interface{}) (err error) {
params := []interface{}{
p.Username,
p.Password,
}
if p.logging {
t := time.Now()
defer func() {
Log().Debugw("called rpc", "method", serviceMethod, "params", args, "error", err, "duration", time.Since(t).String())
}()
}
if p.Customer != "" {
params = append(params, p.Customer)
}
params = append(params, args...)
err := p.getRPC().Call(
hashKey := fmt.Sprintf("%s:%v", serviceMethod, args)

if p.cache != nil && serviceMethod == "getSubdomains" {
if p.cache.Get(hashKey, reply) {
if p.logging {
Log().Debugw("cache hit", "method", serviceMethod, "params", args)
}
return nil
}
}

err = p.getRPC().Call(
serviceMethod,
params,
reply,
)
if p.logging {
Log().Debugw("called rpc", "method", serviceMethod, "params", args, "error", err)
if err == nil && serviceMethod == "getSubdomains" {
if p.cache == nil {
p.cache = cache.NewInMemoryCache()
}
if cacheErr := p.cache.Set(hashKey, reply, 5*time.Second); cacheErr != nil && p.logging {
Log().Warnw("cache set failed", "method", serviceMethod, "params", args, "error", cacheErr)
}
}
return err
}

// getLoopiaRecords retrieves the records for a given zone and name from Loopia. It populates the provided records slice with the results.
func (p *Provider) getLoopiaRecords(ctx context.Context, zone, name string, records *[]loopiaRecord) error {
if !validZone(zone) {
return fmt.Errorf("invalid zone '%s'", zone)
}
if name == "" {
return fmt.Errorf("invalide name '%s'", name)
return fmt.Errorf("invalid name '%s'", name)
}
if p.logging {
Log().Debugw("getLoopiaRecords", "zone", zone, "name", name, "trace", getTrace(ctx))
Expand Down Expand Up @@ -165,14 +193,23 @@ func (p *Provider) getRecords(ctx context.Context, zone, name string) ([]libdns.
for _, r := range records {
rr, err := r.libdnsRecord(name)
if err != nil {
return nil, fmt.Errorf("unexpected error converting record: %w", err)
return nil, fmt.Errorf("unexpected error in client converting record for getRecords: %w", err)
}
result = append(result, rr)
}
return result, nil
}

func (p *Provider) clearCache() {
if p.cache != nil {
if err := p.cache.Clear(); err != nil && p.logging {
Log().Debugw("cache clear failed", "method", "clearCache", "error", err)
}
}
}

func (p *Provider) addRecord(ctx context.Context, zone string, record libdns.Record, withSubdomain bool) (out libdns.Record, id int64, err error) {
defer p.clearCache() // Clear cache after adding a record to ensure consistency for subsequent reads
if p.logging {
Log().Debugw("addRecord",
"zone", zone,
Expand All @@ -184,7 +221,7 @@ func (p *Provider) addRecord(ctx context.Context, zone string, record libdns.Rec
name := record.RR().Name
loopiaToAdd, err := toLoopiaRecord(record, 0)
if err != nil {
return nil, 0, fmt.Errorf("unexpected error converting record: %w", err)
return nil, 0, fmt.Errorf("unexpected error converting record for addRecord: %w", err)
}
if withSubdomain {
var response string
Expand All @@ -209,7 +246,7 @@ func (p *Provider) addRecord(ctx context.Context, zone string, record libdns.Rec
for _, r := range records {
out, err = r.libdnsRecord(name)
if err != nil {
return nil, 0, fmt.Errorf("unexpected error converting record: %w", err)
return nil, 0, fmt.Errorf("unexpected error converting record for addRecord: %w", err)
}
if libdnsRecordEqual(record, out) {
return out, r.ID, nil
Expand All @@ -219,10 +256,6 @@ func (p *Provider) addRecord(ctx context.Context, zone string, record libdns.Rec
return nil, 0, fmt.Errorf("unable to retreive new record to get it's ID")
}

func params(args ...interface{}) []interface{} {
return args
}

func (p *Provider) getZoneRecords(ctx context.Context, zone string) ([]libdns.Record, error) {
if p.logging {
Log().Debugw("getZoneRecords", "zone", zone)
Expand Down Expand Up @@ -254,6 +287,7 @@ myloop:
}

func (p *Provider) addDNSEntries(ctx context.Context, zone string, records []libdns.Record) ([]libdns.Record, error) {
defer p.clearCache() // Clear cache after adding a record to ensure consistency for subsequent reads
if p.logging {
Log().Debugw("addDNSEntries",
"zone", zone,
Expand Down Expand Up @@ -327,16 +361,94 @@ OUTER:
// records in the output zone with that (name, type) pair are those that were
// provided in the input.
func (p *Provider) setRecords(ctx context.Context, zone string, records []libdns.Record) ([]libdns.Record, error) {
defer p.clearCache() // Clear cache after adding a record to ensure consistency for subsequent reads
ctx = addTrace(ctx, "setRecords")
var err error
type nameTypeEntry struct {
name string
rtype string
wanted []libdns.Record
existing []libdns.Record
}

dnsNameRecords := map[string][]libdns.Record{}
entries := map[string]nameTypeEntry{}
foundOrAdded := []libdns.Record{}
for _, r := range records {
n, z := loopify(r.RR().Name, zone)
existing := []loopiaRecord{}
err := p.getLoopiaRecords(ctx, z, n, &existing)
if err != nil {
return nil, fmt.Errorf("unexpected error getting zone records: %w", err)
nameTypeKey := fmt.Sprintf("%s:%s", n, r.RR().Type)
entry, hasName := entries[nameTypeKey]
if !hasName {
entry = nameTypeEntry{
name: n,
rtype: r.RR().Type,
wanted: []libdns.Record{},
existing: []libdns.Record{},
}

existing, ok := dnsNameRecords[n]
// Fetch existing records for the name if not already fetched
if !ok {
existing, err = p.getRecords(ctx, z, n)
if err != nil {
return nil, fmt.Errorf("unexpected error getting existing records for setRecords: %w", err)
}
dnsNameRecords[n] = existing

}
for _, er := range existing {
if er.RR().Type == r.RR().Type {
entry.existing = append(entry.existing, er)
}
}
entry.wanted = append(entry.wanted, r)
entries[nameTypeKey] = entry
}
}
for _, entry := range entries {
// Delete records that are in existing but not in wanted
for _, er := range entry.existing {
found := false
for _, wr := range entry.wanted {
if libdnsEqual(er, wr) {
found = true
break
}
}
if !found {
if p.logging {
Log().Debugw("deleting record", "zone", zone, "name", entry.name, "type", entry.rtype, "record", er)
}
_, err := p.deleteRecords(ctx, zone, []libdns.Record{er})
if err != nil {
return nil, fmt.Errorf("unexpected error deleting record for setRecords: %w", err)
}
}
}

// Add records that are in wanted but not in existing
for _, wr := range entry.wanted {
found := false
for _, er := range entry.existing {
if libdnsEqual(er, wr) {
found = true
foundOrAdded = append(foundOrAdded, er)
break
}
}
if !found {
if p.logging {
Log().Debugw("adding record", "zone", zone, "name", entry.name, "type", entry.rtype, "record", wr)
}
added, err := p.addDNSEntries(ctx, zone, []libdns.Record{wr})
if err != nil {
return nil, fmt.Errorf("unexpected error adding record for setRecords: %w", err)
}
foundOrAdded = append(foundOrAdded, added...)
}
}
}
return nil, errors.New("not implemented")
return foundOrAdded, nil
}

func (p *Provider) updateZoneRecord(ctx context.Context, zone string, record libdns.Record, id int64) (*loopiaRecord, error) {
Expand All @@ -346,7 +458,7 @@ func (p *Provider) updateZoneRecord(ctx context.Context, zone string, record lib
if id == 0 {
return nil, fmt.Errorf("invalid ID")
}

defer p.clearCache() // Clear cache after adding a record to ensure consistency for subsequent reads
zone = cleanZone(zone)
updated := mustToLoopiaRecord(record, id)

Expand All @@ -364,6 +476,7 @@ func (p *Provider) updateZoneRecord(ctx context.Context, zone string, record lib
}

func (p *Provider) deleteRecords(ctx context.Context, zone string, records []libdns.Record) ([]libdns.Record, error) {
defer p.clearCache() // Clear cache after adding a record to ensure consistency for subsequent reads
if p.logging {
Log().Debugw("deleteRecords", "zone", zone, "records", len(records), "trace", getTrace(ctx))
}
Expand Down Expand Up @@ -453,6 +566,7 @@ func (p *Provider) removeDNSEntry(ctx context.Context, zone, name string, id int
if id == 0 {
return fmt.Errorf("invalid ID")
}
defer p.clearCache() // Clear cache after adding a record to ensure consistency for subsequent reads
ctx = addTrace(ctx, "removeDNSEntry")
zone = cleanZone(zone)
var response string
Expand Down
97 changes: 97 additions & 0 deletions internal/cache/cache.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
package cache

import (
"encoding/json"
"sync"
"time"
)

type Cache interface {
// Get retrieves a value from the cache. It returns the value and a boolean indicating whether the key was found.
Get(key string, value any) (ok bool)
Set(key string, value any, ttl time.Duration) error
// Clear removes all cache entries, regardless of their expiry status. This is useful for testing or when you want to clear the cache manually.
Clear() error
}

type noopCache struct{}

func (c *noopCache) Get(key string, value any) (ok bool) {
return false
}

func (c *noopCache) Set(key string, value any, ttl time.Duration) error {
// no-op
return nil
}

func (c *noopCache) Clear() error {
// no-op
return nil
}

func NewNoopCache() Cache {
return &noopCache{}
}

func NewInMemoryCache() Cache {
return &inMemoryCache{
store: make(map[string]cacheEntry),
}
}

type cacheEntry struct {
value json.RawMessage
expiry time.Time
}

type inMemoryCache struct {
mu sync.RWMutex
store map[string]cacheEntry
}

func (c *inMemoryCache) Get(key string, value any) (ok bool) {
c.mu.RLock()
entry, found := c.store[key]
c.mu.RUnlock()
if !found || time.Now().After(entry.expiry) {
if found {
c.mu.Lock()
entry, found = c.store[key]
if found && time.Now().After(entry.expiry) {
delete(c.store, key)
}
c.mu.Unlock()
}
return false
}
if value != nil {
// We need to copy the cached value to the provided pointer.
// This is a bit hacky, but it works for our use case.
if err := json.Unmarshal(entry.value, value); err != nil {
return false
}
}
return true
}

func (c *inMemoryCache) Set(key string, value any, ttl time.Duration) error {
data, err := json.Marshal(value)
if err != nil {
return err
}
c.mu.Lock()
c.store[key] = cacheEntry{
value: data,
expiry: time.Now().Add(ttl),
}
c.mu.Unlock()
return nil
}

func (c *inMemoryCache) Clear() error {
c.mu.Lock()
c.store = make(map[string]cacheEntry)
c.mu.Unlock()
return nil
}
Loading
Loading