Repository URL to install this package:
Version:
6.0.1.dev7 ▾
|
# vim: tabstop=4 shiftwidth=4 softtabstop=4
# Copyright (c) 2013 TrilioData, Inc.
# All Rights Reserved.
"""Utilities and helper functions."""
import argparse
import contextlib
import datetime
import errno
import functools
import hashlib
import inspect
import os
import pyclbr
import random
import re
import shlex
import shutil
import signal
import socket
import sys
import tempfile
import time
import pytz
import json
import paramiko
import configparser
from urllib.parse import urlparse
import uuid
from datetime import time as datetime_time
from datetime import timezone as datetime_timezone
from pathlib import Path
from defusedxml import minidom
from xml.parsers import expat
from defusedxml.sax import expatreader
from xml.sax import saxutils # nosec
import netifaces as ni
from netifaces import AF_INET, AF_INET6, AF_LINK, AF_PACKET, AF_BRIDGE
from netifaces import interfaces, ifaddresses, AF_INET
from distutils.sysconfig import EXEC_PREFIX
from eventlet import event
from eventlet.green import subprocess
from eventlet import greenthread
from eventlet import pools
from oslo_utils import strutils
from oslo_config import cfg
from workloadmgr import exception
from workloadmgr import flags
from workloadmgr.pyvmomi_tools import service_instance as pyvmomi_si
from workloadmgr.openstack.common.gettextutils import _
from workloadmgr.openstack.common import excutils
from workloadmgr.openstack.common import importutils
from workloadmgr.openstack.common import lockutils
from workloadmgr.openstack.common import log as logging
from workloadmgr.openstack.common import timeutils
from cryptography.fernet import Fernet
CONF = cfg.CONF
LOG = logging.getLogger(__name__)
util_opts = [
cfg.StrOpt('triliovault_hostnames',
default='127.0.0.1',
help='Specify trilio hostname=ip pair'),
]
CONF = cfg.CONF
CONF.register_opts(util_opts)
ISO_TIME_FORMAT = "%Y-%m-%dT%H:%M:%S"
PERFECT_TIME_FORMAT = "%Y-%m-%dT%H:%M:%S.%f"
FLAGS = flags.FLAGS
_IS_NEUTRON_ATTEMPTED = False
_IS_NEUTRON = False
synchronized = lockutils.synchronized_with_prefix('workloadmgr-')
tvault_key_file_name = '/etc/triliovault-wlm/tvault_key.pem'
_FORMAT_PATTERNS_1 = [r'(%(key)s[0-9]*\s*[=]\s*)[^\s^\'^\"]+']
_FORMAT_PATTERNS_2 = [r'(%(key)s[0-9]*\s*[=]\s*[\"\'])[^\"\']*([\"\'])',
r'([-]{2}%(key)s[0-9]*\s+)[^\'^\"^=^\s]+([\s]*)']
_SANITIZE_KEYS = ["data", "secret"]
_SANITIZE_PATTERNS_2 = {}
_SANITIZE_PATTERNS_1 = {}
for key in _SANITIZE_KEYS:
_SANITIZE_PATTERNS_2[key] = []
for pattern in _FORMAT_PATTERNS_2:
reg_ex = re.compile(pattern % {'key': key}, re.DOTALL | re.IGNORECASE)
_SANITIZE_PATTERNS_2[key].append(reg_ex)
_SANITIZE_PATTERNS_1[key] = []
for pattern in _FORMAT_PATTERNS_1:
reg_ex = re.compile(pattern % {'key': key}, re.DOTALL | re.IGNORECASE)
_SANITIZE_PATTERNS_1[key].append(reg_ex)
def sanitize_message(message, secret="***"):
substitute1 = r'\g<1>' + secret
substitute2 = r'\g<1>' + secret + r'\g<2>'
if isinstance(message, list):
message = " ".join(message)
if isinstance(message, str):
for key in _SANITIZE_KEYS:
if key in message.lower():
for pattern in _SANITIZE_PATTERNS_2[key]:
message = re.sub(pattern, substitute2, message)
for pattern in _SANITIZE_PATTERNS_1[key]:
message = re.sub(pattern, substitute1, message)
elif isinstance(message, dict):
strutils.mask_dict_password(message)
return message
def find_config(config_path):
"""Find a configuration file using the given hint.
:param config_path: Full or relative path to the config.
:returns: Full path of the config, if it exists.
:raises: `workloadmgr.exception.ConfigNotFound`
"""
possible_locations = [
config_path,
os.path.join(FLAGS.state_path, "etc", "triliovault-wlm", config_path),
os.path.join(FLAGS.state_path, "etc", config_path),
os.path.join(FLAGS.state_path, config_path),
"/etc/triliovault-wlm/%s" % config_path,
]
for path in possible_locations:
if os.path.exists(path):
return os.path.abspath(path)
raise exception.ConfigNotFound(path=os.path.abspath(config_path))
def fetchfile(url, target):
LOG.debug(_('Fetching %s') % url)
execute('curl', '--fail', url, '-o', target)
def _subprocess_setup():
# Python installs a SIGPIPE handler by default. This is usually not what
# non-Python subprocesses expect.
signal.signal(signal.SIGPIPE, signal.SIG_DFL)
def execute(*cmd, **kwargs):
"""Helper method to execute command with optional retry.
If you add a run_as_root=True command, don't forget to add the
corresponding filter to etc/triliovault-wlm/rootwrap.d !
:param cmd: Passed to subprocess.Popen.
:param process_input: Send to opened process.
:param check_exit_code: Single bool, int, or list of allowed exit
codes. Defaults to [0]. Raise
exception.ProcessExecutionError unless
program exits with one of these code.
:param delay_on_retry: True | False. Defaults to True. If set to
True, wait a short amount of time
before retrying.
:param attempts: How many times to retry cmd.
:param run_as_root: True | False. Defaults to False. If set to True,
the command is prefixed by the command specified
in the root_helper FLAG.
:raises exception.Error: on receiving unknown arguments
:raises exception.ProcessExecutionError:
:returns: a tuple, (stdout, stderr) from the spawned process, or None if
the command fails.
"""
process_input = kwargs.pop('process_input', None)
check_exit_code = kwargs.pop('check_exit_code', [0])
ignore_exit_code = False
if isinstance(check_exit_code, bool):
ignore_exit_code = not check_exit_code
check_exit_code = [0]
elif isinstance(check_exit_code, int):
check_exit_code = [check_exit_code]
delay_on_retry = kwargs.pop('delay_on_retry', True)
attempts = kwargs.pop('attempts', 1)
run_as_root = kwargs.pop('run_as_root', False)
shell = kwargs.pop('shell', False)
if len(kwargs):
raise exception.Error(_('Got unknown keyword args '
'to utils.execute: %r') % kwargs)
if run_as_root:
if FLAGS.rootwrap_config is None or FLAGS.root_helper != 'sudo':
LOG.deprecated(_('The root_helper option (which lets you specify '
'a root wrapper different from workloadmgr-rootwrap, '
'and defaults to using sudo) is now deprecated. '
'You should use the rootwrap_config option '
'instead.'))
if (FLAGS.rootwrap_config is not None):
cmd = ['sudo', '%s/bin/workloadmgr-rootwrap'%(EXEC_PREFIX),
FLAGS.rootwrap_config] + list(cmd)
else:
cmd = shlex.split(FLAGS.root_helper) + list(cmd)
cmd = list(map(str, cmd))
sanitize_cmd = sanitize_message(cmd)
while attempts > 0:
attempts -= 1
try:
LOG.debug(_('Running cmd (subprocess): %s'), (sanitize_cmd))
_PIPE = subprocess.PIPE # pylint: disable=E1101
obj = subprocess.Popen(cmd,
stdin=_PIPE,
stdout=_PIPE,
stderr=_PIPE,
close_fds=True,
preexec_fn=_subprocess_setup,
shell=shell)
result = None
if process_input is not None:
result = obj.communicate(bytes(process_input, 'utf-8'))
else:
result = obj.communicate()
obj.stdin.close() # pylint: disable=E1101
_returncode = obj.returncode # pylint: disable=E1101
if _returncode:
LOG.debug(_('Result was %s') % _returncode)
if not ignore_exit_code and _returncode not in check_exit_code:
(stdout, stderr) = result
raise exception.ProcessExecutionError(
exit_code=_returncode,
stdout=stdout,
stderr=stderr,
cmd=sanitize_cmd)
return (str(result[0], encoding='utf-8'), str(result[1], encoding='utf-8'))
except exception.ProcessExecutionError:
if not attempts:
raise
else:
LOG.debug(_('%r failed. Retrying.'), sanitize_cmd)
if delay_on_retry:
greenthread.sleep(random.SystemRandom().randint(20, 200) / 100.0)
finally:
# NOTE(termie): this appears to be necessary to let the subprocess
# call clean something up in between calls, without
# it two execute calls in a row hangs the second one
greenthread.sleep(0)
def trycmd(*args, **kwargs):
"""
A wrapper around execute() to more easily handle warnings and errors.
Returns an (out, err) tuple of strings containing the output of
the command's stdout and stderr. If 'err' is not empty then the
command can be considered to have failed.
:discard_warnings True | False. Defaults to False. If set to True,
then for succeeding commands, stderr is cleared
"""
discard_warnings = kwargs.pop('discard_warnings', False)
try:
out, err = execute(*args, **kwargs)
failed = False
except exception.ProcessExecutionError as exn:
out, err = '', str(exn)
LOG.debug(err)
failed = True
if not failed and discard_warnings and err:
# Handle commands that output to stderr but otherwise succeed
LOG.debug(err)
err = ''
return out, err
def create_channel(client, width, height):
"""Invoke an interactive shell session on server."""
channel = client.invoke_shell()
channel.resize_pty(width, height)
return channel
class SSHPool(pools.Pool):
"""A simple eventlet pool to hold ssh connections."""
def __init__(self, ip, port, conn_timeout, login, password=None,
privatekey=None, *args, **kwargs):
self.ip = ip
self.port = port
self.login = login
self.password = password
self.conn_timeout = conn_timeout if conn_timeout else None
self.privatekey = privatekey
super(SSHPool, self).__init__(*args, **kwargs)
def create(self):
try:
ssh = paramiko.SSHClient()
ssh.set_missing_host_key_policy(paramiko.AutoAddPolicy())
if self.password:
ssh.connect(self.ip,
port=self.port,
username=self.login,
password=self.password,
timeout=self.conn_timeout)
elif self.privatekey:
pkfile = os.path.expanduser(self.privatekey)
privatekey = paramiko.RSAKey.from_private_key_file(pkfile)
ssh.connect(self.ip,
port=self.port,
username=self.login,
pkey=privatekey,
timeout=self.conn_timeout)
else:
msg = _("Specify a password or private_key")
raise exception.WorkloadMgrException(msg)
# Paramiko by default sets the socket timeout to 0.1 seconds,
# ignoring what we set thru the sshclient. This doesn't help for
# keeping long lived connections. Hence we have to bypass it, by
# overriding it after the transport is initialized. We are setting
# the sockettimeout to None and setting a keepalive packet so that,
# the server will keep the connection open. All that does is send
# a keepalive packet every ssh_conn_timeout seconds.
if self.conn_timeout:
transport = ssh.get_transport()
transport.sock.settimeout(None)
transport.set_keepalive(self.conn_timeout)
return ssh
except Exception as e:
msg = _("Error connecting via ssh: %s") % e
LOG.error(msg)
raise paramiko.SSHException(msg)
def get(self):
"""
Return an item from the pool, when one is available. This may
cause the calling greenthread to block. Check if a connection is active
before returning it. For dead connections create and return a new
connection.
"""
if self.free_items:
conn = self.free_items.popleft()
if conn:
if conn.get_transport().is_active():
return conn
else:
conn.close()
return self.create()
if self.current_size < self.max_size:
created = self.create()
self.current_size += 1
return created
return self.channel.get()
def remove(self, ssh):
"""Close an ssh client and remove it from free_items."""
ssh.close()
ssh = None
if ssh in self.free_items:
self.free_items.pop(ssh)
if self.current_size > 0:
self.current_size -= 1
def workloadmgrdir():
import workloadmgr
return os.path.abspath(workloadmgr.__file__).split(
'workloadmgr/__init__.py')[0]
def debug(arg):
LOG.debug(_('debug in callback: %s'), arg)
return arg
def generate_uid(topic, size=8):
characters = '01234567890abcdefghijklmnopqrstuvwxyz'
choices = [random.SystemRandom().choice(characters) for x in range(size)]
return '%s-%s' % (topic, ''.join(choices))
def last_octet(address):
return int(address.split('.')[-1])
def get_my_linklocal(interface):
try:
if_str = execute('ip', '-f', 'inet6', '-o', 'addr', 'show', interface)
condition = '\s+inet6\s+([0-9a-f:]+)/\d+\s+scope\s+link'
links = [re.search(condition, x) for x in if_str[0].split('\n')]
address = [w.group(1) for w in links if w is not None]
if address[0] is not None:
return address[0]
else:
raise exception.Error(_('Link Local address is not found.:%s')
% if_str)
except Exception as ex:
raise exception.Error(_("Couldn't get Link Local IP of %(interface)s"
" :%(ex)s") % locals())
def parse_mailmap(mailmap='.mailmap'):
mapping = {}
if os.path.exists(mailmap):
fp = open(mailmap, 'r')
for ll in fp:
ll = ll.strip()
if not ll.startswith('#') and ' ' in ll:
canonical_email, alias = ll.split(' ')
mapping[alias.lower()] = canonical_email.lower()
return mapping
def str_dict_replace(s, mapping):
for s1, s2 in mapping.items():
s = s.replace(s1, s2)
return s
class LazyPluggable(object):
"""A pluggable backend loaded lazily based on some value."""
def __init__(self, pivot, **backends):
self.__backends = backends
self.__pivot = pivot
self.__backend = None
def __get_backend(self):
if not self.__backend:
backend_name = FLAGS[self.__pivot]
if backend_name not in self.__backends:
raise exception.Error(_('Invalid backend: %s') % backend_name)
backend = self.__backends[backend_name]
if isinstance(backend, tuple):
name = backend[0]
fromlist = backend[1]
else:
name = backend
fromlist = backend
self.__backend = __import__(name, None, None, fromlist)
LOG.debug(_('backend %s'), self.__backend)
return self.__backend
def __getattr__(self, key):
backend = self.__get_backend()
return getattr(backend, key)
class LoopingCallDone(Exception):
"""Exception to break out and stop a LoopingCall.
The poll-function passed to LoopingCall can raise this exception to
break out of the loop normally. This is somewhat analogous to
StopIteration.
An optional return-value can be included as the argument to the exception;
this return-value will be returned by LoopingCall.wait()
"""
def __init__(self, retvalue=True):
""":param retvalue: Value that LoopingCall.wait() should return."""
self.retvalue = retvalue
class LoopingCall(object):
def __init__(self, f=None, *args, **kw):
self.args = args
self.kw = kw
self.f = f
self._running = False
def start(self, interval, initial_delay=None):
self._running = True
done = event.Event()
def _inner():
if initial_delay:
greenthread.sleep(initial_delay)
try:
while self._running:
self.f(*self.args, **self.kw)
if not self._running:
break
greenthread.sleep(interval)
except LoopingCallDone as e:
self.stop()
done.send(e.retvalue)
except Exception:
LOG.exception(_('in looping call'))
done.send_exception(*sys.exc_info())
return
else:
done.send(True)
self.done = done
greenthread.spawn(_inner)
return self.done
def stop(self):
self._running = False
def wait(self):
return self.done.wait()
class ProtectedExpatParser(expatreader.DefusedExpatParser):
"""An expat parser which disables DTD's and entities by default."""
def __init__(self, forbid_dtd=True, forbid_entities=True,
*args, **kwargs):
# Python 2.x old style class
expatreader.DefusedExpatParser.__init__(self, *args, **kwargs)
self.forbid_dtd = forbid_dtd
self.forbid_entities = forbid_entities
def start_doctype_decl(self, name, sysid, pubid, has_internal_subset):
raise ValueError("Inline DTD forbidden")
def entity_decl(self, entityName, is_parameter_entity, value, base,
systemId, publicId, notationName):
raise ValueError("<!ENTITY> forbidden")
def unparsed_entity_decl(self, name, base, sysid, pubid, notation_name):
# expat 1.2
raise ValueError("<!ENTITY> forbidden")
def reset(self):
expatreader.DefusedExpatParser.reset(self)
if self.forbid_dtd:
self._parser.StartDoctypeDeclHandler = self.start_doctype_decl
if self.forbid_entities:
self._parser.EntityDeclHandler = self.entity_decl
self._parser.UnparsedEntityDeclHandler = self.unparsed_entity_decl
def safe_minidom_parse_string(xml_string):
"""Parse an XML string using minidom safely.
"""
try:
return minidom.parseString(xml_string, parser=ProtectedExpatParser())
except Exception as ex:
raise ex
def xhtml_escape(value):
"""Escapes a string so it is valid within XML or XHTML.
"""
return saxutils.escape(value, {'"': '"', "'": '''})
def utf8(value):
"""Try to turn a string into utf-8 if possible.
Code is directly from the utf8 function in
http://github.com/facebook/tornado/blob/master/tornado/escape.py
"""
if isinstance(value, str):
return value.encode('utf-8')
assert isinstance(value, str)
return value
def delete_if_exists(pathname):
"""delete a file, but ignore file not found error"""
try:
os.unlink(pathname)
except OSError as e:
if e.errno == errno.ENOENT:
return
else:
raise
def get_from_path(items, path):
"""Returns a list of items matching the specified path.
Takes an XPath-like expression e.g. prop1/prop2/prop3, and for each item
in items, looks up items[prop1][prop2][prop3]. Like XPath, if any of the
intermediate results are lists it will treat each list item individually.
A 'None' in items or any child expressions will be ignored, this function
will not throw because of None (anywhere) in items. The returned list
will contain no None values.
"""
if path is None:
raise exception.Error('Invalid mini_xpath')
(first_token, sep, remainder) = path.partition('/')
if first_token == '':
raise exception.Error('Invalid mini_xpath')
results = []
if items is None:
return results
if not isinstance(items, list):
# Wrap single objects in a list
items = [items]
for item in items:
if item is None:
continue
get_method = getattr(item, 'get', None)
if get_method is None:
continue
child = get_method(first_token)
if child is None:
continue
if isinstance(child, list):
# Flatten intermediate lists
for x in child:
results.append(x)
else:
results.append(child)
if not sep:
# No more tokens
return results
else:
return get_from_path(results, remainder)
def flatten_dict(dict_, flattened=None):
"""Recursively flatten a nested dictionary."""
flattened = flattened or {}
for key, value in dict_.items():
if hasattr(value, 'iteritems'):
flatten_dict(value, flattened)
else:
flattened[key] = value
return flattened
def partition_dict(dict_, keys):
"""Return two dicts, one with `keys` the other with everything else."""
intersection = {}
difference = {}
for key, value in dict_.items():
if key in keys:
intersection[key] = value
else:
difference[key] = value
return intersection, difference
def map_dict_keys(dict_, key_map):
"""Return a dict in which the dictionaries keys are mapped to new keys."""
mapped = {}
for key, value in dict_.items():
mapped_key = key_map[key] if key in key_map else key
mapped[mapped_key] = value
return mapped
def subset_dict(dict_, keys):
"""Return a dict that only contains a subset of keys."""
subset = partition_dict(dict_, keys)[0]
return subset
def check_isinstance(obj, cls):
"""Checks that obj is of type cls, and lets PyLint infer types."""
if isinstance(obj, cls):
return obj
raise Exception(_('Expected object of type: %s') % (str(cls)))
# TODO(justinsb): Can we make this better??
def is_valid_boolstr(val):
"""Check if the provided string is a valid bool string or not. """
val = str(val).lower()
return (val == 'true' or val == 'false' or
val == 'yes' or val == 'no' or
val == 'y' or val == 'n' or
val == '1' or val == '0')
def is_valid_ipv4(address):
"""valid the address strictly as per format xxx.xxx.xxx.xxx.
where xxx is a value between 0 and 255.
"""
parts = address.split(".")
if len(parts) != 4:
return False
for item in parts:
try:
if not 0 <= int(item) <= 255:
return False
except ValueError:
return False
return True
def monkey_patch():
""" If the Flags.monkey_patch set as True,
this function patches a decorator
for all functions in specified modules.
You can set decorators for each modules
using FLAGS.monkey_patch_modules.
The format is "Module path:Decorator function".
Example: 'workloadmgr.api.ec2.cloud:' \
workloadmgr.openstack.common.notifier.api.notify_decorator'
Parameters of the decorator is as follows.
(See workloadmgr.openstack.common.notifier.api.notify_decorator)
name - name of the function
function - object of the function
"""
# If FLAGS.monkey_patch is not True, this function do nothing.
if not FLAGS.monkey_patch:
return
# Get list of modules and decorators
for module_and_decorator in FLAGS.monkey_patch_modules:
module, decorator_name = module_and_decorator.split(':')
# import decorator function
decorator = importutils.import_class(decorator_name)
__import__(module)
# Retrieve module information using pyclbr
module_data = pyclbr.readmodule_ex(module)
for key in list(module_data.keys()):
# set the decorator for the class methods
if isinstance(module_data[key], pyclbr.Class):
clz = importutils.import_class("%s.%s" % (module, key))
for method, func in inspect.getmembers(clz, inspect.ismethod):
setattr(
clz, method,
decorator("%s.%s.%s" % (module, key, method), func))
# set the decorator for the function
if isinstance(module_data[key], pyclbr.Function):
func = importutils.import_class("%s.%s" % (module, key))
setattr(sys.modules[module], key,
decorator("%s.%s" % (module, key), func))
def convert_to_list_dict(lst, label):
"""Convert a value or list into a list of dicts"""
if not lst:
return None
if not isinstance(lst, list):
lst = [lst]
return [{label: x} for x in lst]
def timefunc(func):
"""Decorator that logs how long a particular function took to execute"""
@functools.wraps(func)
def inner(*args, **kwargs):
start_time = time.time()
try:
return func(*args, **kwargs)
finally:
total_time = time.time() - start_time
LOG.debug(_("timefunc: '%(name)s' took %(total_time).2f secs") %
dict(name=func.__name__, total_time=total_time))
return inner
@contextlib.contextmanager
def logging_error(message):
"""Catches exception, write message to the log, re-raise.
This is a common refinement of save_and_reraise that writes a specific
message to the log.
"""
try:
yield
except Exception as error:
with excutils.save_and_reraise_exception():
LOG.exception(message)
@contextlib.contextmanager
def remove_path_on_error(path):
"""Protect code that wants to operate on PATH atomically.
Any exception will cause PATH to be removed.
"""
try:
yield
except Exception:
with excutils.save_and_reraise_exception():
delete_if_exists(path)
def make_dev_path(dev, partition=None, base='/dev'):
"""Return a path to a particular device.
>>> make_dev_path('xvdc')
/dev/xvdc
>>> make_dev_path('xvdc', 1)
/dev/xvdc1
"""
path = os.path.join(base, dev)
if partition:
path += str(partition)
return path
def total_seconds(td):
"""Local total_seconds implementation for compatibility with python 2.6"""
if hasattr(td, 'total_seconds'):
return td.total_seconds()
else:
return ((td.days * 86400 + td.seconds) * 10 ** 6 +
td.microseconds) / 10.0 ** 6
def sanitize_hostname(hostname):
"""Return a hostname which conforms to RFC-952 and RFC-1123 specs."""
if isinstance(hostname, str):
hostname = hostname.encode('utf-8').decode()
hostname = re.sub('[ _]', '-', hostname)
hostname = re.sub('[^\w.-]+', '', hostname)
hostname = hostname.lower()
hostname = hostname.strip('.-')
return hostname
def read_cached_file(filename, cache_info, reload_func=None):
"""Read from a file if it has been modified.
:param cache_info: dictionary to hold opaque cache.
:param reload_func: optional function to be called with data when
file is reloaded due to a modification.
:returns: data from file
"""
mtime = os.path.getmtime(filename)
if not cache_info or mtime != cache_info.get('mtime'):
with open(filename) as fap:
cache_info['data'] = fap.read()
cache_info['mtime'] = mtime
if reload_func:
reload_func(cache_info['data'])
return cache_info['data']
def file_open(*args, **kwargs):
"""Open file
see built-in file() documentation for more details
Note: The reason this is kept in a separate module is to easily
be able to provide a stub module that doesn't alter system
state at all (for unit tests)
"""
return file(*args, **kwargs)
def hash_file(file_like_object):
"""Generate a hash for the contents of a file."""
checksum = hashlib.sha1()
s = []
any(map(s.append, iter(lambda: file_like_object.read(32768), '')))
s = [x.encode('utf-8') for x in s]
any(map(checksum.update, iter(s)))
return checksum.hexdigest()
@contextlib.contextmanager
def temporary_mutation(obj, **kwargs):
"""Temporarily set the attr on a particular object to a given value then
revert when finished.
One use of this is to temporarily set the read_deleted flag on a context
object:
with temporary_mutation(context, read_deleted="yes"):
do_something_that_needed_deleted_objects()
"""
NOT_PRESENT = object()
old_values = {}
for attr, new_value in list(kwargs.items()):
old_values[attr] = getattr(obj, attr, NOT_PRESENT)
setattr(obj, attr, new_value)
try:
yield
finally:
for attr, old_value in list(old_values.items()):
if old_value is NOT_PRESENT:
del obj[attr]
else:
setattr(obj, attr, old_value)
def service_is_up(service):
"""Check whether a service is up based on last heartbeat."""
last_heartbeat = service['updated_at'] or service['created_at']
# Timestamps in DB are UTC.
elapsed = total_seconds(timeutils.utcnow() - last_heartbeat)
return abs(elapsed) <= FLAGS.service_down_time
def generate_mac_address():
"""Generate an Ethernet MAC address."""
# NOTE(vish): We would prefer to use 0xfe here to ensure that linux
# bridge mac addresses don't change, but it appears to
# conflict with libvirt, so we use the next highest octet
# that has the unicast and locally administered bits set
# properly: 0xfa.
# Discussion: https://bugs.launchpad.net/workloadmgr/+bug/921838
mac = [0xfa, 0x16, 0x3e,
random.SystemRandom().randint(0x00, 0x7f),
random.SystemRandom().randint(0x00, 0xff),
random.SystemRandom().randint(0x00, 0xff)]
return ':'.join(["%02x" % x for x in mac])
def read_file_as_root(file_path):
"""Secure helper to read file as root."""
try:
out, _err = execute('cat', file_path)
return out
except exception.ProcessExecutionError as ex:
if hasattr(ex, 'cmd'):
ex.cmd = sanitize_message(ex.cmd)
raise exception.FileNotFound(file_path=file_path)
@contextlib.contextmanager
def temporary_chown(path, owner_uid=None):
"""Temporarily chown a path.
:params owner_uid: UID of temporary owner (defaults to current user)
"""
if owner_uid is None:
owner_uid = os.getuid()
orig_uid = os.stat(path).st_uid
if orig_uid != owner_uid:
execute('chown', owner_uid, path)
try:
yield
finally:
if orig_uid != owner_uid:
execute('chown', orig_uid, path)
def chmod(path, mode):
"""change the mode of the file.
:params mode
"""
execute('chmod', mode, path)
@contextlib.contextmanager
def tempdir(**kwargs):
tmpdir = tempfile.mkdtemp(**kwargs)
try:
yield tmpdir
finally:
try:
shutil.rmtree(tmpdir)
except OSError as e:
LOG.debug(_('Could not remove tmpdir: %s'), str(e))
def strcmp_const_time(s1, s2):
"""Constant-time string comparison.
:params s1: the first string
:params s2: the second string
:return: True if the strings are equal.
This function takes two strings and compares them. It is intended to be
used when doing a comparison for authentication purposes to help guard
against timing attacks.
"""
if len(s1) != len(s2):
return False
result = 0
for (a, b) in zip(s1, s2):
result |= ord(a) ^ ord(b)
return result == 0
def walk_class_hierarchy(clazz, encountered=None):
"""Walk class hierarchy, yielding most derived classes first"""
if not encountered:
encountered = []
for subclass in clazz.__subclasses__():
if subclass not in encountered:
encountered.append(subclass)
# drill down to leaves first
for subsubclass in walk_class_hierarchy(subclass, encountered):
yield subsubclass
yield subclass
class UndoManager(object):
"""Provides a mechanism to facilitate rolling back a series of actions
when an exception is raised.
"""
def __init__(self):
self.undo_stack = []
def undo_with(self, undo_func):
self.undo_stack.append(undo_func)
def _rollback(self):
for undo_func in reversed(self.undo_stack):
undo_func()
def rollback_and_reraise(self, msg=None, **kwargs):
"""Rollback a series of actions then re-raise the exception.
.. note:: (sirp) This should only be called within an
exception handler.
"""
with excutils.save_and_reraise_exception():
if msg:
LOG.exception(msg, **kwargs)
self._rollback()
def ensure_tree(path):
"""Create a directory (and any ancestor directories required)
:param path: Directory to create
"""
try:
os.makedirs(path)
except OSError as exc:
if exc.errno == errno.EEXIST:
if not os.path.isdir(path):
raise
else:
raise
def to_bytes(text, default=0):
"""Try to turn a string into a number of bytes. Looks at the last
characters of the text to determine what conversion is needed to
turn the input text into a byte number.
Supports: B/b, K/k, M/m, G/g, T/t, KiB, MiB, GiB, TiB,
(or the same with b/B on the end)
"""
BYTE_MULTIPLIERS = {
'': 1,
't': 1024 ** 4,
'g': 1024 ** 3,
'm': 1024 ** 2,
'k': 1024,
}
# Take off everything not number 'like' (which should leave
# only the byte 'identifier' left)
mult_key_org = text.lstrip('-1234567890.')
mult_key = mult_key_org.strip().lower()
mult_key_len = len(mult_key)
if mult_key.endswith("ib"):
mult_key = mult_key[0:-2]
if mult_key.endswith("b"):
mult_key = mult_key[0:-1]
try:
multiplier = BYTE_MULTIPLIERS[mult_key]
if mult_key_len:
# Empty cases shouldn't cause text[0:-0]
text = text[0:-mult_key_len].strip()
return int(float(text) * multiplier)
except KeyError:
msg = _('Unknown byte multiplier: %s') % mult_key_org
raise TypeError(msg)
except ValueError:
return default
class LoopingCallDone(Exception):
"""Exception to break out and stop a LoopingCall.
The poll-function passed to LoopingCall can raise this exception to
break out of the loop normally. This is somewhat analogous to
StopIteration.
An optional return-value can be included as the argument to the exception;
this return-value will be returned by LoopingCall.wait()
"""
def __init__(self, retvalue=True):
""":param retvalue: Value that LoopingCall.wait() should return."""
self.retvalue = retvalue
class LoopingCallBase(object):
def __init__(self, f=None, *args, **kw):
self.args = args
self.kw = kw
self.f = f
self._running = False
self.done = None
def stop(self):
self._running = False
def wait(self):
return self.done.wait()
class FixedIntervalLoopingCall(LoopingCallBase):
"""A looping call which happens at a fixed interval."""
def start(self, interval, initial_delay=None):
self._running = True
done = event.Event()
def _inner():
if initial_delay:
greenthread.sleep(initial_delay)
try:
while self._running:
self.f(*self.args, **self.kw)
if not self._running:
break
greenthread.sleep(interval)
except LoopingCallDone as e:
self.stop()
done.send(e.retvalue)
except Exception:
LOG.exception(_('in fixed duration looping call'))
done.send_exception(*sys.exc_info())
return
else:
done.send(True)
self.done = done
greenthread.spawn(_inner)
return self.done
def is_valid_ipv4(address):
"""Verify that address represents a valid IPv4 address."""
try:
return netaddr.valid_ipv4(address)
except Exception:
return False
def is_valid_ipv6(address):
try:
return netaddr.valid_ipv6(address)
except Exception:
return False
def is_valid_ipv6_cidr(address):
try:
str(netaddr.IPNetwork(address, version=6).cidr)
return True
except Exception:
return False
def get_shortened_ipv6(address):
addr = netaddr.IPAddress(address, version=6)
return str(addr.ipv6())
def get_shortened_ipv6_cidr(address):
net = netaddr.IPNetwork(address, version=6)
return str(net.cidr)
def is_valid_cidr(address):
"""Check if address is valid
The provided address can be a IPv6 or a IPv4
CIDR address.
"""
try:
# Validate the correct CIDR Address
netaddr.IPNetwork(address)
except netaddr.core.AddrFormatError:
return False
except UnboundLocalError:
# NOTE(MotoKen): work around bug in netaddr 0.7.5 (see detail in
# https://github.com/drkjam/netaddr/issues/2)
return False
# Prior validation partially verify /xx part
# Verify it here
ip_segment = address.split('/')
if (len(ip_segment) <= 1 or
ip_segment[1] == ''):
return False
return True
def get_ip_version(network):
"""Returns the IP version of a network (IPv4 or IPv6).
Raises AddrFormatError if invalid network.
"""
if netaddr.IPNetwork(network).version == 6:
return "IPv6"
elif netaddr.IPNetwork(network).version == 4:
return "IPv4"
def is_neutron():
global _IS_NEUTRON_ATTEMPTED
global _IS_NEUTRON
if _IS_NEUTRON_ATTEMPTED:
return _IS_NEUTRON
try:
# compatibility with Folsom/Grizzly configs
cls_name = 'workloadmgr.network.neutronv2.api.API' # CONF.network_api_class
if cls_name == 'workloadmgr.network.quantumv2.api.API':
cls_name = 'workloadmgr.network.neutronv2.api.API'
_IS_NEUTRON_ATTEMPTED = True
from workloadmgr.network.neutronv2 import api as neutron_api
_IS_NEUTRON = issubclass(importutils.import_class(cls_name),
neutron_api.API)
except ImportError:
_IS_NEUTRON = False
return _IS_NEUTRON
def reset_is_neutron():
global _IS_NEUTRON_ATTEMPTED
global _IS_NEUTRON
_IS_NEUTRON_ATTEMPTED = False
_IS_NEUTRON = False
def move_file(src, dest):
execute('mv', src, dest)
def copy_file(src, dest):
execute('cp', src, dest)
def append_unique(dict, new_item, key="id"):
for item in dict:
if item[key] == new_item[key]:
return
dict.append(new_item)
def get_file_mode(path):
"""This primarily exists to make unit testing easier."""
return stat.S_IMODE(os.stat(path).st_mode)
def get_file_gid(path):
"""This primarily exists to make unit testing easier."""
return os.stat(path).st_gid
class ChunkedFile(object):
"""
something that can iterate over a large file
"""
CHUNKSIZE = 65536
def __init__(self, filepath, update=None):
self.filepath = filepath
self.fp = open(self.filepath, 'rb')
self.update = update
self.uploaded_size_incremental = 0
def _update(self, size):
self.uploaded_size_incremental = self.uploaded_size_incremental + size
if self.update and ((self.uploaded_size_incremental > (
5 * 1024 * 1024)) or (self.tell() == os.SEEK_END)):
object = self.update['function'](
self.update['context'], self.update['id'], {
'uploaded_size_incremental': self.uploaded_size_incremental})
LOG.debug(_("progress_percent: %(progress_percent)s") %
{'progress_percent': object.progress_percent, })
self.uploaded_size_incremental = 0
def __iter__(self):
"""Return an iterator over the file"""
try:
if self.fp:
while True:
chunk = self.fp.read(ChunkedFile.CHUNKSIZE)
if chunk:
self._update(len(chunk))
yield chunk
else:
break
finally:
self.close()
def __len__(self):
return os.path.getsize(self.filepath)
def tell(self):
if self.fp:
return self.fp.tell()
def seek(self, offset, whence=os.SEEK_SET):
if self.fp:
return self.fp.seek(offset, whence)
def read(self, size):
if self.fp:
data = self.fp.read(size)
if data:
self._update(len(data))
return data
def close(self):
"""Close the internal file pointer"""
if self.fp:
self.fp.close()
self.fp = None
def get_instance_restore_options(restore_options, instance_id, platform):
if restore_options and platform in restore_options:
if 'instances' in restore_options[platform]:
for instance in restore_options[platform]['instances']:
if instance['id'] == instance_id:
return instance
return {}
def get_vm_migration_options(migration_options, vm_id, platform):
if migration_options and platform in migration_options:
if 'vms' in migration_options[platform]:
for vm in migration_options[platform]['vms']:
if vm['id'] == vm_id:
return vm
return {}
def last_completed_audit_period(unit=None):
"""This method gives you the most recently *completed* audit period.
arguments:
units: string, one of 'hour', 'day', 'month', 'year'
Periods normally begin at the beginning (UTC) of the
period unit (So a 'day' period begins at midnight UTC,
a 'month' unit on the 1st, a 'year' on Jan, 1)
unit string may be appended with an optional offset
like so: 'day@18' This will begin the period at 18:00
UTC. 'month@15' starts a monthly period on the 15th,
and year@3 begins a yearly one on March 1st.
returns: 2 tuple of datetimes (begin, end)
The begin timestamp of this audit period is the same as the
end of the previous.
"""
if not unit:
unit = CONF.volume_usage_audit_period
offset = 0
if '@' in unit:
unit, offset = unit.split("@", 1)
offset = int(offset)
rightnow = timeutils.utcnow()
if unit not in ('month', 'day', 'year', 'hour'):
raise ValueError('Time period must be hour, day, month or year')
if unit == 'month':
if offset == 0:
offset = 1
end = datetime.datetime(day=offset,
month=rightnow.month,
year=rightnow.year)
if end >= rightnow:
year = rightnow.year
if 1 >= rightnow.month:
year -= 1
month = 12 + (rightnow.month - 1)
else:
month = rightnow.month - 1
end = datetime.datetime(day=offset,
month=month,
year=year)
year = end.year
if 1 >= end.month:
year -= 1
month = 12 + (end.month - 1)
else:
month = end.month - 1
begin = datetime.datetime(day=offset, month=month, year=year)
elif unit == 'year':
if offset == 0:
offset = 1
end = datetime.datetime(day=1, month=offset, year=rightnow.year)
if end >= rightnow:
end = datetime.datetime(day=1,
month=offset,
year=rightnow.year - 1)
begin = datetime.datetime(day=1,
month=offset,
year=rightnow.year - 2)
else:
begin = datetime.datetime(day=1,
month=offset,
year=rightnow.year - 1)
elif unit == 'day':
end = datetime.datetime(hour=offset,
day=rightnow.day,
month=rightnow.month,
year=rightnow.year)
if end >= rightnow:
end = end - datetime.timedelta(days=1)
begin = end - datetime.timedelta(days=1)
elif unit == 'hour':
end = rightnow.replace(minute=offset, second=0, microsecond=0)
if end >= rightnow:
end = end - datetime.timedelta(hours=1)
begin = end - datetime.timedelta(hours=1)
return (begin, end)
def check_ssh_injection(cmd_list):
ssh_injection_pattern = ['`', '$', '|', '||', ';', '&', '&&', '>', '>>',
'<']
# Check whether injection attacks exist
for arg in cmd_list:
arg = arg.strip()
# Check for matching quotes on the ends
is_quoted = re.match('^(?P<quote>[\'"])(?P<quoted>.*)(?P=quote)$', arg)
if is_quoted:
# Check for unescaped quotes within the quoted argument
quoted = is_quoted.group('quoted')
if quoted:
if (re.match('[\'"]', quoted) or
re.search('[^\\\\][\'"]', quoted)):
raise exception.SSHInjectionThreat(command=str(cmd_list))
else:
# We only allow spaces within quoted arguments, and that
# is the only special character allowed within quotes
if len(arg.split()) > 1:
raise exception.SSHInjectionThreat(command=str(cmd_list))
# Second, check whether danger character in command. So the shell
# special operator must be a single argument.
for c in ssh_injection_pattern:
if arg == c:
continue
result = arg.find(c)
if not result == -1:
if result == 0 or not arg[result - 1] == '\\':
raise exception.SSHInjectionThreat(command=cmd_list)
def get_ip_addresses():
# we will configure only br-eth0 and eth1 on the appliance
ip_addresses = []
try:
for ifaceName in interfaces():
ip_addresses += [i['addr'] for i in ifaddresses(
ifaceName).setdefault(AF_INET, [{'addr': 'No IP addr'}])
if i['addr'] not in('No IP addr','127.0.0.1')]
ip_addresses += [i['addr'].strip('%' + ifaceName) for i in ifaddresses(
ifaceName).setdefault(AF_INET6, [{'addr': 'No IP addr'}])
if i['addr'] not in('No IP addr','::1')]
tvault_hostnames = CONF.triliovault_hostnames
for thost in tvault_hostnames.split(','):
if '=' in thost:
ip_address, host = thost.split('=')
else:
ip_address = thost
host = socket.getfqdn()
if host in socket.gethostname() or \
socket.gethostname() in host:
ip_addresses.insert(0, str(ip_address))
except Exception as ex:
LOG.exception(ex)
return list(set(ip_addresses))
def sizeof_fmt(num, suffix='B'):
try:
for unit in ['', 'K', 'M', 'G', 'T', 'P', 'E', 'Z']:
if abs(num) < 1024.0:
return "%3.1f %s%s" % (num, unit, suffix)
num /= 1024.0
return "%.1f%s%s" % (num, 'Yi', suffix)
except Exception as ex:
LOG.exception(ex)
return num
def get_hosts_from_transport_url(url):
try:
hosts = []
for u in url.split(','):
parsed_url = urlparse(u)
parsed_host = parsed_url.hostname
hosts.append(parsed_host)
return hosts
except Exception as ex:
LOG.exception(ex)
raise
def load_key(filename):
"""
Load the secret key
"""
try:
with open(filename, "rb") as key:
return key.read()
except Exception as ex:
LOG.exception(ex)
raise
def encrypt_password(raw_password, filename):
"""
Encrypts the password.
"""
try:
key = load_key(filename)
data = raw_password.encode()
fer_data = Fernet(key)
return fer_data.encrypt(data)
except Exception as ex:
LOG.exception(ex)
raise
def decrypt_password(coded_password, filename):
"""
Decrypt the password
"""
try:
coded_password = bytes(coded_password, 'utf-8')
key = load_key(filename)
fer_data = Fernet(key)
return fer_data.decrypt(coded_password).decode("utf-8")
except Exception as ex:
LOG.exception(ex)
raise
def get_ssl_cert_path(filename=None):
""" return the Cert file path as per unix distribution"""
ssl_cert = [("/etc/ssl/certs/ca-certificates.crt", "/usr/local/share/ca-certificates"),
("/etc/pki/ca-trust/extracted/pem/tls-ca-bundle.pem",
"/etc/pki/ca-trust/source/anchors")]
for path in ssl_cert:
try:
if not filename and os.path.exists(path[0]):
return path[0]
elif os.path.exists(path[1]) and os.path.isdir(path[1]):
if os.path.exists(os.path.join(path[1], filename)):
cert = os.path.join(path[1], filename) + str(uuid.uuid4())
return cert
else:
cert = os.path.join(path[1], filename)
return cert
except Exception as ex:
LOG.exception("Error occurred {}:".format(ex))
raise
def parse_encrypted_image_backing_file(backing_file):
"""
Replaces single quote with double quotes.
It Parses the Backing File and returns json.
"""
try:
json_key_position = backing_file.replace("\'", "\"").find(':') + 1
backing_file_context = backing_file[json_key_position:]
return json.loads(backing_file_context)['file']['filename']
except Exception as ex:
return backing_file
def get_vcenter_service_instance():
url_components = urlparse(CONF.vcenter_migration.vcenter_url)
kwargs = {}
kwargs['host'] = url_components.hostname
if url_components.scheme == "https":
kwargs['port'] = 443
else:
kwargs['port'] = 80
if url_components.port:
kwargs['port'] = url_components.port
kwargs['user'] = CONF.vcenter_migration.vcenter_username
kwargs['password'] = CONF.vcenter_migration.vcenter_password
kwargs['disable_ssl_verification'] = CONF.vcenter_migration.vcenter_nossl
kwargs['vcenter_cert_path'] = CONF.vcenter_migration.vcenter_cert_path
args = argparse.Namespace(**kwargs)
si = pyvmomi_si.connect(args)
return si
def convert_jobschedule_date_tz(jobschedule):
if jobschedule.get('timezone') and jobschedule['timezone'] != 'UTC':
start_date_time_str = jobschedule['start_date'] + ' ' + jobschedule['start_time']
output_format = "%m/%d/%Y %I:%M %p"
try:
datetime.datetime.strptime(start_date_time_str,
"%m/%d/%Y %H:%M").strftime("%m/%d/%Y %I:%M %p")
input_format = "%m/%d/%Y %H:%M"
except:
datetime.datetime.strptime(start_date_time_str,
"%m/%d/%Y %I:%M %p").strftime("%m/%d/%Y %I:%M %p")
input_format = "%m/%d/%Y %I:%M %p"
from_zone = pytz.timezone(jobschedule['timezone'])
local_dt = datetime.datetime.strptime(start_date_time_str, input_format)
local_dt = from_zone.localize(local_dt)
dt_horizon = local_dt.astimezone(pytz.timezone('UTC'))
dt_utc_start_date_time = dt_horizon.strftime(output_format).split()
jobschedule['start_date'] = dt_utc_start_date_time[0]
jobschedule['start_time'] = ' '.join(dt_utc_start_date_time[1:])
jobschedule['timezone'] = 'UTC'
jobschedule.pop('appliance_timezone', None)
return jobschedule
def get_next_15_min_interval(current_time):
current_round_off_time = current_time
# round off the next 15 minutes interval for job scheduler for workload creation
round_off_datetime = current_time + (datetime.datetime.min - current_round_off_time.replace(tzinfo=None)) % datetime.timedelta(minutes=15)
# add 1 day if rounding off required date change
if round_off_datetime.time() == datetime_time(12, 0):
round_off_datetime += datetime.timedelta(days=1)
# set the default start_date and start_time
start_date = round_off_datetime.strftime('%m/%d/%Y')
start_time = round_off_datetime.strftime("%I:%M %p")
return start_date, start_time
def get_vcenter_snapshots_by_name_recursively(snapshots, snapname):
snap_obj = []
for snapshot in snapshots:
if snapshot.name == snapname:
snap_obj.append(snapshot)
else:
snap_obj = snap_obj + get_vcenter_snapshots_by_name_recursively(
snapshot.childSnapshotList, snapname)
return snap_obj
def calculate_difference_in_time(updated_at):
updated_time = updated_at.replace(tzinfo=datetime_timezone.utc)
time_difference = datetime.datetime.now(datetime_timezone.utc) - updated_time
return time_difference