2
0
mirror of https://github.com/xcat2/confluent.git synced 2025-01-15 20:27:50 +00:00
Jarrod Johnson abec8c498c Break netlink address fetch on invalid rta_len
It is considered valid for kernel to return a null rta_len
in the midst of data and expect the caller to terminate.
2021-12-20 12:28:35 -05:00

396 lines
14 KiB
Python

#!/usr/bin/python
try:
import http.client as client
except ImportError:
import httplib as client
import ctypes
import ctypes.util
import glob
import os
import select
import socket
import subprocess
import ssl
import sys
import struct
import time
class InvalidApiKey(Exception):
pass
cryptname = ctypes.util.find_library('crypt')
if not cryptname:
if os.path.exists('/usr/lib64/libcrypt.so.1'):
cryptname = 'libcrypt.so.1'
elif os.path.exists('/usr/lib64/libcrypt.so.2'):
cryptname = 'libcrypt.so.2'
c_libcrypt = ctypes.CDLL(cryptname)
c_crypt = c_libcrypt.crypt
c_crypt.argtypes = (ctypes.c_char_p, ctypes.c_char_p)
c_crypt.restype = ctypes.c_char_p
def get_my_addresses():
nlhdrsz = struct.calcsize('IHHII')
ifaddrsz = struct.calcsize('BBBBI')
# RTM_GETADDR = 22
# nlmsghdr struct: u32 len, u16 type, u16 flags, u32 seq, u32 pid
nlhdr = struct.pack('IHHII', nlhdrsz + ifaddrsz, 22, 0x301, 0, 0)
# ifaddrmsg struct: u8 family, u8 prefixlen, u8 flags, u8 scope, u32 index
ifaddrmsg = struct.pack('BBBBI', 0, 0, 0, 0, 0)
s = socket.socket(socket.AF_NETLINK, socket.SOCK_RAW, socket.NETLINK_ROUTE)
s.bind((0, 0))
s.sendall(nlhdr + ifaddrmsg)
addrs = []
while True:
pdata = s.recv(65536)
v = memoryview(pdata)
if struct.unpack('H', v[4:6])[0] == 3: # netlink done message
break
while len(v):
length, typ = struct.unpack('IH', v[:6])
if typ == 20:
fam, plen, _, scope, ridx = struct.unpack('BBBBI', v[nlhdrsz:nlhdrsz+ifaddrsz])
if scope in (253, 0):
rta = v[nlhdrsz+ifaddrsz:length]
while len(rta):
rtalen, rtatyp = struct.unpack('HH', rta[:4])
if rtalen < 4:
break
if rtatyp == 1:
addrs.append((fam, rta[4:rtalen], plen, ridx))
rta = rta[rtalen:]
v = v[length:]
return addrs
def scan_confluents():
srvs = {}
s6 = socket.socket(socket.AF_INET6, socket.SOCK_DGRAM)
s6.setsockopt(socket.IPPROTO_IPV6, socket.IPV6_V6ONLY, 1)
s6.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEPORT, 1)
s6.bind(('::', 1900))
s4 = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
s4.setsockopt(socket.SOL_SOCKET, socket.SO_BROADCAST, 1)
s4.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEPORT, 1)
s4.bind(('0.0.0.0', 1900))
doneidxs = set([])
msg = 'M-SEARCH * HTTP/1.1\r\nST: urn:xcat.org:service:confluent:'
with open('/etc/confluent/confluent.deploycfg') as dcfg:
for line in dcfg.read().split('\n'):
if line.startswith('confluent_uuid:'):
confluentuuid = line.split(': ')[1]
msg += '/confluentuuid=' + confluentuuid
break
with open('/sys/devices/virtual/dmi/id/product_uuid') as uuidin:
msg += '/uuid=' + uuidin.read().strip()
for addrf in glob.glob('/sys/class/net/*/address'):
with open(addrf) as addrin:
hwaddr = addrin.read().strip()
msg += '/mac=' + hwaddr
msg = msg.encode('utf8')
for addr in get_my_addresses():
if addr[0] == socket.AF_INET6:
if addr[-1] in doneidxs:
continue
s6.setsockopt(socket.IPPROTO_IPV6, socket.IPV6_MULTICAST_IF, addr[-1])
try:
s6.sendto(msg, ('ff02::c', 1900))
except OSError:
pass
doneidxs.add(addr[-1])
elif addr[0] == socket.AF_INET:
s4.setsockopt(socket.IPPROTO_IP, socket.IP_MULTICAST_IF, addr[1])
try:
s4.sendto(msg, ('239.255.255.250', 1900))
except OSError:
pass
r = select.select((s4, s6), (), (), 4)
srvlist = []
if r:
r = r[0]
while r:
for s in r:
(rsp, peer) = s.recvfrom(9000)
rsp = rsp.split(b'\r\n')
current = None
for line in rsp:
if line.startswith(b'NODENAME: '):
current = {}
elif line.startswith(b'DEFAULTNET: 1'):
current['isdefault'] = True
elif line.startswith(b'MGTIFACE: '):
current['mgtiface'] = line.replace(b'MGTIFACE: ', b'').strip().decode('utf8')
if len(peer) > 2:
current['myidx'] = peer[-1]
srvs[peer[0]] = current
srvlist.append(peer[0])
r = select.select((s4, s6), (), (), 2)
if r:
r = r[0]
return srvlist, srvs
def get_net_apikey(nodename, mgr):
alpha = 'ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789./'
newpass = ''.join([alpha[x >> 2] for x in bytearray(os.urandom(32))])
salt = '$5$' + ''.join([alpha[x >> 2] for x in bytearray(os.urandom(8))])
newpass = newpass.encode('utf8')
salt = salt.encode('utf8')
crypted = c_crypt(newpass, salt)
for addrinfo in socket.getaddrinfo(mgr, 13001, type=socket.SOCK_STREAM):
try:
clisock = socket.socket(addrinfo[0], addrinfo[1])
clisock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
if addrinfo[0] == socket.AF_INET:
cliaddr = ('0.0.0.0', 302)
else:
cliaddr = ('::', 302)
clisock.bind(cliaddr)
clisock.connect(addrinfo[-1])
rsp = clisock.recv(8)
if rsp != b'\xc2\xd1-\xa8\x80\xd8j\xba':
raise Exception('Unrecognized credential banner')
hellostr = bytearray([1, len(nodename)]) + bytearray(nodename.encode('utf8')) + bytearray(b'\x00\x00')
clisock.send(hellostr)
rsp = bytearray(clisock.recv(2))
if not rsp:
continue
if rsp[0] == 128:
continue
if rsp[0] == 2:
echotoken = clisock.recv(rsp[1])
clisock.recv(2) # drain \x00\x00
clisock.send(bytes(bytearray([3, rsp[1]])))
clisock.send(echotoken)
clisock.send(bytes(bytearray([4, len(crypted)])))
clisock.send(crypted)
clisock.send(b'\x00\x00')
rsp = bytearray(clisock.recv(2))
if rsp[0] == 5:
return newpass.decode('utf8')
finally:
clisock.close()
return ''
def get_apikey(nodename, hosts):
apikey = ""
if os.path.exists('/etc/confluent/confluent.apikey'):
apikey = open('/etc/confluent/confluent.apikey').read().strip()
if apikey:
return apikey
while not apikey:
for host in hosts:
try:
apikey = get_net_apikey(nodename, host)
except OSError:
apikey = None
if apikey:
break
else:
srvlist, _ = scan_confluents()
for host in srvlist:
try:
apikey = get_net_apikey(nodename, host)
except OSError:
apikey = None
if apikey:
break
if not apikey:
sys.stderr.write(
"Failed getting API token, check deployment.apiarmed attribute on {}\n".format(nodename))
time.sleep(10)
with open('/etc/confluent/confluent.apikey', 'w+') as apiout:
apiout.write(apikey)
apikey = apikey.strip()
os.chmod('/etc/confluent/confluent.apikey', 0o600)
return apikey
class HTTPSClient(client.HTTPConnection, object):
def __init__(self, usejson=False, port=443, host=None):
self.stdheaders = {}
mgtiface = None
if usejson:
self.stdheaders['ACCEPT'] = 'application/json'
if host:
self.hosts = [host]
with open('/etc/confluent/confluent.info') as cinfo:
info = cinfo.read().split('\n')
for line in info:
if line.startswith('NODENAME:'):
node = line.split(' ')[1]
self.stdheaders['CONFLUENT_NODENAME'] = node
else:
self.hosts = []
info = open('/etc/confluent/confluent.info').read().split('\n')
havedefault = '0'
for line in info:
if line.startswith('NODENAME:'):
node = line.split(' ')[1]
self.stdheaders['CONFLUENT_NODENAME'] = node
if line.startswith('MANAGER:') and not host:
host = line.split(' ')[1]
if line.startswith('EXTMGRINFO:'):
extinfo = line.split(' ')[1]
extinfo = extinfo.split('|')
if not mgtiface:
host, mgtiface, havedefault = extinfo[:3]
if havedefault == '0' and extinfo[2] == '1':
host, mgtiface, havedefault = extinfo[:3]
if '%' in host:
ifidx = host.split('%', 1)[1]
with open('/tmp/confluent.ifidx', 'w+') as ifout:
ifout.write(ifidx)
self.hosts.append(host)
try:
info = open('/etc/confluent/confluent.deploycfg').read().split('\n')
except Exception:
info = None
if info:
for line in info:
if line.startswith('deploy_server: ') or line.startswith('deploy_server_v6: '):
self.hosts.append(line.split(': ', 1)[1])
self.stdheaders['CONFLUENT_APIKEY'] = get_apikey(node, self.hosts)
if mgtiface:
self.stdheaders['CONFLUENT_MGTIFACE'] = mgtiface
self.port = port
self.host = None
self.node = node
host = self.check_connections()
client.HTTPConnection.__init__(self, host, port)
self.connect()
def set_header(self, key, val):
self.stdheaders[key] = val
def check_connections(self):
foundsrv = None
hosts = self.hosts
for timeo in (0.1, 5):
for host in hosts:
try:
addrinf = socket.getaddrinfo(host, self.port)[0]
psock = socket.socket(addrinf[0])
psock.settimeout(timeo)
psock.connect(addrinf[4])
foundsrv = host
psock.close()
break
except OSError:
continue
else:
continue
break
if not foundsrv:
srvlist, srvs = scan_confluents()
hosts = []
for srv in srvlist:
if srvs[srv].get('isdefault', False):
hosts = [srv] + hosts
else:
hosts = hosts + [srv]
for host in hosts:
try:
addrinf = socket.getaddrinfo(host, self.port)[0]
psock = socket.socket(addrinf[0])
psock.settimeout(timeo)
psock.connect(addrinf[4])
foundsrv = host
psock.close()
break
except OSError:
continue
else:
raise Exception('Unable to reach any hosts')
return foundsrv
def connect(self):
addrinf = socket.getaddrinfo(self.host, self.port)[0]
psock = socket.socket(addrinf[0])
psock.connect(addrinf[4])
ctx = ssl.SSLContext(ssl.PROTOCOL_SSLv23)
ctx.load_verify_locations('/etc/confluent/ca.pem')
host = self.host.split('%', 1)[0]
if '[' not in host and ':' in host:
self.stdheaders['Host'] = '[{0}]'.format(host)
else:
self.stdheaders['Host'] = '{0}'.format(host)
ctx.verify_mode = ssl.CERT_REQUIRED
ctx.check_hostname = True
self.sock = ctx.wrap_socket(psock, server_hostname=host)
def grab_url(self, url, data=None, returnrsp=False):
return self.grab_url_with_status(url, data, returnrsp)[1]
def grab_url_with_status(self, url, data=None, returnrsp=False):
if data:
method = 'POST'
else:
method = 'GET'
authed = False
while not authed:
authed = True
self.request(method, url, data, headers=self.stdheaders)
rsp = self.getresponse()
if rsp.status >= 200 and rsp.status < 300:
if returnrsp:
return rsp.status, rsp
else:
return rsp.status, rsp.read()
if rsp.status == 401:
authed = False
rsp.read()
with open('/etc/confluent/confluent.apikey', 'w+') as akfile:
akfile.write('')
self.stdheaders['CONFLUENT_APIKEY'] = get_apikey(
self.node, [self.host])
raise Exception(rsp.read())
if __name__ == '__main__':
data = None
usejson = False
if '-j' in sys.argv:
usejson = True
if len(sys.argv) == 1:
HTTPSClient()
sys.exit(0)
try:
outbin = sys.argv.index('-o')
sys.argv.pop(outbin)
outbin = sys.argv.pop(outbin)
except ValueError:
outbin = None
try:
waitfor = sys.argv.index('-w')
sys.argv.pop(waitfor)
waitfor = int(sys.argv.pop(waitfor))
except ValueError:
waitfor = None
try:
data = sys.argv.index('-d')
sys.argv.pop(data)
data = sys.argv.pop(data)
except ValueError:
data = None
if outbin:
with open(outbin, 'ab+') as outf:
reader = HTTPSClient(usejson=usejson).grab_url(
sys.argv[1], data, returnrsp=True)
chunk = reader.read(16384)
while chunk:
outf.write(chunk)
chunk = reader.read(16384)
sys.exit(0)
if len(sys.argv) > 2 and os.path.exists(sys.argv[-1]):
data = open(sys.argv[-1]).read()
if waitfor:
client = HTTPSClient(usejson)
status = 201
while status != waitfor:
status, rsp = client.grab_url_with_status(sys.argv[1], data)
sys.stdout.write(rsp.decode())
else:
sys.stdout.write(HTTPSClient(usejson).grab_url(sys.argv[1], data).decode())