-
Notifications
You must be signed in to change notification settings - Fork 1
/
aws-dns.go
161 lines (134 loc) · 4.21 KB
/
aws-dns.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
package main
import (
"flag"
"fmt"
"net"
"os"
"path/filepath"
"regexp"
"strings"
"sync"
"time"
"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/credentials"
"github.com/aws/aws-sdk-go/aws/session"
"github.com/aws/aws-sdk-go/service/ec2"
"github.com/go-ini/ini"
"github.com/miekg/dns"
)
var addresses = make(map[string]string)
const dnsSuffix string = "aws."
var awsRegion = flag.String("region", "ap-southeast-2", "AWS region for API access")
var port = flag.Int("port", 10053, "UDP Port to listen for DNS requests on")
var refreshInterval = flag.Int("refresh", 5, "Number of minutes between refreshing hosts")
func main() {
flag.Parse()
go updateAddresses()
server, addr, _, _ := setupServer()
dns.HandleFunc(dnsSuffix, awsDNSServer)
wg := &sync.WaitGroup{}
wg.Add(1)
defer server.Shutdown()
fmt.Println("Serving on ", addr)
wg.Wait()
}
func updateAddresses() {
ticker := time.NewTicker(time.Duration(*refreshInterval) * time.Minute)
populateAddresses()
select {
case <-ticker.C:
fmt.Println("Updating addresses.")
populateAddresses()
}
}
func setupServer() (*dns.Server, string, chan struct{}, error) {
pc, err := net.ListenPacket("udp", fmt.Sprintf("127.0.0.1:%d", *port))
if err != nil {
panic(err)
// return nil, "", nil, err
}
server := &dns.Server{PacketConn: pc, ReadTimeout: time.Hour, WriteTimeout: time.Hour}
waitLock := sync.Mutex{}
waitLock.Lock()
server.NotifyStartedFunc = waitLock.Unlock
fin := make(chan struct{}, 0)
go func() {
server.ActivateAndServe()
close(fin)
pc.Close()
}()
waitLock.Lock()
return server, pc.LocalAddr().String(), fin, nil
}
func awsDNSServer(w dns.ResponseWriter, req *dns.Msg) {
m := new(dns.Msg)
m.SetReply(req)
record := strings.ToLower(req.Question[0].Name)
fmt.Println("DNS Request: ", record)
if record == "reload-me.aws." {
go populateAddresses()
m.Extra = make([]dns.RR, 1)
m.Extra[0] = &dns.TXT{Hdr: dns.RR_Header{Name: m.Question[0].Name, Rrtype: dns.TypeTXT, Class: dns.ClassNONE, Ttl: 0}, Txt: []string{"Reloaded OK"}}
w.WriteMsg(m)
}
// Lookup the address
if len(addresses[record]) > 0 {
m.Extra = make([]dns.RR, 1)
m.Extra[0] = &dns.A{Hdr: dns.RR_Header{Name: m.Question[0].Name, Rrtype: dns.TypeA, Class: dns.ClassINET, Ttl: 0}, A: net.ParseIP(addresses[record])}
w.WriteMsg(m)
} else {
m.SetRcode(req, dns.RcodeNameError)
m.Authoritative = true
w.WriteMsg(m)
}
}
func populateAddresses() {
for _, profile := range getAvailableAwsProfiles() {
fmt.Printf("Loading hosts for AWS profile %s\n", strings.ToLower(profile))
svc := ec2.New(session.New(), &aws.Config{Credentials: credentials.NewSharedCredentials("", strings.ToLower(profile)), Region: aws.String(*awsRegion)})
// Call the DescribeInstances Operation
resp, err := svc.DescribeInstances(nil)
if err != nil {
fmt.Printf("Unable to load instance details for profile %s: %s\n", profile, err)
continue
}
// resp has all of the response data, pull out instance IDs:
fmt.Println("> Number of instances: ", len(resp.Reservations))
for idx := range resp.Reservations {
for _, inst := range resp.Reservations[idx].Instances {
name := getNameTagVal(inst.Tags)
record := strings.ToLower(fmt.Sprintf("%s.%s", *inst.InstanceId, dnsSuffix))
addresses[record] = *inst.PrivateIpAddress
fmt.Printf("Added address %s: %s\n", record, *inst.PrivateIpAddress)
if len(name) != 0 {
record := strings.ToLower(fmt.Sprintf("%s.%s", parameterizeString(name), dnsSuffix))
addresses[record] = *inst.PrivateIpAddress
fmt.Printf("Added address %s: %s\n", record, *inst.PrivateIpAddress)
}
}
}
}
}
func parameterizeString(input string) string {
input = strings.ToLower(input)
re := regexp.MustCompile("[^a-z0-9-_.]+")
return re.ReplaceAllString(input, "-")
}
func getAvailableAwsProfiles() []string {
homeDir := os.Getenv("HOME")
file := filepath.Join(homeDir, ".aws", "credentials")
config, err := ini.Load(file)
if err != nil {
fmt.Printf("Error loading aws credentials file to discover profiles: %s\n", err)
return []string{"default"}
}
return config.SectionStrings()
}
func getNameTagVal(tags []*ec2.Tag) string {
for _, tag := range tags {
if *tag.Key == "Name" {
return *tag.Value
}
}
return ""
}