Skip to content

Commit

Permalink
Merge pull request #16 from buroa/main
Browse files Browse the repository at this point in the history
fix: delete and recreate records
  • Loading branch information
kashalls authored May 24, 2024
2 parents 2d35841 + d6c5dc6 commit a25aa51
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 78 deletions.
68 changes: 6 additions & 62 deletions internal/unifi/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -144,29 +144,6 @@ func (c *Client) ShipData(url string, body []byte) ([]byte, error) {
return byteArray, nil
}

func (c *Client) PutData(url string, body []byte) ([]byte, error) {
req, _ := http.NewRequest(http.MethodPut, url, bytes.NewBuffer(body))
c.setHeaders(req)

resp, err := c.HTTPClient.Do(req)
if err != nil {
return nil, err
}

if csrf := resp.Header.Get("x-csrf-token"); csrf != "" {
c.csrf = resp.Header.Get("x-csrf-token")
}

defer resp.Body.Close()

byteArray, err := io.ReadAll(resp.Body)
if err != nil {
return nil, err
}

return byteArray, nil
}

func (c *Client) DeleteData(url string) ([]byte, error) {
req, _ := http.NewRequest(http.MethodDelete, url, nil)

Expand Down Expand Up @@ -236,47 +213,14 @@ func (c *Client) CreateEndpoint(endpoint *endpoint.Endpoint) (*DNSRecord, error)
return &newRecord, nil
}

// UpdateEndpoint updates an existing DNS record.
func (c *Client) UpdateEndpoint(endpoint *endpoint.Endpoint) (*DNSRecord, error) {
id, err := c.LookupIdentifier(endpoint.DNSName, endpoint.RecordType)
if err != nil {
return nil, err
}

record := DNSRecord{
Key: endpoint.DNSName,
RecordType: endpoint.RecordType,
TTL: endpoint.RecordTTL,
Value: endpoint.Targets[0],
}

body, err := json.Marshal(record)
if err != nil {
return nil, err
}

resp, err := c.PutData(fmt.Sprintf(UnifiDNSSelectRecord, c.BaseURL, id), body)
if err != nil {
return nil, err
}

var updatedRecord DNSRecord
err = json.Unmarshal(resp, &updatedRecord)
if err != nil {
return nil, err
}

return &updatedRecord, nil
}

// DeleteEndpoint deletes a DNS record.
func (c *Client) DeleteEndpoint(endpoint *endpoint.Endpoint) error {
id, err := c.LookupIdentifier(endpoint.DNSName, endpoint.RecordType)
lookup, err := c.LookupIdentifier(endpoint.DNSName, endpoint.RecordType)
if err != nil {
return err
}

_, err = c.DeleteData(fmt.Sprintf(UnifiDNSSelectRecord, c.BaseURL, id))
_, err = c.DeleteData(fmt.Sprintf(UnifiDNSSelectRecord, c.BaseURL, lookup.ID))
if err != nil {
return err
}
Expand All @@ -285,17 +229,17 @@ func (c *Client) DeleteEndpoint(endpoint *endpoint.Endpoint) error {
}

// LookupIdentifier finds the ID of a DNS record.
func (c *Client) LookupIdentifier(Key string, RecordType string) (string, error) {
func (c *Client) LookupIdentifier(Key string, RecordType string) (*DNSRecord, error) {
records, err := c.ListRecords()
if err != nil {
return "", err
return nil, err
}

for _, r := range records {
if r.Key == Key && r.RecordType == RecordType {
return r.ID, nil
return &r, nil
}
}

return "", fmt.Errorf("record not found")
return nil, fmt.Errorf("record not found")
}
26 changes: 10 additions & 16 deletions internal/unifi/provider.go
Original file line number Diff line number Diff line change
Expand Up @@ -42,12 +42,12 @@ func (p *Provider) Records(ctx context.Context) ([]*endpoint.Endpoint, error) {

var endpoints []*endpoint.Endpoint
for _, record := range records {
ep := endpoint.NewEndpointWithTTL(
record.Key,
record.RecordType,
record.TTL,
record.Value,
)
ep := &endpoint.Endpoint{
DNSName: record.Key,
RecordType: record.RecordType,
RecordTTL: record.TTL,
Targets: endpoint.NewTargets(record.Value),
}

if !p.domainFilter.Match(ep.DNSName) {
continue
Expand All @@ -61,20 +61,14 @@ func (p *Provider) Records(ctx context.Context) ([]*endpoint.Endpoint, error) {

// ApplyChanges applies a given set of changes in the DNS provider.
func (p *Provider) ApplyChanges(ctx context.Context, changes *plan.Changes) error {
for _, ep := range changes.Delete {
if err := p.client.DeleteEndpoint(ep); err != nil {
return err
}
}

for _, ep := range changes.UpdateNew {
if _, err := p.client.UpdateEndpoint(ep); err != nil {
for _, endpoint := range append(changes.UpdateOld, changes.Delete...) {
if err := p.client.DeleteEndpoint(endpoint); err != nil {
return err
}
}

for _, ep := range changes.Create {
if _, err := p.client.CreateEndpoint(ep); err != nil {
for _, endpoint := range append(changes.Create, changes.UpdateNew...) {
if _, err := p.client.CreateEndpoint(endpoint); err != nil {
return err
}
}
Expand Down

0 comments on commit a25aa51

Please sign in to comment.