#!/usr/bin/python
import logging
logging.getLogger('libarchive').addHandler(logging.NullHandler())
import libarchive
import hashlib
import os
import shutil
import sys

COPY = 0
EXTRACT = 1
READFILES = set([
    'media.1/products',
    'media.2/products',
    '.discinfo',
])

HEADERSUMS = set([b'\x85\xeddW\x86\xc5\xbdhx\xbe\x81\x18X\x1e\xb4O\x14\x9d\x11\xb7C8\x9b\x97R\x0c-\xb8Ht\xcb\xb3'])
HASHPRINTS = {
    '69d5f1c5e4474d70b0fb5374bfcb29bf57ba828ff00a55237cd757e61ed71048': {'name': 'cumulus-broadcom-amd64-4.0.0', 'method': COPY},
}

from ctypes import byref, c_longlong, c_size_t, c_void_p

from libarchive.ffi import (
    write_disk_new, write_disk_set_options, write_free, write_header,
    read_data_block, write_data_block, write_finish_entry, ARCHIVE_EOF
)

def extract_entries(entries, flags=0, callback=None, totalsize=None):
    """Extracts the given archive entries into the current directory.
    """
    buff, size, offset = c_void_p(), c_size_t(), c_longlong()
    buff_p, size_p, offset_p = byref(buff), byref(size), byref(offset)
    sizedone = 0
    with libarchive.extract.new_archive_write_disk(flags) as write_p:
        for entry in entries:
            if str(entry).endswith('TRANS.TBL'):
                continue
            write_header(write_p, entry._entry_p)
            read_p = entry._archive_p
            while 1:
                r = read_data_block(read_p, buff_p, size_p, offset_p)
                sizedone += size.value
                if callback:
                    callback({'progress': float(sizedone) / float(totalsize)})
                if r == ARCHIVE_EOF:
                    break
                write_data_block(write_p, buff, size, offset)
            write_finish_entry(write_p)

def extract_file(filepath, flags=0, callback=lambda x: None, imginfo=()):
    """Extracts an archive from a file into the current directory."""
    totalsize = 0
    for img in imginfo:
        if not imginfo[img]:
            continue
        totalsize += imginfo[img]
    with libarchive.file_reader(filepath) as archive:
        extract_entries(archive, flags, callback, totalsize)

def check_centos(isoinfo):
    ver = None
    arch = None
    for entry in isoinfo[0]:
        if 'centos-release-7' in entry:
            dotsplit = entry.split('.')
            arch = dotsplit[-2]
            ver = dotsplit[0].split('release-')[-1].replace('-', '.')
            break
        elif 'centos-release-8' in entry:
            ver = entry.split('-')[2]
            arch = entry.split('.')[-2]
            break
    else:
        return None
    return {'name': 'centos-{0}-{1}'.format(ver, arch), 'method': EXTRACT}


def check_sles(isoinfo):
    ver = None
    arch = 'x86_64'
    disk = None
    distro = ''
    if 'media.1/products' in isoinfo[1]:
        medianame = 'media.1/products'
    elif 'media.2/products' in isoinfo[1]:
        medianame = 'media.2/products'
    else:
        return None
    prodinfo = isoinfo[1][medianame]
    if not isinstance(prodinfo, str):
        prodinfo = prodinfo.decode('utf8')
    prodinfo = prodinfo.split('\n')
    hline = prodinfo[0].split(' ')
    ver = hline[-1].split('-')[0]
    if hline[-1].startswith('15'):
        distro = 'sle'
        if hline[0] == '/':
            disk = '1'
        elif hline[0].startswith('/Module'):
            disk = '2'
    elif hline[-1].startswith('12'):
        if 'SLES' in hline[1]:
            distro = 'sles'
        if '.1' in medianame:
            disk = '1'
        elif '.2' in medianame:
            disk = '2'
    if disk and distro:
        return {'name': '{0}-{1}-{2}'.format(distro, ver, arch),
                'method': EXTRACT, 'subname': disk}
    return None


def check_rhel(isoinfo):
    ver = None
    arch = None
    for entry in isoinfo[0]:
        if 'redhat-release-7' in entry:
            dotsplit = entry.split('.')
            arch = dotsplit[-2]
            ver = dotsplit[0].split('release-')[-1].replace('-', '.')
            break
        elif 'redhat-release-8' in entry:
            ver = entry.split('-')[2]
            arch = entry.split('.')[-2]
            break
    else:
        return None
    return {'name': 'rhel-{0}-{1}'.format(ver, arch), 'method': EXTRACT}

def scan_iso(filename):
    filesizes = {}
    filecontents = {}
    with libarchive.file_reader(filename) as reader:
        for ent in reader:
            if str(ent).endswith('TRANS.TBL'):
                continue
            filesizes[str(ent)] = ent.size
            if str(ent) in READFILES:
                filecontents[str(ent)] = b''
                for block in ent.get_blocks():
                    filecontents[str(ent)] += bytes(block)
    return filesizes, filecontents

def fingerprint(filename):
    with open(sys.argv[1], 'rb') as archive:
        header = archive.read(32768)
        archive.seek(32769)
        if archive.read(6) == b'CD001\x01':
            # ISO image
            isoinfo = scan_iso(filename)
            name = None
            for fun in globals():
                if fun.startswith('check_'):
                    name = globals()[fun](isoinfo)
                    if name:
                        return name, isoinfo[0]
            return None
        else:
            sum = hashlib.sha256(header)
            if sum.digest() in HEADERSUMS:
                archive.seek(32768)
                chunk = archive.read(32768)
                while chunk:
                    sum.update(chunk)
                    chunk = archive.read(32768)
                imginfo = HASHPRINTS.get(sum.hexdigest(), None)
                if imginfo:
                    return imginfo, None


def printit(info):
    sys.stdout.write('     \r{:.2f}%'.format(100 * info['progress']))
    sys.stdout.flush()

def import_image(filename):
    identity = fingerprint(filename)
    if not identity:
        return -1
    identity, imginfo = identity
    targpath = identity['name']
    if identity.get('subname', None):
        targpath += '/' + identity['subname']
    targpath = '/var/lib/confluent/distributions/' + targpath
    try:
        os.makedirs(targpath)
    except OSError as e:
        if e.errno != 17:
            raise
    filename = os.path.abspath(filename)
    os.chdir(targpath)
    print('Importing OS to ' + targpath + ':')
    printit({'progress': 0.0})
    if identity['method'] == EXTRACT:
        extract_file(filename, callback=printit, imginfo=imginfo)
    elif identity['method'] == COPY:
        targpath = os.path.join(targpath, os.path.basename(filename))
        shutil.copyfile(filename, targpath)
    printit({'progress': 1.0})
    sys.stdout.write('\n')


if __name__ == '__main__':
    sys.exit(import_image(sys.argv[1]))