diff --git a/dnsprovider/dnsprovider.go b/dnsprovider/dnsprovider.go index 99f2c2e..bcc6fd1 100644 --- a/dnsprovider/dnsprovider.go +++ b/dnsprovider/dnsprovider.go @@ -6,10 +6,11 @@ import ( "crypto/tls" "encoding/json" "fmt" - "io/ioutil" + "io" "net/http" "net/http/cookiejar" + log "github.com/sirupsen/logrus" "sigs.k8s.io/external-dns/endpoint" "sigs.k8s.io/external-dns/plan" "sigs.k8s.io/external-dns/provider" @@ -17,23 +18,30 @@ import ( // DNSRecord represents a DNS record in the API. type DNSRecord struct { - ID string `json:"_id"` - Enabled bool `json:"enabled"` + ID string `json:"_id,omitempty"` + Enabled bool `json:"enabled,omitempty"` Key string `json:"key"` - Port int `json:"port"` - Priority int `json:"priority"` + Port int `json:"port,omitempty"` + Priority int `json:"priority,omitempty"` RecordType string `json:"record_type"` - TTL endpoint.TTL `json:"ttl"` + TTL endpoint.TTL `json:"ttl,omitempty"` Value string `json:"value"` - Weight int `json:"weight"` + Weight int `json:"weight,omitempty"` } // Client is the DNS provider client. type Client struct { BaseURL string HTTPClient *http.Client + csrf string } +var ( + UnifiLogin = "%s/api/auth/login" + UnifiDNSRecords = "%s/proxy/network/v2/api/site/default/static-dns" + UnifiDNSSelectRecord = "%s/proxy/network/v2/api/site/default/static-dns/%s" +) + // NewClient creates a new DNS provider client and logs in to store cookies. func NewClient(baseURL, username, password string, skipTLSVerify bool) (*Client, error) { jar, err := cookiejar.New(nil) @@ -62,7 +70,7 @@ func NewClient(baseURL, username, password string, skipTLSVerify bool) (*Client, // login authenticates the client and stores the cookies. func (c *Client) login(username, password string) error { - loginURL := fmt.Sprintf("%s/api/auth/login", c.BaseURL) + loginURL := fmt.Sprintf(UnifiLogin, c.BaseURL) credentials := map[string]string{ "username": username, "password": password, @@ -85,18 +93,125 @@ func (c *Client) login(username, password string) error { return nil } +func (c *Client) setHeaders(req *http.Request) { + req.Header.Set("X-CSRF-Token", c.csrf) + req.Header.Add("Accept", "application/json") + req.Header.Add("Content-Type", "application/json; charset=utf-8") +} + +func (c *Client) GetData(url string) ([]byte, error) { + + req, _ := http.NewRequest(http.MethodGet, fmt.Sprintf(url, c.BaseURL), nil) + + c.setHeaders(req) + + resp, err := c.HTTPClient.Do(req) + if err != nil { + return nil, err + } + + if csrf := resp.Header.Get("x-csrf-token"); csrf != "" { + c.csrf = resp.Header.Get("x-csrf-token") + } + + defer resp.Body.Close() + + byteArray, err := io.ReadAll(resp.Body) + if err != nil { + return nil, err + } + + return byteArray, nil +} + +func (c *Client) ShipData(url string, body []byte) ([]byte, error) { + req, _ := http.NewRequest(http.MethodPost, fmt.Sprintf(url, c.BaseURL), bytes.NewBuffer(body)) + + c.setHeaders(req) + log.Debugf("is it a 403 part-1? (gone wrong): %v", req) + resp, err := c.HTTPClient.Do(req) + log.Debugf("is it a 403?: %v", resp) + if err != nil { + return nil, err + } + + if csrf := resp.Header.Get("x-csrf-token"); csrf != "" { + c.csrf = resp.Header.Get("x-csrf-token") + } + + defer resp.Body.Close() + + byteArray, err := io.ReadAll(resp.Body) + if err != nil { + return nil, err + } + + return byteArray, nil +} + +func (c *Client) PutData(url string, body []byte) ([]byte, error) { + req, _ := http.NewRequest(http.MethodPut, url, bytes.NewBuffer(body)) + + c.setHeaders(req) + log.Debugf("is it a 403 part-1? (gone wrong): %v", req) + resp, err := c.HTTPClient.Do(req) + log.Debugf("is it a 403?: %v", resp) + if err != nil { + return nil, err + } + + if csrf := resp.Header.Get("x-csrf-token"); csrf != "" { + c.csrf = resp.Header.Get("x-csrf-token") + } + + defer resp.Body.Close() + + byteArray, err := io.ReadAll(resp.Body) + if err != nil { + return nil, err + } + + return byteArray, nil +} + +func (c *Client) DeleteData(url string) ([]byte, error) { + req, _ := http.NewRequest(http.MethodPost, url, nil) + + c.setHeaders(req) + log.Debugf("is it a 403 part-1? (gone wrong): %v", req) + resp, err := c.HTTPClient.Do(req) + log.Debugf("is it a 403?: %v", resp) + if err != nil { + return nil, err + } + + if csrf := resp.Header.Get("x-csrf-token"); csrf != "" { + c.csrf = resp.Header.Get("x-csrf-token") + } + + defer resp.Body.Close() + + byteArray, err := io.ReadAll(resp.Body) + if err != nil { + return nil, err + } + + return byteArray, nil +} + // ListRecords retrieves all DNS records. func (c *Client) ListRecords() ([]DNSRecord, error) { - resp, err := c.HTTPClient.Get(fmt.Sprintf("%s/proxy/network/v2/api/site/default/static-dns", c.BaseURL)) + resp, err := c.GetData(UnifiDNSRecords) if err != nil { return nil, err } - defer resp.Body.Close() var records []DNSRecord - if err := json.NewDecoder(resp.Body).Decode(&records); err != nil { + err = json.Unmarshal(resp, &records) + if err != nil { return nil, err } + return records, nil } @@ -107,16 +222,22 @@ func (c *Client) CreateRecord(record DNSRecord) (*DNSRecord, error) { return nil, err } - resp, err := c.HTTPClient.Post(fmt.Sprintf("%s/proxy/network/v2/api/site/default/static-dns", c.BaseURL), "application/json", bytes.NewBuffer(body)) + log.Debugf("json marshal: %v", body) + + resp, err := c.ShipData(UnifiDNSRecords, body) if err != nil { return nil, err } - defer resp.Body.Close() + + log.Debugf("json marshal 2: %v", resp) var newRecord DNSRecord - if err := json.NewDecoder(resp.Body).Decode(&newRecord); err != nil { + err = json.Unmarshal(resp, &newRecord) + if err != nil { return nil, err } + + log.Debugf("json marshal 3: %v", newRecord) return &newRecord, nil } @@ -127,20 +248,14 @@ func (c *Client) UpdateRecord(id string, record DNSRecord) (*DNSRecord, error) { return nil, err } - req, err := http.NewRequest(http.MethodPut, fmt.Sprintf("%s/proxy/network/v2/api/site/default/static-dns/%s", c.BaseURL, id), bytes.NewBuffer(body)) + resp, err := c.PutData(fmt.Sprintf("%s/proxy/network/v2/api/site/default/static-dns/%s", c.BaseURL, id), body) if err != nil { return nil, err } - req.Header.Set("Content-Type", "application/json") - - resp, err := c.HTTPClient.Do(req) - if err != nil { - return nil, err - } - defer resp.Body.Close() var updatedRecord DNSRecord - if err := json.NewDecoder(resp.Body).Decode(&updatedRecord); err != nil { + err = json.Unmarshal(resp, &updatedRecord) + if err != nil { return nil, err } return &updatedRecord, nil @@ -148,21 +263,11 @@ func (c *Client) UpdateRecord(id string, record DNSRecord) (*DNSRecord, error) { // DeleteRecord deletes a DNS record. func (c *Client) DeleteRecord(id string) error { - req, err := http.NewRequest(http.MethodDelete, fmt.Sprintf("%s/proxy/network/v2/api/site/default/static-dns/%s", c.BaseURL, id), nil) + _, err := c.DeleteData(fmt.Sprintf(UnifiDNSSelectRecord, c.BaseURL, id)) if err != nil { return err } - resp, err := c.HTTPClient.Do(req) - if err != nil { - return err - } - defer resp.Body.Close() - - if resp.StatusCode != http.StatusOK { - bodyBytes, _ := ioutil.ReadAll(resp.Body) - return fmt.Errorf("failed to delete record: %s", bodyBytes) - } return nil } @@ -206,11 +311,10 @@ func (p *DNSProvider) Records(ctx context.Context) ([]*endpoint.Endpoint, error) func (p *DNSProvider) ApplyChanges(ctx context.Context, changes *plan.Changes) error { for _, ep := range changes.Create { record := DNSRecord{ - ID: ep.SetIdentifier, - Key: ep.DNSName, - Value: ep.Targets[0], + Key: ep.DNSName, + Value: ep.Targets[0], RecordType: ep.RecordType, - TTL: ep.RecordTTL, + TTL: ep.RecordTTL, } if _, err := p.client.CreateRecord(record); err != nil { return err @@ -219,11 +323,11 @@ func (p *DNSProvider) ApplyChanges(ctx context.Context, changes *plan.Changes) e for _, ep := range changes.UpdateNew { record := DNSRecord{ - ID: ep.SetIdentifier, - Key: ep.DNSName, - Value: ep.Targets[0], + ID: ep.SetIdentifier, + Key: ep.DNSName, + Value: ep.Targets[0], RecordType: ep.RecordType, - TTL: ep.RecordTTL, + TTL: ep.RecordTTL, } // Assuming ID can be obtained from DNS name id := ep.DNSName // This needs to be changed to actual ID fetching logic @@ -258,4 +362,4 @@ func (p *DNSProvider) GetDomainFilter() endpoint.DomainFilter { // GetDNSName returns the DNS provider's name. func (p *DNSProvider) GetDNSName() string { return "external-dns-unifi-webhook" -} \ No newline at end of file +}