2
0
mirror of https://github.com/xcat2/confluent.git synced 2025-08-29 14:28:18 +00:00

Have async and traditional client

Since a lot of the traditional client did not need async,
make life easier by just having them in parallel for now.

The server must use the async client, but the client applications can
stick with the somewhat more straightforward synchronous client.
This commit is contained in:
Jarrod Johnson
2024-05-29 12:23:05 -04:00
parent 4a2349d9ad
commit 4c3f93765f
10 changed files with 1291 additions and 272 deletions

View File

@@ -41,7 +41,7 @@
# esc-( would interfere with normal esc use too much
# ~ I will not use for now...
import asyncio
import math
import getpass
import optparse
import os
@@ -51,7 +51,6 @@ import signal
import socket
import struct
import sys
import concurrent.futures
import time
try:
import fcntl
@@ -235,23 +234,14 @@ session = None
def completer(text, state):
try:
with concurrent.futures.ThreadPoolExecutor() as executor:
future = executor.submit(run_completer, text, state)
return future.result()
except Exception:
return rcompleter(text, state)
except:
pass
import traceback
traceback.print_exc()
#import traceback
#traceback.print_exc()
def run_completer(text, state):
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
return loop.run_until_complete(
rcompleter(text, state))
async def rcompleter(text, state):
def rcompleter(text, state):
global candidates
global valid_commands
cline = readline.get_line_buffer()
@@ -281,7 +271,7 @@ async def rcompleter(text, state):
if candidates is None:
candidates = []
targpath = fullpath_target(lastarg)
async for res in session.read(targpath):
for res in session.read(targpath):
if 'item' in res: # a link relation
if type(res['item']) == dict:
candidates.append(res['item']["href"])
@@ -370,7 +360,7 @@ def print_result(res):
print(output.encode('utf-8'))
async def do_command(command, server):
def do_command(command, server):
global exitcode
global target
global currconsole
@@ -409,7 +399,7 @@ async def do_command(command, server):
target = otarget
else:
foundchild = False
async for res in session.read(parentpath, server):
for res in session.read(parentpath, server):
try:
if res['item']['href'] == childname:
foundchild = True
@@ -444,7 +434,7 @@ async def do_command(command, server):
pass
else:
targpath = target
async for res in session.read(targpath):
for res in session.read(targpath):
if 'item' in res: # a link relation
if type(res['item']) == dict:
print(res['item']["href"])
@@ -484,9 +474,9 @@ async def do_command(command, server):
startconsole(nodename)
return
elif argv[0] == 'set':
await setvalues(argv[1:])
setvalues(argv[1:])
elif argv[0] == 'create':
await createresource(argv[1:])
createresource(argv[1:])
elif argv[0] in ('rm', 'delete', 'remove'):
delresource(argv[1])
elif argv[0] in ('unset', 'clear'):
@@ -501,7 +491,7 @@ def shutdown():
tlvdata.send(session.connection, {'operation': 'shutdown', 'path': '/'})
async def createresource(args):
def createresource(args):
resname = args[0]
attribs = args[1:]
keydata = parameterize_attribs(attribs)
@@ -514,12 +504,12 @@ async def createresource(args):
collection, _, resname = targpath.rpartition('/')
if 'name' not in keydata:
keydata['name'] = resname
await makecall(session.create, (collection, keydata))
makecall(session.create, (collection, keydata))
async def makecall(callout, args):
def makecall(callout, args):
global exitcode
async for response in callout(*args):
for response in callout(*args):
if 'deleted' in response:
print("Deleted: " + response['deleted'])
if 'created' in response:
@@ -550,12 +540,12 @@ def clearvalues(resource, attribs):
sys.stderr.write('Error: ' + res['error'] + '\n')
async def delresource(resname):
def delresource(resname):
resname = fullpath_target(resname)
await makecall(session.delete, (resname,))
makecall(session.delete, (resname,))
async def setvalues(attribs):
def setvalues(attribs):
global exitcode
if '=' in attribs[0]: # going straight to attribute
resource = attribs[0][:attribs[0].index("=")]
@@ -569,7 +559,7 @@ async def setvalues(attribs):
if not keydata:
return
targpath = fullpath_target(resource)
async for res in session.update(targpath, keydata):
for res in session.update(targpath, keydata):
if 'error' in res:
if 'errorcode' in res:
exitcode = res['errorcode']
@@ -864,7 +854,7 @@ opts, shellargs = parser.parse_args()
username = None
passphrase = None
async def server_connect():
def server_connect():
global session, username, passphrase
if opts.controlpath:
termhandler.TermHandler(opts.controlpath)
@@ -874,7 +864,7 @@ async def server_connect():
session = client.Command(os.environ['CONFLUENT_HOST'])
else: # unix domain
session = client.Command()
await session.ensure_connected()
# Next stop, reading and writing from whichever of stdin and server goes first.
#see pyghmi code for solconnect.py
if not session.authenticated and username is not None:
@@ -900,14 +890,11 @@ if sys.stdout.isatty():
import readline
async def main():
def main():
global inconsole
global consoleonly
global doexit
global doexit
try:
await server_connect()
except (EOFError, KeyboardInterrupt):
server_connect()
except (EOFError, KeyboardInterrupt) as _:
raise BailOut(0)
except socket.gaierror:
sys.stderr.write('Could not connect to confluent\n')
@@ -929,13 +916,14 @@ async def main():
doexit = False
inconsole = False
pendingcommand = ""
session_node = get_session_node(shellargs)
if session_node is not None:
consoleonly = True
await do_command("start /nodes/%s/console/session" % session_node, netserver)
do_command("start /nodes/%s/console/session" % session_node, netserver)
doexit = True
elif shellargs:
await do_command(shellargs, netserver)
do_command(shellargs, netserver)
quitconfetty(fullexit=True, fixterm=False)
powerstate = None
@@ -966,7 +954,7 @@ async def main():
else:
currcommand = prompt()
try:
await do_command(currcommand, netserver)
do_command(currcommand, netserver)
except socket.error:
try:
server_connect()
@@ -1041,10 +1029,10 @@ if __name__ == '__main__':
if opts.mintime:
deadline = os.times()[4] + float(opts.mintime)
try:
asyncio.get_event_loop().run_until_complete(main())
main()
except BailOut as e:
errcode = e.errorcode
except Exception:
except Exception as e:
import traceback
excinfo = traceback.print_exc()
try:

View File

@@ -15,7 +15,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import asyncio
import optparse
import os
import signal
@@ -86,8 +85,4 @@ if options.previous:
def outhandler(node, res):
for k in res[node]:
client.cprint('{0}: {1}: {2}'.format(node, k.replace('inlet_', ''), res[node][k]))
async def main():
sys.exit(await session.simple_noderange_command(noderange, '/power/{0}'.format(powurl), setstate, promptover=options.maxnodes, key='state', outhandler=outhandler))
if __name__ == '__main__':
asyncio.get_event_loop().run_until_complete(main())
sys.exit(session.simple_noderange_command(noderange, '/power/{0}'.format(powurl), setstate, promptover=options.maxnodes, key='state', outhandler=outhandler))

View File

