# Copyright 2017 Lenovo
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import argparse
import ctypes
import fcntl
import json
from select import select
import socket
import struct
import os
import subprocess
import sys
import time

if os.path.exists('/opt/confluent/bin/apiclient'):
    apiclient = '/opt/confluent/bin/apiclient'
elif os.path.exists('/etc/confluent/apiclient'):
    apiclient = '/etc/confluent/apiclient'

class IpmiMsg(ctypes.Structure):
    _fields_ = [('netfn', ctypes.c_ubyte),
                ('cmd', ctypes.c_ubyte),
                ('data_len', ctypes.c_short),
                ('data', ctypes.POINTER(ctypes.c_ubyte))]


class IpmiSystemInterfaceAddr(ctypes.Structure):
    _fields_ = [('addr_type', ctypes.c_int),
                ('channel', ctypes.c_short),
                ('lun', ctypes.c_ubyte)]


class IpmiRecv(ctypes.Structure):
    _fields_ = [('recv_type', ctypes.c_int),
                ('addr', ctypes.POINTER(IpmiSystemInterfaceAddr)),
                ('addr_len', ctypes.c_uint),
                ('msgid', ctypes.c_long),
                ('msg', IpmiMsg)]


class IpmiReq(ctypes.Structure):
    _fields_ = [('addr', ctypes.POINTER(IpmiSystemInterfaceAddr)),
                ('addr_len', ctypes.c_uint),
                ('msgid', ctypes.c_long),
                ('msg', IpmiMsg)]


_IONONE = 0
_IOWRITE = 1
_IOREAD = 2
IPMICTL_SET_MY_ADDRESS_CMD = (
    _IOREAD << 30 | ctypes.sizeof(ctypes.c_uint) << 16
    | ord('i') << 8 | 17)  # from ipmi.h
IPMICTL_SEND_COMMAND = (
    _IOREAD << 30 | ctypes.sizeof(IpmiReq) << 16
    | ord('i') << 8 | 13)  # from ipmi.h
# next is really IPMICTL_RECEIVE_MSG_TRUNC, but will only use that
IPMICTL_RECV = (
    (_IOWRITE | _IOREAD) << 30 | ctypes.sizeof(IpmiRecv) << 16
    | ord('i') << 8 | 11)  # from ipmi.h
BMC_SLAVE_ADDR = ctypes.c_uint(0x20)
CURRCHAN = 0xf
ADDRTYPE = 0xc


class Session(object):
    def __init__(self, devnode='/dev/ipmi0'):
        """Create a local session inband

        :param: devnode: The path to the ipmi device
        """
        self.ipmidev = open(devnode, 'r+')
        fcntl.ioctl(self.ipmidev, IPMICTL_SET_MY_ADDRESS_CMD, BMC_SLAVE_ADDR)
        # the interface is initted, create some reusable memory for our session
        self.databuffer = ctypes.create_string_buffer(4096)
        self.req = IpmiReq()
        self.rsp = IpmiRecv()
        self.addr = IpmiSystemInterfaceAddr()
        self.req.msg.data = ctypes.cast(
            ctypes.addressof(self.databuffer),
            ctypes.POINTER(ctypes.c_ubyte))
        self.rsp.msg.data = self.req.msg.data
        self.userid = None
        self.password = None

    def await_reply(self):
        rd, _, _ = select((self.ipmidev,), (), (), 1)
        while not rd:
            rd, _, _ = select((self.ipmidev,), (), (), 1)

    def pause(self, seconds):
        time.sleep(seconds)

    @property
    def parsed_rsp(self):
        response = {'netfn': self.rsp.msg.netfn, 'command': self.rsp.msg.cmd,
                    'code': bytearray(self.databuffer.raw)[0],
                    'data': bytearray(
                        self.databuffer.raw[1:self.rsp.msg.data_len])}
        return response

    def await_config(s, bmccfg, channel):
        vlan = bmccfg.get('bmcvlan', None)
        ipv4 = bmccfg.get('bmcipv4', None)
        prefix = bmccfg.get('prefixv4', None)
        gw = bmccfg.get('bmcgw', None)



    def raw_command(self,
                    netfn,
                    command,
                    data=(),
                    bridge_request=None,
                    retry=True,
                    delay_xmit=None,
                    timeout=None,
                    waitall=False, rslun=0):
        self.addr.channel = CURRCHAN
        self.addr.addr_type = ADDRTYPE
        self.addr.lun = rslun
        self.req.addr_len = ctypes.sizeof(IpmiSystemInterfaceAddr)
        self.req.addr = ctypes.pointer(self.addr)
        self.req.msg.netfn = netfn
        self.req.msg.cmd = command
        if data:
            data = memoryview(bytearray(data))
            try:
                self.databuffer[:len(data)] = data[:len(data)]
            except ValueError:
                self.databuffer[:len(data)] = data[:len(data)].tobytes()
        self.req.msg.data_len = len(data)
        fcntl.ioctl(self.ipmidev, IPMICTL_SEND_COMMAND, self.req)
        self.await_reply()
        self.rsp.msg.data_len = 4096
        self.rsp.addr = ctypes.pointer(self.addr)
        self.rsp.addr_len = ctypes.sizeof(IpmiSystemInterfaceAddr)
        fcntl.ioctl(self.ipmidev, IPMICTL_RECV, self.rsp)
        return self.parsed_rsp


