#!/usr/bin/python2
import argparse
import os
import re
import signal
import sys
try:
    signal.signal(signal.SIGPIPE, signal.SIG_DFL)
except AttributeError:
    pass

path = os.path.dirname(os.path.realpath(__file__))
path = os.path.realpath(os.path.join(path, '..', 'lib', 'python'))
if path.startswith('/opt'):
    sys.path.append(path)
import confluent.client as client
import confluent.sortutil as sortutil

def partitionhostsline(line):
    comment = ''
    if '#' in line:
        cmdidx = line.index('#')
        comment = line[cmdidx:]
        line = line[:cmdidx].strip()
    if not line.strip():
        return '', [], comment
    ipaddr, names = line.split(None, 1)
    names = names.split()
    return ipaddr, names, comment

class HostMerger(object):
    def __init__(self):
        self.byip = {}
        self.byname = {}
        self.byname6 = {}
        self.sourcelines = []
        self.targlines = []

    def read_source(self, sourcefile):
        with open(sourcefile, 'r') as hfile:
            self.sourcelines = hfile.read().split('\n')
        while not self.sourcelines[-1]:
            self.sourcelines = self.sourcelines[:-1]
        for x in range(len(self.sourcelines)):
            line = self.sourcelines[x]
            currip, names, comment = partitionhostsline(line)
            if currip:
                self.byip[currip] = x
            byname = self.byname
            if ':' in currip:
                byname = self.byname6
            for name in names:
                byname[name] = x

    def add_entry(self, ip, names):
        targ = self.byname
        if ':' in ip:
            targ = self.byname6
        line = '{:<39} {}'.format(ip, names)
        x = len(self.sourcelines)
        self.sourcelines.append(line)
        for name in names.split():
            if not name:
                continue
            targ[name] = x
        self.byip[ip] = x


    def read_target(self, targetfile):
        if not os.path.exists(targetfile):
            return
        with open(targetfile, 'r') as hfile:
            lines = hfile.read().split('\n')
            while lines and not lines[-1]:
                lines = lines[:-1]
            for y in range(len(lines)):
                line = lines[y]
                currip, names, comment = partitionhostsline(line)
                byname = self.byname
                if ':' in currip:
                    byname = self.byname6
                if currip in self.byip:
                    x = self.byip[currip]
                    if self.sourcelines[x] is None:
                        # have already consumed this entry
                        continue
                    self.targlines.append(self.sourcelines[x])
                    self.sourcelines[x] = None
                    continue
                for name in names:
                    if name in byname:
                        x = byname[name]
                        if self.sourcelines[x] is None:
                            break
                        self.targlines.append(self.sourcelines[x])
                        self.sourcelines[x] = None
                        break
                else:
                    self.targlines.append(line)

    def write_out(self, targetfile):
        while self.targlines and not self.targlines[-1]:
            self.targlines = self.targlines[:-1]
        while self.sourcelines and not self.sourcelines[-1]:
            self.sourcelines = self.sourcelines[:-1]
            if not self.sourcelines:
                break
        with open(targetfile, 'w') as hosts:
            for line in self.targlines:
                hosts.write(line + '\n')
            for line in self.sourcelines:
                if line is not None:
                    hosts.write(line + '\n')

def main():
    ap = argparse.ArgumentParser(description="Create/amend /etc/hosts file for given noderange")
    ap.add_argument('noderange', help='Noderange to generate/update /etc/hosts for')
    ap.add_argument('-a', '--attrib', help='Pull ip addresses and hostnames from attribute database', action='store_true')
    ap.add_argument('-i', '--ip', help='Expression to generate addresses (e.g. 172.16.1.{n1} or fd2b:246f:8a50::{n1:x})')
    ap.add_argument('-n', '--name', help='Expression for name to add ({node}-compute, etc). If unspecified, "{node} {node}.{dns.domain}" will be used', action='append')
    args = ap.parse_args()
    c = client.Command()
    if args.name:
        names = ' '.join(args.name)
    else:
        names = '{node} {node}.{dns.domain}'
    merger = HostMerger()
    if not args.ip and not args.attrib:
        sys.stderr.write('-a or -i is currently required\n')
        sys.exit(1)
    if args.attrib:
        ip4bynode = {}
        ip6bynode = {}
        namesbynode = {}
        domainsbynode = {}
        for ent in c.read('/noderange/{0}/attributes/current'.format(args.noderange)):
            ent = ent.get('databynode', {})
            for node in ent:
                for attrib in ent[node]:
                    val = ent[node][attrib]
                    if not isinstance(val, dict):
                        continue
                    val = val.get('value', None)
                    if attrib == 'dns.domain':
                        domainsbynode[node] = val
                    if attrib.startswith('net.'):
                        nameparts = attrib.split('.')
                        if len(nameparts) == 2:
                            currnet = None
                        else:
                            currnet = '.'.join(nameparts[1:-1])
                        for currdict in (ip4bynode, ip6bynode, namesbynode):
                            if node not in currdict:
                                currdict[node] = {}
                        if attrib.endswith('.ipv4_address') and val:
                            ip4bynode[node][currnet] = val.split('/', 1)[0]
                        elif attrib.endswith('.ipv6_address') and val:
                            ip6bynode[node][currnet] = val.split('/', 1)[0]
                        elif attrib.endswith('.hostname'):
                            namesbynode[node][currnet] = re.split(r'\s+|,', val)
        for node in ip4bynode:
            mydomain = domainsbynode.get(node, None)
            for ipdb in (ip4bynode, ip6bynode):
                for currnet in ipdb[node]:
                    if args.name:
                        sys.stderr.write('Custom name and attribute addresses not currently supported together\n')
                        sys.exit(1)
                    else:
                        names = list(namesbynode.get(node, {}).get(currnet, [node]))
                        if mydomain:
                            for name in names:
                                if mydomain in names:
                                    break
                            else:
                                for name in list(names):
                                    names.append('{0}.{1}'.format(name, mydomain))
                    names = ' '.join(names)
                    merger.add_entry(ipdb[node][currnet], names)
        merger.write_out('/etc/whatnowhosts')
    else:
        namesbynode = {}
        ipbynode = {}
        expurl = '/noderange/{0}/attributes/expression'.format(args.noderange)
        expression = names
        exitcode = 0
        exitcode |= expand_expression(c, namesbynode, expurl, names)
        exitcode |= expand_expression(c, ipbynode, expurl, args.ip)
        if exitcode:
            sys.exit(exitcode)
        for node in ipbynode:
            merger.add_entry(ipbynode[node], namesbynode[node])
    if os.path.exists('/etc/hosts'):
        merger.read_target('/etc/hosts')
        os.rename('/etc/hosts', '/etc/hosts.confluentbkup')
    merger.write_out('/etc/hosts')



def expand_expression(c, namesbynode, expurl, expression):
    exitcode = 0
    for exp in c.create(expurl, {'expression': expression}):
        if 'error' in exp:
            sys.stderr.write(exp['error'] + '\n')
            exitcode |= exp.get('errorcode', 1)
        ex = exp.get('databynode', ())
        for node in ex:
            namesbynode[node] = ex[node]['value']
    return exitcode


if __name__ == '__main__':
    main()