@@ -0,0 +1,827 @@
# vim: tabstop=4 shiftwidth=4 softtabstop=4
# Copyright 2014 IBM Corporation
# Copyright 2015-2019 Lenovo
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import asyncio
import ctypes
import ctypes.util
import dbm
import csv
import errno
import fnmatch
import hashlib
import os
import shlex
import socket
import ssl
import sys
import confluent.asynctlvdata as tlvdata
import confluent.sortutil as sortutil
libssl = ctypes.CDLL(ctypes.util.find_library('ssl'))
libssl.SSL_CTX_set_cert_verify_callback.argtypes = [
ctypes.c_void_p, ctypes.c_void_p, ctypes.c_void_p]
SO_PASSCRED = 16
_attraliases = {
'bmc': 'hardwaremanagement.manager',
'bmcuser': 'secret.hardwaremanagementuser',
'switchuser': 'secret.hardwaremanagementuser',
'bmcpass': 'secret.hardwaremanagementpassword',
'switchpass': 'secret.hardwaremanagementpassword',
}
try:
getinput = raw_input
except NameError:
getinput = input
class PyObject_HEAD(ctypes.Structure):
_fields_ = [
("ob_refcnt", ctypes.c_ssize_t),
("ob_type", ctypes.c_void_p),
]
# see main/Modules/_ssl.c, only caring about the SSL_CTX pointer
class PySSLContext(ctypes.Structure):
_fields_ = [
("ob_base", PyObject_HEAD),
("ctx", ctypes.c_void_p),
]
@ctypes.CFUNCTYPE(ctypes.c_int, ctypes.c_void_p, ctypes.c_void_p)
def verify_stub(store, misc):
return 1
class NestedDict(dict):
def __missing__(self, key):
value = self[key] = type(self)()
return value
def stringify(instr):
# Normalize unicode and bytes to 'str', correcting for
# current python version
if isinstance(instr, bytes) and not isinstance(instr, str):
return instr.decode('utf-8')
elif not isinstance(instr, bytes) and not isinstance(instr, str):
return instr.encode('utf-8')
return instr
class Tabulator(object):
def __init__(self, headers):
self.headers = headers
self.rows = []
def add_row(self, row):
self.rows.append(row)
def get_table(self, order=None):
i = 0
fmtstr = ''
separator = []
for head in self.headers:
if order and order == head:
order = i
neededlen = len(head)
for row in self.rows:
if len(row[i]) > neededlen:
neededlen = len(row[i])
separator.append('-' * (neededlen + 1))
fmtstr += '{{{0}:>{1}}}|'.format(i, neededlen + 1)
i = i + 1
fmtstr = fmtstr[:-1]
yield fmtstr.format(*self.headers)
yield fmtstr.format(*separator)
if order is not None:
for row in sorted(
self.rows,
key=lambda x: sortutil.naturalize_string(x[order])):
yield fmtstr.format(*row)
else:
for row in self.rows:
yield fmtstr.format(*row)
def write_csv(self, output, order=None):
output = csv.writer(output)
output.writerow(self.headers)
i = 0
for head in self.headers:
if order and order == head:
order = i
i = i + 1
if order is not None:
for row in sorted(
self.rows,
key=lambda x: sortutil.naturalize_string(x[order])):
output.writerow(row)
else:
for row in self.rows:
output.writerow(row)
def printerror(res, node=None):
exitcode = 0
if 'errorcode' in res:
exitcode = res['errorcode']
for node in res.get('databynode', {}):
exitcode = res['databynode'][node].get('errorcode', exitcode)
if 'error' in res['databynode'][node]:
sys.stderr.write(
'{0}: {1}\n'.format(node, res['databynode'][node]['error']))
if exitcode == 0:
exitcode = 1
if 'error' in res:
if node:
sys.stderr.write('{0}: {1}\n'.format(node, res['error']))
else:
sys.stderr.write('{0}\n'.format(res['error']))
if 'errorcode' not in res:
exitcode = 1
return exitcode
def cprint(txt):
try:
print(txt)
except UnicodeEncodeError:
print(txt.encode('utf8'))
sys.stdout.flush()
def _parseserver(string):
if ']:' in string:
server, port = string[1:].split(']:')
elif string[0] == '[':
server = string[1:-1]
port = '13001'
elif ':' in string:
server, port = string.split(':')
else:
server = string
port = '13001'
return server, port
class Command(object):
def __init__(self, server=None):
self._prevdict = None
self._prevkeyname = None
self.connection = None
self._currnoderange = None
self.unixdomain = False
if server is None:
if 'CONFLUENT_HOST' in os.environ:
self.serverloc = os.environ['CONFLUENT_HOST']
else:
self.serverloc = '/var/run/confluent/api.sock'
else:
self.serverloc = server
self.connected = False
async def ensure_connected(self):
if self.connected:
return True
if os.path.isabs(self.serverloc) and os.path.exists(self.serverloc):
self._connect_unix()
self.unixdomain = True
elif self.serverloc == '/var/run/confluent/api.sock':
raise Exception('Confluent service is not available')
else:
await self._connect_tls()
self.protversion = int((await tlvdata.recv(self.connection)).split(
b'--')[1].strip()[1:])
authdata = await tlvdata.recv(self.connection)
if authdata['authpassed'] == 1:
self.authenticated = True
else:
self.authenticated = False
if not self.authenticated and 'CONFLUENT_USER' in os.environ:
username = os.environ['CONFLUENT_USER']
passphrase = os.environ['CONFLUENT_PASSPHRASE']
await self.authenticate(username, passphrase)
self.connected = True
async def add_file(self, name, handle, mode):
await self.ensure_connected()
if self.protversion < 3:
raise Exception('Not supported with connected confluent server')
if not self.unixdomain:
raise Exception('Can only add a file to a unix domain connection')
tlvdata.send(self.connection, {'filename': name, 'mode': mode}, handle)
async def authenticate(self, username, password):
await tlvdata.send(self.connection,
{'username': username, 'password': password})
authdata = await tlvdata.recv(self.connection)
if authdata['authpassed'] == 1:
self.authenticated = True
def add_precede_key(self, keyname):
self._prevkeyname = keyname
def add_precede_dict(self, dict):
self._prevdict = dict
def handle_results(self, ikey, rc, res, errnodes=None, outhandler=None):
if 'error' in res:
if errnodes is not None:
errnodes.add(self._currnoderange)
sys.stderr.write('Error: {0}\n'.format(res['error']))
if 'errorcode' in res:
return res['errorcode']
else:
return 1
if 'databynode' not in res:
return 0
res = res['databynode']
for node in res:
if 'error' in res[node]:
if errnodes is not None:
errnodes.add(node)
sys.stderr.write('{0}: Error: {1}\n'.format(
node, res[node]['error']))
if 'errorcode' in res[node]:
rc |= res[node]['errorcode']
else:
rc |= 1
elif ikey in res[node]:
if 'value' in res[node][ikey]:
val = res[node][ikey]['value']
elif 'isset' in res[node][ikey]:
val = '********' if res[node][ikey] else ''
else:
val = repr(res[node][ikey])
if self._prevkeyname and self._prevkeyname in res[node]:
cprint('{0}: {2}->{1}'.format(
node, val, res[node][self._prevkeyname]['value']))
elif self._prevdict and node in self._prevdict:
cprint('{0}: {2}->{1}'.format(
node, val, self._prevdict[node]))
else:
cprint('{0}: {1}'.format(node, val))
elif outhandler:
outhandler(node, res)
return rc
async def simple_noderange_command(self, noderange, resource, input=None,
key=None, errnodes=None, promptover=None, outhandler=None, **kwargs):
try:
self._currnoderange = noderange
rc = 0
if resource[0] == '/':
resource = resource[1:]
# The implicit key is the resource basename
if key is None:
ikey = resource.rpartition('/')[-1]
else:
ikey = key
if input is None:
async for res in self.read('/noderange/{0}/{1}'.format(
noderange, resource)):
rc = self.handle_results(ikey, rc, res, errnodes, outhandler)
else:
await self.stop_if_noderange_over(noderange, promptover)
kwargs[ikey] = input
async for res in self.update('/noderange/{0}/{1}'.format(
noderange, resource), kwargs):
rc = self.handle_results(ikey, rc, res, errnodes, outhandler)
self._currnoderange = None
return rc
except KeyboardInterrupt:
cprint('')
return 0
async def stop_if_noderange_over(self, noderange, maxnodes):
if maxnodes is None:
return
nsize = await self.get_noderange_size(noderange)
if nsize > maxnodes:
if nsize == 1:
nodename = [x async for x in self.read(
'/noderange/{0}/nodes/'.format(noderange))][0].get('item', {}).get('href', None)
nodename = nodename[:-1]
p = getinput('Command is about to affect node {0}, continue (y/n)? '.format(nodename))
else:
p = getinput('Command is about to affect {0} nodes, continue (y/n)? '.format(nsize))
if p.lower() != 'y':
sys.stderr.write('Aborting at user request\n')
sys.exit(1)
raise Exception("Aborting at user request")
async def get_noderange_size(self, noderange):
numnodes = 0
async for node in self.read('/noderange/{0}/nodes/'.format(noderange)):
if node.get('item', {}).get('href', None):
numnodes += 1
else:
raise Exception("Error trying to size noderange {0}".format(noderange))
return numnodes
async def simple_nodegroups_command(self, noderange, resource, input=None, key=None, **kwargs):
try:
rc = 0
if resource[0] == '/':
resource = resource[1:]
# The implicit key is the resource basename
if key is None:
ikey = resource.rpartition('/')[-1]
else:
ikey = key
if input is None:
for res in await self.read('/nodegroups/{0}/{1}'.format(
noderange, resource)):
rc = self.handle_results(ikey, rc, res)
else:
kwargs[ikey] = input
for res in await self.update('/nodegroups/{0}/{1}'.format(
noderange, resource), kwargs):
rc = self.handle_results(ikey, rc, res)
return rc
except KeyboardInterrupt:
cprint('')
return 0
async def read(self, path, parameters=None):
await self.ensure_connected()
if not self.authenticated:
raise Exception('Unauthenticated')
async for rsp in send_request(
'retrieve', path, self.connection, parameters):
yield rsp
async def update(self, path, parameters=None):
await self.ensure_connected()
if not self.authenticated:
raise Exception('Unauthenticated')
async for rsp in send_request(
'update', path, self.connection, parameters):
yield rsp
async def create(self, path, parameters=None):
await self.ensure_connected()
if not self.authenticated:
raise Exception('Unauthenticated')
async for rsp in send_request(
'create', path, self.connection, parameters):
yield rsp
async def delete(self, path, parameters=None):
await self.ensure_connected()
if not self.authenticated:
raise Exception('Unauthenticated')
async for rsp in send_request(
'delete', path, self.connection, parameters):
yield rsp
def _connect_unix(self):
self.connection = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
self.connection.setsockopt(socket.SOL_SOCKET, SO_PASSCRED, 1)
self.connection.connect(self.serverloc)
async def _connect_tls(self):
server, port = _parseserver(self.serverloc)
for res in socket.getaddrinfo(server, port, socket.AF_UNSPEC,
socket.SOCK_STREAM):
af, socktype, proto, canonname, sa = res
try:
self.connection = socket.socket(af, socktype, proto)
self.connection.setsockopt(
socket.IPPROTO_TCP, socket.TCP_NODELAY, 1)
except:
self.connection = None
continue
try:
self.connection.settimeout(5)
self.connection.connect(sa)
self.connection.settimeout(0)
except:
raise
self.connection.close()
self.connection = None
continue
break
if self.connection is None:
raise Exception("Failed to connect to %s" % self.serverloc)
#TODO(jbjohnso): server certificate validation
clientcfgdir = os.path.join(os.path.expanduser("~"), ".confluent")
try:
os.makedirs(clientcfgdir)
except OSError as exc:
if not (exc.errno == errno.EEXIST and os.path.isdir(clientcfgdir)):
raise
cacert = os.path.join(clientcfgdir, "ca.pem")
certreqs = ssl.CERT_REQUIRED
knownhosts = False
if not os.path.exists(cacert):
cacert = None
certreqs = ssl.CERT_NONE
knownhosts = True
ctx = ssl.SSLContext(ssl.PROTOCOL_TLSv1_2)
ssl_ctx = PySSLContext.from_address(id(ctx)).ctx
libssl.SSL_CTX_set_cert_verify_callback(ssl_ctx, verify_stub, 0)
sreader = asyncio.StreamReader()
sreaderprot = asyncio.StreamReaderProtocol(sreader)
cloop = asyncio.get_event_loop()
tport, _ = await cloop.create_connection(
lambda: sreaderprot, sock=self.connection, ssl=ctx, server_hostname='x')
swriter = asyncio.StreamWriter(tport, sreaderprot, sreader, cloop)
self.connection = (sreader, swriter)
#self.connection = ssl.wrap_socket(self.connection, ca_certs=cacert,
# cert_reqs=certreqs)
if knownhosts:
certdata = tport.get_extra_info('ssl_object').getpeercert(binary_form=True)
# certdata = self.connection.getpeercert(binary_form=True)
fingerprint = 'sha512$' + hashlib.sha512(certdata).hexdigest()
fingerprint = fingerprint.encode('utf-8')
hostid = '@'.join((port, server))
khf = dbm.open(os.path.join(clientcfgdir, "knownhosts"), 'c', 384)
if hostid in khf:
if fingerprint == khf[hostid]:
return
else:
replace = getinput(
"MISMATCHED CERTIFICATE DATA, ACCEPT NEW? (y/n):")
if replace not in ('y', 'Y'):
raise Exception("BAD CERTIFICATE")
cprint('Adding new key for %s:%s' % (server, port))
khf[hostid] = fingerprint
async def send_request(operation, path, server, parameters=None):
"""This function iterates over all the responses
received from the server.
:param operation: The operation to request, retrieve, update, delete,
create, start, stop
:param path: The URI path to the resource to operate on
:param server: The socket to send data over
:param parameters: Parameters if any to send along with the request
"""
payload = {'operation': operation, 'path': path}
if parameters is not None:
payload['parameters'] = parameters
await tlvdata.send(server, payload)
result = await tlvdata.recv(server)
while '_requestdone' not in result:
try:
yield result
except GeneratorExit:
while '_requestdone' not in result:
result = await tlvdata.recv(server)
raise
result = await tlvdata.recv(server)
def attrrequested(attr, attrlist, seenattributes, node=None):
for candidate in attrlist:
truename = candidate
if candidate.startswith('hm'):
candidate = candidate.replace('hm', 'hardwaremanagement', 1)
if candidate in _attraliases:
candidate = _attraliases[candidate]
if fnmatch.fnmatch(attr.lower(), candidate.lower()):
if node is None:
seenattributes.add(truename)
else:
seenattributes[node][truename] = True
return True
elif attr.lower().startswith(candidate.lower() + '.'):
if node is None:
seenattributes.add(truename)
else:
seenattributes[node][truename] = 1
return True
return False
async def printattributes(session, requestargs, showtype, nodetype, noderange, options):
path = '/{0}/{1}/attributes/{2}'.format(nodetype, noderange, showtype)
return await print_attrib_path(path, session, requestargs, options)
def _sort_attrib(k):
if isinstance(k[1], dict) and k[1].get('sortid', None) is not None:
return k[1]['sortid']
return k[0]
async def print_attrib_path(path, session, requestargs, options, rename=None, attrprefix=None):
exitcode = 0
seenattributes = NestedDict()
allnodes = set([])
async for res in session.read(path):
if 'error' in res:
sys.stderr.write(res['error'] + '\n')
exitcode = 1
continue
for node in sorted(res['databynode']):
allnodes.add(node)
for attr, val in sorted(res['databynode'][node].items(), key=_sort_attrib):
if attr == 'error':
sys.stderr.write('{0}: Error: {1}\n'.format(node, val))
continue
if attr == 'errorcode':
exitcode |= val
continue
seenattributes[node][attr] = True
if rename:
printattr = rename.get(attr, attr)
else:
printattr = attr
if attrprefix:
printattr = attrprefix + printattr
currattr = res['databynode'][node][attr]
if show_attr(attr, requestargs, seenattributes, options, node):
if 'value' in currattr:
if currattr['value'] is not None:
val = currattr['value']
if isinstance(val, list):
val = ','.join(val)
attrout = '{0}: {1}: {2}'.format(
node, printattr, val).strip()
else:
attrout = '{0}: {1}:'.format(node, printattr)
elif 'isset' in currattr:
if currattr['isset']:
attrout = '{0}: {1}: ********'.format(node,
printattr)
else:
attrout = '{0}: {1}:'.format(node, printattr)
elif isinstance(currattr, dict) and 'broken' in currattr:
attrout = '{0}: {1}: *ERROR* BROKEN EXPRESSION: ' \
'{2}'.format(node, printattr,
currattr['broken'])
elif isinstance(currattr, list) or isinstance(currattr, tuple):
attrout = '{0}: {1}: {2}'.format(node, attr, ','.join(map(str, currattr)))
elif isinstance(currattr, dict):
dictout = []
for k, v in currattr.items:
dictout.append("{0}={1}".format(k, v))
attrout = '{0}: {1}: {2}'.format(node, printattr, ','.join(map(str, dictout)))
else:
cprint("CODE ERROR" + repr(attr))
try:
blame = options.blame
except AttributeError:
blame = False
if blame or (isinstance(currattr, dict) and 'broken' in currattr):
blamedata = []
if 'inheritedfrom' in currattr:
blamedata.append('inherited from group {0}'.format(
currattr['inheritedfrom']
))
if 'expression' in currattr:
blamedata.append(
'derived from expression "{0}"'.format(
currattr['expression']))
if blamedata:
attrout += ' (' + ', '.join(blamedata) + ')'
try:
comparedefault = options.comparedefault
except AttributeError:
comparedefault = False
if comparedefault:
try:
exclude = options.exclude
except AttributeError:
exclude = False
if ((requestargs and not exclude) or
(currattr.get('default', None) is not None and
currattr.get('value', None) is not None and
currattr['value'] != currattr['default'])):
cval = ','.join(currattr['value']) if isinstance(
currattr['value'], list) else currattr['value']
dval = ','.join(currattr['default']) if isinstance(
currattr['default'], list) else currattr['default']
cprint('{0}: {1}: {2} (Default: {3})'.format(
node, printattr, cval, dval))
else:
try:
details = options.detail
except AttributeError:
details = False
if details:
if currattr.get('help', None):
attrout += u' (Help: {0})'.format(
currattr['help'])
if currattr.get('possible', None):
try:
attrout += u' (Choices: {0})'.format(
','.join(currattr['possible']))
except TypeError:
pass
cprint(attrout)
somematched = set([])
printmissing = set([])
badnodes = NestedDict()
if not exitcode:
if requestargs:
for attr in requestargs:
for node in allnodes:
if attr in seenattributes[node]:
somematched.add(attr)
else:
badnodes[node][attr] = True
exitcode = 1
for node in sortutil.natural_sort(badnodes):
for attr in badnodes[node]:
if attr in somematched:
sys.stderr.write(
'Error: {0} matches no valid value for {1}\n'.format(
attr, node))
else:
printmissing.add(attr)
for missing in printmissing:
sys.stderr.write('Error: {0} not a valid attribute\n'.format(missing))
return exitcode
def show_attr(attr, requestargs, seenattributes, options, node):
try:
reverse = options.exclude
except AttributeError:
reverse = False
if requestargs is None or requestargs == []:
return True
processattr = attrrequested(attr, requestargs, seenattributes, node)
if reverse:
processattr = not processattr
return processattr
def printgroupattributes(session, requestargs, showtype, nodetype, noderange, options):
exitcode = 0
seenattributes = set([])
for res in session.read('/{0}/{1}/attributes/{2}'.format(nodetype, noderange, showtype)):
if 'error' in res:
sys.stderr.write(res['error'] + '\n')
exitcode = 1
continue
for attr in res:
seenattributes.add(attr)
currattr = res[attr]
if (requestargs is None or requestargs == [] or attrrequested(attr, requestargs, seenattributes)):
if 'value' in currattr:
if currattr['value'] is not None:
attrout = '{0}: {1}: {2}'.format(
noderange, attr, currattr['value'])
else:
attrout = '{0}: {1}:'.format(noderange, attr)
elif 'isset' in currattr:
if currattr['isset']:
attrout = '{0}: {1}: ********'.format(noderange, attr)
else:
attrout = '{0}: {1}:'.format(noderange, attr)
elif isinstance(currattr, dict) and 'broken' in currattr:
attrout = '{0}: {1}: *ERROR* BROKEN EXPRESSION: ' \
'{2}'.format(noderange, attr,
currattr['broken'])
elif 'expression' in currattr:
attrout = '{0}: {1}: (will derive from expression {2})'.format(noderange, attr, currattr['expression'])
elif isinstance(currattr, list) or isinstance(currattr, tuple):
attrout = '{0}: {1}: {2}'.format(noderange, attr, ','.join(map(str, currattr)))
elif isinstance(currattr, dict):
dictout = []
for k, v in currattr.items:
dictout.append("{0}={1}".format(k, v))
attrout = '{0}: {1}: {2}'.format(noderange, attr, ','.join(map(str, dictout)))
else:
cprint("CODE ERROR" + repr(attr))
cprint(attrout)
if not exitcode:
if requestargs:
for attr in requestargs:
if attr not in seenattributes:
sys.stderr.write('Error: {0} not a valid attribute\n'.format(attr))
exitcode = 1
return exitcode
async def updateattrib(session, updateargs, nodetype, noderange, options, dictassign=None):
# update attribute
exitcode = 0
if options.clear:
targpath = '/{0}/{1}/attributes/all'.format(nodetype, noderange)
keydata = {}
for attrib in updateargs[1:]:
keydata[attrib] = None
async for res in session.update(targpath, keydata):
for node in res.get('databynode', {}):
for warnmsg in res['databynode'][node].get('_warnings', []):
sys.stderr.write('Warning: ' + warnmsg + '\n')
if 'error' in res:
if 'errorcode' in res:
exitcode = res['errorcode']
sys.stderr.write('Error: ' + res['error'] + '\n')
sys.exit(exitcode)
elif hasattr(options, 'environment') and options.environment:
for key in updateargs[1:]:
key = key.replace('.', '_')
value = os.environ.get(
key, os.environ[key.upper()])
# Let's do one pass to make sure that there's not a usage problem
for key in updateargs[1:]:
key = key.replace('.', '_')
value = os.environ.get(
key, os.environ[key.upper()])
if (nodetype == "nodegroups"):
exitcode = await session.simple_nodegroups_command(noderange,
'attributes/all',
value, key)
else:
exitcode = await session.simple_noderange_command(noderange,
'attributes/all',
value, key)
sys.exit(exitcode)
elif dictassign:
for key in dictassign:
if nodetype == 'nodegroups':
exitcode = await session.simple_nodegroups_command(
noderange, 'attributes/all', dictassign[key], key)
else:
exitcode = await session.simple_noderange_command(
noderange, 'attributes/all', dictassign[key], key)
else:
if "=" in updateargs[1]:
update_ready = True
for arg in updateargs[1:]:
if not '=' in arg:
update_ready = False
exitcode = 1
if not update_ready:
sys.stderr.write('Error: {0} Can not set and read at the same time!\n'.format(str(updateargs[1:])))
sys.exit(exitcode)
try:
for val in updateargs[1:]:
val = val.split('=', 1)
if val[0][-1] in (',', '-', '^'):
key = val[0][:-1]
if val[0][-1] == ',':
value = {'prepend': val[1]}
elif val[0][-1] in ('-', '^'):
value = {'remove': val[1]}
else:
key = val[0]
value = val[1]
if (nodetype == "nodegroups"):
exitcode = await session.simple_nodegroups_command(noderange, 'attributes/all',
value, key)
else:
exitcode = await session.simple_noderange_command(noderange, 'attributes/all',
value, key)
except Exception:
sys.stderr.write('Error: {0} not a valid expression\n'.format(str(updateargs[1:])))
exitcode = 1
sys.exit(exitcode)
return exitcode
# So we try to prevent bad things from happening when globbing
# We tried to head this off at the shell, but the various solutions would end
# up breaking the shell in various ways (breaking pipe capability if using
# DEBUG, breaking globbing if in pipe, etc)
# Then we tried to parse the original commandline instead, however shlex isn't
# going to parse full bourne language (e.g. knowing that '|' and '>' and
# a world of other things would not be in our command line
# so finally, just make sure the noderange appears verbatim in the command line
# if we glob to something, then bash will change noderange and this should
# detect it and save the user from tragedy
def check_globbing(noderange):
if not os.path.exists(noderange):
return True
rawargs = os.environ.get('CURRENT_CMDLINE', None)
if rawargs:
rawargs = shlex.split(rawargs)
for arg in rawargs:
if arg.startswith('$'):
arg = arg[1:]
if arg.endswith(';'):
arg = arg[:-1]
arg = os.environ.get(arg, '$' + arg)
if arg.startswith(noderange):
break
else:
sys.stderr.write(
'Shell glob conflict detected, specified target "{0}" '
'not in command line, but is a file. You can use "set -f" in '
'bash or change directories such that there is no filename '
'that would conflict.'
'\n'.format(noderange))
sys.exit(1)

View File

@@ -0,0 +1,318 @@
# vim: tabstop=4 shiftwidth=4 softtabstop=4
# Copyright 2014 IBM Corporation
# Copyright 2015 Lenovo
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import array
import asyncio
import ctypes
import ctypes.util
import confluent.tlv as tlv
import socket
from datetime import datetime
import json
import os
import struct
try:
unicode
except NameError:
unicode = str
try:
range = xrange
except NameError:
pass
class iovec(ctypes.Structure): # from uio.h
_fields_ = [('iov_base', ctypes.c_void_p),
('iov_len', ctypes.c_size_t)]
iovec_ptr = ctypes.POINTER(iovec)
class cmsghdr(ctypes.Structure): # also from bits/socket.h
_fields_ = [('cmsg_len', ctypes.c_size_t),
('cmsg_level', ctypes.c_int),
('cmsg_type', ctypes.c_int)]
@classmethod
def init_data(cls, cmsg_len, cmsg_level, cmsg_type, cmsg_data):
Data = ctypes.c_ubyte * ctypes.sizeof(cmsg_data)
class _flexhdr(ctypes.Structure):
_fields_ = cls._fields_ + [('cmsg_data', Data)]
datab = Data(*bytearray(cmsg_data))
return _flexhdr(cmsg_len=cmsg_len, cmsg_level=cmsg_level,
cmsg_type=cmsg_type, cmsg_data=datab)
def CMSG_LEN(length):
sizeof_cmshdr = ctypes.sizeof(cmsghdr)
return ctypes.c_size_t(CMSG_ALIGN(sizeof_cmshdr).value + length)
SCM_RIGHTS = 1
class msghdr(ctypes.Structure): # from bits/socket.h
_fields_ = [('msg_name', ctypes.c_void_p),
('msg_namelen', ctypes.c_uint),
('msg_iov', ctypes.POINTER(iovec)),
('msg_iovlen', ctypes.c_size_t),
('msg_control', ctypes.c_void_p),
('msg_controllen', ctypes.c_size_t),
('msg_flags', ctypes.c_int)]
def CMSG_ALIGN(length): # bits/socket.h
ret = (length + ctypes.sizeof(ctypes.c_size_t) - 1
& ~(ctypes.sizeof(ctypes.c_size_t) - 1))
return ctypes.c_size_t(ret)
def CMSG_SPACE(length): # bits/socket.h
ret = CMSG_ALIGN(length).value + CMSG_ALIGN(ctypes.sizeof(cmsghdr)).value
return ctypes.c_size_t(ret)
class ClientFile(object):
def __init__(self, name, mode, fd):
self.fileobject = os.fdopen(fd, mode)
self.filename = name
def _sendmsg(loop, fut, sock, msg, fds, rfd):
if rfd is not None:
loop.remove_reader(rfd)
if fut.cancelled():
return
try:
retdata = sock.sendmsg(
[msg],
[(socket.SOL_SOCKET, socket.SCM_RIGHTS, array.array("i", fds))])
except (BlockingIOError, InterruptedError):
fd = sock.fileno()
loop.add_reader(fd, _sendmsg, loop, fut, sock, fd)
except Exception as exc:
fut.set_exception(exc)
else:
fut.set_result(retdata)
def send_fds(sock, msg, fds):
cloop = asyncio.get_event_loop()
fut = cloop.create_future()
_sendmsg(cloop, fut, sock, msg, fds, None)
return fut
def _recvmsg(loop, fut, sock, msglen, maxfds, rfd):
if rfd is not None:
loop.remove_reader(rfd)
fds = array.array("i") # Array of ints
try:
msg, ancdata, flags, addr = sock.recvmsg(
msglen, socket.CMSG_LEN(maxfds * fds.itemsize))
except (BlockingIOError, InterruptedError):
fd = sock.fileno()
loop.add_reader(fd, _recvmsg, loop, fut, sock, fd)
except Exception as exc:
fut.set_exception(exc)
else:
for cmsg_level, cmsg_type, cmsg_data in ancdata:
if (cmsg_level == socket.SOL_SOCKET
and cmsg_type == socket.SCM_RIGHTS):
# Append data, ignoring any truncated integers at the end.
fds.frombytes(
cmsg_data[
:len(cmsg_data) - (len(cmsg_data) % fds.itemsize)])
fut.set_result(msglen, list(fds))
def recv_fds(sock, msglen, maxfds):
cloop = asyncio.get_event_loop()
fut = cloop.create_future()
_recvmsg(cloop, fut, sock, msglen, maxfds, None)
return fut
def decodestr(value):
ret = None
try:
ret = value.decode('utf-8')
except UnicodeDecodeError:
try:
ret = value.decode('cp437')
except UnicodeDecodeError:
ret = value
except AttributeError:
return value
return ret
def unicode_dictvalues(dictdata):
for key in dictdata:
if isinstance(dictdata[key], bytes):
dictdata[key] = decodestr(dictdata[key])
elif isinstance(dictdata[key], datetime):
dictdata[key] = dictdata[key].strftime('%Y-%m-%dT%H:%M:%S')
elif isinstance(dictdata[key], list):
_unicode_list(dictdata[key])
elif isinstance(dictdata[key], dict):
unicode_dictvalues(dictdata[key])
def _unicode_list(currlist):
for i in range(len(currlist)):
if isinstance(currlist[i], str):
currlist[i] = decodestr(currlist[i])
elif isinstance(currlist[i], dict):
unicode_dictvalues(currlist[i])
elif isinstance(currlist[i], list):
_unicode_list(currlist[i])
async def sendall(handle, data):
if isinstance(handle, tuple):
handle[1].write(data)
return await handle[1].drain()
else:
cloop = asyncio.get_event_loop()
return await cloop.sock_sendall(handle, data)
async def send(handle, data, filehandle=None):
cloop = asyncio.get_event_loop()
if isinstance(data, unicode):
try:
data = data.encode('utf-8')
except AttributeError:
pass
if isinstance(data, bytes) or isinstance(data, unicode):
# plain text, e.g. console data
tl = len(data)
if tl == 0:
# if you don't have anything to say, don't say anything at all
return
if tl < 16777216:
# type for string is '0', so we don't need
# to xor anything in
await sendall(handle, struct.pack("!I", tl))
else:
raise Exception("String data length exceeds protocol")
await sendall(handle, data)
elif isinstance(data, dict): # JSON currently only goes to 4 bytes
# Some structured message, like what would be seen in http responses
unicode_dictvalues(data) # make everything unicode, assuming UTF-8
sdata = json.dumps(data, ensure_ascii=False, separators=(',', ':'))
sdata = sdata.encode('utf-8')
tl = len(sdata)
if tl > 16777215:
raise Exception("JSON data exceeds protocol limits")
# xor in the type (0b1 << 24)
if filehandle is None:
tl |= 16777216
await sendall(handle, struct.pack("!I", tl))
await sendall(handle, sdata)
elif isinstance(handle, tuple):
raise Exception("Cannot send filehandle over network socket")
else:
tl |= (2 << 24)
await cloop.sock_sendall(handle, struct.pack("!I", tl))
await send_fds(handle, b'', [filehandle])
async def _grabhdl(handle, size):
if isinstance(handle, tuple):
return await handle[0].read(size)
else:
cloop = asyncio.get_event_loop()
return await cloop.sock_recv(handle, size)
async def recvall(handle, size):
rd = await _grabhdl(handle, size)
while len(rd) < size:
nd = await _grabhdl(handle, size - len(rd))
if not nd:
raise Exception("Error reading data")
rd += nd
return rd
async def recv(handle):
tl = await _grabhdl(handle, 4)
if not tl:
return None
while len(tl) < 4:
ndata = await _grabhdl(handle, 4 - len(tl))
if not ndata:
raise Exception("Error reading data")
tl += ndata
if len(tl) == 0:
return None
tl = struct.unpack("!I", tl)[0]
if tl & 0b10000000000000000000000000000000:
raise Exception("Protocol Violation, reserved bit set")
# 4 byte tlv
dlen = tl & 16777215 # grab lower 24 bits
datatype = (tl & 2130706432) >> 24 # grab 7 bits from near beginning
if dlen == 0:
return None
if datatype == tlv.Types.filehandle:
if isinstance(handle, tuple):
raise Exception('Filehandle not supported over TLS socket')
filehandles = array.array('i')
rawbuffer = bytearray(2048)
pkttype = ctypes.c_ubyte * 2048
data = pkttype.from_buffer(rawbuffer)
cmsgsize = CMSG_SPACE(ctypes.sizeof(ctypes.c_int)).value
cmsgarr = bytearray(cmsgsize)
cmtype = ctypes.c_ubyte * cmsgsize
cmsg = cmtype.from_buffer(cmsgarr)
cmsg.cmsg_level = socket.SOL_SOCKET
cmsg.cmsg_type = SCM_RIGHTS
cmsg.cmsg_len = CMSG_LEN(ctypes.sizeof(ctypes.c_int))
iov = iovec()
iov.iov_base = ctypes.addressof(data)
iov.iov_len = 2048
msg = msghdr()
msg.msg_iov = ctypes.pointer(iov)
msg.msg_iovlen = 1
msg.msg_control = ctypes.addressof(cmsg)
msg.msg_controllen = ctypes.sizeof(cmsg)
i = await recv_fds(handle, 2048, 4)
print(repr(i))
data = i[0]
filehandles = i[1]
data = json.loads(bytes(data))
return ClientFile(data['filename'], data['mode'], filehandles[0])
else:
data = await _grabhdl(handle, dlen)
while len(data) < dlen:
ndata = await _grabhdl(handle, dlen - len(data))
if not ndata:
raise Exception("Error reading data")
data += ndata
if datatype == tlv.Types.text:
return data
elif datatype == tlv.Types.json:
return json.loads(data)