def _is_tsm(model):
    return model in ('7y00', '7z01', '7y98', '7y99')

def set_port(s, port, vendor, model):
    oport = port
    if vendor not in ('IBM', 'Lenovo'):
        raise Exception('{0} not implemented'.format(vendor))
    if _is_tsm(model):
        return set_port_tsm(s, port, model)
    else:
        set_port_xcc(s, port, model)
        return 1


def get_remote_config_mod(vendor, model, waiters):
    if vendor in ('IBM', 'Lenovo'):
        if _is_tsm(model):
            for waiter in waiters:
                if waiter:
                    sys.stdout.write('Waiting for TSM network to activate')
                    for x in range(0, 90):
                        sys.stdout.write('.')
                        sys.stdout.flush()
                        time.sleep(1)
                    sys.stdout.write('Complete\n')
                    break
            return 'tsm'
        else:
            return 'xcc'
    return None

def set_port_tsm(s, port, model):
    oport = port
    sys.stdout.write('Setting TSM port to "{}"...'.format(oport))
    sys.stdout.flush()
    if port == 'ocp':
        s.raw_command(0x32, 0x71, b'\x00\x01\x00')
    elif port == 'dedicated':
        s.raw_command(0x32, 0x71, b'\x00\x00\x00')
    else:
        raise Exception("Unsupported port for TSM")
    timer = 15
    while timer:
        timer = timer - 1
        time.sleep(1.0)
        sys.stdout.write('.')
        sys.stdout.flush()
    if port == 'ocp':
        iface = 0
        s.raw_command(0x32, 0x71, b'\x00\x00\x03')
    elif port == 'dedicated':
        iface = 1
        s.raw_command(0x32, 0x71, b'\x00\x01\x03')
    rsp = s.raw_command(0x32, 0x72, bytearray([4, iface, 0]))
    print('Complete')
    return int(rsp['data'][0])


def set_port_xcc(s, port, model):
    oport = port
    if port.lower() == 'dedicated':
        port = b'\x01'
    elif port.lower() in ('ml2', 'ocp'):
        port = b'\x02\x00'
    elif port.lower() == 'lom':
        if model == '7x58':
            port = b'\x00\x02'
        else:
            port = b'\x00\x00'
    else:
        port = port.split(' ')
        port = bytes(bytearray([int(x) for x in port]))
    currport = bytes(s.raw_command(0xc, 2, b'\x01\xc0\x00\x00')['data'][1:])
    if port == currport:
        sys.stdout.write('XCC port already set to "{}"\n'.format(oport))
        return
    sys.stdout.write('Setting XCC port to "{}"...'.format(oport))
    sys.stdout.flush()
    s.raw_command(0xc, 1, b'\x01\xc0' + port)
    tries = 60
    while currport != port and tries:
        tries -= 1
        time.sleep(0.5)
        sys.stdout.write('.')
        sys.stdout.flush()
        currport = bytes(s.raw_command(0xc, 2, b'\x01\xc0\x00\x00')['data'][1:])
    if not tries:
        raise Exception('Timeout attempting to set port')
    sys.stdout.write('Complete\n')


