mirror of
https://github.com/xcat2/confluent.git
synced 2025-08-29 14:28:18 +00:00
Have async and traditional client
Since a lot of the traditional client did not need async, make life easier by just having them in parallel for now. The server must use the async client, but the client applications can stick with the somewhat more straightforward synchronous client.
This commit is contained in:
@@ -41,7 +41,7 @@
|
||||
# esc-( would interfere with normal esc use too much
|
||||
# ~ I will not use for now...
|
||||
|
||||
import asyncio
|
||||
import math
|
||||
import getpass
|
||||
import optparse
|
||||
import os
|
||||
@@ -51,7 +51,6 @@ import signal
|
||||
import socket
|
||||
import struct
|
||||
import sys
|
||||
import concurrent.futures
|
||||
import time
|
||||
try:
|
||||
import fcntl
|
||||
@@ -235,23 +234,14 @@ session = None
|
||||
|
||||
def completer(text, state):
|
||||
try:
|
||||
with concurrent.futures.ThreadPoolExecutor() as executor:
|
||||
future = executor.submit(run_completer, text, state)
|
||||
return future.result()
|
||||
except Exception:
|
||||
return rcompleter(text, state)
|
||||
except:
|
||||
pass
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
#import traceback
|
||||
#traceback.print_exc()
|
||||
|
||||
|
||||
def run_completer(text, state):
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
return loop.run_until_complete(
|
||||
rcompleter(text, state))
|
||||
|
||||
|
||||
async def rcompleter(text, state):
|
||||
def rcompleter(text, state):
|
||||
global candidates
|
||||
global valid_commands
|
||||
cline = readline.get_line_buffer()
|
||||
@@ -281,7 +271,7 @@ async def rcompleter(text, state):
|
||||
if candidates is None:
|
||||
candidates = []
|
||||
targpath = fullpath_target(lastarg)
|
||||
async for res in session.read(targpath):
|
||||
for res in session.read(targpath):
|
||||
if 'item' in res: # a link relation
|
||||
if type(res['item']) == dict:
|
||||
candidates.append(res['item']["href"])
|
||||
@@ -370,7 +360,7 @@ def print_result(res):
|
||||
print(output.encode('utf-8'))
|
||||
|
||||
|
||||
async def do_command(command, server):
|
||||
def do_command(command, server):
|
||||
global exitcode
|
||||
global target
|
||||
global currconsole
|
||||
@@ -409,7 +399,7 @@ async def do_command(command, server):
|
||||
target = otarget
|
||||
else:
|
||||
foundchild = False
|
||||
async for res in session.read(parentpath, server):
|
||||
for res in session.read(parentpath, server):
|
||||
try:
|
||||
if res['item']['href'] == childname:
|
||||
foundchild = True
|
||||
@@ -444,7 +434,7 @@ async def do_command(command, server):
|
||||
pass
|
||||
else:
|
||||
targpath = target
|
||||
async for res in session.read(targpath):
|
||||
for res in session.read(targpath):
|
||||
if 'item' in res: # a link relation
|
||||
if type(res['item']) == dict:
|
||||
print(res['item']["href"])
|
||||
@@ -484,9 +474,9 @@ async def do_command(command, server):
|
||||
startconsole(nodename)
|
||||
return
|
||||
elif argv[0] == 'set':
|
||||
await setvalues(argv[1:])
|
||||
setvalues(argv[1:])
|
||||
elif argv[0] == 'create':
|
||||
await createresource(argv[1:])
|
||||
createresource(argv[1:])
|
||||
elif argv[0] in ('rm', 'delete', 'remove'):
|
||||
delresource(argv[1])
|
||||
elif argv[0] in ('unset', 'clear'):
|
||||
@@ -501,7 +491,7 @@ def shutdown():
|
||||
tlvdata.send(session.connection, {'operation': 'shutdown', 'path': '/'})
|
||||
|
||||
|
||||
async def createresource(args):
|
||||
def createresource(args):
|
||||
resname = args[0]
|
||||
attribs = args[1:]
|
||||
keydata = parameterize_attribs(attribs)
|
||||
@@ -514,12 +504,12 @@ async def createresource(args):
|
||||
collection, _, resname = targpath.rpartition('/')
|
||||
if 'name' not in keydata:
|
||||
keydata['name'] = resname
|
||||
await makecall(session.create, (collection, keydata))
|
||||
makecall(session.create, (collection, keydata))
|
||||
|
||||
|
||||
async def makecall(callout, args):
|
||||
def makecall(callout, args):
|
||||
global exitcode
|
||||
async for response in callout(*args):
|
||||
for response in callout(*args):
|
||||
if 'deleted' in response:
|
||||
print("Deleted: " + response['deleted'])
|
||||
if 'created' in response:
|
||||
@@ -550,12 +540,12 @@ def clearvalues(resource, attribs):
|
||||
sys.stderr.write('Error: ' + res['error'] + '\n')
|
||||
|
||||
|
||||
async def delresource(resname):
|
||||
def delresource(resname):
|
||||
resname = fullpath_target(resname)
|
||||
await makecall(session.delete, (resname,))
|
||||
makecall(session.delete, (resname,))
|
||||
|
||||
|
||||
async def setvalues(attribs):
|
||||
def setvalues(attribs):
|
||||
global exitcode
|
||||
if '=' in attribs[0]: # going straight to attribute
|
||||
resource = attribs[0][:attribs[0].index("=")]
|
||||
@@ -569,7 +559,7 @@ async def setvalues(attribs):
|
||||
if not keydata:
|
||||
return
|
||||
targpath = fullpath_target(resource)
|
||||
async for res in session.update(targpath, keydata):
|
||||
for res in session.update(targpath, keydata):
|
||||
if 'error' in res:
|
||||
if 'errorcode' in res:
|
||||
exitcode = res['errorcode']
|
||||
@@ -864,7 +854,7 @@ opts, shellargs = parser.parse_args()
|
||||
|
||||
username = None
|
||||
passphrase = None
|
||||
async def server_connect():
|
||||
def server_connect():
|
||||
global session, username, passphrase
|
||||
if opts.controlpath:
|
||||
termhandler.TermHandler(opts.controlpath)
|
||||
@@ -874,7 +864,7 @@ async def server_connect():
|
||||
session = client.Command(os.environ['CONFLUENT_HOST'])
|
||||
else: # unix domain
|
||||
session = client.Command()
|
||||
await session.ensure_connected()
|
||||
|
||||
# Next stop, reading and writing from whichever of stdin and server goes first.
|
||||
#see pyghmi code for solconnect.py
|
||||
if not session.authenticated and username is not None:
|
||||
@@ -900,14 +890,11 @@ if sys.stdout.isatty():
|
||||
import readline
|
||||
|
||||
|
||||
async def main():
|
||||
def main():
|
||||
global inconsole
|
||||
global consoleonly
|
||||
global doexit
|
||||
global doexit
|
||||
try:
|
||||
await server_connect()
|
||||
except (EOFError, KeyboardInterrupt):
|
||||
server_connect()
|
||||
except (EOFError, KeyboardInterrupt) as _:
|
||||
raise BailOut(0)
|
||||
except socket.gaierror:
|
||||
sys.stderr.write('Could not connect to confluent\n')
|
||||
@@ -929,13 +916,14 @@ async def main():
|
||||
|
||||
doexit = False
|
||||
inconsole = False
|
||||
pendingcommand = ""
|
||||
session_node = get_session_node(shellargs)
|
||||
if session_node is not None:
|
||||
consoleonly = True
|
||||
await do_command("start /nodes/%s/console/session" % session_node, netserver)
|
||||
do_command("start /nodes/%s/console/session" % session_node, netserver)
|
||||
doexit = True
|
||||
elif shellargs:
|
||||
await do_command(shellargs, netserver)
|
||||
do_command(shellargs, netserver)
|
||||
quitconfetty(fullexit=True, fixterm=False)
|
||||
|
||||
powerstate = None
|
||||
@@ -966,7 +954,7 @@ async def main():
|
||||
else:
|
||||
currcommand = prompt()
|
||||
try:
|
||||
await do_command(currcommand, netserver)
|
||||
do_command(currcommand, netserver)
|
||||
except socket.error:
|
||||
try:
|
||||
server_connect()
|
||||
@@ -1041,10 +1029,10 @@ if __name__ == '__main__':
|
||||
if opts.mintime:
|
||||
deadline = os.times()[4] + float(opts.mintime)
|
||||
try:
|
||||
asyncio.get_event_loop().run_until_complete(main())
|
||||
main()
|
||||
except BailOut as e:
|
||||
errcode = e.errorcode
|
||||
except Exception:
|
||||
except Exception as e:
|
||||
import traceback
|
||||
excinfo = traceback.print_exc()
|
||||
try:
|
||||
|
@@ -15,7 +15,6 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import asyncio
|
||||
import optparse
|
||||
import os
|
||||
import signal
|
||||
@@ -86,8 +85,4 @@ if options.previous:
|
||||
def outhandler(node, res):
|
||||
for k in res[node]:
|
||||
client.cprint('{0}: {1}: {2}'.format(node, k.replace('inlet_', ''), res[node][k]))
|
||||
async def main():
|
||||
sys.exit(await session.simple_noderange_command(noderange, '/power/{0}'.format(powurl), setstate, promptover=options.maxnodes, key='state', outhandler=outhandler))
|
||||
|
||||
if __name__ == '__main__':
|
||||
asyncio.get_event_loop().run_until_complete(main())
|
||||
sys.exit(session.simple_noderange_command(noderange, '/power/{0}'.format(powurl), setstate, promptover=options.maxnodes, key='state', outhandler=outhandler))
|
||||
|
827
confluent_client/confluent/asynclient.py
Normal file
827
confluent_client/confluent/asynclient.py
Normal file
@@ -0,0 +1,827 @@
|
||||
# vim: tabstop=4 shiftwidth=4 softtabstop=4
|
||||
|
||||
# Copyright 2014 IBM Corporation
|
||||
# Copyright 2015-2019 Lenovo
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import asyncio
|
||||
import ctypes
|
||||
import ctypes.util
|
||||
import dbm
|
||||
import csv
|
||||
import errno
|
||||
import fnmatch
|
||||
import hashlib
|
||||
import os
|
||||
import shlex
|
||||
import socket
|
||||
import ssl
|
||||
import sys
|
||||
import confluent.asynctlvdata as tlvdata
|
||||
import confluent.sortutil as sortutil
|
||||
libssl = ctypes.CDLL(ctypes.util.find_library('ssl'))
|
||||
libssl.SSL_CTX_set_cert_verify_callback.argtypes = [
|
||||
ctypes.c_void_p, ctypes.c_void_p, ctypes.c_void_p]
|
||||
|
||||
SO_PASSCRED = 16
|
||||
|
||||
_attraliases = {
|
||||
'bmc': 'hardwaremanagement.manager',
|
||||
'bmcuser': 'secret.hardwaremanagementuser',
|
||||
'switchuser': 'secret.hardwaremanagementuser',
|
||||
'bmcpass': 'secret.hardwaremanagementpassword',
|
||||
'switchpass': 'secret.hardwaremanagementpassword',
|
||||
}
|
||||
|
||||
try:
|
||||
getinput = raw_input
|
||||
except NameError:
|
||||
getinput = input
|
||||
|
||||
|
||||
class PyObject_HEAD(ctypes.Structure):
|
||||
_fields_ = [
|
||||
("ob_refcnt", ctypes.c_ssize_t),
|
||||
("ob_type", ctypes.c_void_p),
|
||||
]
|
||||
|
||||
|
||||
# see main/Modules/_ssl.c, only caring about the SSL_CTX pointer
|
||||
class PySSLContext(ctypes.Structure):
|
||||
_fields_ = [
|
||||
("ob_base", PyObject_HEAD),
|
||||
("ctx", ctypes.c_void_p),
|
||||
]
|
||||
|
||||
|
||||
@ctypes.CFUNCTYPE(ctypes.c_int, ctypes.c_void_p, ctypes.c_void_p)
|
||||
def verify_stub(store, misc):
|
||||
return 1
|
||||
|
||||
|
||||
class NestedDict(dict):
|
||||
def __missing__(self, key):
|
||||
value = self[key] = type(self)()
|
||||
return value
|
||||
|
||||
|
||||
def stringify(instr):
|
||||
# Normalize unicode and bytes to 'str', correcting for
|
||||
# current python version
|
||||
if isinstance(instr, bytes) and not isinstance(instr, str):
|
||||
return instr.decode('utf-8')
|
||||
elif not isinstance(instr, bytes) and not isinstance(instr, str):
|
||||
return instr.encode('utf-8')
|
||||
return instr
|
||||
|
||||
|
||||
class Tabulator(object):
|
||||
def __init__(self, headers):
|
||||
self.headers = headers
|
||||
self.rows = []
|
||||
|
||||
def add_row(self, row):
|
||||
self.rows.append(row)
|
||||
|
||||
def get_table(self, order=None):
|
||||
i = 0
|
||||
fmtstr = ''
|
||||
separator = []
|
||||
for head in self.headers:
|
||||
if order and order == head:
|
||||
order = i
|
||||
neededlen = len(head)
|
||||
for row in self.rows:
|
||||
if len(row[i]) > neededlen:
|
||||
neededlen = len(row[i])
|
||||
separator.append('-' * (neededlen + 1))
|
||||
fmtstr += '{{{0}:>{1}}}|'.format(i, neededlen + 1)
|
||||
i = i + 1
|
||||
fmtstr = fmtstr[:-1]
|
||||
yield fmtstr.format(*self.headers)
|
||||
yield fmtstr.format(*separator)
|
||||
if order is not None:
|
||||
for row in sorted(
|
||||
self.rows,
|
||||
key=lambda x: sortutil.naturalize_string(x[order])):
|
||||
yield fmtstr.format(*row)
|
||||
else:
|
||||
for row in self.rows:
|
||||
yield fmtstr.format(*row)
|
||||
|
||||
def write_csv(self, output, order=None):
|
||||
output = csv.writer(output)
|
||||
output.writerow(self.headers)
|
||||
i = 0
|
||||
for head in self.headers:
|
||||
if order and order == head:
|
||||
order = i
|
||||
i = i + 1
|
||||
if order is not None:
|
||||
for row in sorted(
|
||||
self.rows,
|
||||
key=lambda x: sortutil.naturalize_string(x[order])):
|
||||
output.writerow(row)
|
||||
else:
|
||||
for row in self.rows:
|
||||
output.writerow(row)
|
||||
|
||||
|
||||
def printerror(res, node=None):
|
||||
exitcode = 0
|
||||
if 'errorcode' in res:
|
||||
exitcode = res['errorcode']
|
||||
for node in res.get('databynode', {}):
|
||||
exitcode = res['databynode'][node].get('errorcode', exitcode)
|
||||
if 'error' in res['databynode'][node]:
|
||||
sys.stderr.write(
|
||||
'{0}: {1}\n'.format(node, res['databynode'][node]['error']))
|
||||
if exitcode == 0:
|
||||
exitcode = 1
|
||||
if 'error' in res:
|
||||
if node:
|
||||
sys.stderr.write('{0}: {1}\n'.format(node, res['error']))
|
||||
else:
|
||||
sys.stderr.write('{0}\n'.format(res['error']))
|
||||
if 'errorcode' not in res:
|
||||
exitcode = 1
|
||||
return exitcode
|
||||
|
||||
|
||||
def cprint(txt):
|
||||
try:
|
||||
print(txt)
|
||||
except UnicodeEncodeError:
|
||||
print(txt.encode('utf8'))
|
||||
sys.stdout.flush()
|
||||
|
||||
def _parseserver(string):
|
||||
if ']:' in string:
|
||||
server, port = string[1:].split(']:')
|
||||
elif string[0] == '[':
|
||||
server = string[1:-1]
|
||||
port = '13001'
|
||||
elif ':' in string:
|
||||
server, port = string.split(':')
|
||||
else:
|
||||
server = string
|
||||
port = '13001'
|
||||
return server, port
|
||||
|
||||
|
||||
class Command(object):
|
||||
def __init__(self, server=None):
|
||||
self._prevdict = None
|
||||
self._prevkeyname = None
|
||||
self.connection = None
|
||||
self._currnoderange = None
|
||||
self.unixdomain = False
|
||||
if server is None:
|
||||
if 'CONFLUENT_HOST' in os.environ:
|
||||
self.serverloc = os.environ['CONFLUENT_HOST']
|
||||
else:
|
||||
self.serverloc = '/var/run/confluent/api.sock'
|
||||
else:
|
||||
self.serverloc = server
|
||||
self.connected = False
|
||||
|
||||
async def ensure_connected(self):
|
||||
if self.connected:
|
||||
return True
|
||||
if os.path.isabs(self.serverloc) and os.path.exists(self.serverloc):
|
||||
self._connect_unix()
|
||||
self.unixdomain = True
|
||||
elif self.serverloc == '/var/run/confluent/api.sock':
|
||||
raise Exception('Confluent service is not available')
|
||||
else:
|
||||
await self._connect_tls()
|
||||
self.protversion = int((await tlvdata.recv(self.connection)).split(
|
||||
b'--')[1].strip()[1:])
|
||||
authdata = await tlvdata.recv(self.connection)
|
||||
if authdata['authpassed'] == 1:
|
||||
self.authenticated = True
|
||||
else:
|
||||
self.authenticated = False
|
||||
if not self.authenticated and 'CONFLUENT_USER' in os.environ:
|
||||
username = os.environ['CONFLUENT_USER']
|
||||
passphrase = os.environ['CONFLUENT_PASSPHRASE']
|
||||
await self.authenticate(username, passphrase)
|
||||
self.connected = True
|
||||
|
||||
async def add_file(self, name, handle, mode):
|
||||
await self.ensure_connected()
|
||||
if self.protversion < 3:
|
||||
raise Exception('Not supported with connected confluent server')
|
||||
if not self.unixdomain:
|
||||
raise Exception('Can only add a file to a unix domain connection')
|
||||
tlvdata.send(self.connection, {'filename': name, 'mode': mode}, handle)
|
||||
|
||||
async def authenticate(self, username, password):
|
||||
await tlvdata.send(self.connection,
|
||||
{'username': username, 'password': password})
|
||||
authdata = await tlvdata.recv(self.connection)
|
||||
if authdata['authpassed'] == 1:
|
||||
self.authenticated = True
|
||||
|
||||
def add_precede_key(self, keyname):
|
||||
self._prevkeyname = keyname
|
||||
|
||||
def add_precede_dict(self, dict):
|
||||
self._prevdict = dict
|
||||
|
||||
def handle_results(self, ikey, rc, res, errnodes=None, outhandler=None):
|
||||
if 'error' in res:
|
||||
if errnodes is not None:
|
||||
errnodes.add(self._currnoderange)
|
||||
sys.stderr.write('Error: {0}\n'.format(res['error']))
|
||||
if 'errorcode' in res:
|
||||
return res['errorcode']
|
||||
else:
|
||||
return 1
|
||||
if 'databynode' not in res:
|
||||
return 0
|
||||
res = res['databynode']
|
||||
for node in res:
|
||||
if 'error' in res[node]:
|
||||
if errnodes is not None:
|
||||
errnodes.add(node)
|
||||
sys.stderr.write('{0}: Error: {1}\n'.format(
|
||||
node, res[node]['error']))
|
||||
if 'errorcode' in res[node]:
|
||||
rc |= res[node]['errorcode']
|
||||
else:
|
||||
rc |= 1
|
||||
elif ikey in res[node]:
|
||||
if 'value' in res[node][ikey]:
|
||||
val = res[node][ikey]['value']
|
||||
elif 'isset' in res[node][ikey]:
|
||||
val = '********' if res[node][ikey] else ''
|
||||
else:
|
||||
val = repr(res[node][ikey])
|
||||
if self._prevkeyname and self._prevkeyname in res[node]:
|
||||
cprint('{0}: {2}->{1}'.format(
|
||||
node, val, res[node][self._prevkeyname]['value']))
|
||||
elif self._prevdict and node in self._prevdict:
|
||||
cprint('{0}: {2}->{1}'.format(
|
||||
node, val, self._prevdict[node]))
|
||||
else:
|
||||
cprint('{0}: {1}'.format(node, val))
|
||||
elif outhandler:
|
||||
outhandler(node, res)
|
||||
return rc
|
||||
|
||||
async def simple_noderange_command(self, noderange, resource, input=None,
|
||||
key=None, errnodes=None, promptover=None, outhandler=None, **kwargs):
|
||||
try:
|
||||
self._currnoderange = noderange
|
||||
rc = 0
|
||||
if resource[0] == '/':
|
||||
resource = resource[1:]
|
||||
# The implicit key is the resource basename
|
||||
if key is None:
|
||||
ikey = resource.rpartition('/')[-1]
|
||||
else:
|
||||
ikey = key
|
||||
if input is None:
|
||||
async for res in self.read('/noderange/{0}/{1}'.format(
|
||||
noderange, resource)):
|
||||
rc = self.handle_results(ikey, rc, res, errnodes, outhandler)
|
||||
else:
|
||||
await self.stop_if_noderange_over(noderange, promptover)
|
||||
kwargs[ikey] = input
|
||||
async for res in self.update('/noderange/{0}/{1}'.format(
|
||||
noderange, resource), kwargs):
|
||||
rc = self.handle_results(ikey, rc, res, errnodes, outhandler)
|
||||
self._currnoderange = None
|
||||
return rc
|
||||
except KeyboardInterrupt:
|
||||
cprint('')
|
||||
return 0
|
||||
|
||||
async def stop_if_noderange_over(self, noderange, maxnodes):
|
||||
if maxnodes is None:
|
||||
return
|
||||
nsize = await self.get_noderange_size(noderange)
|
||||
if nsize > maxnodes:
|
||||
if nsize == 1:
|
||||
nodename = [x async for x in self.read(
|
||||
'/noderange/{0}/nodes/'.format(noderange))][0].get('item', {}).get('href', None)
|
||||
nodename = nodename[:-1]
|
||||
p = getinput('Command is about to affect node {0}, continue (y/n)? '.format(nodename))
|
||||
else:
|
||||
p = getinput('Command is about to affect {0} nodes, continue (y/n)? '.format(nsize))
|
||||
if p.lower() != 'y':
|
||||
sys.stderr.write('Aborting at user request\n')
|
||||
sys.exit(1)
|
||||
raise Exception("Aborting at user request")
|
||||
|
||||
|
||||
async def get_noderange_size(self, noderange):
|
||||
numnodes = 0
|
||||
async for node in self.read('/noderange/{0}/nodes/'.format(noderange)):
|
||||
if node.get('item', {}).get('href', None):
|
||||
numnodes += 1
|
||||
else:
|
||||
raise Exception("Error trying to size noderange {0}".format(noderange))
|
||||
return numnodes
|
||||
|
||||
async def simple_nodegroups_command(self, noderange, resource, input=None, key=None, **kwargs):
|
||||
try:
|
||||
rc = 0
|
||||
if resource[0] == '/':
|
||||
resource = resource[1:]
|
||||
# The implicit key is the resource basename
|
||||
if key is None:
|
||||
ikey = resource.rpartition('/')[-1]
|
||||
else:
|
||||
ikey = key
|
||||
if input is None:
|
||||
for res in await self.read('/nodegroups/{0}/{1}'.format(
|
||||
noderange, resource)):
|
||||
rc = self.handle_results(ikey, rc, res)
|
||||
else:
|
||||
kwargs[ikey] = input
|
||||
for res in await self.update('/nodegroups/{0}/{1}'.format(
|
||||
noderange, resource), kwargs):
|
||||
rc = self.handle_results(ikey, rc, res)
|
||||
return rc
|
||||
except KeyboardInterrupt:
|
||||
cprint('')
|
||||
return 0
|
||||
|
||||
async def read(self, path, parameters=None):
|
||||
await self.ensure_connected()
|
||||
if not self.authenticated:
|
||||
raise Exception('Unauthenticated')
|
||||
async for rsp in send_request(
|
||||
'retrieve', path, self.connection, parameters):
|
||||
yield rsp
|
||||
|
||||
async def update(self, path, parameters=None):
|
||||
await self.ensure_connected()
|
||||
if not self.authenticated:
|
||||
raise Exception('Unauthenticated')
|
||||
async for rsp in send_request(
|
||||
'update', path, self.connection, parameters):
|
||||
yield rsp
|
||||
|
||||
async def create(self, path, parameters=None):
|
||||
await self.ensure_connected()
|
||||
if not self.authenticated:
|
||||
raise Exception('Unauthenticated')
|
||||
async for rsp in send_request(
|
||||
'create', path, self.connection, parameters):
|
||||
yield rsp
|
||||
|
||||
async def delete(self, path, parameters=None):
|
||||
await self.ensure_connected()
|
||||
if not self.authenticated:
|
||||
raise Exception('Unauthenticated')
|
||||
async for rsp in send_request(
|
||||
'delete', path, self.connection, parameters):
|
||||
yield rsp
|
||||
|
||||
def _connect_unix(self):
|
||||
self.connection = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
|
||||
self.connection.setsockopt(socket.SOL_SOCKET, SO_PASSCRED, 1)
|
||||
self.connection.connect(self.serverloc)
|
||||
|
||||
async def _connect_tls(self):
|
||||
server, port = _parseserver(self.serverloc)
|
||||
for res in socket.getaddrinfo(server, port, socket.AF_UNSPEC,
|
||||
socket.SOCK_STREAM):
|
||||
af, socktype, proto, canonname, sa = res
|
||||
try:
|
||||
self.connection = socket.socket(af, socktype, proto)
|
||||
self.connection.setsockopt(
|
||||
socket.IPPROTO_TCP, socket.TCP_NODELAY, 1)
|
||||
except:
|
||||
self.connection = None
|
||||
continue
|
||||
try:
|
||||
self.connection.settimeout(5)
|
||||
self.connection.connect(sa)
|
||||
self.connection.settimeout(0)
|
||||
except:
|
||||
raise
|
||||
self.connection.close()
|
||||
self.connection = None
|
||||
continue
|
||||
break
|
||||
if self.connection is None:
|
||||
raise Exception("Failed to connect to %s" % self.serverloc)
|
||||
#TODO(jbjohnso): server certificate validation
|
||||
clientcfgdir = os.path.join(os.path.expanduser("~"), ".confluent")
|
||||
try:
|
||||
os.makedirs(clientcfgdir)
|
||||
except OSError as exc:
|
||||
if not (exc.errno == errno.EEXIST and os.path.isdir(clientcfgdir)):
|
||||
raise
|
||||
cacert = os.path.join(clientcfgdir, "ca.pem")
|
||||
certreqs = ssl.CERT_REQUIRED
|
||||
knownhosts = False
|
||||
if not os.path.exists(cacert):
|
||||
cacert = None
|
||||
certreqs = ssl.CERT_NONE
|
||||
knownhosts = True
|
||||
ctx = ssl.SSLContext(ssl.PROTOCOL_TLSv1_2)
|
||||
ssl_ctx = PySSLContext.from_address(id(ctx)).ctx
|
||||
libssl.SSL_CTX_set_cert_verify_callback(ssl_ctx, verify_stub, 0)
|
||||
sreader = asyncio.StreamReader()
|
||||
sreaderprot = asyncio.StreamReaderProtocol(sreader)
|
||||
cloop = asyncio.get_event_loop()
|
||||
tport, _ = await cloop.create_connection(
|
||||
lambda: sreaderprot, sock=self.connection, ssl=ctx, server_hostname='x')
|
||||
swriter = asyncio.StreamWriter(tport, sreaderprot, sreader, cloop)
|
||||
self.connection = (sreader, swriter)
|
||||
#self.connection = ssl.wrap_socket(self.connection, ca_certs=cacert,
|
||||
# cert_reqs=certreqs)
|
||||
if knownhosts:
|
||||
certdata = tport.get_extra_info('ssl_object').getpeercert(binary_form=True)
|
||||
# certdata = self.connection.getpeercert(binary_form=True)
|
||||
fingerprint = 'sha512$' + hashlib.sha512(certdata).hexdigest()
|
||||
fingerprint = fingerprint.encode('utf-8')
|
||||
hostid = '@'.join((port, server))
|
||||
khf = dbm.open(os.path.join(clientcfgdir, "knownhosts"), 'c', 384)
|
||||
if hostid in khf:
|
||||
if fingerprint == khf[hostid]:
|
||||
return
|
||||
else:
|
||||
replace = getinput(
|
||||
"MISMATCHED CERTIFICATE DATA, ACCEPT NEW? (y/n):")
|
||||
if replace not in ('y', 'Y'):
|
||||
raise Exception("BAD CERTIFICATE")
|
||||
cprint('Adding new key for %s:%s' % (server, port))
|
||||
khf[hostid] = fingerprint
|
||||
|
||||
|
||||
async def send_request(operation, path, server, parameters=None):
|
||||
"""This function iterates over all the responses
|
||||
received from the server.
|
||||
|
||||
:param operation: The operation to request, retrieve, update, delete,
|
||||
create, start, stop
|
||||
:param path: The URI path to the resource to operate on
|
||||
:param server: The socket to send data over
|
||||
:param parameters: Parameters if any to send along with the request
|
||||
"""
|
||||
payload = {'operation': operation, 'path': path}
|
||||
if parameters is not None:
|
||||
payload['parameters'] = parameters
|
||||
await tlvdata.send(server, payload)
|
||||
result = await tlvdata.recv(server)
|
||||
while '_requestdone' not in result:
|
||||
try:
|
||||
yield result
|
||||
except GeneratorExit:
|
||||
while '_requestdone' not in result:
|
||||
result = await tlvdata.recv(server)
|
||||
raise
|
||||
result = await tlvdata.recv(server)
|
||||
|
||||
|
||||
def attrrequested(attr, attrlist, seenattributes, node=None):
|
||||
for candidate in attrlist:
|
||||
truename = candidate
|
||||
if candidate.startswith('hm'):
|
||||
candidate = candidate.replace('hm', 'hardwaremanagement', 1)
|
||||
if candidate in _attraliases:
|
||||
candidate = _attraliases[candidate]
|
||||
if fnmatch.fnmatch(attr.lower(), candidate.lower()):
|
||||
if node is None:
|
||||
seenattributes.add(truename)
|
||||
else:
|
||||
seenattributes[node][truename] = True
|
||||
return True
|
||||
elif attr.lower().startswith(candidate.lower() + '.'):
|
||||
if node is None:
|
||||
seenattributes.add(truename)
|
||||
else:
|
||||
seenattributes[node][truename] = 1
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
async def printattributes(session, requestargs, showtype, nodetype, noderange, options):
|
||||
path = '/{0}/{1}/attributes/{2}'.format(nodetype, noderange, showtype)
|
||||
return await print_attrib_path(path, session, requestargs, options)
|
||||
|
||||
def _sort_attrib(k):
|
||||
if isinstance(k[1], dict) and k[1].get('sortid', None) is not None:
|
||||
return k[1]['sortid']
|
||||
return k[0]
|
||||
|
||||
async def print_attrib_path(path, session, requestargs, options, rename=None, attrprefix=None):
|
||||
exitcode = 0
|
||||
seenattributes = NestedDict()
|
||||
allnodes = set([])
|
||||
async for res in session.read(path):
|
||||
if 'error' in res:
|
||||
sys.stderr.write(res['error'] + '\n')
|
||||
exitcode = 1
|
||||
continue
|
||||
for node in sorted(res['databynode']):
|
||||
allnodes.add(node)
|
||||
for attr, val in sorted(res['databynode'][node].items(), key=_sort_attrib):
|
||||
if attr == 'error':
|
||||
sys.stderr.write('{0}: Error: {1}\n'.format(node, val))
|
||||
continue
|
||||
if attr == 'errorcode':
|
||||
exitcode |= val
|
||||
continue
|
||||
seenattributes[node][attr] = True
|
||||
if rename:
|
||||
printattr = rename.get(attr, attr)
|
||||
else:
|
||||
printattr = attr
|
||||
if attrprefix:
|
||||
printattr = attrprefix + printattr
|
||||
currattr = res['databynode'][node][attr]
|
||||
if show_attr(attr, requestargs, seenattributes, options, node):
|
||||
if 'value' in currattr:
|
||||
if currattr['value'] is not None:
|
||||
val = currattr['value']
|
||||
if isinstance(val, list):
|
||||
val = ','.join(val)
|
||||
attrout = '{0}: {1}: {2}'.format(
|
||||
node, printattr, val).strip()
|
||||
else:
|
||||
attrout = '{0}: {1}:'.format(node, printattr)
|
||||
elif 'isset' in currattr:
|
||||
if currattr['isset']:
|
||||
attrout = '{0}: {1}: ********'.format(node,
|
||||
printattr)
|
||||
else:
|
||||
attrout = '{0}: {1}:'.format(node, printattr)
|
||||
elif isinstance(currattr, dict) and 'broken' in currattr:
|
||||
attrout = '{0}: {1}: *ERROR* BROKEN EXPRESSION: ' \
|
||||
'{2}'.format(node, printattr,
|
||||
currattr['broken'])
|
||||
elif isinstance(currattr, list) or isinstance(currattr, tuple):
|
||||
attrout = '{0}: {1}: {2}'.format(node, attr, ','.join(map(str, currattr)))
|
||||
elif isinstance(currattr, dict):
|
||||
dictout = []
|
||||
for k, v in currattr.items:
|
||||
dictout.append("{0}={1}".format(k, v))
|
||||
attrout = '{0}: {1}: {2}'.format(node, printattr, ','.join(map(str, dictout)))
|
||||
else:
|
||||
cprint("CODE ERROR" + repr(attr))
|
||||
try:
|
||||
blame = options.blame
|
||||
except AttributeError:
|
||||
blame = False
|
||||
if blame or (isinstance(currattr, dict) and 'broken' in currattr):
|
||||
blamedata = []
|
||||
if 'inheritedfrom' in currattr:
|
||||
blamedata.append('inherited from group {0}'.format(
|
||||
currattr['inheritedfrom']
|
||||
))
|
||||
if 'expression' in currattr:
|
||||
blamedata.append(
|
||||
'derived from expression "{0}"'.format(
|
||||
currattr['expression']))
|
||||
if blamedata:
|
||||
attrout += ' (' + ', '.join(blamedata) + ')'
|
||||
try:
|
||||
comparedefault = options.comparedefault
|
||||
except AttributeError:
|
||||
comparedefault = False
|
||||
if comparedefault:
|
||||
try:
|
||||
exclude = options.exclude
|
||||
except AttributeError:
|
||||
exclude = False
|
||||
if ((requestargs and not exclude) or
|
||||
(currattr.get('default', None) is not None and
|
||||
currattr.get('value', None) is not None and
|
||||
currattr['value'] != currattr['default'])):
|
||||
cval = ','.join(currattr['value']) if isinstance(
|
||||
currattr['value'], list) else currattr['value']
|
||||
dval = ','.join(currattr['default']) if isinstance(
|
||||
currattr['default'], list) else currattr['default']
|
||||
cprint('{0}: {1}: {2} (Default: {3})'.format(
|
||||
node, printattr, cval, dval))
|
||||
else:
|
||||
|
||||
try:
|
||||
details = options.detail
|
||||
except AttributeError:
|
||||
details = False
|
||||
if details:
|
||||
if currattr.get('help', None):
|
||||
attrout += u' (Help: {0})'.format(
|
||||
currattr['help'])
|
||||
if currattr.get('possible', None):
|
||||
try:
|
||||
attrout += u' (Choices: {0})'.format(
|
||||
','.join(currattr['possible']))
|
||||
except TypeError:
|
||||
pass
|
||||
cprint(attrout)
|
||||
somematched = set([])
|
||||
printmissing = set([])
|
||||
badnodes = NestedDict()
|
||||
if not exitcode:
|
||||
if requestargs:
|
||||
for attr in requestargs:
|
||||
for node in allnodes:
|
||||
if attr in seenattributes[node]:
|
||||
somematched.add(attr)
|
||||
else:
|
||||
badnodes[node][attr] = True
|
||||
exitcode = 1
|
||||
for node in sortutil.natural_sort(badnodes):
|
||||
for attr in badnodes[node]:
|
||||
if attr in somematched:
|
||||
sys.stderr.write(
|
||||
'Error: {0} matches no valid value for {1}\n'.format(
|
||||
attr, node))
|
||||
else:
|
||||
printmissing.add(attr)
|
||||
for missing in printmissing:
|
||||
sys.stderr.write('Error: {0} not a valid attribute\n'.format(missing))
|
||||
return exitcode
|
||||
|
||||
|
||||
def show_attr(attr, requestargs, seenattributes, options, node):
|
||||
try:
|
||||
reverse = options.exclude
|
||||
except AttributeError:
|
||||
reverse = False
|
||||
if requestargs is None or requestargs == []:
|
||||
return True
|
||||
processattr = attrrequested(attr, requestargs, seenattributes, node)
|
||||
if reverse:
|
||||
processattr = not processattr
|
||||
return processattr
|
||||
|
||||
|
||||
def printgroupattributes(session, requestargs, showtype, nodetype, noderange, options):
|
||||
exitcode = 0
|
||||
seenattributes = set([])
|
||||
for res in session.read('/{0}/{1}/attributes/{2}'.format(nodetype, noderange, showtype)):
|
||||
if 'error' in res:
|
||||
sys.stderr.write(res['error'] + '\n')
|
||||
exitcode = 1
|
||||
continue
|
||||
for attr in res:
|
||||
seenattributes.add(attr)
|
||||
currattr = res[attr]
|
||||
if (requestargs is None or requestargs == [] or attrrequested(attr, requestargs, seenattributes)):
|
||||
if 'value' in currattr:
|
||||
if currattr['value'] is not None:
|
||||
attrout = '{0}: {1}: {2}'.format(
|
||||
noderange, attr, currattr['value'])
|
||||
else:
|
||||
attrout = '{0}: {1}:'.format(noderange, attr)
|
||||
elif 'isset' in currattr:
|
||||
if currattr['isset']:
|
||||
attrout = '{0}: {1}: ********'.format(noderange, attr)
|
||||
else:
|
||||
attrout = '{0}: {1}:'.format(noderange, attr)
|
||||
elif isinstance(currattr, dict) and 'broken' in currattr:
|
||||
attrout = '{0}: {1}: *ERROR* BROKEN EXPRESSION: ' \
|
||||
'{2}'.format(noderange, attr,
|
||||
currattr['broken'])
|
||||
elif 'expression' in currattr:
|
||||
attrout = '{0}: {1}: (will derive from expression {2})'.format(noderange, attr, currattr['expression'])
|
||||
elif isinstance(currattr, list) or isinstance(currattr, tuple):
|
||||
attrout = '{0}: {1}: {2}'.format(noderange, attr, ','.join(map(str, currattr)))
|
||||
elif isinstance(currattr, dict):
|
||||
dictout = []
|
||||
for k, v in currattr.items:
|
||||
dictout.append("{0}={1}".format(k, v))
|
||||
attrout = '{0}: {1}: {2}'.format(noderange, attr, ','.join(map(str, dictout)))
|
||||
else:
|
||||
cprint("CODE ERROR" + repr(attr))
|
||||
cprint(attrout)
|
||||
if not exitcode:
|
||||
if requestargs:
|
||||
for attr in requestargs:
|
||||
if attr not in seenattributes:
|
||||
sys.stderr.write('Error: {0} not a valid attribute\n'.format(attr))
|
||||
exitcode = 1
|
||||
return exitcode
|
||||
|
||||
async def updateattrib(session, updateargs, nodetype, noderange, options, dictassign=None):
|
||||
# update attribute
|
||||
exitcode = 0
|
||||
if options.clear:
|
||||
targpath = '/{0}/{1}/attributes/all'.format(nodetype, noderange)
|
||||
keydata = {}
|
||||
for attrib in updateargs[1:]:
|
||||
keydata[attrib] = None
|
||||
async for res in session.update(targpath, keydata):
|
||||
for node in res.get('databynode', {}):
|
||||
for warnmsg in res['databynode'][node].get('_warnings', []):
|
||||
sys.stderr.write('Warning: ' + warnmsg + '\n')
|
||||
if 'error' in res:
|
||||
if 'errorcode' in res:
|
||||
exitcode = res['errorcode']
|
||||
sys.stderr.write('Error: ' + res['error'] + '\n')
|
||||
sys.exit(exitcode)
|
||||
elif hasattr(options, 'environment') and options.environment:
|
||||
for key in updateargs[1:]:
|
||||
key = key.replace('.', '_')
|
||||
value = os.environ.get(
|
||||
key, os.environ[key.upper()])
|
||||
# Let's do one pass to make sure that there's not a usage problem
|
||||
for key in updateargs[1:]:
|
||||
key = key.replace('.', '_')
|
||||
value = os.environ.get(
|
||||
key, os.environ[key.upper()])
|
||||
if (nodetype == "nodegroups"):
|
||||
exitcode = await session.simple_nodegroups_command(noderange,
|
||||
'attributes/all',
|
||||
value, key)
|
||||
else:
|
||||
exitcode = await session.simple_noderange_command(noderange,
|
||||
'attributes/all',
|
||||
value, key)
|
||||
sys.exit(exitcode)
|
||||
elif dictassign:
|
||||
for key in dictassign:
|
||||
if nodetype == 'nodegroups':
|
||||
exitcode = await session.simple_nodegroups_command(
|
||||
noderange, 'attributes/all', dictassign[key], key)
|
||||
else:
|
||||
exitcode = await session.simple_noderange_command(
|
||||
noderange, 'attributes/all', dictassign[key], key)
|
||||
else:
|
||||
if "=" in updateargs[1]:
|
||||
update_ready = True
|
||||
for arg in updateargs[1:]:
|
||||
if not '=' in arg:
|
||||
update_ready = False
|
||||
exitcode = 1
|
||||
if not update_ready:
|
||||
sys.stderr.write('Error: {0} Can not set and read at the same time!\n'.format(str(updateargs[1:])))
|
||||
sys.exit(exitcode)
|
||||
try:
|
||||
for val in updateargs[1:]:
|
||||
val = val.split('=', 1)
|
||||
if val[0][-1] in (',', '-', '^'):
|
||||
key = val[0][:-1]
|
||||
if val[0][-1] == ',':
|
||||
value = {'prepend': val[1]}
|
||||
elif val[0][-1] in ('-', '^'):
|
||||
value = {'remove': val[1]}
|
||||
else:
|
||||
key = val[0]
|
||||
value = val[1]
|
||||
if (nodetype == "nodegroups"):
|
||||
exitcode = await session.simple_nodegroups_command(noderange, 'attributes/all',
|
||||
value, key)
|
||||
else:
|
||||
exitcode = await session.simple_noderange_command(noderange, 'attributes/all',
|
||||
value, key)
|
||||
except Exception:
|
||||
sys.stderr.write('Error: {0} not a valid expression\n'.format(str(updateargs[1:])))
|
||||
exitcode = 1
|
||||
sys.exit(exitcode)
|
||||
return exitcode
|
||||
|
||||
|
||||
# So we try to prevent bad things from happening when globbing
|
||||
# We tried to head this off at the shell, but the various solutions would end
|
||||
# up breaking the shell in various ways (breaking pipe capability if using
|
||||
# DEBUG, breaking globbing if in pipe, etc)
|
||||
# Then we tried to parse the original commandline instead, however shlex isn't
|
||||
# going to parse full bourne language (e.g. knowing that '|' and '>' and
|
||||
# a world of other things would not be in our command line
|
||||
# so finally, just make sure the noderange appears verbatim in the command line
|
||||
# if we glob to something, then bash will change noderange and this should
|
||||
# detect it and save the user from tragedy
|
||||
def check_globbing(noderange):
|
||||
if not os.path.exists(noderange):
|
||||
return True
|
||||
rawargs = os.environ.get('CURRENT_CMDLINE', None)
|
||||
if rawargs:
|
||||
rawargs = shlex.split(rawargs)
|
||||
for arg in rawargs:
|
||||
if arg.startswith('$'):
|
||||
arg = arg[1:]
|
||||
if arg.endswith(';'):
|
||||
arg = arg[:-1]
|
||||
arg = os.environ.get(arg, '$' + arg)
|
||||
if arg.startswith(noderange):
|
||||
break
|
||||
else:
|
||||
sys.stderr.write(
|
||||
'Shell glob conflict detected, specified target "{0}" '
|
||||
'not in command line, but is a file. You can use "set -f" in '
|
||||
'bash or change directories such that there is no filename '
|
||||
'that would conflict.'
|
||||
'\n'.format(noderange))
|
||||
sys.exit(1)
|
318
confluent_client/confluent/asynctlvdata.py
Normal file
318
confluent_client/confluent/asynctlvdata.py
Normal file
@@ -0,0 +1,318 @@
|
||||
# vim: tabstop=4 shiftwidth=4 softtabstop=4
|
||||
|
||||
# Copyright 2014 IBM Corporation
|
||||
# Copyright 2015 Lenovo
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import array
|
||||
import asyncio
|
||||
import ctypes
|
||||
import ctypes.util
|
||||
import confluent.tlv as tlv
|
||||
import socket
|
||||
from datetime import datetime
|
||||
import json
|
||||
import os
|
||||
import struct
|
||||
|
||||
try:
|
||||
unicode
|
||||
except NameError:
|
||||
unicode = str
|
||||
|
||||
try:
|
||||
range = xrange
|
||||
except NameError:
|
||||
pass
|
||||
|
||||
|
||||
class iovec(ctypes.Structure): # from uio.h
|
||||
_fields_ = [('iov_base', ctypes.c_void_p),
|
||||
('iov_len', ctypes.c_size_t)]
|
||||
|
||||
|
||||
iovec_ptr = ctypes.POINTER(iovec)
|
||||
|
||||
|
||||
class cmsghdr(ctypes.Structure): # also from bits/socket.h
|
||||
_fields_ = [('cmsg_len', ctypes.c_size_t),
|
||||
('cmsg_level', ctypes.c_int),
|
||||
('cmsg_type', ctypes.c_int)]
|
||||
|
||||
@classmethod
|
||||
def init_data(cls, cmsg_len, cmsg_level, cmsg_type, cmsg_data):
|
||||
Data = ctypes.c_ubyte * ctypes.sizeof(cmsg_data)
|
||||
|
||||
class _flexhdr(ctypes.Structure):
|
||||
_fields_ = cls._fields_ + [('cmsg_data', Data)]
|
||||
|
||||
datab = Data(*bytearray(cmsg_data))
|
||||
return _flexhdr(cmsg_len=cmsg_len, cmsg_level=cmsg_level,
|
||||
cmsg_type=cmsg_type, cmsg_data=datab)
|
||||
|
||||
|
||||
def CMSG_LEN(length):
|
||||
sizeof_cmshdr = ctypes.sizeof(cmsghdr)
|
||||
return ctypes.c_size_t(CMSG_ALIGN(sizeof_cmshdr).value + length)
|
||||
|
||||
|
||||
SCM_RIGHTS = 1
|
||||
|
||||
|
||||
class msghdr(ctypes.Structure): # from bits/socket.h
|
||||
_fields_ = [('msg_name', ctypes.c_void_p),
|
||||
('msg_namelen', ctypes.c_uint),
|
||||
('msg_iov', ctypes.POINTER(iovec)),
|
||||
('msg_iovlen', ctypes.c_size_t),
|
||||
('msg_control', ctypes.c_void_p),
|
||||
('msg_controllen', ctypes.c_size_t),
|
||||
('msg_flags', ctypes.c_int)]
|
||||
|
||||
|
||||
def CMSG_ALIGN(length): # bits/socket.h
|
||||
ret = (length + ctypes.sizeof(ctypes.c_size_t) - 1
|
||||
& ~(ctypes.sizeof(ctypes.c_size_t) - 1))
|
||||
return ctypes.c_size_t(ret)
|
||||
|
||||
|
||||
def CMSG_SPACE(length): # bits/socket.h
|
||||
ret = CMSG_ALIGN(length).value + CMSG_ALIGN(ctypes.sizeof(cmsghdr)).value
|
||||
return ctypes.c_size_t(ret)
|
||||
|
||||
|
||||
class ClientFile(object):
|
||||
def __init__(self, name, mode, fd):
|
||||
self.fileobject = os.fdopen(fd, mode)
|
||||
self.filename = name
|
||||
|
||||
|
||||
|
||||
|
||||
def _sendmsg(loop, fut, sock, msg, fds, rfd):
|
||||
if rfd is not None:
|
||||
loop.remove_reader(rfd)
|
||||
if fut.cancelled():
|
||||
return
|
||||
try:
|
||||
retdata = sock.sendmsg(
|
||||
[msg],
|
||||
[(socket.SOL_SOCKET, socket.SCM_RIGHTS, array.array("i", fds))])
|
||||
except (BlockingIOError, InterruptedError):
|
||||
fd = sock.fileno()
|
||||
loop.add_reader(fd, _sendmsg, loop, fut, sock, fd)
|
||||
except Exception as exc:
|
||||
fut.set_exception(exc)
|
||||
else:
|
||||
fut.set_result(retdata)
|
||||
|
||||
|
||||
def send_fds(sock, msg, fds):
|
||||
cloop = asyncio.get_event_loop()
|
||||
fut = cloop.create_future()
|
||||
_sendmsg(cloop, fut, sock, msg, fds, None)
|
||||
return fut
|
||||
|
||||
|
||||
def _recvmsg(loop, fut, sock, msglen, maxfds, rfd):
|
||||
if rfd is not None:
|
||||
loop.remove_reader(rfd)
|
||||
fds = array.array("i") # Array of ints
|
||||
try:
|
||||
msg, ancdata, flags, addr = sock.recvmsg(
|
||||
msglen, socket.CMSG_LEN(maxfds * fds.itemsize))
|
||||
except (BlockingIOError, InterruptedError):
|
||||
fd = sock.fileno()
|
||||
loop.add_reader(fd, _recvmsg, loop, fut, sock, fd)
|
||||
except Exception as exc:
|
||||
fut.set_exception(exc)
|
||||
else:
|
||||
for cmsg_level, cmsg_type, cmsg_data in ancdata:
|
||||
if (cmsg_level == socket.SOL_SOCKET
|
||||
and cmsg_type == socket.SCM_RIGHTS):
|
||||
# Append data, ignoring any truncated integers at the end.
|
||||
fds.frombytes(
|
||||
cmsg_data[
|
||||
:len(cmsg_data) - (len(cmsg_data) % fds.itemsize)])
|
||||
fut.set_result(msglen, list(fds))
|
||||
|
||||
|
||||
def recv_fds(sock, msglen, maxfds):
|
||||
cloop = asyncio.get_event_loop()
|
||||
fut = cloop.create_future()
|
||||
_recvmsg(cloop, fut, sock, msglen, maxfds, None)
|
||||
return fut
|
||||
|
||||
|
||||
def decodestr(value):
|
||||
ret = None
|
||||
try:
|
||||
ret = value.decode('utf-8')
|
||||
except UnicodeDecodeError:
|
||||
try:
|
||||
ret = value.decode('cp437')
|
||||
except UnicodeDecodeError:
|
||||
ret = value
|
||||
except AttributeError:
|
||||
return value
|
||||
return ret
|
||||
|
||||
|
||||
def unicode_dictvalues(dictdata):
|
||||
for key in dictdata:
|
||||
if isinstance(dictdata[key], bytes):
|
||||
dictdata[key] = decodestr(dictdata[key])
|
||||
elif isinstance(dictdata[key], datetime):
|
||||
dictdata[key] = dictdata[key].strftime('%Y-%m-%dT%H:%M:%S')
|
||||
elif isinstance(dictdata[key], list):
|
||||
_unicode_list(dictdata[key])
|
||||
elif isinstance(dictdata[key], dict):
|
||||
unicode_dictvalues(dictdata[key])
|
||||
|
||||
|
||||
def _unicode_list(currlist):
|
||||
for i in range(len(currlist)):
|
||||
if isinstance(currlist[i], str):
|
||||
currlist[i] = decodestr(currlist[i])
|
||||
elif isinstance(currlist[i], dict):
|
||||
unicode_dictvalues(currlist[i])
|
||||
elif isinstance(currlist[i], list):
|
||||
_unicode_list(currlist[i])
|
||||
|
||||
|
||||
async def sendall(handle, data):
|
||||
if isinstance(handle, tuple):
|
||||
handle[1].write(data)
|
||||
return await handle[1].drain()
|
||||
else:
|
||||
cloop = asyncio.get_event_loop()
|
||||
return await cloop.sock_sendall(handle, data)
|
||||
|
||||
|
||||
async def send(handle, data, filehandle=None):
|
||||
cloop = asyncio.get_event_loop()
|
||||
if isinstance(data, unicode):
|
||||
try:
|
||||
data = data.encode('utf-8')
|
||||
except AttributeError:
|
||||
pass
|
||||
if isinstance(data, bytes) or isinstance(data, unicode):
|
||||
# plain text, e.g. console data
|
||||
tl = len(data)
|
||||
if tl == 0:
|
||||
# if you don't have anything to say, don't say anything at all
|
||||
return
|
||||
if tl < 16777216:
|
||||
# type for string is '0', so we don't need
|
||||
# to xor anything in
|
||||
await sendall(handle, struct.pack("!I", tl))
|
||||
else:
|
||||
raise Exception("String data length exceeds protocol")
|
||||
await sendall(handle, data)
|
||||
elif isinstance(data, dict): # JSON currently only goes to 4 bytes
|
||||
# Some structured message, like what would be seen in http responses
|
||||
unicode_dictvalues(data) # make everything unicode, assuming UTF-8
|
||||
sdata = json.dumps(data, ensure_ascii=False, separators=(',', ':'))
|
||||
sdata = sdata.encode('utf-8')
|
||||
tl = len(sdata)
|
||||
if tl > 16777215:
|
||||
raise Exception("JSON data exceeds protocol limits")
|
||||
# xor in the type (0b1 << 24)
|
||||
if filehandle is None:
|
||||
tl |= 16777216
|
||||
await sendall(handle, struct.pack("!I", tl))
|
||||
await sendall(handle, sdata)
|
||||
elif isinstance(handle, tuple):
|
||||
raise Exception("Cannot send filehandle over network socket")
|
||||
else:
|
||||
tl |= (2 << 24)
|
||||
await cloop.sock_sendall(handle, struct.pack("!I", tl))
|
||||
await send_fds(handle, b'', [filehandle])
|
||||
|
||||
|
||||
async def _grabhdl(handle, size):
|
||||
if isinstance(handle, tuple):
|
||||
return await handle[0].read(size)
|
||||
else:
|
||||
cloop = asyncio.get_event_loop()
|
||||
return await cloop.sock_recv(handle, size)
|
||||
|
||||
|
||||
async def recvall(handle, size):
|
||||
rd = await _grabhdl(handle, size)
|
||||
while len(rd) < size:
|
||||
nd = await _grabhdl(handle, size - len(rd))
|
||||
if not nd:
|
||||
raise Exception("Error reading data")
|
||||
rd += nd
|
||||
return rd
|
||||
|
||||
|
||||
async def recv(handle):
|
||||
tl = await _grabhdl(handle, 4)
|
||||
if not tl:
|
||||
return None
|
||||
while len(tl) < 4:
|
||||
ndata = await _grabhdl(handle, 4 - len(tl))
|
||||
if not ndata:
|
||||
raise Exception("Error reading data")
|
||||
tl += ndata
|
||||
if len(tl) == 0:
|
||||
return None
|
||||
tl = struct.unpack("!I", tl)[0]
|
||||
if tl & 0b10000000000000000000000000000000:
|
||||
raise Exception("Protocol Violation, reserved bit set")
|
||||
# 4 byte tlv
|
||||
dlen = tl & 16777215 # grab lower 24 bits
|
||||
datatype = (tl & 2130706432) >> 24 # grab 7 bits from near beginning
|
||||
if dlen == 0:
|
||||
return None
|
||||
if datatype == tlv.Types.filehandle:
|
||||
if isinstance(handle, tuple):
|
||||
raise Exception('Filehandle not supported over TLS socket')
|
||||
filehandles = array.array('i')
|
||||
rawbuffer = bytearray(2048)
|
||||
pkttype = ctypes.c_ubyte * 2048
|
||||
data = pkttype.from_buffer(rawbuffer)
|
||||
cmsgsize = CMSG_SPACE(ctypes.sizeof(ctypes.c_int)).value
|
||||
cmsgarr = bytearray(cmsgsize)
|
||||
cmtype = ctypes.c_ubyte * cmsgsize
|
||||
cmsg = cmtype.from_buffer(cmsgarr)
|
||||
cmsg.cmsg_level = socket.SOL_SOCKET
|
||||
cmsg.cmsg_type = SCM_RIGHTS
|
||||
cmsg.cmsg_len = CMSG_LEN(ctypes.sizeof(ctypes.c_int))
|
||||
iov = iovec()
|
||||
iov.iov_base = ctypes.addressof(data)
|
||||
iov.iov_len = 2048
|
||||
msg = msghdr()
|
||||
msg.msg_iov = ctypes.pointer(iov)
|
||||
msg.msg_iovlen = 1
|
||||
msg.msg_control = ctypes.addressof(cmsg)
|
||||
msg.msg_controllen = ctypes.sizeof(cmsg)
|
||||
i = await recv_fds(handle, 2048, 4)
|
||||
print(repr(i))
|
||||
data = i[0]
|
||||
filehandles = i[1]
|
||||
data = json.loads(bytes(data))
|
||||
return ClientFile(data['filename'], data['mode'], filehandles[0])
|
||||
else:
|
||||
data = await _grabhdl(handle, dlen)
|
||||
while len(data) < dlen:
|
||||
ndata = await _grabhdl(handle, dlen - len(data))
|
||||
if not ndata:
|
||||
raise Exception("Error reading data")
|
||||
data += ndata
|
||||
if datatype == tlv.Types.text:
|
||||
return data
|
||||
elif datatype == tlv.Types.json:
|
||||
return json.loads(data)
|
@@ -15,10 +15,10 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import asyncio
|
||||
import ctypes
|
||||
import ctypes.util
|
||||
import dbm
|
||||
try:
|
||||
import anydbm as dbm
|
||||
except ImportError:
|
||||
import dbm
|
||||
import csv
|
||||
import errno
|
||||
import fnmatch
|
||||
@@ -30,9 +30,6 @@ import ssl
|
||||
import sys
|
||||
import confluent.tlvdata as tlvdata
|
||||
import confluent.sortutil as sortutil
|
||||
libssl = ctypes.CDLL(ctypes.util.find_library('ssl'))
|
||||
libssl.SSL_CTX_set_cert_verify_callback.argtypes = [
|
||||
ctypes.c_void_p, ctypes.c_void_p, ctypes.c_void_p]
|
||||
|
||||
SO_PASSCRED = 16
|
||||
|
||||
@@ -50,26 +47,6 @@ except NameError:
|
||||
getinput = input
|
||||
|
||||
|
||||
class PyObject_HEAD(ctypes.Structure):
|
||||
_fields_ = [
|
||||
("ob_refcnt", ctypes.c_ssize_t),
|
||||
("ob_type", ctypes.c_void_p),
|
||||
]
|
||||
|
||||
|
||||
# see main/Modules/_ssl.c, only caring about the SSL_CTX pointer
|
||||
class PySSLContext(ctypes.Structure):
|
||||
_fields_ = [
|
||||
("ob_base", PyObject_HEAD),
|
||||
("ctx", ctypes.c_void_p),
|
||||
]
|
||||
|
||||
|
||||
@ctypes.CFUNCTYPE(ctypes.c_int, ctypes.c_void_p, ctypes.c_void_p)
|
||||
def verify_stub(store, misc):
|
||||
return 1
|
||||
|
||||
|
||||
class NestedDict(dict):
|
||||
def __missing__(self, key):
|
||||
value = self[key] = type(self)()
|
||||
@@ -85,7 +62,6 @@ def stringify(instr):
|
||||
return instr.encode('utf-8')
|
||||
return instr
|
||||
|
||||
|
||||
class Tabulator(object):
|
||||
def __init__(self, headers):
|
||||
self.headers = headers
|
||||
@@ -145,8 +121,7 @@ def printerror(res, node=None):
|
||||
for node in res.get('databynode', {}):
|
||||
exitcode = res['databynode'][node].get('errorcode', exitcode)
|
||||
if 'error' in res['databynode'][node]:
|
||||
sys.stderr.write(
|
||||
'{0}: {1}\n'.format(node, res['databynode'][node]['error']))
|
||||
sys.stderr.write('{0}: {1}\n'.format(node, res['databynode'][node]['error']))
|
||||
if exitcode == 0:
|
||||
exitcode = 1
|
||||
if 'error' in res:
|
||||
@@ -194,21 +169,16 @@ class Command(object):
|
||||
self.serverloc = '/var/run/confluent/api.sock'
|
||||
else:
|
||||
self.serverloc = server
|
||||
self.connected = False
|
||||
|
||||
async def ensure_connected(self):
|
||||
if self.connected:
|
||||
return True
|
||||
if os.path.isabs(self.serverloc) and os.path.exists(self.serverloc):
|
||||
self._connect_unix()
|
||||
self.unixdomain = True
|
||||
elif self.serverloc == '/var/run/confluent/api.sock':
|
||||
raise Exception('Confluent service is not available')
|
||||
else:
|
||||
await self._connect_tls()
|
||||
self.protversion = int((await tlvdata.recv(self.connection)).split(
|
||||
self._connect_tls()
|
||||
self.protversion = int(tlvdata.recv(self.connection).split(
|
||||
b'--')[1].strip()[1:])
|
||||
authdata = await tlvdata.recv(self.connection)
|
||||
authdata = tlvdata.recv(self.connection)
|
||||
if authdata['authpassed'] == 1:
|
||||
self.authenticated = True
|
||||
else:
|
||||
@@ -216,21 +186,19 @@ class Command(object):
|
||||
if not self.authenticated and 'CONFLUENT_USER' in os.environ:
|
||||
username = os.environ['CONFLUENT_USER']
|
||||
passphrase = os.environ['CONFLUENT_PASSPHRASE']
|
||||
await self.authenticate(username, passphrase)
|
||||
self.connected = True
|
||||
self.authenticate(username, passphrase)
|
||||
|
||||
async def add_file(self, name, handle, mode):
|
||||
await self.ensure_connected()
|
||||
def add_file(self, name, handle, mode):
|
||||
if self.protversion < 3:
|
||||
raise Exception('Not supported with connected confluent server')
|
||||
if not self.unixdomain:
|
||||
raise Exception('Can only add a file to a unix domain connection')
|
||||
tlvdata.send(self.connection, {'filename': name, 'mode': mode}, handle)
|
||||
|
||||
async def authenticate(self, username, password):
|
||||
await tlvdata.send(self.connection,
|
||||
{'username': username, 'password': password})
|
||||
authdata = await tlvdata.recv(self.connection)
|
||||
def authenticate(self, username, password):
|
||||
tlvdata.send(self.connection,
|
||||
{'username': username, 'password': password})
|
||||
authdata = tlvdata.recv(self.connection)
|
||||
if authdata['authpassed'] == 1:
|
||||
self.authenticated = True
|
||||
|
||||
@@ -281,7 +249,7 @@ class Command(object):
|
||||
outhandler(node, res)
|
||||
return rc
|
||||
|
||||
async def simple_noderange_command(self, noderange, resource, input=None,
|
||||
def simple_noderange_command(self, noderange, resource, input=None,
|
||||
key=None, errnodes=None, promptover=None, outhandler=None, **kwargs):
|
||||
try:
|
||||
self._currnoderange = noderange
|
||||
@@ -294,13 +262,13 @@ class Command(object):
|
||||
else:
|
||||
ikey = key
|
||||
if input is None:
|
||||
async for res in self.read('/noderange/{0}/{1}'.format(
|
||||
for res in self.read('/noderange/{0}/{1}'.format(
|
||||
noderange, resource)):
|
||||
rc = self.handle_results(ikey, rc, res, errnodes, outhandler)
|
||||
else:
|
||||
await self.stop_if_noderange_over(noderange, promptover)
|
||||
self.stop_if_noderange_over(noderange, promptover)
|
||||
kwargs[ikey] = input
|
||||
async for res in self.update('/noderange/{0}/{1}'.format(
|
||||
for res in self.update('/noderange/{0}/{1}'.format(
|
||||
noderange, resource), kwargs):
|
||||
rc = self.handle_results(ikey, rc, res, errnodes, outhandler)
|
||||
self._currnoderange = None
|
||||
@@ -309,14 +277,14 @@ class Command(object):
|
||||
cprint('')
|
||||
return 0
|
||||
|
||||
async def stop_if_noderange_over(self, noderange, maxnodes):
|
||||
def stop_if_noderange_over(self, noderange, maxnodes):
|
||||
if maxnodes is None:
|
||||
return
|
||||
nsize = await self.get_noderange_size(noderange)
|
||||
nsize = self.get_noderange_size(noderange)
|
||||
if nsize > maxnodes:
|
||||
if nsize == 1:
|
||||
nodename = [x async for x in self.read(
|
||||
'/noderange/{0}/nodes/'.format(noderange))][0].get('item', {}).get('href', None)
|
||||
nodename = list(self.read(
|
||||
'/noderange/{0}/nodes/'.format(noderange)))[0].get('item', {}).get('href', None)
|
||||
nodename = nodename[:-1]
|
||||
p = getinput('Command is about to affect node {0}, continue (y/n)? '.format(nodename))
|
||||
else:
|
||||
@@ -327,16 +295,16 @@ class Command(object):
|
||||
raise Exception("Aborting at user request")
|
||||
|
||||
|
||||
async def get_noderange_size(self, noderange):
|
||||
def get_noderange_size(self, noderange):
|
||||
numnodes = 0
|
||||
async for node in self.read('/noderange/{0}/nodes/'.format(noderange)):
|
||||
for node in self.read('/noderange/{0}/nodes/'.format(noderange)):
|
||||
if node.get('item', {}).get('href', None):
|
||||
numnodes += 1
|
||||
else:
|
||||
raise Exception("Error trying to size noderange {0}".format(noderange))
|
||||
return numnodes
|
||||
|
||||
async def simple_nodegroups_command(self, noderange, resource, input=None, key=None, **kwargs):
|
||||
def simple_nodegroups_command(self, noderange, resource, input=None, key=None, **kwargs):
|
||||
try:
|
||||
rc = 0
|
||||
if resource[0] == '/':
|
||||
@@ -347,12 +315,12 @@ class Command(object):
|
||||
else:
|
||||
ikey = key
|
||||
if input is None:
|
||||
for res in await self.read('/nodegroups/{0}/{1}'.format(
|
||||
for res in self.read('/nodegroups/{0}/{1}'.format(
|
||||
noderange, resource)):
|
||||
rc = self.handle_results(ikey, rc, res)
|
||||
else:
|
||||
kwargs[ikey] = input
|
||||
for res in await self.update('/nodegroups/{0}/{1}'.format(
|
||||
for res in self.update('/nodegroups/{0}/{1}'.format(
|
||||
noderange, resource), kwargs):
|
||||
rc = self.handle_results(ikey, rc, res)
|
||||
return rc
|
||||
@@ -360,44 +328,32 @@ class Command(object):
|
||||
cprint('')
|
||||
return 0
|
||||
|
||||
async def read(self, path, parameters=None):
|
||||
await self.ensure_connected()
|
||||
def read(self, path, parameters=None):
|
||||
if not self.authenticated:
|
||||
raise Exception('Unauthenticated')
|
||||
async for rsp in send_request(
|
||||
'retrieve', path, self.connection, parameters):
|
||||
yield rsp
|
||||
return send_request('retrieve', path, self.connection, parameters)
|
||||
|
||||
async def update(self, path, parameters=None):
|
||||
await self.ensure_connected()
|
||||
def update(self, path, parameters=None):
|
||||
if not self.authenticated:
|
||||
raise Exception('Unauthenticated')
|
||||
async for rsp in send_request(
|
||||
'update', path, self.connection, parameters):
|
||||
yield rsp
|
||||
return send_request('update', path, self.connection, parameters)
|
||||
|
||||
async def create(self, path, parameters=None):
|
||||
await self.ensure_connected()
|
||||
def create(self, path, parameters=None):
|
||||
if not self.authenticated:
|
||||
raise Exception('Unauthenticated')
|
||||
async for rsp in send_request(
|
||||
'create', path, self.connection, parameters):
|
||||
yield rsp
|
||||
return send_request('create', path, self.connection, parameters)
|
||||
|
||||
async def delete(self, path, parameters=None):
|
||||
await self.ensure_connected()
|
||||
def delete(self, path, parameters=None):
|
||||
if not self.authenticated:
|
||||
raise Exception('Unauthenticated')
|
||||
async for rsp in send_request(
|
||||
'delete', path, self.connection, parameters):
|
||||
yield rsp
|
||||
return send_request('delete', path, self.connection, parameters)
|
||||
|
||||
def _connect_unix(self):
|
||||
self.connection = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
|
||||
self.connection.setsockopt(socket.SOL_SOCKET, SO_PASSCRED, 1)
|
||||
self.connection.connect(self.serverloc)
|
||||
|
||||
async def _connect_tls(self):
|
||||
def _connect_tls(self):
|
||||
server, port = _parseserver(self.serverloc)
|
||||
for res in socket.getaddrinfo(server, port, socket.AF_UNSPEC,
|
||||
socket.SOCK_STREAM):
|
||||
@@ -412,7 +368,7 @@ class Command(object):
|
||||
try:
|
||||
self.connection.settimeout(5)
|
||||
self.connection.connect(sa)
|
||||
self.connection.settimeout(0)
|
||||
self.connection.settimeout(None)
|
||||
except:
|
||||
raise
|
||||
self.connection.close()
|
||||
@@ -435,21 +391,10 @@ class Command(object):
|
||||
cacert = None
|
||||
certreqs = ssl.CERT_NONE
|
||||
knownhosts = True
|
||||
ctx = ssl.SSLContext(ssl.PROTOCOL_TLSv1_2)
|
||||
ssl_ctx = PySSLContext.from_address(id(ctx)).ctx
|
||||
libssl.SSL_CTX_set_cert_verify_callback(ssl_ctx, verify_stub, 0)
|
||||
sreader = asyncio.StreamReader()
|
||||
sreaderprot = asyncio.StreamReaderProtocol(sreader)
|
||||
cloop = asyncio.get_event_loop()
|
||||
tport, _ = await cloop.create_connection(
|
||||
lambda: sreaderprot, sock=self.connection, ssl=ctx, server_hostname='x')
|
||||
swriter = asyncio.StreamWriter(tport, sreaderprot, sreader, cloop)
|
||||
self.connection = (sreader, swriter)
|
||||
#self.connection = ssl.wrap_socket(self.connection, ca_certs=cacert,
|
||||
# cert_reqs=certreqs)
|
||||
self.connection = ssl.wrap_socket(self.connection, ca_certs=cacert,
|
||||
cert_reqs=certreqs)
|
||||
if knownhosts:
|
||||
certdata = tport.get_extra_info('ssl_object').getpeercert(binary_form=True)
|
||||
# certdata = self.connection.getpeercert(binary_form=True)
|
||||
certdata = self.connection.getpeercert(binary_form=True)
|
||||
fingerprint = 'sha512$' + hashlib.sha512(certdata).hexdigest()
|
||||
fingerprint = fingerprint.encode('utf-8')
|
||||
hostid = '@'.join((port, server))
|
||||
@@ -466,7 +411,7 @@ class Command(object):
|
||||
khf[hostid] = fingerprint
|
||||
|
||||
|
||||
async def send_request(operation, path, server, parameters=None):
|
||||
def send_request(operation, path, server, parameters=None):
|
||||
"""This function iterates over all the responses
|
||||
received from the server.
|
||||
|
||||
@@ -479,16 +424,16 @@ async def send_request(operation, path, server, parameters=None):
|
||||
payload = {'operation': operation, 'path': path}
|
||||
if parameters is not None:
|
||||
payload['parameters'] = parameters
|
||||
await tlvdata.send(server, payload)
|
||||
result = await tlvdata.recv(server)
|
||||
tlvdata.send(server, payload)
|
||||
result = tlvdata.recv(server)
|
||||
while '_requestdone' not in result:
|
||||
try:
|
||||
yield result
|
||||
except GeneratorExit:
|
||||
while '_requestdone' not in result:
|
||||
result = await tlvdata.recv(server)
|
||||
result = tlvdata.recv(server)
|
||||
raise
|
||||
result = await tlvdata.recv(server)
|
||||
result = tlvdata.recv(server)
|
||||
|
||||
|
||||
def attrrequested(attr, attrlist, seenattributes, node=None):
|
||||
@@ -513,20 +458,20 @@ def attrrequested(attr, attrlist, seenattributes, node=None):
|
||||
return False
|
||||
|
||||
|
||||
async def printattributes(session, requestargs, showtype, nodetype, noderange, options):
|
||||
def printattributes(session, requestargs, showtype, nodetype, noderange, options):
|
||||
path = '/{0}/{1}/attributes/{2}'.format(nodetype, noderange, showtype)
|
||||
return await print_attrib_path(path, session, requestargs, options)
|
||||
return print_attrib_path(path, session, requestargs, options)
|
||||
|
||||
def _sort_attrib(k):
|
||||
if isinstance(k[1], dict) and k[1].get('sortid', None) is not None:
|
||||
return k[1]['sortid']
|
||||
return k[0]
|
||||
|
||||
async def print_attrib_path(path, session, requestargs, options, rename=None, attrprefix=None):
|
||||
def print_attrib_path(path, session, requestargs, options, rename=None, attrprefix=None):
|
||||
exitcode = 0
|
||||
seenattributes = NestedDict()
|
||||
allnodes = set([])
|
||||
async for res in session.read(path):
|
||||
for res in session.read(path):
|
||||
if 'error' in res:
|
||||
sys.stderr.write(res['error'] + '\n')
|
||||
exitcode = 1
|
||||
@@ -714,7 +659,7 @@ def printgroupattributes(session, requestargs, showtype, nodetype, noderange, op
|
||||
exitcode = 1
|
||||
return exitcode
|
||||
|
||||
async def updateattrib(session, updateargs, nodetype, noderange, options, dictassign=None):
|
||||
def updateattrib(session, updateargs, nodetype, noderange, options, dictassign=None):
|
||||
# update attribute
|
||||
exitcode = 0
|
||||
if options.clear:
|
||||
@@ -722,7 +667,7 @@ async def updateattrib(session, updateargs, nodetype, noderange, options, dictas
|
||||
keydata = {}
|
||||
for attrib in updateargs[1:]:
|
||||
keydata[attrib] = None
|
||||
async for res in session.update(targpath, keydata):
|
||||
for res in session.update(targpath, keydata):
|
||||
for node in res.get('databynode', {}):
|
||||
for warnmsg in res['databynode'][node].get('_warnings', []):
|
||||
sys.stderr.write('Warning: ' + warnmsg + '\n')
|
||||
@@ -742,21 +687,21 @@ async def updateattrib(session, updateargs, nodetype, noderange, options, dictas
|
||||
value = os.environ.get(
|
||||
key, os.environ[key.upper()])
|
||||
if (nodetype == "nodegroups"):
|
||||
exitcode = await session.simple_nodegroups_command(noderange,
|
||||
exitcode = session.simple_nodegroups_command(noderange,
|
||||
'attributes/all',
|
||||
value, key)
|
||||
else:
|
||||
exitcode = await session.simple_noderange_command(noderange,
|
||||
exitcode = session.simple_noderange_command(noderange,
|
||||
'attributes/all',
|
||||
value, key)
|
||||
sys.exit(exitcode)
|
||||
elif dictassign:
|
||||
for key in dictassign:
|
||||
if nodetype == 'nodegroups':
|
||||
exitcode = await session.simple_nodegroups_command(
|
||||
exitcode = session.simple_nodegroups_command(
|
||||
noderange, 'attributes/all', dictassign[key], key)
|
||||
else:
|
||||
exitcode = await session.simple_noderange_command(
|
||||
exitcode = session.simple_noderange_command(
|
||||
noderange, 'attributes/all', dictassign[key], key)
|
||||
else:
|
||||
if "=" in updateargs[1]:
|
||||
@@ -781,12 +726,12 @@ async def updateattrib(session, updateargs, nodetype, noderange, options, dictas
|
||||
key = val[0]
|
||||
value = val[1]
|
||||
if (nodetype == "nodegroups"):
|
||||
exitcode = await session.simple_nodegroups_command(noderange, 'attributes/all',
|
||||
exitcode = session.simple_nodegroups_command(noderange, 'attributes/all',
|
||||
value, key)
|
||||
else:
|
||||
exitcode = await session.simple_noderange_command(noderange, 'attributes/all',
|
||||
exitcode = session.simple_noderange_command(noderange, 'attributes/all',
|
||||
value, key)
|
||||
except Exception:
|
||||
except:
|
||||
sys.stderr.write('Error: {0} not a valid expression\n'.format(str(updateargs[1:])))
|
||||
exitcode = 1
|
||||
sys.exit(exitcode)
|
||||
|
@@ -16,11 +16,15 @@
|
||||
# limitations under the License.
|
||||
|
||||
import array
|
||||
import asyncio
|
||||
import ctypes
|
||||
import ctypes.util
|
||||
import confluent.tlv as tlv
|
||||
import socket
|
||||
try:
|
||||
import eventlet.green.socket as socket
|
||||
import eventlet.green.select as select
|
||||
except ImportError:
|
||||
import socket
|
||||
import select
|
||||
from datetime import datetime
|
||||
import json
|
||||
import os
|
||||
@@ -36,7 +40,6 @@ try:
|
||||
except NameError:
|
||||
pass
|
||||
|
||||
|
||||
class iovec(ctypes.Structure): # from uio.h
|
||||
_fields_ = [('iov_base', ctypes.c_void_p),
|
||||
('iov_len', ctypes.c_size_t)]
|
||||
@@ -53,7 +56,6 @@ class cmsghdr(ctypes.Structure): # also from bits/socket.h
|
||||
@classmethod
|
||||
def init_data(cls, cmsg_len, cmsg_level, cmsg_type, cmsg_data):
|
||||
Data = ctypes.c_ubyte * ctypes.sizeof(cmsg_data)
|
||||
|
||||
class _flexhdr(ctypes.Structure):
|
||||
_fields_ = cls._fields_ + [('cmsg_data', Data)]
|
||||
|
||||
@@ -96,63 +98,13 @@ class ClientFile(object):
|
||||
self.fileobject = os.fdopen(fd, mode)
|
||||
self.filename = name
|
||||
|
||||
|
||||
|
||||
|
||||
def _sendmsg(loop, fut, sock, msg, fds, rfd):
|
||||
if rfd is not None:
|
||||
loop.remove_reader(rfd)
|
||||
if fut.cancelled():
|
||||
return
|
||||
try:
|
||||
retdata = sock.sendmsg(
|
||||
[msg],
|
||||
[(socket.SOL_SOCKET, socket.SCM_RIGHTS, array.array("i", fds))])
|
||||
except (BlockingIOError, InterruptedError):
|
||||
fd = sock.fileno()
|
||||
loop.add_reader(fd, _sendmsg, loop, fut, sock, fd)
|
||||
except Exception as exc:
|
||||
fut.set_exception(exc)
|
||||
else:
|
||||
fut.set_result(retdata)
|
||||
|
||||
|
||||
def send_fds(sock, msg, fds):
|
||||
cloop = asyncio.get_event_loop()
|
||||
fut = cloop.create_future()
|
||||
_sendmsg(cloop, fut, sock, msg, fds, None)
|
||||
return fut
|
||||
|
||||
|
||||
def _recvmsg(loop, fut, sock, msglen, maxfds, rfd):
|
||||
if rfd is not None:
|
||||
loop.remove_reader(rfd)
|
||||
fds = array.array("i") # Array of ints
|
||||
try:
|
||||
msg, ancdata, flags, addr = sock.recvmsg(
|
||||
msglen, socket.CMSG_LEN(maxfds * fds.itemsize))
|
||||
except (BlockingIOError, InterruptedError):
|
||||
fd = sock.fileno()
|
||||
loop.add_reader(fd, _recvmsg, loop, fut, sock, fd)
|
||||
except Exception as exc:
|
||||
fut.set_exception(exc)
|
||||
else:
|
||||
for cmsg_level, cmsg_type, cmsg_data in ancdata:
|
||||
if (cmsg_level == socket.SOL_SOCKET
|
||||
and cmsg_type == socket.SCM_RIGHTS):
|
||||
# Append data, ignoring any truncated integers at the end.
|
||||
fds.frombytes(
|
||||
cmsg_data[
|
||||
:len(cmsg_data) - (len(cmsg_data) % fds.itemsize)])
|
||||
fut.set_result(msglen, list(fds))
|
||||
|
||||
|
||||
def recv_fds(sock, msglen, maxfds):
|
||||
cloop = asyncio.get_event_loop()
|
||||
fut = cloop.create_future()
|
||||
_recvmsg(cloop, fut, sock, msglen, maxfds, None)
|
||||
return fut
|
||||
|
||||
libc = ctypes.CDLL(ctypes.util.find_library('c'))
|
||||
recvmsg = libc.recvmsg
|
||||
recvmsg.argtypes = [ctypes.c_int, ctypes.POINTER(msghdr), ctypes.c_int]
|
||||
recvmsg.restype = ctypes.c_int
|
||||
sendmsg = libc.sendmsg
|
||||
sendmsg.argtypes = [ctypes.c_int, ctypes.POINTER(msghdr), ctypes.c_int]
|
||||
sendmsg.restype = ctypes.c_size_t
|
||||
|
||||
def decodestr(value):
|
||||
ret = None
|
||||
@@ -167,7 +119,6 @@ def decodestr(value):
|
||||
return value
|
||||
return ret
|
||||
|
||||
|
||||
def unicode_dictvalues(dictdata):
|
||||
for key in dictdata:
|
||||
if isinstance(dictdata[key], bytes):
|
||||
@@ -190,17 +141,7 @@ def _unicode_list(currlist):
|
||||
_unicode_list(currlist[i])
|
||||
|
||||
|
||||
async def sendall(handle, data):
|
||||
if isinstance(handle, tuple):
|
||||
handle[1].write(data)
|
||||
return await handle[1].drain()
|
||||
else:
|
||||
cloop = asyncio.get_event_loop()
|
||||
return await cloop.sock_sendall(handle, data)
|
||||
|
||||
|
||||
async def send(handle, data, filehandle=None):
|
||||
cloop = asyncio.get_event_loop()
|
||||
def send(handle, data, filehandle=None):
|
||||
if isinstance(data, unicode):
|
||||
try:
|
||||
data = data.encode('utf-8')
|
||||
@@ -215,10 +156,10 @@ async def send(handle, data, filehandle=None):
|
||||
if tl < 16777216:
|
||||
# type for string is '0', so we don't need
|
||||
# to xor anything in
|
||||
await sendall(handle, struct.pack("!I", tl))
|
||||
handle.sendall(struct.pack("!I", tl))
|
||||
else:
|
||||
raise Exception("String data length exceeds protocol")
|
||||
await sendall(handle, data)
|
||||
handle.sendall(data)
|
||||
elif isinstance(data, dict): # JSON currently only goes to 4 bytes
|
||||
# Some structured message, like what would be seen in http responses
|
||||
unicode_dictvalues(data) # make everything unicode, assuming UTF-8
|
||||
@@ -230,40 +171,41 @@ async def send(handle, data, filehandle=None):
|
||||
# xor in the type (0b1 << 24)
|
||||
if filehandle is None:
|
||||
tl |= 16777216
|
||||
await sendall(handle, struct.pack("!I", tl))
|
||||
await sendall(handle, sdata)
|
||||
elif isinstance(handle, tuple):
|
||||
raise Exception("Cannot send filehandle over network socket")
|
||||
handle.sendall(struct.pack("!I", tl))
|
||||
handle.sendall(sdata)
|
||||
else:
|
||||
tl |= (2 << 24)
|
||||
await cloop.sock_sendall(handle, struct.pack("!I", tl))
|
||||
await send_fds(handle, b'', [filehandle])
|
||||
handle.sendall(struct.pack("!I", tl))
|
||||
cdtype = ctypes.c_ubyte * len(sdata)
|
||||
cdata = cdtype.from_buffer(bytearray(sdata))
|
||||
ciov = iovec(iov_base=ctypes.addressof(cdata),
|
||||
iov_len=ctypes.c_size_t(ctypes.sizeof(cdata)))
|
||||
fd = ctypes.c_int(filehandle)
|
||||
cmh = cmsghdr.init_data(
|
||||
cmsg_len=CMSG_LEN(
|
||||
ctypes.sizeof(fd)), cmsg_level=socket.SOL_SOCKET,
|
||||
cmsg_type=SCM_RIGHTS, cmsg_data=fd)
|
||||
mh = msghdr(msg_name=None, msg_len=0, msg_iov=iovec_ptr(ciov),
|
||||
msg_iovlen=1, msg_control=ctypes.addressof(cmh),
|
||||
msg_controllen=ctypes.c_size_t(ctypes.sizeof(cmh)))
|
||||
sendmsg(handle.fileno(), mh, 0)
|
||||
|
||||
|
||||
async def _grabhdl(handle, size):
|
||||
if isinstance(handle, tuple):
|
||||
return await handle[0].read(size)
|
||||
else:
|
||||
cloop = asyncio.get_event_loop()
|
||||
return await cloop.sock_recv(handle, size)
|
||||
|
||||
|
||||
async def recvall(handle, size):
|
||||
rd = await _grabhdl(handle, size)
|
||||
def recvall(handle, size):
|
||||
rd = handle.recv(size)
|
||||
while len(rd) < size:
|
||||
nd = await _grabhdl(handle, size - len(rd))
|
||||
nd = handle.recv(size - len(rd))
|
||||
if not nd:
|
||||
raise Exception("Error reading data")
|
||||
rd += nd
|
||||
return rd
|
||||
|
||||
|
||||
async def recv(handle):
|
||||
tl = await _grabhdl(handle, 4)
|
||||
def recv(handle):
|
||||
tl = handle.recv(4)
|
||||
if not tl:
|
||||
return None
|
||||
while len(tl) < 4:
|
||||
ndata = await _grabhdl(handle, 4 - len(tl))
|
||||
ndata = handle.recv(4 - len(tl))
|
||||
if not ndata:
|
||||
raise Exception("Error reading data")
|
||||
tl += ndata
|
||||
@@ -278,8 +220,6 @@ async def recv(handle):
|
||||
if dlen == 0:
|
||||
return None
|
||||
if datatype == tlv.Types.filehandle:
|
||||
if isinstance(handle, tuple):
|
||||
raise Exception('Filehandle not supported over TLS socket')
|
||||
filehandles = array.array('i')
|
||||
rawbuffer = bytearray(2048)
|
||||
pkttype = ctypes.c_ubyte * 2048
|
||||
@@ -299,16 +239,23 @@ async def recv(handle):
|
||||
msg.msg_iovlen = 1
|
||||
msg.msg_control = ctypes.addressof(cmsg)
|
||||
msg.msg_controllen = ctypes.sizeof(cmsg)
|
||||
i = await recv_fds(handle, 2048, 4)
|
||||
print(repr(i))
|
||||
data = i[0]
|
||||
filehandles = i[1]
|
||||
select.select([handle], [], [])
|
||||
i = recvmsg(handle.fileno(), ctypes.pointer(msg), 0)
|
||||
cdata = cmsgarr[CMSG_LEN(0).value:]
|
||||
data = rawbuffer[:i]
|
||||
if cmsg.cmsg_level == socket.SOL_SOCKET and cmsg.cmsg_type == SCM_RIGHTS:
|
||||
try:
|
||||
filehandles.fromstring(bytes(
|
||||
cdata[:len(cdata) - len(cdata) % filehandles.itemsize]))
|
||||
except AttributeError:
|
||||
filehandles.frombytes(bytes(
|
||||
cdata[:len(cdata) - len(cdata) % filehandles.itemsize]))
|
||||
data = json.loads(bytes(data))
|
||||
return ClientFile(data['filename'], data['mode'], filehandles[0])
|
||||
else:
|
||||
data = await _grabhdl(handle, dlen)
|
||||
data = handle.recv(dlen)
|
||||
while len(data) < dlen:
|
||||
ndata = await _grabhdl(handle, dlen - len(data))
|
||||
ndata = handle.recv(dlen - len(data))
|
||||
if not ndata:
|
||||
raise Exception("Error reading data")
|
||||
data += ndata
|
||||
|
@@ -21,7 +21,7 @@ import confluent.config.configmanager as cfm
|
||||
import confluent.exceptions as exc
|
||||
import confluent.log as log
|
||||
import confluent.noderange as noderange
|
||||
import confluent.tlvdata as tlvdata
|
||||
import confluent.asynctlvdata as tlvdata
|
||||
import confluent.util as util
|
||||
import socket
|
||||
import ssl
|
||||
|
@@ -37,7 +37,7 @@ import asyncio
|
||||
import confluent
|
||||
import confluent.alerts as alerts
|
||||
import confluent.log as log
|
||||
import confluent.tlvdata as tlvdata
|
||||
import confluent.asynctlvdata as tlvdata
|
||||
import confluent.config.attributes as attrscheme
|
||||
import confluent.config.configmanager as cfm
|
||||
import confluent.collective.manager as collective
|
||||
|
@@ -30,7 +30,7 @@ from aiohttp import web, web_urldispatcher, connector, ClientSession, WSMsgType
|
||||
import confluent.auth as auth
|
||||
import confluent.config.attributes as attribs
|
||||
import confluent.config.configmanager as configmanager
|
||||
import confluent.consoleserver as consoleserver
|
||||
#import confluent.consoleserver as consoleserver
|
||||
import confluent.discovery.core as disco
|
||||
import confluent.forwarder as forwarder
|
||||
import confluent.exceptions as exc
|
||||
@@ -40,7 +40,7 @@ import confluent.core as pluginapi
|
||||
import confluent.asynchttp
|
||||
import confluent.selfservice as selfservice
|
||||
import confluent.shellserver as shellserver
|
||||
import confluent.tlvdata
|
||||
import confluent.asynctlvdata as tlvdata
|
||||
import confluent.util as util
|
||||
import copy
|
||||
import json
|
||||
@@ -52,7 +52,6 @@ try:
|
||||
import urlparse
|
||||
except ModuleNotFoundError:
|
||||
import urllib.parse as urlparse
|
||||
tlvdata = confluent.tlvdata
|
||||
|
||||
|
||||
_cleaner = None
|
||||
|
@@ -38,7 +38,7 @@ import ssl
|
||||
import confluent.auth as auth
|
||||
import confluent.credserver as credserver
|
||||
import confluent.config.conf as conf
|
||||
import confluent.tlvdata as tlvdata
|
||||
import confluent.asynctlvdata as tlvdata
|
||||
#import confluent.consoleserver as consoleserver
|
||||
import confluent.config.configmanager as configmanager
|
||||
import confluent.exceptions as exc
|
||||
|
Reference in New Issue
Block a user