diff --git a/client.py b/client.py index 06fc573..a03ac3a 100644 --- a/client.py +++ b/client.py @@ -163,7 +163,7 @@ def start(self): raise Fatal('%r expected STARTED, got %r' % (self.argv, line)) def sethostip(self, hostname, ip): - assert(not re.search(r'[^-\w]', hostname)) + assert(not re.search(r'[^-\w\.]', hostname)) assert(not re.search(r'[^0-9.]', ip)) self.pfile.write('HOST %s,%s\n' % (hostname, ip)) self.pfile.flush() diff --git a/firewall.py b/firewall.py index 53d866e..cbc3312 100644 --- a/firewall.py +++ b/firewall.py @@ -430,6 +430,19 @@ def restore_etc_hosts(port): rewrite_etc_hosts(port) +def _mask(ip, width): + nip = struct.unpack('!I', socket.inet_aton(ip))[0] + masked = nip & shl(shl(1, width) - 1, 32-width) + return socket.inet_ntoa(struct.pack('!I', masked)) + + +def ip_in_subnets(ip, subnets): + for swidth,sexclude,snet in sorted(subnets, reverse=True): + if _mask(snet, swidth) == _mask(ip, swidth): + return not sexclude + return False + + # This is some voodoo for setting up the kernel's transparent # proxying stuff. If subnets is empty, we just delete our sshuttle rules; # otherwise we delete it, then make them from scratch. @@ -521,8 +534,9 @@ def main(port, dnsport, syslog): line = sys.stdin.readline(128) if line.startswith('HOST '): (name,ip) = line[5:].strip().split(',', 1) - hostmap[name] = ip - rewrite_etc_hosts(port) + if ip_in_subnets(ip, subnets): + hostmap[name] = ip + rewrite_etc_hosts(port) elif line: raise Fatal('expected EOF, got %r' % line) else: diff --git a/helpers.py b/helpers.py index 45a028b..d8de08d 100644 --- a/helpers.py +++ b/helpers.py @@ -78,3 +78,10 @@ def islocal(ip): return True # it's a local IP, or there would have been an error +def shl(n, bits): + # we use our own implementation of left-shift because + # results may be different between older and newer versions + # of python for numbers like 1<<32. We use long() because + # int(2**32) doesn't work in older python, which has limited + # int sizes. + return n * long(2**bits) diff --git a/hostwatch.py b/hostwatch.py index 66e7461..e2bdb2b 100644 --- a/hostwatch.py +++ b/hostwatch.py @@ -51,15 +51,20 @@ def read_host_cache(): words = line.strip().split(',') if len(words) == 2: (name,ip) = words - name = re.sub(r'[^-\w]', '-', name).strip() + name = re.sub(r'[^-\w\.]', '-', name).strip() ip = re.sub(r'[^0-9.]', '', ip).strip() if name and ip: found_host(name, ip) - -def found_host(hostname, ip): - hostname = re.sub(r'\..*', '', hostname) - hostname = re.sub(r'[^-\w]', '_', hostname) + +def found_host(full_hostname, ip): + full_hostname = re.sub(r'[^-\w\.]', '_', full_hostname) + hostname = re.sub(r'\..*', '', full_hostname) + _insert_host(full_hostname, ip) + _insert_host(hostname, ip) + + +def _insert_host(hostname, ip): if (ip.startswith('127.') or ip.startswith('255.') or hostname == 'localhost'): return diff --git a/server.py b/server.py index 5f2e5e4..1032d4c 100644 --- a/server.py +++ b/server.py @@ -37,20 +37,11 @@ def _maskbits(netmask): if not netmask: return 32 for i in range(32): - if netmask[0] & _shl(1, i): + if netmask[0] & shl(1, i): return 32-i return 0 -def _shl(n, bits): - # we use our own implementation of left-shift because - # results may be different between older and newer versions - # of python for numbers like 1<<32. We use long() because - # int(2**32) doesn't work in older python, which has limited - # int sizes. - return n * long(2**bits) - - def _list_routes(): argv = ['netstat', '-rn'] p = ssubprocess.Popen(argv, stdout=ssubprocess.PIPE) @@ -63,7 +54,7 @@ def _list_routes(): maskw = _ipmatch(cols[2]) # linux only mask = _maskbits(maskw) # returns 32 if maskw is null width = min(ipw[1], mask) - ip = ipw[0] & _shl(_shl(1, width) - 1, 32-width) + ip = ipw[0] & shl(shl(1, width) - 1, 32-width) routes.append((socket.inet_ntoa(struct.pack('!I', ip)), width)) rv = p.wait() if rv != 0: