From d6c5dc606b4f0931f6ae88c13a34e6eeab7f168c Mon Sep 17 00:00:00 2001 From: Steven Kreitzer Date: Fri, 24 May 2024 16:24:53 -0500 Subject: [PATCH] fix: delete and recreate records --- internal/unifi/client.go | 68 ++++---------------------------------- internal/unifi/provider.go | 26 ++++++--------- 2 files changed, 16 insertions(+), 78 deletions(-) diff --git a/internal/unifi/client.go b/internal/unifi/client.go index a62f073..d7e0584 100644 --- a/internal/unifi/client.go +++ b/internal/unifi/client.go @@ -144,29 +144,6 @@ func (c *Client) ShipData(url string, body []byte) ([]byte, error) { return byteArray, nil } -func (c *Client) PutData(url string, body []byte) ([]byte, error) { - req, _ := http.NewRequest(http.MethodPut, url, bytes.NewBuffer(body)) - c.setHeaders(req) - - resp, err := c.HTTPClient.Do(req) - if err != nil { - return nil, err - } - - if csrf := resp.Header.Get("x-csrf-token"); csrf != "" { - c.csrf = resp.Header.Get("x-csrf-token") - } - - defer resp.Body.Close() - - byteArray, err := io.ReadAll(resp.Body) - if err != nil { - return nil, err - } - - return byteArray, nil -} - func (c *Client) DeleteData(url string) ([]byte, error) { req, _ := http.NewRequest(http.MethodDelete, url, nil) @@ -236,47 +213,14 @@ func (c *Client) CreateEndpoint(endpoint *endpoint.Endpoint) (*DNSRecord, error) return &newRecord, nil } -// UpdateEndpoint updates an existing DNS record. -func (c *Client) UpdateEndpoint(endpoint *endpoint.Endpoint) (*DNSRecord, error) { - id, err := c.LookupIdentifier(endpoint.DNSName, endpoint.RecordType) - if err != nil { - return nil, err - } - - record := DNSRecord{ - Key: endpoint.DNSName, - RecordType: endpoint.RecordType, - TTL: endpoint.RecordTTL, - Value: endpoint.Targets[0], - } - - body, err := json.Marshal(record) - if err != nil { - return nil, err - } - - resp, err := c.PutData(fmt.Sprintf(UnifiDNSSelectRecord, c.BaseURL, id), body) - if err != nil { - return nil, err - } - - var updatedRecord DNSRecord - err = json.Unmarshal(resp, &updatedRecord) - if err != nil { - return nil, err - } - - return &updatedRecord, nil -} - // DeleteEndpoint deletes a DNS record. func (c *Client) DeleteEndpoint(endpoint *endpoint.Endpoint) error { - id, err := c.LookupIdentifier(endpoint.DNSName, endpoint.RecordType) + lookup, err := c.LookupIdentifier(endpoint.DNSName, endpoint.RecordType) if err != nil { return err } - _, err = c.DeleteData(fmt.Sprintf(UnifiDNSSelectRecord, c.BaseURL, id)) + _, err = c.DeleteData(fmt.Sprintf(UnifiDNSSelectRecord, c.BaseURL, lookup.ID)) if err != nil { return err } @@ -285,17 +229,17 @@ func (c *Client) DeleteEndpoint(endpoint *endpoint.Endpoint) error { } // LookupIdentifier finds the ID of a DNS record. -func (c *Client) LookupIdentifier(Key string, RecordType string) (string, error) { +func (c *Client) LookupIdentifier(Key string, RecordType string) (*DNSRecord, error) { records, err := c.ListRecords() if err != nil { - return "", err + return nil, err } for _, r := range records { if r.Key == Key && r.RecordType == RecordType { - return r.ID, nil + return &r, nil } } - return "", fmt.Errorf("record not found") + return nil, fmt.Errorf("record not found") } diff --git a/internal/unifi/provider.go b/internal/unifi/provider.go index 0c24325..228bb2e 100644 --- a/internal/unifi/provider.go +++ b/internal/unifi/provider.go @@ -42,12 +42,12 @@ func (p *Provider) Records(ctx context.Context) ([]*endpoint.Endpoint, error) { var endpoints []*endpoint.Endpoint for _, record := range records { - ep := endpoint.NewEndpointWithTTL( - record.Key, - record.RecordType, - record.TTL, - record.Value, - ) + ep := &endpoint.Endpoint{ + DNSName: record.Key, + RecordType: record.RecordType, + RecordTTL: record.TTL, + Targets: endpoint.NewTargets(record.Value), + } if !p.domainFilter.Match(ep.DNSName) { continue @@ -61,20 +61,14 @@ func (p *Provider) Records(ctx context.Context) ([]*endpoint.Endpoint, error) { // ApplyChanges applies a given set of changes in the DNS provider. func (p *Provider) ApplyChanges(ctx context.Context, changes *plan.Changes) error { - for _, ep := range changes.Delete { - if err := p.client.DeleteEndpoint(ep); err != nil { - return err - } - } - - for _, ep := range changes.UpdateNew { - if _, err := p.client.UpdateEndpoint(ep); err != nil { + for _, endpoint := range append(changes.UpdateOld, changes.Delete...) { + if err := p.client.DeleteEndpoint(endpoint); err != nil { return err } } - for _, ep := range changes.Create { - if _, err := p.client.CreateEndpoint(ep); err != nil { + for _, endpoint := range append(changes.Create, changes.UpdateNew...) { + if _, err := p.client.CreateEndpoint(endpoint); err != nil { return err } }