diff --git a/confluent_client/confluent/client.py b/confluent_client/confluent/client.py index c70662ec..7ae7b24c 100644 --- a/confluent_client/confluent/client.py +++ b/confluent_client/confluent/client.py @@ -39,6 +39,10 @@ _attraliases = { 'bmcpass': 'secret.hardwaremanagementpassword', } +try: + input = raw_input +except NameError: + pass def stringify(instr): # Normalize unicode and bytes to 'str', correcting for @@ -219,7 +223,7 @@ class Command(object): return rc def simple_noderange_command(self, noderange, resource, input=None, - key=None, errnodes=None, **kwargs): + key=None, errnodes=None, promptover=None, **kwargs): try: self._currnoderange = noderange rc = 0 @@ -235,6 +239,8 @@ class Command(object): noderange, resource)): rc = self.handle_results(ikey, rc, res, errnodes) else: + if promptover is not None: + self.stop_if_noderange_over(noderange, promptover) kwargs[ikey] = input for res in self.update('/noderange/{0}/{1}'.format( noderange, resource), kwargs): @@ -244,6 +250,23 @@ class Command(object): except KeyboardInterrupt: cprint('') return 0 + + def stop_if_noderange_over(self, noderange, maxnodes): + nsize = self.get_noderange_size(noderange) + if nsize > maxnodes: + p = input('Command is about to affect {0} nodes, continue (y/n)?'.format(nsize)) + if p.lower() != 'y': + raise Exception("Aborting at user request") + + + def get_noderange_size(self, noderange): + numnodes = 0 + for node in self.read('/noderange/{0}/nodes/'.format(noderange)): + if node.get('item', {}).get('href', None): + numnodes += 1 + else: + raise Exception("Error trying to size noderange {0}".format(noderange)) + return numnodes def simple_nodegroups_command(self, noderange, resource, input=None, key=None, **kwargs): try: @@ -344,12 +367,8 @@ class Command(object): if fingerprint == khf[hostid]: return else: - try: - replace = raw_input( - "MISMATCHED CERTIFICATE DATA, ACCEPT NEW? (y/n):") - except NameError: - replace = input( - "MISMATCHED CERTIFICATE DATA, ACCEPT NEW? (y/n):") + replace = input( + "MISMATCHED CERTIFICATE DATA, ACCEPT NEW? (y/n):") if replace not in ('y', 'Y'): raise Exception("BAD CERTIFICATE") cprint('Adding new key for %s:%s' % (server, port))