#!/usr/bin/python
try:
    import http.client as client
except ImportError:
    import httplib as client
import os
import socket
import subprocess
import ssl
import sys

def get_apikey(nodename, mgr):
    sealnew = True
    if os.path.exists('/etc/confluent/confluent.apikey'):
        return open('/etc/confluent/confluent.apikey').read().strip()
    apikey = subprocess.check_output(['/opt/confluent/bin/clortho', nodename, mgr])
    if not isinstance(apikey, str):
        apikey = apikey.decode('utf8')
    if apikey.startswith('SEALED:'):
        sealnew = False
        with open('/etc/confluent/confluent.sealedapikey', 'w+') as apiout:
            apiout.write(apikey[7:])
        with open('/etc/confluent/confluent.sealedapikey') as inp:
            sp = subprocess.Popen(['/usr/bin/clevis-decrypt-tpm2'],
                stdin=inp, stdout=subprocess.PIPE)
            apikey = sp.communicate()[0]
            if not isinstance(apikey, str):
                apikey = apikey.decode('utf8')
    with open('/etc/confluent/confluent.apikey', 'w+') as apiout:
        apiout.write(apikey)
    if sealnew and os.path.exists('/usr/bin/clevis-encrypt-tpm2'):
        try:
            with open('/etc/confluent/confluent.apikey') as apin:
                sealed = subprocess.check_output(
                    ['/usr/bin/clevis-encrypt-tpm2', '{}'], stdin=apin)
            print(HTTPSClient().grab_url('/confluent-api/self/saveapikey', sealed).decode())
        except Exception:
            sys.stderr.write('Unable to persist API key through TPM2 sealing\n')
    apikey = apikey.strip()
    os.chmod('/etc/confluent/confluent.apikey', 0o600)
    return apikey

class HTTPSClient(client.HTTPConnection, object):
    def __init__(self, json=False, port=443):
        self.stdheaders = {}
        info = open('/etc/confluent/confluent.info').read().split('\n')
        host = None
        mgtiface = None
        havedefault = '0'
        for line in info:
            if line.startswith('NODENAME:'):
                node = line.split(' ')[1]
                self.stdheaders['CONFLUENT_NODENAME'] = node
            if line.startswith('MANAGER:') and not host:
                host = line.split(' ')[1]
            if line.startswith('EXTMGRINFO:'):
                extinfo = line.split(' ')[1]
                extinfo = extinfo.split('|')
                if not mgtiface:
                    host, mgtiface, havedefault = extinfo[:3]
                if havedefault == '0' and extinfo[2] == '1':
                    host, mgtiface, havedefault = extinfo[:3]
        if '%' in host:
            ifidx = host.split('%', 1)[1]
            with open('/tmp/confluent.ifidx', 'w+') as ifout:
                ifout.write(ifidx)
        if json:
            self.stdheaders['ACCEPT'] = 'application/json'
        try:
            info = open('/etc/confluent/confluent.deploycfg').read().split('\n')
        except Exception:
            info = None
        if info:
            for line in info:
                if line.startswith('deploy_server: '):
                    host = line.split(': ', 1)[1]
                    break
        self.stdheaders['CONFLUENT_APIKEY'] = get_apikey(node, host)
        if mgtiface:
            self.stdheaders['CONFLUENT_MGTIFACE'] = mgtiface
        client.HTTPConnection.__init__(self, host, port)
        self.connect()

    def set_header(self, key, val):
        self.stdheaders[key] = val

    def connect(self):
        addrinf = socket.getaddrinfo(self.host, self.port)[0]
        psock = socket.socket(addrinf[0])
        psock.connect(addrinf[4])
        ctx = ssl.SSLContext(ssl.PROTOCOL_SSLv23)
        ctx.load_verify_locations('/etc/confluent/ca.pem')
        host = self.host.split('%', 1)[0]
        if '[' not in host and ':' in host:
            self.stdheaders['Host'] = '[{0}]'.format(host)
        else:
            self.stdheaders['Host'] = '{0}'.format(host)
        ctx.verify_mode = ssl.CERT_REQUIRED
        ctx.check_hostname = True
        self.sock = ctx.wrap_socket(psock, server_hostname=host)

    def grab_url(self, url, data=None, returnrsp=False):
        return self.grab_url_with_status(url, data, returnrsp)[1]

    def grab_url_with_status(self, url, data=None, returnrsp=False):
        if data:
            method = 'POST'
        else:
            method = 'GET'
        self.request(method, url, data, headers=self.stdheaders)
        rsp = self.getresponse()
        if rsp.status >= 200 and rsp.status < 300:
            if returnrsp:
                return rsp.status, rsp
            else:
                return rsp.status, rsp.read()
        raise Exception(rsp.read())

if __name__ == '__main__':
    data = None
    json = False
    if '-j' in sys.argv:
        json = True
    if len(sys.argv) == 1:
        HTTPSClient()
        sys.exit(0)
    try:
        outbin = sys.argv.index('-o')
        sys.argv.pop(outbin)
        outbin = sys.argv.pop(outbin)
    except ValueError:
        outbin = None
    try:
        waitfor = sys.argv.index('-w')
        sys.argv.pop(waitfor)
        waitfor = int(sys.argv.pop(waitfor))
    except ValueError:
        waitfor = None
    try:
        data = sys.argv.index('-d')
        sys.argv.pop(data)
        data = sys.argv.pop(data)
    except ValueError:
        data = None
    if outbin:
        with open(outbin, 'wb') as outf:
            reader = HTTPSClient(json=json).grab_url(
                sys.argv[1], data, returnrsp=True)
            chunk = reader.read(16384)
            while chunk:
                outf.write(chunk)
                chunk = reader.read(16384)
        sys.exit(0)
    if len(sys.argv) > 2 and os.path.exists(sys.argv[-1]):
        data = open(sys.argv[-1]).read()
    if waitfor:
        client = HTTPSClient(json=json)
        status = 201
        while status != waitfor:
            status, rsp = client.grab_url_with_status(sys.argv[1], data)
            sys.stdout.write(rsp.decode())
    else:
        sys.stdout.write(HTTPSClient(json=json).grab_url(sys.argv[1], data).decode())
