diff --git a/LICENSE b/LICENSE index bc1194c..e69acd3 100644 --- a/LICENSE +++ b/LICENSE @@ -1,6 +1,6 @@ MIT License -Copyright (c) +Copyright (c) 2025 philip bergman Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal diff --git a/README.md b/README.md index 871645f..d6f7cd2 100644 --- a/README.md +++ b/README.md @@ -1,29 +1,114 @@ -DEVELOPER INSTRUCTIONS: -======================= +## Abstract Provider for `libdns` -This repo is a template for developers to use when creating new [libdns](https://github.com/libdns/libdns) provider implementations. +This package helps reduce duplicated and fragmented code across different DNS providers by implementing the core logic for the main `libdns` interfaces: -Be sure to update: +* `RecordGetter` +* `RecordAppender` +* `RecordSetter` +* `RecordDeleter` +* `ZoneLister` -- The package name -- The Go module name in go.mod -- The latest `libdns/libdns` version in go.mod -- All comments and documentation, including README below and godocs -- License (must be compatible with Apache/MIT) -- All "TODO:"s is in the code -- All methods that currently do nothing +As defined in the [libdns contracts](https://github.com/libdns/libdns/blob/master/libdns.go). -**Please be sure to conform to the semantics described at the [libdns godoc](https://github.com/libdns/libdns).** +It works on the principle that this *provider helper* fetches all records for a zone, generates a change list, and passes that list to the `client` to apply. -_Remove this section from the readme before publishing._ +By doing so, the only thing you need to implement is a [`client`](client.go). +This approach allows faster development of new providers and ensures more consistent behavior, since all contract logic is handled and maintained in one central place. + + +--- + +## Client + +A client implementation should follow this interface signature: + +```go +GetDNSList(ctx context.Context, domain string) ([]libdns.Record, error) +SetDNSList(ctx context.Context, domain string, change ChangeList) ([]libdns.Record, error) +``` + +A simple implementation could look like this: + +```go +func (c *client) create(ctx context.Context, domain string, record *libdns.RR) error { + // ... + return nil +} + +func (c *client) remove(ctx context.Context, domain string, record *libdns.RR) error { + // ... + return nil +} + +func (c *client) SetDNSList(ctx context.Context, domain string, change ChangeList) ([]libdns.Record, error) { + + for record := range change.Iterate(provider.Delete) { + if err := c.remove(ctx, domain, record); err != nil { + return nil, err + } + } + + for record := range change.Iterate(provider.Create) { + if err := c.create(ctx, domain, record); err != nil { + return nil, err + } + } + + return nil, nil +} +``` --- -\ for [`libdns`](https://github.com/libdns/libdns) -======================= +## Provider + +Because Go doesn’t support class-level abstraction, this package provides helper functions that your provider can call directly: + +```go +type Provider struct { + client Client + mutex sync.RWMutex +} + +func (p *Provider) getClient() Client { + // initialize client... + return p.client +} + +func (p *Provider) GetRecords(ctx context.Context, zone string) ([]libdns.Record, error) { + return GetRecords(ctx, &p.mutex, p.getClient(), zone) +} + +func (p *Provider) AppendRecords(ctx context.Context, zone string, recs []libdns.Record) ([]libdns.Record, error) { + return AppendRecords(ctx, &p.mutex, p.getClient(), zone, recs) +} -[![Go Reference](https://pkg.go.dev/badge/test.svg)](https://pkg.go.dev/github.com/libdns/TODO:PROVIDER_NAME) +func (p *Provider) SetRecords(ctx context.Context, zone string, recs []libdns.Record) ([]libdns.Record, error) { + return SetRecords(ctx, &p.mutex, p.getClient(), zone, recs) +} -This package implements the [libdns interfaces](https://github.com/libdns/libdns) for \, allowing you to manage DNS records. +func (p *Provider) DeleteRecords(ctx context.Context, zone string, recs []libdns.Record) ([]libdns.Record, error) { + return DeleteRecords(ctx, &p.mutex, p.getClient(), zone, recs) +} -TODO: Show how to configure and use. Explain any caveats. +var ( + _ libdns.RecordGetter = (*Provider)(nil) + _ libdns.RecordAppender = (*Provider)(nil) + _ libdns.RecordSetter = (*Provider)(nil) + _ libdns.RecordDeleter = (*Provider)(nil) +) +``` + +--- + +### Implemented Interfaces + +| Interface | Implementation Function | +| ----------------------- | --------------------------------- | +| `libdns.RecordGetter` | [GetRecords](record_get.go) | +| `libdns.RecordAppender` | [AppendRecords](record_get.go) | +| `libdns.RecordSetter` | [SetRecords](record_set.go) | +| `libdns.RecordDeleter` | [DeleteRecords](record_delete.go) | +| `libdns.ZoneLister` | [ListZones](zone_list.go) | + +--- diff --git a/change_list.go b/change_list.go new file mode 100644 index 0000000..73e2f86 --- /dev/null +++ b/change_list.go @@ -0,0 +1,113 @@ +package provider + +import ( + "iter" + + "github.com/libdns/libdns" +) + +type ChangeState uint8 + +const ( + NoChange ChangeState = 1 << iota + Delete + Create +) + +type ChangeRecord struct { + record *libdns.RR + state ChangeState +} + +type ChangeList interface { + // Iterate wil return an iterator that returns records that + // match the given state. For example, when called like + // `Iterate(Delete)` will only return records marked for + // removal. The ChangeState can be combined to iterate + // multiple states like `Iterate(Delete|Create)` which + // will return all records that are marked delete or + // as created. + Iterate(state ChangeState) iter.Seq[*libdns.RR] + // Creates will return a slice of records that are + // marked for creating + Creates() []*libdns.RR + // Deletes will return a slice of records that are + // marked for deleting + Deletes() []*libdns.RR + // GetList will return a slice of records that + // represents the new dns list which can be used + // to update the whole set for a zone + GetList() []*libdns.RR + // Has wil check if this list has records for + // given state + Has(state ChangeState) bool + // addRecord is not exported because the record + // list is immutable + addRecord(record *libdns.RR, state ChangeState) +} + +type changes struct { + records []*ChangeRecord + state ChangeState +} + +func NewChangeList(size ...int) ChangeList { + + var records []*ChangeRecord + + switch len(size) { + case 1: + records = make([]*ChangeRecord, size[0]) + case 2: + records = make([]*ChangeRecord, size[0], size[1]) + default: + records = make([]*ChangeRecord, 0) + } + + return &changes{ + records: records, + } +} + +func (c *changes) addRecord(record *libdns.RR, state ChangeState) { + c.records = append(c.records, &ChangeRecord{record: record, state: state}) + c.state |= state +} + +func (c *changes) Has(state ChangeState) bool { + return 0 != (c.state & state) +} + +func (c *changes) Iterate(state ChangeState) iter.Seq[*libdns.RR] { + return func(yield func(*libdns.RR) bool) { + for i, x := 0, len(c.records); i < x; i++ { + if c.records[i].state == (c.records[i].state & state) { + if false == yield(c.records[i].record) { + return + } + } + } + } +} + +func (c *changes) Creates() []*libdns.RR { + return c.list(Create) +} + +func (c *changes) Deletes() []*libdns.RR { + return c.list(Delete) +} + +func (c *changes) GetList() []*libdns.RR { + return c.list(Create | NoChange) +} + +func (c *changes) list(state ChangeState) []*libdns.RR { + var items = make([]*libdns.RR, 0) + + for record := range c.Iterate(state) { + items = append(items, record) + } + + return items +} diff --git a/client.go b/client.go new file mode 100644 index 0000000..14e9953 --- /dev/null +++ b/client.go @@ -0,0 +1,57 @@ +package provider + +import ( + "context" + + "github.com/libdns/libdns" +) + +type Domain interface { + Name() string +} + +type Client interface { + // GetDNSList returns all DNS records available for the given zone. + // + // The returned records can be of the opaque RR type. If the provider supports + // parsing, the records will be automatically parsed before being returned. + GetDNSList(ctx context.Context, domain string) ([]libdns.Record, error) + + // SetDNSList processes a ChangeList and updates DNS records based on their state. + // + // This allows the client to focus only on handling the changes, while the provider + // logic for appending, setting, and deleting records is centralized. + // + // Example: iterating through individual changes + // + // // Remove records marked for deletion + // for remove := range change.Iterate(Delete) { + // // remove record + // } + // + // // Create records marked for creation + // for create := range change.Iterate(Create) { + // // create record + // } + // + // Example: updating the whole zone at once + // + // // Generate a filtered list of all changes + // updatedRecords := change.GetList() + // + // // Use this list to update the entire zone file in a single call + // client.UpdateZone(ctx, domain, updatedRecords) + // + // Notes: + // - If the client API supports full-zone updates and returns the new record set, + // this can be returned. The provider uses this to validate records and skip + // extra API calls. + // - For clients that do not support full-zone updates or handle records individually, + // returning nil is fine. + SetDNSList(ctx context.Context, domain string, change ChangeList) ([]libdns.Record, error) +} + +type ZoneAwareClient interface { + Client + Domains(ctx context.Context) ([]Domain, error) +} diff --git a/debug.go b/debug.go new file mode 100644 index 0000000..293e86f --- /dev/null +++ b/debug.go @@ -0,0 +1,150 @@ +package provider + +import ( + "bufio" + "bytes" + "fmt" + "io" + "net/http" + "net/http/httputil" + "os" + "time" +) + +type OutputLevel uint8 + +// DebugTransport is an HTTP transport wrapper that logs outgoing requests +// and incoming responses for debugging purposes. +// +// It implements the http.RoundTripper interface and can be used to wrap +// an existing transport (such as http.DefaultTransport) to add debug output. +// +// Example: +// +// client := &http.Client{ +// Transport: &DebugTransport{ +// RoundTripper: http.DefaultTransport, +// config: ... +// }, +// } +type DebugTransport struct { + http.RoundTripper + Config DebugConfig +} + +type DebugConfig interface { + DebugOutputLevel() OutputLevel + DebugOutput() io.Writer +} + +// DebugAware is an interface implemented by types that support +// configurable debug logging of client communication. +// +// Implementations typically allow controlling the debug output level +// and destination writer used for HTTP or API requests. +// +// Example: +// +// type Provider struct { +// DebugLevel OutputLevel +// DebugOutput io.Writer +// client *http.Client +// } +// +// func (p *Provider) DebugOutputLevel() OutputLevel { +// return p.DebugLevel +// } +// +// func (p *Provider) DebugOutput() io.Writer { +// return p.DebugOutput +// } +// +// func (p *Provider) SetDebug(level OutputLevel, writer io.Writer) { +// p.DebugLevel = level +// p.DebugOutput = writer +// } +// +// func (p *Provider) getClient() *http.Client { +// if p.client == nil { +// p.client = &http.Client{ +// Transport: &DebugTransport{ +// RoundTripper: http.DefaultTransport, +// config: p, +// }, +// } +// } +// return p.client +// } +type DebugAware interface { + SetDebug(level OutputLevel, writer io.Writer) +} + +const ( + OutputNone OutputLevel = 0x00 + OutputVerbose = 0x01 + OutputVeryVerbose = 0x02 + OutputDebug = 0x03 +) + +func (t *DebugTransport) RoundTrip(req *http.Request) (*http.Response, error) { + var now time.Time + var out io.Writer + + if t.Config.DebugOutputLevel() >= OutputVerbose { + out = t.Config.DebugOutput() + + if nil == out { + out = os.Stdout + } + } + if t.Config.DebugOutputLevel() == OutputVerbose { + now = time.Now() + } + + if nil != out && t.Config.DebugOutputLevel() >= OutputVeryVerbose { + dumpWire(req, httputil.DumpRequest, "c", out, t.Config.DebugOutputLevel() == OutputDebug) + } + + response, err := t.RoundTripper.RoundTrip(req) + + if out != nil && nil != response { + + if t.Config.DebugOutputLevel() >= OutputVeryVerbose { + dumpWire(response, httputil.DumpResponse, "s", out, t.Config.DebugOutputLevel() == OutputDebug) + } else { + dumpLine(response, out, now) + } + } + + return response, err +} + +func dumpLine(response *http.Response, write io.Writer, start time.Time) { + var req = response.Request + var uri string + + if uri = req.RequestURI; "" == uri { + uri = req.URL.RequestURI() + } + + _, _ = fmt.Fprintf( + write, + "[%d] %s \"%s HTTP/%d.%d\" %d (%s)\r\n", + start.UnixMilli(), + req.Method, + uri, + req.ProtoMajor, + req.ProtoMinor, + response.StatusCode, + time.Now().Sub(start).Round(time.Millisecond), + ) +} + +func dumpWire[T *http.Request | *http.Response](x T, d func(T, bool) ([]byte, error), p string, o io.Writer, z bool) { + if out, err := d(x, z); err == nil { + scanner := bufio.NewScanner(bytes.NewReader(out)) + for scanner.Scan() { + _, _ = fmt.Fprintf(o, "[%s] %s\n", p, scanner.Text()) + } + } +} diff --git a/go.mod b/go.mod index fb0abbb..d415cd1 100644 --- a/go.mod +++ b/go.mod @@ -1,5 +1,5 @@ -module github.com/libdns/template +module github.com/pbergman/provider -go 1.18 +go 1.24.4 -require github.com/libdns/libdns v1.0.0 +require github.com/libdns/libdns v1.1.1 diff --git a/go.sum b/go.sum index feb101c..573f219 100644 --- a/go.sum +++ b/go.sum @@ -1,2 +1,2 @@ -github.com/libdns/libdns v1.0.0 h1:IvYaz07JNz6jUQ4h/fv2R4sVnRnm77J/aOuC9B+TQTA= -github.com/libdns/libdns v1.0.0/go.mod h1:4Bj9+5CQiNMVGf87wjX4CY3HQJypUHRuLvlsfsZqLWQ= +github.com/libdns/libdns v1.1.1 h1:wPrHrXILoSHKWJKGd0EiAVmiJbFShguILTg9leS/P/U= +github.com/libdns/libdns v1.1.1/go.mod h1:4Bj9+5CQiNMVGf87wjX4CY3HQJypUHRuLvlsfsZqLWQ= diff --git a/iterators.go b/iterators.go new file mode 100644 index 0000000..d19e0fb --- /dev/null +++ b/iterators.go @@ -0,0 +1,54 @@ +package provider + +import ( + "iter" + "strings" + + "github.com/libdns/libdns" +) + +func RecordIterator(records *[]libdns.Record) iter.Seq2[*libdns.Record, libdns.RR] { + return func(yield func(*libdns.Record, libdns.RR) bool) { + for _, record := range *records { + if false == yield(&record, record.RR()) { + return + } + } + } +} + +func lookup(item *libdns.RR, records *[]libdns.Record, lookup func(a, b *libdns.RR) bool) *libdns.Record { + next, stop := iter.Pull2(RecordIterator(records)) + + defer stop() + + for { + origin, check, ok := next() + + if !ok { + return nil + } + + if lookup(item, &check) { + return origin + } + } +} + +func lookupByNameAndType(item *libdns.RR, records *[]libdns.Record) *libdns.Record { + return lookup(item, records, func(a, b *libdns.RR) bool { + return strings.EqualFold(a.Name, b.Name) && a.Type == b.Type + }) +} + +func IsInList(item *libdns.RR, records *[]libdns.Record, ttl bool) bool { + return nil != lookup(item, records, func(a, b *libdns.RR) bool { + return strings.EqualFold(a.Name, b.Name) && a.Type == b.Type && a.Data == b.Data && (false == ttl || a.TTL == b.TTL) + }) +} + +func isEligibleForRemoval(item *libdns.RR, records *[]libdns.Record) bool { + return nil != lookup(item, records, func(a, b *libdns.RR) bool { + return strings.EqualFold(a.Name, b.Name) && (b.Type == "" || a.Type == b.Type) && (b.Data == "" || a.Data == b.Data) && (b.TTL == 0 || a.TTL == b.TTL) + }) +} diff --git a/locker.go b/locker.go new file mode 100644 index 0000000..37f4905 --- /dev/null +++ b/locker.go @@ -0,0 +1,39 @@ +package provider + +import ( + "sync" +) + +func rlock(mutex sync.Locker) func() { + + if nil == mutex { + return nil + } + + type rlock interface { + RUnlock() + RLock() + } + + // fallback to normal mutex + var lock, unlock = mutex.Lock, mutex.Unlock + + if v, o := mutex.(rlock); o { + lock, unlock = v.RLock, v.RUnlock + } + + lock() + + return sync.OnceFunc(unlock) +} + +func lock(mutex sync.Locker) func() { + + if nil == mutex { + return nil + } + + mutex.Lock() + + return sync.OnceFunc(mutex.Unlock) +} diff --git a/provider.go b/provider.go deleted file mode 100644 index 17a4e89..0000000 --- a/provider.go +++ /dev/null @@ -1,58 +0,0 @@ -// Package libdnstemplate implements a DNS record management client compatible -// with the libdns interfaces for . TODO: This package is a -// template only. Customize all godocs for actual implementation. -package libdnstemplate - -import ( - "context" - "fmt" - - "github.com/libdns/libdns" -) - -// TODO: Providers must not require additional provisioning steps by the callers; it -// should work simply by populating a struct and calling methods on it. If your DNS -// service requires long-lived state or some extra provisioning step, do it implicitly -// when methods are called; sync.Once can help with this, and/or you can use a -// sync.(RW)Mutex in your Provider struct to synchronize implicit provisioning. - -// Provider facilitates DNS record manipulation with . -type Provider struct { - // TODO: Put config fields here (with snake_case json struct tags on exported fields), for example: - APIToken string `json:"api_token,omitempty"` - - // Exported config fields should be JSON-serializable or omitted (`json:"-"`) -} - -// GetRecords lists all the records in the zone. -func (p *Provider) GetRecords(ctx context.Context, zone string) ([]libdns.Record, error) { - // Make sure to return RR-type-specific structs, not libdns.RR structs. - return nil, fmt.Errorf("TODO: not implemented") -} - -// AppendRecords adds records to the zone. It returns the records that were added. -func (p *Provider) AppendRecords(ctx context.Context, zone string, records []libdns.Record) ([]libdns.Record, error) { - // Make sure to return RR-type-specific structs, not libdns.RR structs. - return nil, fmt.Errorf("TODO: not implemented") -} - -// SetRecords sets the records in the zone, either by updating existing records or creating new ones. -// It returns the updated records. -func (p *Provider) SetRecords(ctx context.Context, zone string, records []libdns.Record) ([]libdns.Record, error) { - // Make sure to return RR-type-specific structs, not libdns.RR structs. - return nil, fmt.Errorf("TODO: not implemented") -} - -// DeleteRecords deletes the specified records from the zone. It returns the records that were deleted. -func (p *Provider) DeleteRecords(ctx context.Context, zone string, records []libdns.Record) ([]libdns.Record, error) { - // Make sure to return RR-type-specific structs, not libdns.RR structs. - return nil, fmt.Errorf("TODO: not implemented") -} - -// Interface guards -var ( - _ libdns.RecordGetter = (*Provider)(nil) - _ libdns.RecordAppender = (*Provider)(nil) - _ libdns.RecordSetter = (*Provider)(nil) - _ libdns.RecordDeleter = (*Provider)(nil) -) diff --git a/record_append.go b/record_append.go new file mode 100644 index 0000000..b3ae89b --- /dev/null +++ b/record_append.go @@ -0,0 +1,60 @@ +package provider + +import ( + "context" + "sync" + + "github.com/libdns/libdns" +) + +// AppendRecords appends new records to the change list performing validation. +// +// The assumption is that when the full list is returned to the provider, the +// provider will handle any necessary validation and return an error if any +// issues are found. +func AppendRecords(ctx context.Context, mutex sync.Locker, client Client, zone string, records []libdns.Record) ([]libdns.Record, error) { + + if unlock := lock(mutex); unlock != nil { + defer unlock() + } + + existing, err := GetRecords(ctx, nil, client, zone) + + if err != nil { + return nil, err + } + + var change = NewChangeList(0, len(existing)+len(records)) + + for _, record := range RecordIterator(&existing) { + change.addRecord(&record, NoChange) + } + + for _, record := range RecordIterator(&records) { + change.addRecord(&record, Create) + } + + items, err := client.SetDNSList(ctx, zone, change) + + if err != nil { + return nil, err + } + + if nil == items { + items, err = GetRecords(ctx, nil, client, zone) + + if err != nil { + return nil, err + } + } + + var ret = make([]libdns.Record, 0) + + for origin, record := range RecordIterator(&items) { + if false == IsInList(&record, &existing, false) { + ret = append(ret, *origin) + } + } + + return ret, nil +} diff --git a/record_delete.go b/record_delete.go new file mode 100644 index 0000000..276dbbb --- /dev/null +++ b/record_delete.go @@ -0,0 +1,77 @@ +package provider + +import ( + "context" + "sync" + + "github.com/libdns/libdns" +) + +// DeleteRecords marks the input records for deletion when they exactly match +// or partially match existing records, following the rules defined in the +// libdns contract. +// +// For more details, see: +// https://github.com/libdns/libdns/blob/master/libdns.go#L228C1-L237C43 +func DeleteRecords(ctx context.Context, mutex sync.Locker, client Client, zone string, deletes []libdns.Record) ([]libdns.Record, error) { + + var unlock = lock(mutex) + + if nil != unlock { + defer unlock() + } + + records, err := GetRecords(ctx, nil, client, zone) + + if err != nil { + return nil, err + } + + var change = NewChangeList(len(records)) + + for _, record := range RecordIterator(&records) { + var state = NoChange + + if isEligibleForRemoval(&record, &deletes) { + state = Delete + } + + change.addRecord(&record, state) + } + + if false == change.Has(Delete) { + return []libdns.Record{}, nil + } + + curr, err := client.SetDNSList(ctx, zone, change) + + if err != nil { + return nil, err + } + + if nil != unlock { + unlock() + } + + if unlock := rlock(mutex); nil != unlock { + defer unlock() + } + + if nil == curr { + curr, err = GetRecords(ctx, nil, client, zone) + + if err != nil { + return nil, err + } + } + + var removed = make([]libdns.Record, 0) + + for origin, record := range RecordIterator(&records) { + if false == IsInList(&record, &curr, false) && isEligibleForRemoval(&record, &deletes) { + removed = append(removed, *origin) + } + } + + return removed, nil +} diff --git a/record_get.go b/record_get.go new file mode 100644 index 0000000..94de8fb --- /dev/null +++ b/record_get.go @@ -0,0 +1,41 @@ +package provider + +import ( + "context" + "sync" + + "github.com/libdns/libdns" +) + +// GetRecords retrieves all records for the given zone from the client and ensures +// that the returned records are properly typed according to their specific RR type. +func GetRecords(ctx context.Context, mutex sync.Locker, client Client, zone string) ([]libdns.Record, error) { + + if unlock := rlock(mutex); nil != unlock { + defer unlock() + } + + list, err := client.GetDNSList(ctx, zone) + + if err != nil { + return nil, err + } + + type recordParser interface { + Parse() (libdns.Record, error) + } + + for i, c := 0, len(list); i < c; i++ { + if v, ok := list[i].(recordParser); ok { + x, err := v.Parse() + + if err != nil { + return nil, err + } + + list[i] = x + } + } + + return list, nil +} diff --git a/record_set.go b/record_set.go new file mode 100644 index 0000000..db5f924 --- /dev/null +++ b/record_set.go @@ -0,0 +1,92 @@ +package provider + +import ( + "context" + "sync" + + "github.com/libdns/libdns" +) + +// SetRecords updates existing records by marking them as either NoChange or Delete +// based on the given input, and appends the input records with state Create. +// This ensures compliance with the libdns contract and produces the expected results. +// +// Example provided by the contract can be found here: +// https://github.com/libdns/libdns/blob/master/libdns.go#L182-L216 +func SetRecords(ctx context.Context, mutex sync.Locker, client Client, zone string, records []libdns.Record) ([]libdns.Record, error) { + + var unlock = lock(mutex) + var ret = make([]libdns.Record, 0) + + if nil != unlock { + defer unlock() + } + + existing, err := GetRecords(ctx, nil, client, zone) + + if err != nil { + return nil, err + } + + var change = NewChangeList(0, len(existing)+len(records)) + + for _, record := range RecordIterator(&existing) { + + var state = NoChange + + if found := lookupByNameAndType(&record, &records); found != nil { + + var rr = (*found).RR() + + // only mark as delete when differs + if rr.Data != record.Data || rr.TTL != record.TTL { + state = Delete + } + } + + change.addRecord(&record, state) + } + + for _, item := range RecordIterator(&records) { + + if IsInList(&item, &existing, true) { + continue + } + + change.addRecord(&item, Create) + } + + if false == change.Has(Delete|Create) { + return ret, nil + } + + curr, err := client.SetDNSList(ctx, zone, change) + + if err != nil { + return nil, err + } + + if nil != unlock { + unlock() + } + + if unlock := rlock(mutex); nil != unlock { + defer unlock() + } + + if nil == curr { + curr, err = GetRecords(ctx, nil, client, zone) + + if err != nil { + return nil, err + } + } + + for x, record := range RecordIterator(&curr) { + if false == IsInList(&record, &existing, true) && nil != lookupByNameAndType(&record, &records) { + ret = append(ret, *x) + } + } + + return ret, nil +} diff --git a/test/README.md b/test/README.md new file mode 100644 index 0000000..e5ac119 --- /dev/null +++ b/test/README.md @@ -0,0 +1,45 @@ +## Testing + +This test helper is included to help you verify that your provider correctly implements the [libdns contract](https://github.com/libdns/libdns/blob/master/libdns.go). It can partially test your implementation or do a full test depending on `TestMode` parameter. + +To use this, create as `provider_test.go` in the same directory as you provider, initialize your provider and run the test. + +## Example + +```go + +import ( + ... + "github.com/pbergman/provider/test" +) + +func TestProvider(t *testing.T) { + + var provider = &Provider{ + ApiKey: os.Getenv("API_KEY"), + ... + } + + test.RunProviderTests(t, provider, test.TestAll) +} + +``` + +After that, you should be able to run your tests like this: + +```bash + +API_KEY=.... go test -v +``` + +## TestMode + +to partially test parts of the interface, you could do something like + +```go + +test.RunProviderTests(t, provider, test.TestAll^(test.TestDeleter|test.TestAppender)) + +``` + +to skip the delete and append test and can be useful when implementing a new provider or debugging. \ No newline at end of file diff --git a/test/provider.go b/test/provider.go new file mode 100644 index 0000000..305f640 --- /dev/null +++ b/test/provider.go @@ -0,0 +1,539 @@ +package test + +import ( + "bufio" + "bytes" + "context" + "fmt" + "net/netip" + "os" + "strings" + "sync" + "testing" + "text/tabwriter" + "time" + + "github.com/libdns/libdns" + helper "github.com/pbergman/provider" +) + +type TestMode uint64 + +const ( + TestAppender TestMode = 1 << iota + TestDeleter + TestGetter + TestSetter + TestZones + TestAll = TestAppender | TestDeleter | TestGetter | TestSetter | TestZones +) + +type Provider interface { + libdns.RecordAppender + libdns.RecordDeleter + libdns.RecordGetter + libdns.RecordSetter +} + +func RunProviderTests(t *testing.T, provider Provider, mode TestMode) { + + var wg sync.WaitGroup + + if zoneListener, ok := provider.(libdns.ZoneLister); ok { + if TestZones == (TestZones & mode) { + wg.Add(1) + t.Run("ListZones", func(t *testing.T) { + defer wg.Done() + testListZones(t, zoneListener) + }) + } + } else { + t.Skipf("ListZones not implemented.") + } + + var zones = getZonesForTesting(t, provider) + + if TestGetter == (TestGetter & mode) { + t.Run("RecordGetter", func(t *testing.T) { + wg.Add(1) + defer wg.Done() + testRecordGetter(t, provider, zones) + }) + } + + wg.Wait() + + if TestAppender == (TestAppender & mode) { + t.Run("RecordAppender", func(t *testing.T) { + testRecordAppender(t, provider, zones) + }) + } + + if TestSetter == (TestSetter & mode) { + t.Run("RecordSetter - Example 1", func(t *testing.T) { + testRecordsSetExample1(t, provider, zones) + }) + + t.Run("RecordSetter - Example 2", func(t *testing.T) { + testRecordsSetExample2(t, provider, zones) + }) + } + + if TestDeleter == (TestDeleter & mode) { + t.Run("RecordDeleter", func(t *testing.T) { + testDeleteRecords(t, provider, zones) + }) + } +} + +func printRecords(t *testing.T, records []libdns.Record, invalid libdns.Record, prefix string) { + + var buf = new(bytes.Buffer) + var writer = tabwriter.NewWriter(buf, 0, 4, 2, ' ', tabwriter.Debug) + var isWritten = false + var write = func(prefix string, record libdns.RR) { + _, _ = fmt.Fprintf(writer, "%s%s\t %s\t %s\t %s\n", prefix, record.Name, record.TTL, record.Type, record.Data) + } + + for _, record := range records { + var rr = record.RR() + + if invalid != nil { + prefix = "✓ " + + if record.RR().Type == invalid.RR().Type && record.RR().Data == invalid.RR().Data && strings.EqualFold(record.RR().Name, invalid.RR().Name) { + prefix = "× " + isWritten = true + } + } + + write(prefix, rr) + } + + if false == isWritten && nil != invalid { + write("× ", invalid.RR()) + } + + _ = writer.Flush() + + scanner := bufio.NewScanner(buf) + + for scanner.Scan() { + t.Log(scanner.Text()) + } +} + +func getZonesForTesting(t *testing.T, p Provider) []string { + + if v, ok := os.LookupEnv("ZONE"); ok { + return strings.Split(v, ",") + } + + if o, ok := p.(libdns.ZoneLister); ok { + zones, err := o.ListZones(context.Background()) + + if err != nil { + t.Fatalf("ListZones failed: %v", err) + } + + var ret = make([]string, len(zones)) + + for idx, zone := range zones { + ret[idx] = zone.Name + } + + return ret + } + + t.Fatal("No valid zones found, either implement libdns.ZoneLister or use ZONE environment variable") + + return nil +} + +func testListZones(t *testing.T, provider libdns.ZoneLister) { + + zones, err := provider.ListZones(context.Background()) + + if err != nil { + t.Fatalf("ListZones failed: %v", err) + } + + t.Log("checking if the zone includes trailing dot") + + for _, zone := range zones { + if strings.HasSuffix(zone.Name, ".") { + t.Logf("✓ %s", zone.Name) + } else { + t.Fatalf("missing trailing dot: %s", zone.Name) + } + } + +} + +func testReturnTypes(t *testing.T, records []libdns.Record) { + var buf = new(bytes.Buffer) + var writer = tabwriter.NewWriter(buf, 0, 4, 2, ' ', tabwriter.Debug) + + for _, record := range records { + switch record.(type) { + case *libdns.RR, libdns.RR: + t.Fatalf("expecting specific RR-type instead of the opaque RR struct (%#+v)", record) + default: + _, _ = fmt.Fprintf(writer, "✓ %s\t%s\t%T\n", record.RR().Name, record.RR().Type, record) + } + } + + _ = writer.Flush() + + scanner := bufio.NewScanner(buf) + + for scanner.Scan() { + t.Log(scanner.Text()) + } +} + +func testRecordGetter(t *testing.T, provider Provider, zones []string) { + + t.Log("not much specials to test except for errors from client and return types") + + for _, zone := range zones { + + records, err := provider.GetRecords(context.Background(), zone) + + if err != nil { + t.Fatalf("GetRecords failed: %v", err) + } + + t.Logf("records in zone: \"%s\"", zone) + printRecords(t, records, nil, " ") + + t.Logf("testing return record types are not of type libdns.RR") + testReturnTypes(t, records) + } + +} + +func testRecordAppender(t *testing.T, provider Provider, zones []string) { + + var records = []libdns.Record{ + libdns.TXT{ + Name: "LibDNS_test_append_records", + Text: "Proin nec metus in mauris malesuada aliquet", + TTL: 24 * time.Hour, + }, + libdns.TXT{ + Name: "LibDNS_test_append_records", + Text: "Lorem ipsum dolor sit amet, consectetur adipiscing elit.", + TTL: 24 * time.Hour, + }, + libdns.TXT{ + Name: "LibDNS_test_append_records", + Text: "Praesent molestie mi a lorem aliquam maximus.", + TTL: 24 * time.Hour, + }, + libdns.TXT{ + Name: "LibDNS_test_append_records", + Text: "Nulla ultricies eros quis velit tincidunt, in molestie lorem molestie.", + TTL: 24 * time.Hour, + }, + } + + t.Log("the contract states it should create records and never change existing records and") + t.Log("return the records that were created (specific RR-type that correspond to the type)") + + for _, zone := range zones { + + out, err := provider.AppendRecords(context.Background(), zone, records) + + defer provider.DeleteRecords(context.Background(), zone, records) + + if err != nil { + t.Fatalf("AppendRecords failed: %v", err) + } + + t.Logf("successfully added %d records to zone %s", len(records), zone) + t.Log("testing return records in record lists") + + for _, record := range helper.RecordIterator(&records) { + if false == helper.IsInList(&record, &out, false) { + printRecords(t, records, record, " ") + t.Fatal("returned unexpected records") + } + } + + printRecords(t, records, nil, "✓ ") + + t.Logf("testing return record types are not of type libdns.RR") + testReturnTypes(t, out) + + t.Logf("testing for error while updating records") + + if _, err := provider.AppendRecords(context.Background(), zone, records); err == nil { + t.Fatalf("expecting failed but didn't") + } + } +} + +func testRecordsSetExample1(t *testing.T, provider Provider, zones []string) { + var name = "LibDNS.test.set.records" + var original = []libdns.Record{ + libdns.Address{Name: name, IP: netip.MustParseAddr("192.0.2.1"), TTL: 1 * time.Hour}, + libdns.Address{Name: name, IP: netip.MustParseAddr("192.0.2.2"), TTL: 1 * time.Hour}, + libdns.TXT{Name: name, Text: "hello world", TTL: 1 * time.Hour}, + } + + var input = []libdns.Record{ + libdns.Address{Name: name, IP: netip.MustParseAddr("192.0.2.3"), TTL: 1 * time.Hour}, + } + + t.Log("will test the following example:") + t.Log("") + t.Log("// Example 1:") + t.Log("//") + t.Log("// ;; Original zone") + t.Log("// example.com. 3600 IN A 192.0.2.1") + t.Log("// example.com. 3600 IN A 192.0.2.2") + t.Log("// example.com. 3600 IN TXT \"hello world\"") + t.Log("//") + t.Log("// ;; Input") + t.Log("// example.com. 3600 IN A 192.0.2.3") + t.Log("//") + t.Log("// ;; Resultant zone") + t.Log("// example.com. 3600 IN A 192.0.2.3") + t.Log("// example.com. 3600 IN TXT \"hello world\"") + + for _, zone := range zones { + out, err := provider.AppendRecords(context.Background(), zone, original) + + if err != nil { + t.Fatalf("AppendRecords failed: %v", err) + } + + t.Logf("records appended to zone \"%s\":", zone) + printRecords(t, out, nil, "✓ ") + + defer provider.DeleteRecords(context.Background(), zone, []libdns.Record{ + libdns.Address{Name: name}, + libdns.TXT{Name: name}, + }) + + t.Logf("set record \"%s %s %s %s\":", input[0].RR().Name, input[0].RR().TTL, input[0].RR().Type, input[0].RR().Data) + + ret, err := provider.SetRecords(context.Background(), zone, input) + + if err != nil { + t.Fatalf("SetRecords failed: %v", err) + } + + if len(ret) != 1 { + t.Fatalf("should have returned 1 record got %d", len(ret)) + } + + t.Logf("testing return record types are not of type libdns.RR") + testReturnTypes(t, ret) + + curr, err := provider.GetRecords(context.Background(), zone) + + if err != nil { + t.Fatalf("GetRecords failed: %v", err) + } + + t.Log("current records in zone:") + printRecords(t, curr, nil, " ") + + var shouldNotExist = original[:2] + + t.Log("testing if following records are removed") + printRecords(t, shouldNotExist, nil, " ") + + for invalid, record := range helper.RecordIterator(&shouldNotExist) { + if helper.IsInList(&record, &curr, false) { + t.Log("") + printRecords(t, curr, *invalid, " ") + t.Fatal("invalid records returned") + } + } + + var shouldExist = append(original[2:], input[0]) + t.Log("testing if following records are present") + printRecords(t, shouldExist, nil, " ") + + for invalid, record := range helper.RecordIterator(&shouldExist) { + if false == helper.IsInList(&record, &curr, false) { + t.Log("") + printRecords(t, curr, *invalid, " ") + t.Fatal("invalid records returned") + } + } + } +} + +func testRecordsSetExample2(t *testing.T, provider Provider, zones []string) { + var name = "LibDNS.test.set.records" + + var original = []libdns.Record{ + libdns.Address{Name: "alpha." + name, IP: netip.MustParseAddr("2001:db8::1")}, + libdns.Address{Name: "alpha." + name, IP: netip.MustParseAddr("2001:db8::2")}, + libdns.Address{Name: "beta." + name, IP: netip.MustParseAddr("2001:db8::3")}, + libdns.Address{Name: "beta." + name, IP: netip.MustParseAddr("2001:db8::4")}, + } + + var input = []libdns.Record{ + libdns.Address{Name: "alpha." + name, IP: netip.MustParseAddr("2001:db8::1")}, + libdns.Address{Name: "alpha." + name, IP: netip.MustParseAddr("2001:db8::2")}, + libdns.Address{Name: "alpha." + name, IP: netip.MustParseAddr("2001:db8::5")}, + } + + t.Log("will test the following example:") + t.Log("") + t.Log("// ;; Original zone") + t.Log("// alpha.example.com. 3600 IN AAAA 2001:db8::1") + t.Log("// alpha.example.com. 3600 IN AAAA 2001:db8::2") + t.Log("// beta.example.com. 3600 IN AAAA 2001:db8::3") + t.Log("// beta.example.com. 3600 IN AAAA 2001:db8::4") + t.Log("//") + t.Log("// ;; Input") + t.Log("// alpha.example.com. 3600 IN AAAA 2001:db8::1") + t.Log("// alpha.example.com. 3600 IN AAAA 2001:db8::2") + t.Log("// alpha.example.com. 3600 IN AAAA 2001:db8::5") + t.Log("//") + t.Log("// ;; Resultant zone") + t.Log("// alpha.example.com. 3600 IN AAAA 2001:db8::1") + t.Log("// alpha.example.com. 3600 IN AAAA 2001:db8::2") + t.Log("// alpha.example.com. 3600 IN AAAA 2001:db8::5") + t.Log("// beta.example.com. 3600 IN AAAA 2001:db8::3") + t.Log("// beta.example.com. 3600 IN AAAA 2001:db8::4") + + for _, zone := range zones { + out, err := provider.AppendRecords(context.Background(), zone, original) + + if err != nil { + t.Fatalf("AppendRecords failed: %v", err) + } + + t.Logf("records appended to zone \"%s\":", zone) + printRecords(t, out, nil, "✓ ") + + // make sure we delete all records even on failure + defer provider.DeleteRecords(context.Background(), zone, []libdns.Record{ + libdns.RR{Name: "alpha." + name, Type: "AAAA"}, + libdns.RR{Name: "beta." + name, Type: "AAAA"}, + }) + + t.Logf("set record \"%s %s %s %s\":", input[0].RR().Name, input[0].RR().TTL, input[0].RR().Type, input[0].RR().Data) + ret, err := provider.SetRecords(context.Background(), zone, input) + + if err != nil { + t.Fatalf("SetRecords failed: %v", err) + } + + if len(ret) != 1 { + t.Fatalf("should have returned 1 record got %d", len(ret)) + } + + t.Logf("testing return record types are not of type libdns.RR") + testReturnTypes(t, ret) + + curr, err := provider.GetRecords(context.Background(), zone) + + if err != nil { + t.Fatalf("GetRecords failed: %v", err) + } + + t.Log("current records in zone:") + printRecords(t, curr, nil, " ") + + var shouldExist = append(original, input[2]) + + t.Log("testing if following records are present") + printRecords(t, shouldExist, nil, " ") + + for invalid, record := range helper.RecordIterator(&shouldExist) { + if false == helper.IsInList(&record, &curr, false) { + t.Log("") + printRecords(t, curr, *invalid, " ") + t.Fatal("AppendRecords returned unexpected records") + } + } + } +} + +func testDeleteRecords(t *testing.T, provider Provider, zones []string) { + + var name = "LibDNS.test.rm.records" + + for _, zone := range zones { + + var records = make([]libdns.Record, 0) + + for i := 1; i <= 10; i++ { + records = append(records, + libdns.Address{Name: name, IP: netip.MustParseAddr(fmt.Sprintf("2001:db8::%d", i))}, + libdns.Address{Name: name, IP: netip.MustParseAddr(fmt.Sprintf("127.0.0.%d", i))}, + ) + } + + out, err := provider.SetRecords(context.Background(), zone, records) + + if err != nil { + t.Fatalf("SetRecords failed: %v", err) + } + + t.Logf("set test records for zone \"%s\":", zone) + printRecords(t, out, nil, "✓ ") + + var toRemove = records[:5] + + removed, err := provider.DeleteRecords(context.Background(), zone, toRemove) + + if err != nil { + t.Fatalf("DeleteRecords failed: %v", err) + } + + t.Logf("testing return record types are not of type libdns.RR") + testReturnTypes(t, removed) + + for _, x := range helper.RecordIterator(&removed) { + if false == helper.IsInList(&x, &toRemove, false) { + t.Log("") + printRecords(t, toRemove, x, " ") + t.Fatal("returned unexpected records") + } + } + + t.Log("deleted records:") + printRecords(t, removed, nil, "✓ ") + + t.Log("checking removed records against records in zone") + curr, err := provider.GetRecords(context.Background(), zone) + + if err != nil { + t.Fatalf("GetRecords failed: %v", err) + } + + for _, x := range helper.RecordIterator(&toRemove) { + if helper.IsInList(&x, &curr, false) { + t.Log("") + printRecords(t, curr, x, " ") + t.Fatal("returned unexpected records") + } + } + + t.Log("try to delete records based name only") + + removed, err = provider.DeleteRecords(context.Background(), zone, []libdns.Record{ + libdns.RR{Name: name}, + }) + + if err != nil { + t.Fatalf("DeleteRecords failed: %v", err) + } + + t.Log("deleted records:") + printRecords(t, removed, nil, "✓ ") + + if len(removed) != 15 { + t.Fatalf("returned invalid count of records: expecting 15 got %d", len(removed)) + } + } +} diff --git a/zone_list.go b/zone_list.go new file mode 100644 index 0000000..6e66221 --- /dev/null +++ b/zone_list.go @@ -0,0 +1,43 @@ +package provider + +import ( + "context" + "sync" + + "github.com/libdns/libdns" +) + +// ListZones returns all available zones. Most APIs support listing of all managed +// domains, which can be used as zones. +// +// This function ensures that the returned domain names include a trailing dot +// to indicate the root zone. +func ListZones(ctx context.Context, mutex sync.Locker, client ZoneAwareClient) ([]libdns.Zone, error) { + + if unlock := lock(mutex); nil != unlock { + defer unlock() + } + + domains, err := client.Domains(ctx) + + if err != nil { + return nil, err + } + + var zones = make([]libdns.Zone, len(domains)) + + for i, c := 0, len(domains); i < c; i++ { + + var name = domains[i].Name() + + if name[len(name)-1] != '.' { + name += "." + } + + zones[i] = libdns.Zone{ + Name: name, + } + } + + return zones, nil +}