#!/usr/bin/python2
# vim: tabstop=4 shiftwidth=4 softtabstop=4

# Copyright 2016-2017 Lenovo
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from collections import deque
import optparse
import os
import select
import shlex
import signal
import subprocess
import sys

try:
    signal.signal(signal.SIGPIPE, signal.SIG_DFL)
except AttributeError:
    pass
path = os.path.dirname(os.path.realpath(__file__))
path = os.path.realpath(os.path.join(path, '..', 'lib', 'python'))
if path.startswith('/opt'):
    sys.path.append(path)

import confluent.client as client
import confluent.screensqueeze as sq
import confluent.sortutil as sortutil


def run():
    argparser = optparse.OptionParser(
        usage="Usage: %prog <file/directorylist> <noderange>:<destination>",
    )
    argparser.add_option('-m', '--maxnodes', type='int',
                     help='Specify a maximum number of '
                          'nodes to run rsync to, '
                          'prompting if over the threshold')   
    argparser.add_option('-f', '-c', '--count', type='int', default=168,
                         help='Number of nodes to concurrently rsync')
    # among other things, FD_SETSIZE limits.  Besides, spawning too many
    # processes can be unkind for the unaware on memory pressure and such...
    argparser.disable_interspersed_args()
    (options, args) = argparser.parse_args()
    if len(args) < 2 or ':' not in args[-1]:
        argparser.print_help()
        sys.exit(1)
    concurrentprocs = options.count
    noderange, targpath = args[-1].split(':', 1)
    client.check_globbing(noderange)
    c = client.Command()
    cmdstr = ' '.join(args[:-1])
    cmdstr = 'rsync -av --info=progress2 ' + cmdstr
    cmdstr += ' {node}:' + targpath

    currprocs = 0
    all = set([])
    pipedesc = {}
    pendingexecs = deque()
    exitcode = 0
    c.stop_if_noderange_over(noderange, options.maxnodes)
    for exp in c.create('/noderange/{0}/attributes/expression'.format(noderange),
            {'expression': cmdstr}):
        if 'error' in exp:
            sys.stderr.write(exp['error'] + '\n')
            exitcode |= exp.get('errorcode', 1)
        ex = exp.get('databynode', ())
        for node in ex:
            cmd = ex[node]['value']
            if not isinstance(cmd, bytes) and not isinstance(cmd, str):
                cmd = cmd.encode('utf-8')
            cmdv = shlex.split(cmd)
            if currprocs < concurrentprocs:
                currprocs += 1
                run_cmdv(node, cmdv, all, pipedesc)
            else:
                pendingexecs.append((node, cmdv))
    if not all or exitcode:
        sys.exit(exitcode)
    rdy, _, _ = select.select(all, [], [], 10)
    nodeerrs = {}
    pernodeout = {}
    pernodefile = {}
    output = sq.ScreenPrinter(noderange, c)

    while all:
        for r in rdy:
            desc = pipedesc[r]
            node = desc['node']
            data = True
            while data and select.select([r], [], [], 0)[0]:
                data = r.read(1)
                if data:
                    if desc['type'] == 'stdout':
                        if node not in pernodeout:
                            pernodeout[node] = ''
                        pernodeout[node] += client.stringify(data)
                        if '\n' in pernodeout[node]:
                            currout, pernodeout[node] = pernodeout[node].split('\n', 1)
                            if currout:
                                pernodefile[node] = os.path.basename(currout)
                        if '\r' in pernodeout[node]:
                            currout, pernodeout[node] = pernodeout[node].split('\r', 1)
                            if currout:
                                currout = currout.split()
                                try:
                                    currout = currout[1]
                                    output.set_output(node, '{0}:{1}'.format(pernodefile[node], currout))
                                except IndexError:
                                    pernodefile = currout[0]
                                    pass
                    else:
                        output.set_output(node, 'error!')
                        if node not in nodeerrs:
                            nodeerrs[node] = ''
                        nodeerrs[node] += client.stringify(data)
                else:
                    pop = desc['popen']
                    ret = pop.poll()
                    if ret is not None:
                        exitcode = exitcode | ret
                        all.discard(r)
                        r.close()
                        if node not in nodeerrs:
                            output.set_output(node, 'complete')
                        if desc['type'] == 'stdout' and pendingexecs:
                            node, cmdv = pendingexecs.popleft()
                            run_cmdv(node, cmdv, all, pipedesc)
        if all:
            rdy, _, _ = select.select(all, [], [], 10)
    for node in nodeerrs:
        for line in nodeerrs[node].split('\n'):
            sys.stderr.write('{0}: {1}\n'.format(node, line))
    sys.exit(exitcode)


def run_cmdv(node, cmdv, all, pipedesc):
    nopen = subprocess.Popen(
        cmdv, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
    pipedesc[nopen.stdout] = {'node': node, 'popen': nopen,
                              'type': 'stdout'}
    pipedesc[nopen.stderr] = {'node': node, 'popen': nopen,
                              'type': 'stderr'}
    all.add(nopen.stdout)
    all.add(nopen.stderr)


if __name__ == '__main__':
    run()