Why Gemfury? Push, build, and install  RubyGems npm packages Python packages Maven artifacts PHP packages Go Modules Debian packages RPM packages NuGet packages

Repository URL to install this package:

Details    
qiskit-ibm-provider / api / session.py
Size: Mime:
# This code is part of Qiskit.
#
# (C) Copyright IBM 2021.
#
# This code is licensed under the Apache License, Version 2.0. You may
# obtain a copy of this license in the LICENSE.txt file in the root directory
# of this source tree or at http://www.apache.org/licenses/LICENSE-2.0.
#
# Any modifications or derivative works of this code must retain this
# copyright notice, and modified files need to carry a notice indicating
# that they have been altered from the originals.

"""Session customized for IBM Quantum access."""

import os
import re
import logging
from typing import Dict, Optional, Any, Tuple, Union
import pkg_resources

from requests import Session, RequestException, Response
from requests.adapters import HTTPAdapter
from requests.auth import AuthBase
from urllib3.util.retry import Retry

from qiskit_ibm_provider.utils.utils import filter_data
from .exceptions import RequestsApiError
from ..version import __version__ as ibm_provider_version

STATUS_FORCELIST = (
    500,  # General server error
    502,  # Bad Gateway
    503,  # Service Unavailable
    504,  # Gateway Timeout
    520,  # Cloudflare general error
    521,  # Cloudflare web server is down
    522,  # Cloudflare connection timeout
    524,  # Cloudflare Timeout
)
CUSTOM_HEADER_ENV_VAR = "QISKIT_IBM_CUSTOM_CLIENT_APP_HEADER"
logger = logging.getLogger(__name__)
# Regex used to match the `/devices` endpoint, capturing the device name as group(2).
# The number of letters for group(2) must be greater than 1, so it does not match
# the `/devices/v/1` endpoint.
# Capture groups: (/devices/)(<device_name>)(</optional rest of the url>)
RE_DEVICES_ENDPOINT = re.compile(r"^(.*/devices/)([^/}]{2,})(.*)$", re.IGNORECASE)


def _get_client_header() -> str:
    """Return the client version."""
    try:
        client_header = "qiskit/" + pkg_resources.get_distribution("qiskit").version
        return client_header
    except Exception:  # pylint: disable=broad-except
        pass

    qiskit_pkgs = ["qiskit-terra"]
    pkg_versions = {"qiskit-ibm-provider": ibm_provider_version}
    for pkg_name in qiskit_pkgs:
        try:
            pkg_versions[pkg_name] = pkg_resources.get_distribution(pkg_name).version
        except Exception:  # pylint: disable=broad-except
            pass
    return ",".join(pkg_versions.keys()) + "/" + ",".join(pkg_versions.values())


CLIENT_APPLICATION = _get_client_header()


class PostForcelistRetry(Retry):
    """Custom ``urllib3.Retry`` class that performs retry on ``POST`` errors in the force list.

    Retrying of ``POST`` requests are allowed *only* when the status code
    returned is on the ``STATUS_FORCELIST``. While ``POST``
    requests are recommended not to be retried due to not being idempotent,
    the IBM Quantum API guarantees that retrying on specific 5xx errors is safe.
    """

    def increment(  # type: ignore[no-untyped-def]
        self,
        method=None,
        url=None,
        response=None,
        error=None,
        _pool=None,
        _stacktrace=None,
    ):
        """Overwrites parent class increment method for logging."""
        if logger.getEffectiveLevel() is logging.DEBUG:
            status = data = headers = None
            if response:
                status = response.status
                data = response.data
                headers = response.headers
            logger.debug(
                "Retrying method=%s, url=%s, status=%s, error=%s, data=%s, headers=%s",
                method,
                url,
                status,
                error,
                data,
                headers,
            )
        return super().increment(
            method=method,
            url=url,
            response=response,
            error=error,
            _pool=_pool,
            _stacktrace=_stacktrace,
        )

    def is_retry(
        self, method: str, status_code: int, has_retry_after: bool = False
    ) -> bool:
        """Indicate whether the request should be retried.

        Args:
            method: Request method.
            status_code: Status code.
            has_retry_after: Whether retry has been done before.

        Returns:
            ``True`` if the request should be retried, ``False`` otherwise.
        """
        if method.upper() == "POST" and status_code in self.status_forcelist:
            return True

        return super().is_retry(method, status_code, has_retry_after)


