Why Gemfury? Push, build, and install  RubyGems npm packages Python packages Maven artifacts PHP packages Go Modules Debian packages RPM packages NuGet packages

Repository URL to install this package:

Details    
zco / zco / zco.py
Size: Mime:
import click
import boto3
from pssh.pssh_client import ParallelSSHClient
import os
import time


__version__ = '0.2.3'


class InstanceList(object):
    """
    A class for storing command line arguments.
    """

    def __init__(self, name=None, role=None, role_type=None, env=None, availability_zone=None, asg=None, timeout=10):
            self.filter_options = {
                'Name': name,
                'Role': role,
                'RoleType': role_type,
                'Environment': env,
                'AvailabilityZone': availability_zone,
                'aws:autoscaling:groupName': asg,
            }
            self.timeout = timeout
            self.filters = self._buildFilters()
            self.filters.append({
                'Name': 'instance-state-name',
                'Values': ['running']
            })

    def get_instance_list(self):
        client = boto3.client('ec2')
        response = client.describe_instances(
            Filters=self.filters
        )
        instance_list = []
        for reservation in response['Reservations']:
            instance_list = instance_list + reservation['Instances']
        return instance_list

    def _buildFilters(self):
        """ Build filters for boto3 query. """
        filters = map(
            lambda kv: {
                'Name': 'tag:' + kv[0], 'Values': [kv[1]]
            },
            self.filter_options.items()
        )
        filters = [x for x in filters if x['Values'] != [None]]
        return filters


def _list_by_tag(instances, tag_name='Name'):
    """ Extract tag value for boto3 EC2 objects. """
    tag_values = map(
        lambda instance: _get_tag_value(instance, tag_name),
        instances
    )
    return tag_values


def _get_tag_value(instance, tag_name):
    """ Iterate through a tag list and get result by tag_name. """

    return next((tag.get('Value', '') for tag in instance.get('Tags', {}) if tag.get('Key', '') == tag_name), '')


def _extract_time(time_string):
    """ Convert user input time str (1, 23s, 1.5m, 1h) to seconds. """

    try:
        if time_string.isdigit():  # If input are int string
            return int(time_string)
        factor = 1
        unit = time_string[-1].lower()
        if unit == 's':
            factor = 1
        elif unit == 'm':
            factor = 60
        elif unit == 'h':
            factor = 3600
        else:
            raise Exception("Unsupported format")

        num = float(time_string[:-1])
        res = int(num * factor)
        return res if res >= 0 else None
    except Exception:
        return None


