Repository URL to install this package:
|
Version:
0.24.2 ▾
|
# This code is part of Qiskit.
#
# (C) Copyright IBM 2017, 2018.
#
# This code is licensed under the Apache License, Version 2.0. You may
# obtain a copy of this license in the LICENSE.txt file in the root directory
# of this source tree or at http://www.apache.org/licenses/LICENSE-2.0.
#
# Any modifications or derivative works of this code must retain this
# copyright notice, and modified files need to carry a notice indicating
# that they have been altered from the originals.
"""Decorator for using with Qiskit unit tests."""
import collections.abc
import functools
import os
import socket
import sys
from typing import Union, Callable, Type, Iterable
import unittest
from qiskit.utils import wrap_method
from qiskit.utils.deprecation import deprecate_func
from .testing_options import get_test_options
HAS_NET_CONNECTION = None
def _has_connection(hostname, port):
"""Checks if internet connection exists to host via specified port.
If any exception is raised while trying to open a socket this will return
false.
Args:
hostname (str): Hostname to connect to.
port (int): Port to connect to
Returns:
bool: Has connection or not
"""
try:
host = socket.gethostbyname(hostname)
socket.create_connection((host, port), 2).close()
return True
except Exception: # pylint: disable=broad-except
return False
def is_aer_provider_available():
"""Check if the C++ simulator can be instantiated.
Returns:
bool: True if simulator executable is available
"""
# TODO: HACK FROM THE DEPTHS OF DESPAIR AS AER DOES NOT WORK ON MAC
if sys.platform == "darwin":
return False
try:
import qiskit.providers.aer # pylint: disable=unused-import
except ImportError:
return False
return True
def requires_aer_provider(test_item):
"""Decorator that skips test if qiskit aer provider is not available
Args:
test_item (callable): function or class to be decorated.
Returns:
callable: the decorated function.
"""
reason = "Aer provider not found, skipping test"
return unittest.skipIf(not is_aer_provider_available(), reason)(test_item)
def slow_test(func):
"""Decorator that signals that the test takes minutes to run.
Args:
func (callable): test function to be decorated.
Returns:
callable: the decorated function.
"""
@functools.wraps(func)
def _wrapper(*args, **kwargs):
skip_slow = not TEST_OPTIONS["run_slow"]
if skip_slow:
raise unittest.SkipTest("Skipping slow tests")
return func(*args, **kwargs)
return _wrapper
def _get_credentials():
"""Finds the credentials for a specific test and options.
Returns:
Credentials: set of credentials
Raises:
SkipTest: when credentials can't be found
"""
try:
from qiskit.providers.ibmq.credentials import Credentials, discover_credentials
except ImportError as ex:
raise unittest.SkipTest(
"qiskit-ibmq-provider could not be found, "
"and is required for executing online tests. "
'To install, run "pip install qiskit-ibmq-provider" '
"or check your installation."
) from ex
if os.getenv("IBMQ_TOKEN") and os.getenv("IBMQ_URL"):
return Credentials(os.getenv("IBMQ_TOKEN"), os.getenv("IBMQ_URL"))
elif os.getenv("QISKIT_TESTS_USE_CREDENTIALS_FILE"):
# Attempt to read the standard credentials.
discovered_credentials = discover_credentials()
if discovered_credentials:
# Decide which credentials to use for testing.
if len(discovered_credentials) > 1:
raise unittest.SkipTest(
"More than 1 credential set found, use: "
"IBMQ_TOKEN and IBMQ_URL env variables to "
"set credentials explicitly"
)
# Use the first available credentials.
return list(discovered_credentials.values())[0]
raise unittest.SkipTest(
"No IBMQ credentials found for running the test. This is required for running online tests."
)
@deprecate_func(additional_msg="Instead, use ``online_test``", since="0.17.0")
def requires_qe_access(func):
"""Deprecated in favor of `online_test`"""
@functools.wraps(func)
def _wrapper(self, *args, **kwargs):
if TEST_OPTIONS["skip_online"]:
raise unittest.SkipTest("Skipping online tests")
credentials = _get_credentials()
self.using_ibmq_credentials = credentials.is_ibmq()
kwargs.update({"qe_token": credentials.token, "qe_url": credentials.url})
return func(self, *args, **kwargs)
return _wrapper
def online_test(func):
"""Decorator that signals that the test uses the network (and the online API):
It involves:
* determines if the test should be skipped by checking environment
variables.
* if the `USE_ALTERNATE_ENV_CREDENTIALS` environment variable is
set, it reads the credentials from an alternative set of environment
variables.
* if the test is not skipped, it reads `qe_token` and `qe_url` from
`Qconfig.py`, environment variables or qiskitrc.
* if the test is not skipped, it appends `qe_token` and `qe_url` as
arguments to the test function.
Args:
func (callable): test function to be decorated.
Returns:
callable: the decorated function.
"""
@functools.wraps(func)
def _wrapper(self, *args, **kwargs):
# To avoid checking the connection in each test
global HAS_NET_CONNECTION # pylint: disable=global-statement
if TEST_OPTIONS["skip_online"]:
raise unittest.SkipTest("Skipping online tests")
if HAS_NET_CONNECTION is None:
HAS_NET_CONNECTION = _has_connection("qiskit.org", 443)
if not HAS_NET_CONNECTION:
raise unittest.SkipTest("Test requires internet connection.")
credentials = _get_credentials()
self.using_ibmq_credentials = credentials.is_ibmq()
kwargs.update({"qe_token": credentials.token, "qe_url": credentials.url})
return func(self, *args, **kwargs)
return _wrapper
def enforce_subclasses_call(
methods: Union[str, Iterable[str]], attr: str = "_enforce_subclasses_call_cache"
) -> Callable[[Type], Type]:
"""Class decorator which enforces that if any subclasses define on of the ``methods``, they must
call ``super().<method>()`` or face a ``ValueError`` at runtime.
This is unlikely to be useful for concrete test classes, who are not normally subclassed. It
should not be used on user-facing code, because it prevents subclasses from being free to
override parent-class behavior, even when the parent-class behavior is not needed.
This adds behavior to the ``__init__`` and ``__init_subclass__`` methods of the class, in
addition to the named methods of this class and all subclasses. The checks could be averted in
grandchildren if a child class overrides ``__init_subclass__`` without up-calling the decorated
class's method, though this would typically break inheritance principles.
Arguments:
methods:
Names of the methods to add the enforcement to. These do not necessarily need to be
defined in the class body, provided they are somewhere in the method-resolution tree.
attr:
The attribute which will be added to all instances of this class and subclasses, in
order to manage the call enforcement. This can be changed to avoid clashes.
Returns:
A decorator, which returns its input class with the class with the relevant methods modified
to include checks, and injection code in the ``__init_subclass__`` method.
"""
methods = {methods} if isinstance(methods, str) else set(methods)
def initialize_call_memory(self, *_args, **_kwargs):
"""Add the extra attribute used for tracking the method calls."""
setattr(self, attr, set())
def save_call_status(name):
"""Decorator, whose return saves the fact that the top-level method call occurred."""
def out(self, *_args, **_kwargs):
getattr(self, attr).add(name)
return out
def clear_call_status(name):
"""Decorator, whose return clears the call status of the method ``name``. This prepares the
call tracking for the child class's method call."""
def out(self, *_args, **_kwargs):
getattr(self, attr).discard(name)
return out
def enforce_call_occurred(name):
"""Decorator, whose return checks that the top-level method call occurred, and raises
``ValueError`` if not. Concretely, this is an assertion that ``save_call_status`` ran."""
def out(self, *_args, **_kwargs):
cache = getattr(self, attr)
if name not in cache:
classname = self.__name__ if isinstance(self, type) else type(self).__name__
raise ValueError(
f"Parent '{name}' method was not called by '{classname}.{name}'."
f" Ensure you have put in calls to 'super().{name}()'."
)
return out
def wrap_subclass_methods(cls):
"""Wrap all the ``methods`` of ``cls`` with the call-tracking assertions that the top-level
versions of the methods were called (likely via ``super()``)."""
# Only wrap methods who are directly defined in this class; if we're resolving to a method
# higher up the food chain, then it will already have been wrapped.
for name in set(cls.__dict__) & methods:
wrap_method(
cls,
name,
before=clear_call_status(name),
after=enforce_call_occurred(name),
)
def decorator(cls):
# Add a class-level memory on, so class methods will work as well. Instances will override
# this on instantiation, to keep the "namespace" of class- and instance-methods separate.
initialize_call_memory(cls)
# Do the extra bits after the main body of __init__ so we can check we're not overwriting
# anything, and after __init_subclass__ in case the decorated class wants to influence the
# creation of the subclass's methods before we get to them.
wrap_method(cls, "__init__", after=initialize_call_memory)
for name in methods:
wrap_method(cls, name, before=save_call_status(name))
wrap_method(cls, "__init_subclass__", after=wrap_subclass_methods)
return cls
return decorator
class _TestOptions(collections.abc.Mapping):
"""Lazy-loading view onto the test options retrieved from the environment."""
__slots__ = ("_options",)
def __init__(self):
self._options = None
def _load(self):
if self._options is None:
self._options = get_test_options()
def __getitem__(self, key):
self._load()
return self._options[key]
def __iter__(self):
self._load()
return iter(self._options)
def __len__(self):
self._load()
return len(self._options)
TEST_OPTIONS = _TestOptions()