class RetrySession(Session):
    """Custom session with retry and handling of specific parameters.

    This is a child class of ``requests.Session``. It has its own retry
    policy and handles IBM Quantum specific parameters.
    """

    def __init__(
        self,
        base_url: str,
        retries_total: int = 5,
        retries_connect: int = 3,
        backoff_factor: float = 0.5,
        verify: bool = True,
        proxies: Optional[Dict[str, str]] = None,
        auth: Optional[AuthBase] = None,
        timeout: Tuple[float, Union[float, None]] = (5.0, None),
    ) -> None:
        """RetrySession constructor.

        Args:
            base_url: Base URL for the session's requests.
            retries_total: Number of total retries for the requests.
            retries_connect: Number of connect retries for the requests.
            backoff_factor: Backoff factor between retry attempts.
            verify: Whether to enable SSL verification.
            proxies: Proxy URLs mapped by protocol.
            auth: Authentication handler.
            timeout: Timeout for the requests, in the form of (connection_timeout,
                total_timeout).
        """
        super().__init__()

        self.base_url = base_url

        self._initialize_retry(retries_total, retries_connect, backoff_factor)
        self._initialize_session_parameters(verify, proxies or {}, auth)
        self._timeout = timeout

    def __del__(self) -> None:
        """RetrySession destructor. Closes the session."""
        try:
            self.close()
        except Exception:  # pylint: disable=broad-except
            # ignore errors that may happen during cleanup
            pass

    def _initialize_retry(
        self, retries_total: int, retries_connect: int, backoff_factor: float
    ) -> None:
        """Set the session retry policy.

        Args:
            retries_total: Number of total retries for the requests.
            retries_connect: Number of connect retries for the requests.
            backoff_factor: Backoff factor between retry attempts.
        """
        retry = PostForcelistRetry(
            total=retries_total,
            connect=retries_connect,
            backoff_factor=backoff_factor,
            status_forcelist=STATUS_FORCELIST,
        )

        retry_adapter = HTTPAdapter(max_retries=retry)
        self.mount("http://", retry_adapter)
        self.mount("https://", retry_adapter)

    def _initialize_session_parameters(
        self, verify: bool, proxies: Dict[str, str], auth: Optional[AuthBase] = None
    ) -> None:
        """Set the session parameters and attributes.

        Args:
            verify: Whether to enable SSL verification.
            proxies: Proxy URLs mapped by protocol.
            auth: Authentication handler.
        """
        client_app_header = CLIENT_APPLICATION

        # Append custom header to the end if specified
        custom_header = os.getenv(CUSTOM_HEADER_ENV_VAR)
        if custom_header:
            client_app_header += "/" + custom_header

        self.headers.update({"X-Qx-Client-Application": client_app_header})

        self.auth = auth
        self.proxies = proxies or {}
        self.verify = verify

    def request(  # type: ignore[override]
        self, method: str, url: str, bare: bool = False, **kwargs: Any
    ) -> Response:
        """Construct, prepare, and send a ``Request``.

        If `bare` is not specified, prepend the base URL to the input `url`.
        Timeout value is passed if proxies are not used.

        Args:
            method: Method for the new request (e.g. ``POST``).
            url: URL for the new request.
            bare: If ``True``, do not send IBM Quantum specific information
                (such as access token) in the request or modify the input `url`.
            **kwargs: Additional arguments for the request.

        Returns:
            Response object.

        Raises:
            RequestsApiError: If the request failed.
        """
        # pylint: disable=arguments-differ
        if bare:
            final_url = url
            # Explicitly pass `None` as the `access_token` param, disabling it.
            params = kwargs.get("params", {})
            params.update({"access_token": None})
            kwargs.update({"params": params})
        else:
            final_url = self.base_url + url

        # Add a timeout to the connection for non-proxy connections.
        if not self.proxies and "timeout" not in kwargs:
            kwargs.update({"timeout": self._timeout})

        headers = self.headers.copy()
        headers.update(kwargs.pop("headers", {}))

        try:
            self._log_request_info(final_url, method, kwargs)
            response = super().request(method, final_url, headers=headers, **kwargs)
            response.raise_for_status()
        except RequestException as ex:
            # Wrap the requests exceptions into a IBM Q custom one, for
            # compatibility.
            message = str(ex)
            status_code = -1
            if ex.response is not None:
                status_code = ex.response.status_code
                try:
                    error_json = ex.response.json()["error"]
                    message += ". {}, Error code: {}.".format(
                        error_json["message"], error_json["code"]
                    )
                    logger.debug(
                        "Response uber-trace-id: %s",
                        ex.response.headers["uber-trace-id"],
                    )
                except Exception:  # pylint: disable=broad-except
                    # the response did not contain the expected json.
                    message += f". {ex.response.text}"

            raise RequestsApiError(message, status_code) from ex

        return response

    def _log_request_info(
        self, url: str, method: str, request_data: Dict[str, Any]
    ) -> None:
        """Log the request data, filtering out specific information.

        Note:
            The string ``...`` is used to denote information that has been filtered out
            from the request, within the url and request data. Currently, the backend name
            is filtered out from endpoint URLs, using a regex to capture the name, and from
            the data sent to the server when submitting a job.

            The request data is only logged for the following URLs, since they contain useful
            information: ``/Jobs`` (POST), ``/Jobs/status`` (GET),
            and ``/devices/<device_name>/properties`` (GET).

        Args:
            url: URL for the new request.
            method: Method for the new request (e.g. ``POST``)
            request_data:Additional arguments for the request.

        Raises:
            Exception: If there was an error logging the request information.
        """
        # Replace the device name in the URL with `...` if it matches, otherwise leave it as is.
        filtered_url = re.sub(RE_DEVICES_ENDPOINT, "\\1...\\3", url)

        if self._is_worth_logging(filtered_url):
            try:
                if logger.getEffectiveLevel() is logging.DEBUG:
                    request_data_to_log = ""
                    if filtered_url in ("/devices/.../properties", "/Jobs"):
                        # Log filtered request data for these endpoints.
                        request_data_to_log = "Request Data: {}.".format(
                            filter_data(request_data)
                        )
                    logger.debug(
                        "Endpoint: %s. Method: %s. %s",
                        filtered_url,
                        method.upper(),
                        request_data_to_log,
                    )
            except Exception as ex:  # pylint: disable=broad-except
                # Catch general exception so as not to disturb the program if filtering fails.
                logger.info(
                    "Filtering failed when logging request information: %s", str(ex)
                )

    def _is_worth_logging(self, endpoint_url: str) -> bool:
        """Returns whether the endpoint URL should be logged.

        The checks in place help filter out endpoint URL logs that would add noise
        and no helpful information.

        Args:
            endpoint_url: The endpoint URL that will be logged.

        Returns:
            Whether the endpoint URL should be logged.
        """
        if endpoint_url.endswith(
            (
                "/queue/status",
                "/devices/v/1",
                "/Jobs/status",
                "/.../properties",
                "/.../defaults",
            )
        ):
            return False
        if endpoint_url.startswith(("/users", "/version")):
            return False
        if endpoint_url == "/Network":
            return False
        if "objectstorage" in endpoint_url:
            return False
        if "bookings" in endpoint_url:
            return False

        return True

    def __getstate__(self) -> Dict:
        """Overwrite Session's getstate to include all attributes."""
        state = super().__getstate__()  # type: ignore
        state.update(self.__dict__)
        return state