def _create_batches(instances, batchsize):
    """
    A dumb way to generate batches.
    batchsize is the number of instances in a batch.
    """

    batch_size = int(batchsize)
    if batch_size == 0:
        raise ValueError("Batchsize should not be 0")
        return
    batch_count = (len(instances) - 1) // batch_size + 1
    batches = [[] for i in range(batch_count)]
    for index, instance in enumerate(instances):
        batches[index // batch_size].append(instance)
    return batches


def _command_helper(command, instances, timeout):
    """ The actual function for running SSH command."""

    IPs = []  # Use private IP for SSH
    lookup = {}  # {privateIP : hostname}
    for instance in instances:
        private_ip = instance['PrivateIpAddress']
        IPs.append(private_ip)
        lookup[private_ip] = _get_tag_value(instance, 'Name')
    client = ParallelSSHClient(IPs, timeout=timeout) # GOTCHA
    output = client.run_command(command)
    results = []
    for ip in output:  # kv dict
        results.append({
            'host': lookup[ip],
            'ip': ip,
            'response': '\n'.join([line for line in output[ip]['stdout']])
        })
    return results


def print_version(ctx, param, value):
    if not value or ctx.resilient_parsing:
        return
    click.echo('Version ' + __version__)
    ctx.exit()


def _print_version():
    click.echo('Version ' + __version__)


@click.group()
@click.option('-v', '--version', is_flag=True, callback=print_version,
              expose_value=False, is_eager=True)
@click.option('-n', '--name', default=None, help='hostname filter')
@click.option('-r', '--role', default=None, help='role filter')
@click.option('-rt', '--roletype', default=None, help='roletype filter')
@click.option('-e', '--env', default=None, help='env to filter by')
@click.option('-az', '--availability-zone', default=None, help='az to filter by')
@click.option('-asg', '--autoscalinggroup', default=None, help='auto scaling group name to filter by')
@click.option('-t', '--timeout', help='SSH timeout', default=10, type=int)
@click.pass_context
def cli(context, name, role, roletype, env, availability_zone, autoscalinggroup, timeout):
    context.obj = InstanceList(name, role, roletype, env, availability_zone, autoscalinggroup, timeout)


@click.command()
@click.pass_context
@click.option('-v', '--verbose/--no-verbose', default=False)
def list(context, verbose):
    """List all ec2 instances by name."""
    instance_names = _list_by_tag(context.obj.get_instance_list())
    click.echo('\n'.join(instance_names))


@click.command()
@click.pass_context
def run_more(context):
    """Run a series of commands on a group of hosts in parallel"""
    instances = context.obj.get_instance_list()
    instance_names = _list_by_tag(instances)
    click.echo('\n'.join(sorted(instance_names)))
    command = ''
    while command != 'exit':
        command = click.prompt(
            'Enter a command to run on these instances (exit or ^C to cancel)'
        )
        # assume we don't need refresh instance list between runs.
        results = _command_helper(command, instances, context.obj.timeout)
        for result in results:
            click.echo(result['host'] + '\n\t' + '\n\t'.join((result['response']).split('\n')))


@click.command()
@click.pass_context
@click.option('-c', '--command', default='hostname')
def run(context, command):
    """Run command across multiple servers."""
    results = _command_helper(
            command,
            context.obj.get_instance_list(),
            context.obj.timeout
    )
    for result in results:
        padding = '\n\t'.ljust(len(result['host']))
        click.echo(result['host'] + '\t' + padding.join((result['response']).split('\n')))


@click.command()
@click.pass_context
@click.option('-c', '--command', default='hostname', help="Command to run")
@click.option('-b', '--batchsize', default='1', help="Run command in groups of X hosts")
@click.option('-t', '--timeout', default='10',
        help="Seconds(or mins, hrs) between batches. Accept: 15, 15s, 15m, 1.5h"
)
def run_batch(context, command, batchsize, timeout):
    """Run command in batches with timeout between runs."""
    try:
        timeout_seconds = _extract_time(timeout)
        if timeout_seconds is None:
            click.echo(
                    "Invalid timeout format." +
                    "Valid forms includes: 1, 123s, 1.5m, 3h"
            )
            return
        # Use '//' to run under both Python 2 and 3
        batches = _create_batches(context.obj.get_instance_list(), batchsize)
        for index, batch in enumerate(batches):
            results = _command_helper(command, batch, context.obj.timeout)
            for result in results:
                click.echo(
                        result['host'] + '\t' + '\n\t\t\t\t\t\t\t'
                        .join((result['response']).split('\n'))
                )
            if index < len(batches) - 1:
                click.echo(
                        "\nBatch %d of %d completed. Wait %d seconds\n" %
                        (index + 1, len(batches), timeout_seconds)
                )
                time.sleep(timeout_seconds)
    except ValueError:
        click.echo("Please check parameter format. Or see 'zco run_batch --help'")
    except Exception as e:
        click.echo('Unexpected exceptions happened.')
        print(e)


@click.command()
@click.pass_context
def update_autocomplete(context):
    """Update ssh autocomplete file."""
    cb_hosts_file = os.path.expanduser('~/.ssh/cb_hosts')
    if not os.path.isfile(cb_hosts_file):
        old_count = 0
    else:
        with open(os.path.expanduser('~/.ssh/cb_hosts'), 'r+') as hosts_file:
            old_count = sum(1 for line in hosts_file)
    with open(os.path.expanduser('~/.ssh/cb_hosts'), 'w+') as hosts_file:
        instance_names = _list_by_tag(context.obj.get_instance_list())
        hostnames = [
            instance_name + '.caffeine.io'
            for instance_name in instance_names
        ]
        hosts_file.write('\n'.join(hostnames) + '\n')
    new_count = len(hostnames)
    click.echo('Successfully updated autocomplete hosts list!')
    click.echo(
        'Previous host count: ' + click.style(str(old_count), fg='red') +
        '\nNew host count: ' + click.style(str(new_count), fg='green')
    )


@click.command()
def install_autocomplete():
    """Configure ssh autocomplete."""
    directory = os.path.expanduser('~/.bash_completion/')
    if not os.path.exists(directory):
        os.makedirs(directory)
    source = os.path.dirname(os.path.realpath(__file__)) + '/autocomplete.sh'
    destination = os.path.expanduser('~/.bash_completion/autocomplete.sh')
    if os.path.exists(destination):
        click.echo('Have an existing symlink. Will renew it.')
        os.unlink(destination)
    os.symlink(source, destination)


cli.add_command(list)
cli.add_command(run)
cli.add_command(run_more)
cli.add_command(run_batch)
cli.add_command(install_autocomplete)
cli.add_command(update_autocomplete)