Repository URL to install this package:
|
Version:
0.2.3 ▾
|
import click
import boto3
from pssh.pssh_client import ParallelSSHClient
import os
import time
__version__ = '0.2.3'
class InstanceList(object):
"""
A class for storing command line arguments.
"""
def __init__(self, name=None, role=None, role_type=None, env=None, availability_zone=None, asg=None, timeout=10):
self.filter_options = {
'Name': name,
'Role': role,
'RoleType': role_type,
'Environment': env,
'AvailabilityZone': availability_zone,
'aws:autoscaling:groupName': asg,
}
self.timeout = timeout
self.filters = self._buildFilters()
self.filters.append({
'Name': 'instance-state-name',
'Values': ['running']
})
def get_instance_list(self):
client = boto3.client('ec2')
response = client.describe_instances(
Filters=self.filters
)
instance_list = []
for reservation in response['Reservations']:
instance_list = instance_list + reservation['Instances']
return instance_list
def _buildFilters(self):
""" Build filters for boto3 query. """
filters = map(
lambda kv: {
'Name': 'tag:' + kv[0], 'Values': [kv[1]]
},
self.filter_options.items()
)
filters = [x for x in filters if x['Values'] != [None]]
return filters
def _list_by_tag(instances, tag_name='Name'):
""" Extract tag value for boto3 EC2 objects. """
tag_values = map(
lambda instance: _get_tag_value(instance, tag_name),
instances
)
return tag_values
def _get_tag_value(instance, tag_name):
""" Iterate through a tag list and get result by tag_name. """
return next((tag.get('Value', '') for tag in instance.get('Tags', {}) if tag.get('Key', '') == tag_name), '')
def _extract_time(time_string):
""" Convert user input time str (1, 23s, 1.5m, 1h) to seconds. """
try:
if time_string.isdigit(): # If input are int string
return int(time_string)
factor = 1
unit = time_string[-1].lower()
if unit == 's':
factor = 1
elif unit == 'm':
factor = 60
elif unit == 'h':
factor = 3600
else:
raise Exception("Unsupported format")
num = float(time_string[:-1])
res = int(num * factor)
return res if res >= 0 else None
except Exception:
return None
def _create_batches(instances, batchsize):
"""
A dumb way to generate batches.
batchsize is the number of instances in a batch.
"""
batch_size = int(batchsize)
if batch_size == 0:
raise ValueError("Batchsize should not be 0")
return
batch_count = (len(instances) - 1) // batch_size + 1
batches = [[] for i in range(batch_count)]
for index, instance in enumerate(instances):
batches[index // batch_size].append(instance)
return batches
def _command_helper(command, instances, timeout):
""" The actual function for running SSH command."""
IPs = [] # Use private IP for SSH
lookup = {} # {privateIP : hostname}
for instance in instances:
private_ip = instance['PrivateIpAddress']
IPs.append(private_ip)
lookup[private_ip] = _get_tag_value(instance, 'Name')
client = ParallelSSHClient(IPs, timeout=timeout) # GOTCHA
output = client.run_command(command)
results = []
for ip in output: # kv dict
results.append({
'host': lookup[ip],
'ip': ip,
'response': '\n'.join([line for line in output[ip]['stdout']])
})
return results
def print_version(ctx, param, value):
if not value or ctx.resilient_parsing:
return
click.echo('Version ' + __version__)
ctx.exit()
def _print_version():
click.echo('Version ' + __version__)
@click.group()
@click.option('-v', '--version', is_flag=True, callback=print_version,
expose_value=False, is_eager=True)
@click.option('-n', '--name', default=None, help='hostname filter')
@click.option('-r', '--role', default=None, help='role filter')
@click.option('-rt', '--roletype', default=None, help='roletype filter')
@click.option('-e', '--env', default=None, help='env to filter by')
@click.option('-az', '--availability-zone', default=None, help='az to filter by')
@click.option('-asg', '--autoscalinggroup', default=None, help='auto scaling group name to filter by')
@click.option('-t', '--timeout', help='SSH timeout', default=10, type=int)
@click.pass_context
def cli(context, name, role, roletype, env, availability_zone, autoscalinggroup, timeout):
context.obj = InstanceList(name, role, roletype, env, availability_zone, autoscalinggroup, timeout)
@click.command()
@click.pass_context
@click.option('-v', '--verbose/--no-verbose', default=False)
def list(context, verbose):
"""List all ec2 instances by name."""
instance_names = _list_by_tag(context.obj.get_instance_list())
click.echo('\n'.join(instance_names))
@click.command()
@click.pass_context
def run_more(context):
"""Run a series of commands on a group of hosts in parallel"""
instances = context.obj.get_instance_list()
instance_names = _list_by_tag(instances)
click.echo('\n'.join(sorted(instance_names)))
command = ''
while command != 'exit':
command = click.prompt(
'Enter a command to run on these instances (exit or ^C to cancel)'
)
# assume we don't need refresh instance list between runs.
results = _command_helper(command, instances, context.obj.timeout)
for result in results:
click.echo(result['host'] + '\n\t' + '\n\t'.join((result['response']).split('\n')))
@click.command()
@click.pass_context
@click.option('-c', '--command', default='hostname')
def run(context, command):
"""Run command across multiple servers."""
results = _command_helper(
command,
context.obj.get_instance_list(),
context.obj.timeout
)
for result in results:
padding = '\n\t'.ljust(len(result['host']))
click.echo(result['host'] + '\t' + padding.join((result['response']).split('\n')))
@click.command()
@click.pass_context
@click.option('-c', '--command', default='hostname', help="Command to run")
@click.option('-b', '--batchsize', default='1', help="Run command in groups of X hosts")
@click.option('-t', '--timeout', default='10',
help="Seconds(or mins, hrs) between batches. Accept: 15, 15s, 15m, 1.5h"
)
def run_batch(context, command, batchsize, timeout):
"""Run command in batches with timeout between runs."""
try:
timeout_seconds = _extract_time(timeout)
if timeout_seconds is None:
click.echo(
"Invalid timeout format." +
"Valid forms includes: 1, 123s, 1.5m, 3h"
)
return
# Use '//' to run under both Python 2 and 3
batches = _create_batches(context.obj.get_instance_list(), batchsize)
for index, batch in enumerate(batches):
results = _command_helper(command, batch, context.obj.timeout)
for result in results:
click.echo(
result['host'] + '\t' + '\n\t\t\t\t\t\t\t'
.join((result['response']).split('\n'))
)
if index < len(batches) - 1:
click.echo(
"\nBatch %d of %d completed. Wait %d seconds\n" %
(index + 1, len(batches), timeout_seconds)
)
time.sleep(timeout_seconds)
except ValueError:
click.echo("Please check parameter format. Or see 'zco run_batch --help'")
except Exception as e:
click.echo('Unexpected exceptions happened.')
print(e)
@click.command()
@click.pass_context
def update_autocomplete(context):
"""Update ssh autocomplete file."""
cb_hosts_file = os.path.expanduser('~/.ssh/cb_hosts')
if not os.path.isfile(cb_hosts_file):
old_count = 0
else:
with open(os.path.expanduser('~/.ssh/cb_hosts'), 'r+') as hosts_file:
old_count = sum(1 for line in hosts_file)
with open(os.path.expanduser('~/.ssh/cb_hosts'), 'w+') as hosts_file:
instance_names = _list_by_tag(context.obj.get_instance_list())
hostnames = [
instance_name + '.caffeine.io'
for instance_name in instance_names
]
hosts_file.write('\n'.join(hostnames) + '\n')
new_count = len(hostnames)
click.echo('Successfully updated autocomplete hosts list!')
click.echo(
'Previous host count: ' + click.style(str(old_count), fg='red') +
'\nNew host count: ' + click.style(str(new_count), fg='green')
)
@click.command()
def install_autocomplete():
"""Configure ssh autocomplete."""
directory = os.path.expanduser('~/.bash_completion/')
if not os.path.exists(directory):
os.makedirs(directory)
source = os.path.dirname(os.path.realpath(__file__)) + '/autocomplete.sh'
destination = os.path.expanduser('~/.bash_completion/autocomplete.sh')
if os.path.exists(destination):
click.echo('Have an existing symlink. Will renew it.')
os.unlink(destination)
os.symlink(source, destination)
cli.add_command(list)
cli.add_command(run)
cli.add_command(run_more)
cli.add_command(run_batch)
cli.add_command(install_autocomplete)
cli.add_command(update_autocomplete)