def check_vlan(s, vlan, channel):
    if vlan == 'off':
        vlan = b'\x00\x00'
    else:
        vlan = int(vlan)
        if vlan:
            vlan = vlan | 32768
        vlan = struct.pack('<H', vlan)
    currvlan = bytes(s.raw_command(0xc, 2, bytearray([channel, 0x14 ,0, 0]))['data'][1:])
    if bytearray(currvlan)[1] & 0b10000000 == 0:
        currvlan = b'\x00\x00'
    return currvlan == vlan


def set_vlan(s, vlan, channel):
    ovlan = vlan
    if vlan == 'off':
        vlan = b'\x00\x00'
    else:
        vlan = int(vlan)
        if vlan:
            vlan = vlan | 32768
        vlan = struct.pack('<H', vlan)
    if check_vlan(s, ovlan, channel):
        sys.stdout.write('VLAN already configured to "{0}"\n'.format(ovlan))
        return False
    rsp = s.raw_command(0xc, 1, bytearray([channel, 0x14]) + vlan)
    if rsp.get('code', 1) == 0:
        print('VLAN configured to "{}"'.format(ovlan))
    else:
        print('Error setting vlan: ' + repr(rsp))
    return True


def get_lan_channel(s):
    for chan in range(1, 16):
        rsp = s.raw_command(6, 0x42, bytearray([chan]))['data']
        if not rsp:
            continue
        medtype = int(rsp[1]) & 0b1111111
        if medtype not in (4, 6):
            continue
        rsp = s.raw_command(0xc, 2, bytearray([chan, 5, 0, 0]))
        if rsp.get('code', 1) == 0:
            return chan
    return 1


def check_ipv4(s, ipaddr, channel):
    ipaddr = bytearray(socket.inet_aton(ipaddr))
    rsp = s.raw_command(0xc, 2, bytearray([channel, 3, 0, 0]))['data'][-4:]
    return rsp == ipaddr

def set_ipv4(s, ipaddr, channel):
    oipaddr = ipaddr
    ipaddr = bytearray(socket.inet_aton(ipaddr))
    if check_ipv4(s, oipaddr, channel):
        print('IP Address already set to {}'.format(oipaddr))
        return False
    rsp = int(s.raw_command(0xc, 2, bytearray([channel, 4, 0, 0]))['data'][1]) & 0b1111
    if rsp != 1:
        sys.stdout.write("Changing configuration to static...")
        sys.stdout.flush()
        resp = s.raw_command(0xc, 1, bytearray([channel, 4, 1]))
        tries = 0
        while rsp != 1 and tries < 30:
            sys.stdout.write('.')
            sys.stdout.flush()
            tries += 1
            time.sleep(0.5)
            rsp = int(s.raw_command(0xc, 2, bytearray([channel, 4, 0, 0]))['data'][1]) & 0b1111
        sys.stdout.write('Complete\n')
        sys.stdout.flush()
    print('Setting IP to {}'.format(oipaddr))
    s.raw_command(0xc, 1, bytearray([channel, 3]) + ipaddr)
    return True


def check_subnet(s, prefix, channel):
    prefix = int(prefix)
    mask = bytearray(struct.pack('!I', (2**32 - 1) ^ (2**(32 - prefix) - 1)))
    rsp = s.raw_command(0xc, 2, bytearray([channel, 6, 0, 0]))['data'][-4:]
    return rsp == mask

def set_subnet(s, prefix, channel):
    oprefix = prefix
    prefix = int(prefix)
    mask = bytearray(struct.pack('!I', (2**32 - 1) ^ (2**(32 - prefix) - 1)))
    if check_subnet(s, prefix, channel):
        print('Subnet Mask already set to /{}'.format(oprefix))
        return False
    print('Setting subnet mask to /{}'.format(oprefix))
    s.raw_command(0xc, 1, bytearray([channel, 6]) + mask)
    return True


def check_gateway(s, gw, channel):
    gw = bytearray(socket.inet_aton(gw))
    rsp = s.raw_command(0xc, 2, bytearray([channel, 12, 0, 0]))['data'][-4:]
    return rsp == gw