View File

@@ -15,10 +15,10 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import asyncio
import ctypes
import ctypes.util
import dbm
try:
import anydbm as dbm
except ImportError:
import dbm
import csv
import errno
import fnmatch
@@ -30,9 +30,6 @@ import ssl
import sys
import confluent.tlvdata as tlvdata
import confluent.sortutil as sortutil
libssl = ctypes.CDLL(ctypes.util.find_library('ssl'))
libssl.SSL_CTX_set_cert_verify_callback.argtypes = [
ctypes.c_void_p, ctypes.c_void_p, ctypes.c_void_p]
SO_PASSCRED = 16
@@ -50,26 +47,6 @@ except NameError:
getinput = input
class PyObject_HEAD(ctypes.Structure):
_fields_ = [
("ob_refcnt", ctypes.c_ssize_t),
("ob_type", ctypes.c_void_p),
]
# see main/Modules/_ssl.c, only caring about the SSL_CTX pointer
class PySSLContext(ctypes.Structure):
_fields_ = [
("ob_base", PyObject_HEAD),
("ctx", ctypes.c_void_p),
]
@ctypes.CFUNCTYPE(ctypes.c_int, ctypes.c_void_p, ctypes.c_void_p)
def verify_stub(store, misc):
return 1
class NestedDict(dict):
def __missing__(self, key):
value = self[key] = type(self)()
@@ -85,7 +62,6 @@ def stringify(instr):
return instr.encode('utf-8')
return instr
class Tabulator(object):
def __init__(self, headers):
self.headers = headers
@@ -145,8 +121,7 @@ def printerror(res, node=None):
for node in res.get('databynode', {}):
exitcode = res['databynode'][node].get('errorcode', exitcode)
if 'error' in res['databynode'][node]:
sys.stderr.write(
'{0}: {1}\n'.format(node, res['databynode'][node]['error']))
sys.stderr.write('{0}: {1}\n'.format(node, res['databynode'][node]['error']))
if exitcode == 0:
exitcode = 1
if 'error' in res:
@@ -194,21 +169,16 @@ class Command(object):
self.serverloc = '/var/run/confluent/api.sock'
else:
self.serverloc = server
self.connected = False
async def ensure_connected(self):
if self.connected:
return True
if os.path.isabs(self.serverloc) and os.path.exists(self.serverloc):
self._connect_unix()
self.unixdomain = True
elif self.serverloc == '/var/run/confluent/api.sock':
raise Exception('Confluent service is not available')
else:
await self._connect_tls()
self.protversion = int((await tlvdata.recv(self.connection)).split(
self._connect_tls()
self.protversion = int(tlvdata.recv(self.connection).split(
b'--')[1].strip()[1:])
authdata = await tlvdata.recv(self.connection)
authdata = tlvdata.recv(self.connection)
if authdata['authpassed'] == 1:
self.authenticated = True
else:
@@ -216,21 +186,19 @@ class Command(object):
if not self.authenticated and 'CONFLUENT_USER' in os.environ:
username = os.environ['CONFLUENT_USER']
passphrase = os.environ['CONFLUENT_PASSPHRASE']
await self.authenticate(username, passphrase)
self.connected = True
self.authenticate(username, passphrase)
async def add_file(self, name, handle, mode):
await self.ensure_connected()
def add_file(self, name, handle, mode):
if self.protversion < 3:
raise Exception('Not supported with connected confluent server')
if not self.unixdomain:
raise Exception('Can only add a file to a unix domain connection')
tlvdata.send(self.connection, {'filename': name, 'mode': mode}, handle)
async def authenticate(self, username, password):
await tlvdata.send(self.connection,
{'username': username, 'password': password})
authdata = await tlvdata.recv(self.connection)
def authenticate(self, username, password):
tlvdata.send(self.connection,
{'username': username, 'password': password})
authdata = tlvdata.recv(self.connection)
if authdata['authpassed'] == 1:
self.authenticated = True
@@ -281,7 +249,7 @@ class Command(object):
outhandler(node, res)
return rc
async def simple_noderange_command(self, noderange, resource, input=None,
def simple_noderange_command(self, noderange, resource, input=None,
key=None, errnodes=None, promptover=None, outhandler=None, **kwargs):
try:
self._currnoderange = noderange
@@ -294,13 +262,13 @@ class Command(object):
else:
ikey = key
if input is None:
async for res in self.read('/noderange/{0}/{1}'.format(
for res in self.read('/noderange/{0}/{1}'.format(
noderange, resource)):
rc = self.handle_results(ikey, rc, res, errnodes, outhandler)
else:
await self.stop_if_noderange_over(noderange, promptover)
self.stop_if_noderange_over(noderange, promptover)
kwargs[ikey] = input
async for res in self.update('/noderange/{0}/{1}'.format(
for res in self.update('/noderange/{0}/{1}'.format(
noderange, resource), kwargs):
rc = self.handle_results(ikey, rc, res, errnodes, outhandler)
self._currnoderange = None
@@ -309,14 +277,14 @@ class Command(object):
cprint('')
return 0
async def stop_if_noderange_over(self, noderange, maxnodes):
def stop_if_noderange_over(self, noderange, maxnodes):
if maxnodes is None:
return
nsize = await self.get_noderange_size(noderange)
nsize = self.get_noderange_size(noderange)
if nsize > maxnodes:
if nsize == 1:
nodename = [x async for x in self.read(
'/noderange/{0}/nodes/'.format(noderange))][0].get('item', {}).get('href', None)
nodename = list(self.read(
'/noderange/{0}/nodes/'.format(noderange)))[0].get('item', {}).get('href', None)
nodename = nodename[:-1]
p = getinput('Command is about to affect node {0}, continue (y/n)? '.format(nodename))
else:
@@ -327,16 +295,16 @@ class Command(object):
raise Exception("Aborting at user request")
async def get_noderange_size(self, noderange):
def get_noderange_size(self, noderange):
numnodes = 0
async for node in self.read('/noderange/{0}/nodes/'.format(noderange)):
for node in self.read('/noderange/{0}/nodes/'.format(noderange)):
if node.get('item', {}).get('href', None):
numnodes += 1
else:
raise Exception("Error trying to size noderange {0}".format(noderange))
return numnodes
async def simple_nodegroups_command(self, noderange, resource, input=None, key=None, **kwargs):
def simple_nodegroups_command(self, noderange, resource, input=None, key=None, **kwargs):
try:
rc = 0
if resource[0] == '/':
@@ -347,12 +315,12 @@ class Command(object):
else:
ikey = key
if input is None:
for res in await self.read('/nodegroups/{0}/{1}'.format(
for res in self.read('/nodegroups/{0}/{1}'.format(
noderange, resource)):
rc = self.handle_results(ikey, rc, res)
else:
kwargs[ikey] = input
for res in await self.update('/nodegroups/{0}/{1}'.format(
for res in self.update('/nodegroups/{0}/{1}'.format(
noderange, resource), kwargs):
rc = self.handle_results(ikey, rc, res)
return rc
@@ -360,44 +328,32 @@ class Command(object):
cprint('')
return 0
async def read(self, path, parameters=None):
await self.ensure_connected()
def read(self, path, parameters=None):
if not self.authenticated:
raise Exception('Unauthenticated')
async for rsp in send_request(
'retrieve', path, self.connection, parameters):
yield rsp
return send_request('retrieve', path, self.connection, parameters)
async def update(self, path, parameters=None):
await self.ensure_connected()
def update(self, path, parameters=None):
if not self.authenticated:
raise Exception('Unauthenticated')
async for rsp in send_request(
'update', path, self.connection, parameters):
yield rsp
return send_request('update', path, self.connection, parameters)
async def create(self, path, parameters=None):
await self.ensure_connected()
def create(self, path, parameters=None):
if not self.authenticated:
raise Exception('Unauthenticated')
async for rsp in send_request(
'create', path, self.connection, parameters):
yield rsp
return send_request('create', path, self.connection, parameters)
async def delete(self, path, parameters=None):
await self.ensure_connected()
def delete(self, path, parameters=None):
if not self.authenticated:
raise Exception('Unauthenticated')
async for rsp in send_request(
'delete', path, self.connection, parameters):
yield rsp
return send_request('delete', path, self.connection, parameters)
def _connect_unix(self):
self.connection = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
self.connection.setsockopt(socket.SOL_SOCKET, SO_PASSCRED, 1)
self.connection.connect(self.serverloc)
async def _connect_tls(self):
def _connect_tls(self):
server, port = _parseserver(self.serverloc)
for res in socket.getaddrinfo(server, port, socket.AF_UNSPEC,
socket.SOCK_STREAM):
@@ -412,7 +368,7 @@ class Command(object):
try:
self.connection.settimeout(5)
self.connection.connect(sa)
self.connection.settimeout(0)
self.connection.settimeout(None)
except:
raise
self.connection.close()
@@ -435,21 +391,10 @@ class Command(object):
cacert = None
certreqs = ssl.CERT_NONE
knownhosts = True
ctx = ssl.SSLContext(ssl.PROTOCOL_TLSv1_2)
ssl_ctx = PySSLContext.from_address(id(ctx)).ctx
libssl.SSL_CTX_set_cert_verify_callback(ssl_ctx, verify_stub, 0)
sreader = asyncio.StreamReader()
sreaderprot = asyncio.StreamReaderProtocol(sreader)
cloop = asyncio.get_event_loop()
tport, _ = await cloop.create_connection(
lambda: sreaderprot, sock=self.connection, ssl=ctx, server_hostname='x')
swriter = asyncio.StreamWriter(tport, sreaderprot, sreader, cloop)
self.connection = (sreader, swriter)
#self.connection = ssl.wrap_socket(self.connection, ca_certs=cacert,
# cert_reqs=certreqs)
self.connection = ssl.wrap_socket(self.connection, ca_certs=cacert,
cert_reqs=certreqs)
if knownhosts:
certdata = tport.get_extra_info('ssl_object').getpeercert(binary_form=True)
# certdata = self.connection.getpeercert(binary_form=True)
certdata = self.connection.getpeercert(binary_form=True)
fingerprint = 'sha512$' + hashlib.sha512(certdata).hexdigest()
fingerprint = fingerprint.encode('utf-8')
hostid = '@'.join((port, server))
@@ -466,7 +411,7 @@ class Command(object):
khf[hostid] = fingerprint
async def send_request(operation, path, server, parameters=None):
def send_request(operation, path, server, parameters=None):
"""This function iterates over all the responses
received from the server.
@@ -479,16 +424,16 @@ async def send_request(operation, path, server, parameters=None):
payload = {'operation': operation, 'path': path}
if parameters is not None:
payload['parameters'] = parameters
await tlvdata.send(server, payload)
result = await tlvdata.recv(server)
tlvdata.send(server, payload)
result = tlvdata.recv(server)
while '_requestdone' not in result:
try:
yield result
except GeneratorExit:
while '_requestdone' not in result:
result = await tlvdata.recv(server)
result = tlvdata.recv(server)
raise
result = await tlvdata.recv(server)
result = tlvdata.recv(server)
def attrrequested(attr, attrlist, seenattributes, node=None):
@@ -513,20 +458,20 @@ def attrrequested(attr, attrlist, seenattributes, node=None):
return False
async def printattributes(session, requestargs, showtype, nodetype, noderange, options):
def printattributes(session, requestargs, showtype, nodetype, noderange, options):
path = '/{0}/{1}/attributes/{2}'.format(nodetype, noderange, showtype)
return await print_attrib_path(path, session, requestargs, options)
return print_attrib_path(path, session, requestargs, options)
def _sort_attrib(k):
if isinstance(k[1], dict) and k[1].get('sortid', None) is not None:
return k[1]['sortid']
return k[0]
async def print_attrib_path(path, session, requestargs, options, rename=None, attrprefix=None):
def print_attrib_path(path, session, requestargs, options, rename=None, attrprefix=None):
exitcode = 0
seenattributes = NestedDict()
allnodes = set([])
async for res in session.read(path):
for res in session.read(path):
if 'error' in res:
sys.stderr.write(res['error'] + '\n')
exitcode = 1
@@ -714,7 +659,7 @@ def printgroupattributes(session, requestargs, showtype, nodetype, noderange, op
exitcode = 1
return exitcode
async def updateattrib(session, updateargs, nodetype, noderange, options, dictassign=None):
def updateattrib(session, updateargs, nodetype, noderange, options, dictassign=None):
# update attribute
exitcode = 0
if options.clear:
@@ -722,7 +667,7 @@ async def updateattrib(session, updateargs, nodetype, noderange, options, dictas
keydata = {}
for attrib in updateargs[1:]:
keydata[attrib] = None
async for res in session.update(targpath, keydata):
for res in session.update(targpath, keydata):
for node in res.get('databynode', {}):
for warnmsg in res['databynode'][node].get('_warnings', []):
sys.stderr.write('Warning: ' + warnmsg + '\n')
@@ -742,21 +687,21 @@ async def updateattrib(session, updateargs, nodetype, noderange, options, dictas
value = os.environ.get(
key, os.environ[key.upper()])
if (nodetype == "nodegroups"):
exitcode = await session.simple_nodegroups_command(noderange,
exitcode = session.simple_nodegroups_command(noderange,
'attributes/all',
value, key)
else:
exitcode = await session.simple_noderange_command(noderange,
exitcode = session.simple_noderange_command(noderange,
'attributes/all',
value, key)
sys.exit(exitcode)
elif dictassign:
for key in dictassign:
if nodetype == 'nodegroups':
exitcode = await session.simple_nodegroups_command(
exitcode = session.simple_nodegroups_command(
noderange, 'attributes/all', dictassign[key], key)
else:
exitcode = await session.simple_noderange_command(
exitcode = session.simple_noderange_command(
noderange, 'attributes/all', dictassign[key], key)
else:
if "=" in updateargs[1]:
@@ -781,12 +726,12 @@ async def updateattrib(session, updateargs, nodetype, noderange, options, dictas
key = val[0]
value = val[1]
if (nodetype == "nodegroups"):
exitcode = await session.simple_nodegroups_command(noderange, 'attributes/all',
exitcode = session.simple_nodegroups_command(noderange, 'attributes/all',
value, key)
else:
exitcode = await session.simple_noderange_command(noderange, 'attributes/all',
exitcode = session.simple_noderange_command(noderange, 'attributes/all',
value, key)
except Exception:
except:
sys.stderr.write('Error: {0} not a valid expression\n'.format(str(updateargs[1:])))
exitcode = 1
sys.exit(exitcode)

View File

@@ -16,11 +16,15 @@
# limitations under the License.
import array
import asyncio
import ctypes
import ctypes.util
import confluent.tlv as tlv
import socket
try:
import eventlet.green.socket as socket
import eventlet.green.select as select
except ImportError:
import socket
import select
from datetime import datetime
import json
import os
@@ -36,7 +40,6 @@ try:
except NameError:
pass
class iovec(ctypes.Structure): # from uio.h
_fields_ = [('iov_base', ctypes.c_void_p),
('iov_len', ctypes.c_size_t)]
@@ -53,7 +56,6 @@ class cmsghdr(ctypes.Structure): # also from bits/socket.h
@classmethod
def init_data(cls, cmsg_len, cmsg_level, cmsg_type, cmsg_data):
Data = ctypes.c_ubyte * ctypes.sizeof(cmsg_data)
class _flexhdr(ctypes.Structure):
_fields_ = cls._fields_ + [('cmsg_data', Data)]
@@ -96,63 +98,13 @@ class ClientFile(object):
self.fileobject = os.fdopen(fd, mode)
self.filename = name
def _sendmsg(loop, fut, sock, msg, fds, rfd):
if rfd is not None:
loop.remove_reader(rfd)
if fut.cancelled():
return
try:
retdata = sock.sendmsg(
[msg],
[(socket.SOL_SOCKET, socket.SCM_RIGHTS, array.array("i", fds))])
except (BlockingIOError, InterruptedError):
fd = sock.fileno()
loop.add_reader(fd, _sendmsg, loop, fut, sock, fd)
except Exception as exc:
fut.set_exception(exc)
else:
fut.set_result(retdata)
def send_fds(sock, msg, fds):
cloop = asyncio.get_event_loop()
fut = cloop.create_future()
_sendmsg(cloop, fut, sock, msg, fds, None)
return fut
def _recvmsg(loop, fut, sock, msglen, maxfds, rfd):
if rfd is not None:
loop.remove_reader(rfd)
fds = array.array("i") # Array of ints
try:
msg, ancdata, flags, addr = sock.recvmsg(
msglen, socket.CMSG_LEN(maxfds * fds.itemsize))
except (BlockingIOError, InterruptedError):
fd = sock.fileno()
loop.add_reader(fd, _recvmsg, loop, fut, sock, fd)
except Exception as exc:
fut.set_exception(exc)
else:
for cmsg_level, cmsg_type, cmsg_data in ancdata:
if (cmsg_level == socket.SOL_SOCKET
and cmsg_type == socket.SCM_RIGHTS):
# Append data, ignoring any truncated integers at the end.
fds.frombytes(
cmsg_data[
:len(cmsg_data) - (len(cmsg_data) % fds.itemsize)])
fut.set_result(msglen, list(fds))
def recv_fds(sock, msglen, maxfds):
cloop = asyncio.get_event_loop()
fut = cloop.create_future()
_recvmsg(cloop, fut, sock, msglen, maxfds, None)
return fut
libc = ctypes.CDLL(ctypes.util.find_library('c'))
recvmsg = libc.recvmsg
recvmsg.argtypes = [ctypes.c_int, ctypes.POINTER(msghdr), ctypes.c_int]
recvmsg.restype = ctypes.c_int
sendmsg = libc.sendmsg
sendmsg.argtypes = [ctypes.c_int, ctypes.POINTER(msghdr), ctypes.c_int]
sendmsg.restype = ctypes.c_size_t
def decodestr(value):
ret = None
@@ -167,7 +119,6 @@ def decodestr(value):
return value
return ret
def unicode_dictvalues(dictdata):
for key in dictdata:
if isinstance(dictdata[key], bytes):
@@ -190,17 +141,7 @@ def _unicode_list(currlist):
_unicode_list(currlist[i])
async def sendall(handle, data):
if isinstance(handle, tuple):
handle[1].write(data)
return await handle[1].drain()
else:
cloop = asyncio.get_event_loop()
return await cloop.sock_sendall(handle, data)
async def send(handle, data, filehandle=None):
cloop = asyncio.get_event_loop()
def send(handle, data, filehandle=None):
if isinstance(data, unicode):
try:
data = data.encode('utf-8')
@@ -215,10 +156,10 @@ async def send(handle, data, filehandle=None):
if tl < 16777216:
# type for string is '0', so we don't need
# to xor anything in
await sendall(handle, struct.pack("!I", tl))
handle.sendall(struct.pack("!I", tl))
else:
raise Exception("String data length exceeds protocol")
await sendall(handle, data)
handle.sendall(data)
elif isinstance(data, dict): # JSON currently only goes to 4 bytes
# Some structured message, like what would be seen in http responses
unicode_dictvalues(data) # make everything unicode, assuming UTF-8
@@ -230,40 +171,41 @@ async def send(handle, data, filehandle=None):
# xor in the type (0b1 << 24)
if filehandle is None:
tl |= 16777216
await sendall(handle, struct.pack("!I", tl))
await sendall(handle, sdata)
elif isinstance(handle, tuple):
raise Exception("Cannot send filehandle over network socket")
handle.sendall(struct.pack("!I", tl))
handle.sendall(sdata)
else:
tl |= (2 << 24)
await cloop.sock_sendall(handle, struct.pack("!I", tl))
await send_fds(handle, b'', [filehandle])
handle.sendall(struct.pack("!I", tl))
cdtype = ctypes.c_ubyte * len(sdata)
cdata = cdtype.from_buffer(bytearray(sdata))
ciov = iovec(iov_base=ctypes.addressof(cdata),
iov_len=ctypes.c_size_t(ctypes.sizeof(cdata)))
fd = ctypes.c_int(filehandle)
cmh = cmsghdr.init_data(
cmsg_len=CMSG_LEN(
ctypes.sizeof(fd)), cmsg_level=socket.SOL_SOCKET,
cmsg_type=SCM_RIGHTS, cmsg_data=fd)
mh = msghdr(msg_name=None, msg_len=0, msg_iov=iovec_ptr(ciov),
msg_iovlen=1, msg_control=ctypes.addressof(cmh),
msg_controllen=ctypes.c_size_t(ctypes.sizeof(cmh)))
sendmsg(handle.fileno(), mh, 0)
async def _grabhdl(handle, size):
if isinstance(handle, tuple):
return await handle[0].read(size)
else:
cloop = asyncio.get_event_loop()
return await cloop.sock_recv(handle, size)
async def recvall(handle, size):
rd = await _grabhdl(handle, size)
def recvall(handle, size):
rd = handle.recv(size)
while len(rd) < size:
nd = await _grabhdl(handle, size - len(rd))
nd = handle.recv(size - len(rd))
if not nd:
raise Exception("Error reading data")
rd += nd
return rd
async def recv(handle):
tl = await _grabhdl(handle, 4)
def recv(handle):
tl = handle.recv(4)
if not tl:
return None
while len(tl) < 4:
ndata = await _grabhdl(handle, 4 - len(tl))
ndata = handle.recv(4 - len(tl))
if not ndata:
raise Exception("Error reading data")
tl += ndata
@@ -278,8 +220,6 @@ async def recv(handle):
if dlen == 0:
return None
if datatype == tlv.Types.filehandle:
if isinstance(handle, tuple):
raise Exception('Filehandle not supported over TLS socket')
filehandles = array.array('i')
rawbuffer = bytearray(2048)
pkttype = ctypes.c_ubyte * 2048
@@ -299,16 +239,23 @@ async def recv(handle):
msg.msg_iovlen = 1
msg.msg_control = ctypes.addressof(cmsg)
msg.msg_controllen = ctypes.sizeof(cmsg)
i = await recv_fds(handle, 2048, 4)
print(repr(i))
data = i[0]
filehandles = i[1]
select.select([handle], [], [])
i = recvmsg(handle.fileno(), ctypes.pointer(msg), 0)
cdata = cmsgarr[CMSG_LEN(0).value:]
data = rawbuffer[:i]
if cmsg.cmsg_level == socket.SOL_SOCKET and cmsg.cmsg_type == SCM_RIGHTS:
try:
filehandles.fromstring(bytes(
cdata[:len(cdata) - len(cdata) % filehandles.itemsize]))
except AttributeError:
filehandles.frombytes(bytes(
cdata[:len(cdata) - len(cdata) % filehandles.itemsize]))
data = json.loads(bytes(data))
return ClientFile(data['filename'], data['mode'], filehandles[0])
else:
data = await _grabhdl(handle, dlen)
data = handle.recv(dlen)
while len(data) < dlen:
ndata = await _grabhdl(handle, dlen - len(data))
ndata = handle.recv(dlen - len(data))
if not ndata:
raise Exception("Error reading data")
data += ndata

View File

@@ -21,7 +21,7 @@ import confluent.config.configmanager as cfm
import confluent.exceptions as exc
import confluent.log as log
import confluent.noderange as noderange
import confluent.tlvdata as tlvdata
import confluent.asynctlvdata as tlvdata
import confluent.util as util
import socket
import ssl

View File

@@ -37,7 +37,7 @@ import asyncio
import confluent
import confluent.alerts as alerts
import confluent.log as log
import confluent.tlvdata as tlvdata
import confluent.asynctlvdata as tlvdata
import confluent.config.attributes as attrscheme
import confluent.config.configmanager as cfm
import confluent.collective.manager as collective

View File

@@ -30,7 +30,7 @@ from aiohttp import web, web_urldispatcher, connector, ClientSession, WSMsgType
import confluent.auth as auth
import confluent.config.attributes as attribs
import confluent.config.configmanager as configmanager
import confluent.consoleserver as consoleserver
#import confluent.consoleserver as consoleserver
import confluent.discovery.core as disco
import confluent.forwarder as forwarder
import confluent.exceptions as exc
@@ -40,7 +40,7 @@ import confluent.core as pluginapi
import confluent.asynchttp
import confluent.selfservice as selfservice
import confluent.shellserver as shellserver
import confluent.tlvdata
import confluent.asynctlvdata as tlvdata
import confluent.util as util
import copy
import json
@@ -52,7 +52,6 @@ try:
import urlparse
except ModuleNotFoundError:
import urllib.parse as urlparse
tlvdata = confluent.tlvdata
_cleaner = None

View File

@@ -38,7 +38,7 @@ import ssl
import confluent.auth as auth
import confluent.credserver as credserver
import confluent.config.conf as conf
import confluent.tlvdata as tlvdata
import confluent.asynctlvdata as tlvdata
#import confluent.consoleserver as consoleserver
import confluent.config.configmanager as configmanager
import confluent.exceptions as exc