def set_gateway(s, gw, channel):
    ogw = gw
    gw = bytearray(socket.inet_aton(gw))
    rsp = s.raw_command(0xc, 2, bytearray([channel, 12, 0, 0]))['data'][-4:]
    if check_gateway(s, ogw, channel):
        print('Gateway already set to {}'.format(ogw))
        return False
    print('Setting gateway to {}'.format(ogw))
    s.raw_command(0xc, 1, bytearray([channel, 12]) + gw)
    return True

def dotwait():
    sys.stdout.write('.')
    sys.stdout.flush()
    time.sleep(0.5)

def main():
    a = argparse.ArgumentParser(description='Locally configure a BMC device')
    a.add_argument('-v', help='vlan id or off to disable vlan tagging')
    a.add_argument('-p', help='Which port to use (dedicated, lom, ocp, ml2)')
    a.add_argument('-i', help='JSON configuration file to read for '
                   'configuration')
    a.add_argument('-c', help='Use Confluent API to direct BMC configuration',
                   action='store_true')
    args = a.parse_args()
    if args.i:
        bmccfg = json.load(open(args.i))
    elif args.c:
        bmccfgsrc = subprocess.check_output(
            [sys.executable, apiclient, '/confluent-api/self/bmcconfig', '-j'])
        bmccfg = json.loads(bmccfgsrc)
    else:
        bmccfg = {}
    if args.p is not None:
        bmccfg['bmcport'] = args.p
    if args.v is not None:
        bmccfg['bmcvlan'] = args.v
    vendor = open('/sys/devices/virtual/dmi/id/sys_vendor').read()
    vendor = vendor.strip()
    try:
        model = open('/sys/devices/virtual/dmi/id/product_sku').read().strip()
        if model == 'none':
            raise Exception('No SKU')
    except Exception:
        model = open('/sys/devices/virtual/dmi/id/product_name').read()
    if vendor in ('Lenovo', 'IBM'):
        if '[' in model and ']' in model:
            model = model.split('[')[1].split(']')[0]
        model = model[:4].lower()
    s = Session('/dev/ipmi0')
    if not bmccfg:
        print("No BMC configuration specified, exiting.")
        return
    if bmccfg.get('bmcport', None):
        channel = set_port(s, bmccfg['bmcport'], vendor, model)
    else:
        channel = get_lan_channel(s)
    awaitvlan = False
    awaitip = False
    awaitprefix = False
    awaitgw = False
    if bmccfg.get('bmcvlan', None):
        awaitvlan = set_vlan(s, bmccfg['bmcvlan'], channel)
    if bmccfg.get('bmcipv4', None):
        awaitip = set_ipv4(s, bmccfg['bmcipv4'], channel)
    if bmccfg.get('prefixv4', None):
        awaitprefix = set_subnet(s, bmccfg['prefixv4'], channel)
    if bmccfg.get('bmcgw', None):
        awaitgw = set_gateway(s, bmccfg['bmcgw'], channel)
    sys.stdout.write('Waiting for changes to take effect...')
    sys.stdout.flush()
    while awaitvlan and not check_vlan(s, bmccfg['bmcvlan'], channel):
        dotwait()
    while awaitip and not check_ipv4(s, bmccfg['bmcipv4'], channel):
        dotwait()
    while awaitprefix and not check_subnet(s, bmccfg['prefixv4'], channel):
        dotwait()
    while awaitgw and not check_gateway(s, bmccfg['bmcgw'], channel):
        dotwait()
    sys.stdout.write('done\n')
    sys.stdout.flush()
    cfgmod = get_remote_config_mod(vendor, model, (awaitip, awaitvlan, awaitprefix, awaitgw))
    if cfgmod:
        with open('configbmc.configmod', 'w+') as cm:
            cm.write('configmod: {0}\n'.format(cfgmod))
        sys.stdout.write('Requesting remote configuration of authentication...')
        sys.stdout.flush()
        bmccfgsrc = subprocess.check_output(
            [sys.executable, apiclient, '/confluent-api/self/remoteconfigbmc', 'configbmc.configmod'])
        sys.stdout.write('done\n')
        sys.stdout.flush()


if __name__ == '__main__':
    main()