Why Gemfury? Push, build, and install  RubyGems npm packages Python packages Maven artifacts PHP packages Go Modules Debian packages RPM packages NuGet packages

Repository URL to install this package:

Details    
aiobotocore / aiobotocore / utils.py
Size: Mime:
import asyncio
import inspect
import json
import logging

import aiohttp.client_exceptions
import botocore.awsrequest
from botocore.exceptions import (
    InvalidIMDSEndpointError,
    MetadataRetrievalError,
)
from botocore.utils import (
    DEFAULT_METADATA_SERVICE_TIMEOUT,
    METADATA_BASE_URL,
    BadIMDSRequestError,
    ClientError,
    ContainerMetadataFetcher,
    HTTPClientError,
    IMDSFetcher,
    IMDSRegionProvider,
    InstanceMetadataFetcher,
    InstanceMetadataRegionFetcher,
    ReadTimeoutError,
    S3RegionRedirector,
    get_environ_proxies,
    os,
    resolve_imds_endpoint_mode,
)

import aiobotocore.httpsession
from aiobotocore._helpers import asynccontextmanager

logger = logging.getLogger(__name__)
RETRYABLE_HTTP_ERRORS = (
    aiohttp.client_exceptions.ClientError,
    asyncio.TimeoutError,
)


class _RefCountedSession(aiobotocore.httpsession.AIOHTTPSession):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.__ref_count = 0
        self.__lock = None

    @asynccontextmanager
    async def acquire(self):
        if not self.__lock:
            self.__lock = asyncio.Lock()

        # ensure we have a session
        async with self.__lock:
            self.__ref_count += 1

            try:
                if self.__ref_count == 1:
                    await self.__aenter__()
            except BaseException:
                self.__ref_count -= 1
                raise

        try:
            yield self
        finally:
            async with self.__lock:
                if self.__ref_count == 1:
                    await self.__aexit__(None, None, None)

                self.__ref_count -= 1


class AioIMDSFetcher(IMDSFetcher):
    def __init__(
        self,
        timeout=DEFAULT_METADATA_SERVICE_TIMEOUT,  # noqa: E501, lgtm [py/missing-call-to-init]
        num_attempts=1,
        base_url=METADATA_BASE_URL,
        env=None,
        user_agent=None,
        config=None,
        session=None,
    ):
        self._timeout = timeout
        self._num_attempts = num_attempts
        if config is None:
            config = {}
        self._base_url = self._select_base_url(base_url, config)
        self._config = config

        if env is None:
            env = os.environ.copy()
        self._disabled = env.get('AWS_EC2_METADATA_DISABLED', 'false').lower()
        self._disabled = self._disabled == 'true'
        self._user_agent = user_agent

        self._session = session or _RefCountedSession(
            timeout=self._timeout,
            proxies=get_environ_proxies(self._base_url),
        )

    async def _fetch_metadata_token(self):
        self._assert_enabled()
        url = self._construct_url(self._TOKEN_PATH)
        headers = {
            'x-aws-ec2-metadata-token-ttl-seconds': self._TOKEN_TTL,
        }
        self._add_user_agent(headers)

        request = botocore.awsrequest.AWSRequest(
            method='PUT', url=url, headers=headers
        )

        async with self._session.acquire() as session:
            for i in range(self._num_attempts):
                try:
                    response = await session.send(request.prepare())
                    if response.status_code == 200:
                        return await response.text
                    elif response.status_code in (404, 403, 405):
                        return None
                    elif response.status_code in (400,):
                        raise BadIMDSRequestError(request)
                except ReadTimeoutError:
                    return None
                except RETRYABLE_HTTP_ERRORS as e:
                    logger.debug(
                        "Caught retryable HTTP exception while making metadata "
                        "service request to %s: %s",
                        url,
                        e,
                        exc_info=True,
                    )
                except HTTPClientError as e:
                    error = e.kwargs.get('error')
                    if (
                        error
                        and getattr(error, 'errno', None) == 8
                        or str(getattr(error, 'os_error', None))
                        == 'Domain name not found'
                    ):  # threaded vs async resolver
                        raise InvalidIMDSEndpointError(endpoint=url, error=e)
                    else:
                        raise

        return None

    async def _get_request(self, url_path, retry_func, token=None):
        self._assert_enabled()
        if retry_func is None:
            retry_func = self._default_retry
        url = self._construct_url(url_path)
        headers = {}
        if token is not None:
            headers['x-aws-ec2-metadata-token'] = token
        self._add_user_agent(headers)

        async with self._session.acquire() as session:
            for i in range(self._num_attempts):
                try:
                    request = botocore.awsrequest.AWSRequest(
                        method='GET', url=url, headers=headers
                    )
                    response = await session.send(request.prepare())
                    should_retry = retry_func(response)
                    if inspect.isawaitable(should_retry):
                        should_retry = await should_retry

                    if not should_retry:
                        return response
                except RETRYABLE_HTTP_ERRORS as e:
                    logger.debug(
                        "Caught retryable HTTP exception while making metadata "
                        "service request to %s: %s",
                        url,
                        e,
                        exc_info=True,
                    )
        raise self._RETRIES_EXCEEDED_ERROR_CLS()

    async def _default_retry(self, response):
        return await self._is_non_ok_response(
            response
        ) or await self._is_empty(response)

    async def _is_non_ok_response(self, response):
        if response.status_code != 200:
            await self._log_imds_response(response, 'non-200', log_body=True)
            return True
        return False

    async def _is_empty(self, response):
        if not await response.content:
            await self._log_imds_response(response, 'no body', log_body=True)
            return True
        return False

    async def _log_imds_response(
        self, response, reason_to_log, log_body=False
    ):
        statement = (
            "Metadata service returned %s response "
            "with status code of %s for url: %s"
        )
        logger_args = [reason_to_log, response.status_code, response.url]
        if log_body:
            statement += ", content body: %s"
            logger_args.append(await response.content)
        logger.debug(statement, *logger_args)


class AioInstanceMetadataFetcher(AioIMDSFetcher, InstanceMetadataFetcher):
    async def retrieve_iam_role_credentials(self):
        try:
            token = await self._fetch_metadata_token()
            role_name = await self._get_iam_role(token)
            credentials = await self._get_credentials(role_name, token)
            if self._contains_all_credential_fields(credentials):
                credentials = {
                    'role_name': role_name,
                    'access_key': credentials['AccessKeyId'],
                    'secret_key': credentials['SecretAccessKey'],
                    'token': credentials['Token'],
                    'expiry_time': credentials['Expiration'],
                }
                self._evaluate_expiration(credentials)
                return credentials
            else:
                if 'Code' in credentials and 'Message' in credentials:
                    logger.debug(
                        'Error response received when retrieving'
                        'credentials: %s.',
                        credentials,
                    )
                return {}
        except self._RETRIES_EXCEEDED_ERROR_CLS:
            logger.debug(
                "Max number of attempts exceeded (%s) when "
                "attempting to retrieve data from metadata service.",
                self._num_attempts,
            )
        except BadIMDSRequestError as e:
            logger.debug("Bad IMDS request: %s", e.request)
        return {}

    async def _get_iam_role(self, token=None):
        return await (
            await self._get_request(
                url_path=self._URL_PATH,
                retry_func=self._needs_retry_for_role_name,
                token=token,
            )
        ).text

    async def _get_credentials(self, role_name, token=None):
        r = await self._get_request(
            url_path=self._URL_PATH + role_name,
            retry_func=self._needs_retry_for_credentials,
            token=token,
        )
        return json.loads(await r.text)

    async def _is_invalid_json(self, response):
        try:
            json.loads(await response.text)
            return False
        except ValueError:
            await self._log_imds_response(response, 'invalid json')
            return True

    async def _needs_retry_for_role_name(self, response):
        return await self._is_non_ok_response(
            response
        ) or await self._is_empty(response)

    async def _needs_retry_for_credentials(self, response):
        return (
            await self._is_non_ok_response(response)
            or await self._is_empty(response)
            or await self._is_invalid_json(response)
        )


class AioIMDSRegionProvider(IMDSRegionProvider):
    async def provide(self):
        """Provide the region value from IMDS."""
        instance_region = await self._get_instance_metadata_region()
        return instance_region

    async def _get_instance_metadata_region(self):
        fetcher = self._get_fetcher()
        region = await fetcher.retrieve_region()
        return region

    def _create_fetcher(self):
        metadata_timeout = self._session.get_config_variable(
            'metadata_service_timeout'
        )
        metadata_num_attempts = self._session.get_config_variable(
            'metadata_service_num_attempts'
        )
        imds_config = {
            'ec2_metadata_service_endpoint': self._session.get_config_variable(
                'ec2_metadata_service_endpoint'
            ),
            'ec2_metadata_service_endpoint_mode': resolve_imds_endpoint_mode(
                self._session
            ),
        }
        fetcher = AioInstanceMetadataRegionFetcher(
            timeout=metadata_timeout,
            num_attempts=metadata_num_attempts,
            env=self._environ,
            user_agent=self._session.user_agent(),
            config=imds_config,
        )
        return fetcher


class AioInstanceMetadataRegionFetcher(
    AioIMDSFetcher, InstanceMetadataRegionFetcher
):
    async def retrieve_region(self):
        try:
            region = await self._get_region()
            return region
        except self._RETRIES_EXCEEDED_ERROR_CLS:
            logger.debug(
                "Max number of attempts exceeded (%s) when "
                "attempting to retrieve data from metadata service.",
                self._num_attempts,
            )
        return None

    async def _get_region(self):
        token = await self._fetch_metadata_token()
        response = await self._get_request(
            url_path=self._URL_PATH,
            retry_func=self._default_retry,
            token=token,
        )
        availability_zone = await response.text
        region = availability_zone[:-1]
        return region


class AioS3RegionRedirector(S3RegionRedirector):
    async def redirect_from_error(
        self, request_dict, response, operation, **kwargs
    ):
        if response is None:
            # This could be none if there was a ConnectionError or other
            # transport error.
            return

        if self._is_s3_accesspoint(request_dict.get('context', {})):
            logger.debug(
                'S3 request was previously to an accesspoint, not redirecting.'
            )
            return

        if request_dict.get('context', {}).get('s3_redirected'):
            logger.debug(
                'S3 request was previously redirected, not redirecting.'
            )
            return

        error = response[1].get('Error', {})
        error_code = error.get('Code')
        response_metadata = response[1].get('ResponseMetadata', {})

        # We have to account for 400 responses because
        # if we sign a Head* request with the wrong region,
        # we'll get a 400 Bad Request but we won't get a
        # body saying it's an "AuthorizationHeaderMalformed".
        is_special_head_object = (
            error_code in ('301', '400') and operation.name == 'HeadObject'
        )
        is_special_head_bucket = (
            error_code in ('301', '400')
            and operation.name == 'HeadBucket'
            and 'x-amz-bucket-region'
            in response_metadata.get('HTTPHeaders', {})
        )
        is_wrong_signing_region = (
            error_code == 'AuthorizationHeaderMalformed' and 'Region' in error
        )
        is_redirect_status = response[0] is not None and response[
            0
        ].status_code in (301, 302, 307)
        is_permanent_redirect = error_code == 'PermanentRedirect'
        if not any(
            [
                is_special_head_object,
                is_wrong_signing_region,
                is_permanent_redirect,
                is_special_head_bucket,
                is_redirect_status,
            ]
        ):
            return

        bucket = request_dict['context']['signing']['bucket']
        client_region = request_dict['context'].get('client_region')
        new_region = await self.get_bucket_region(bucket, response)

        if new_region is None:
            logger.debug(
                "S3 client configured for region %s but the bucket %s is not "
                "in that region and the proper region could not be "
                "automatically determined." % (client_region, bucket)
            )
            return

        logger.debug(
            "S3 client configured for region %s but the bucket %s is in region"
            " %s; Please configure the proper region to avoid multiple "
            "unnecessary redirects and signing attempts."
            % (client_region, bucket, new_region)
        )
        endpoint = self._endpoint_resolver.resolve('s3', new_region)
        endpoint = endpoint['endpoint_url']

        signing_context = {
            'region': new_region,
            'bucket': bucket,
            'endpoint': endpoint,
        }
        request_dict['context']['signing'] = signing_context

        self._cache[bucket] = signing_context
        self.set_request_url(request_dict, request_dict['context'])

        request_dict['context']['s3_redirected'] = True

        # Return 0 so it doesn't wait to retry
        return 0

    async def get_bucket_region(self, bucket, response):
        # First try to source the region from the headers.
        service_response = response[1]
        response_headers = service_response['ResponseMetadata']['HTTPHeaders']
        if 'x-amz-bucket-region' in response_headers:
            return response_headers['x-amz-bucket-region']

        # Next, check the error body
        region = service_response.get('Error', {}).get('Region', None)
        if region is not None:
            return region

        # Finally, HEAD the bucket. No other choice sadly.
        try:
            response = await self._client.head_bucket(Bucket=bucket)
            headers = response['ResponseMetadata']['HTTPHeaders']
        except ClientError as e:
            headers = e.response['ResponseMetadata']['HTTPHeaders']

        region = headers.get('x-amz-bucket-region', None)
        return region


class AioContainerMetadataFetcher(ContainerMetadataFetcher):
    def __init__(
        self, session=None, sleep=asyncio.sleep
    ):  # noqa: E501, lgtm [py/missing-call-to-init]
        if session is None:
            session = _RefCountedSession(timeout=self.TIMEOUT_SECONDS)
        self._session = session
        self._sleep = sleep

    async def retrieve_full_uri(self, full_url, headers=None):
        self._validate_allowed_url(full_url)
        return await self._retrieve_credentials(full_url, headers)

    async def retrieve_uri(self, relative_uri):
        """Retrieve JSON metadata from ECS metadata.

        :type relative_uri: str
        :param relative_uri: A relative URI, e.g "/foo/bar?id=123"

        :return: The parsed JSON response.

        """
        full_url = self.full_url(relative_uri)
        return await self._retrieve_credentials(full_url)

    async def _retrieve_credentials(self, full_url, extra_headers=None):
        headers = {'Accept': 'application/json'}
        if extra_headers is not None:
            headers.update(extra_headers)
        attempts = 0
        while True:
            try:
                return await self._get_response(
                    full_url, headers, self.TIMEOUT_SECONDS
                )
            except MetadataRetrievalError as e:
                logger.debug(
                    "Received error when attempting to retrieve "
                    "container metadata: %s",
                    e,
                    exc_info=True,
                )
                await self._sleep(self.SLEEP_TIME)
                attempts += 1
                if attempts >= self.RETRY_ATTEMPTS:
                    raise

    async def _get_response(self, full_url, headers, timeout):
        try:
            async with self._session.acquire() as session:
                AWSRequest = botocore.awsrequest.AWSRequest
                request = AWSRequest(
                    method='GET', url=full_url, headers=headers
                )
                response = await session.send(request.prepare())
                response_text = (await response.content).decode('utf-8')

                if response.status_code != 200:
                    raise MetadataRetrievalError(
                        error_msg=(
                            "Received non 200 response (%s) from ECS metadata: %s"
                        )
                        % (response.status_code, response_text)
                    )
                try:
                    return json.loads(response_text)
                except ValueError:
                    error_msg = "Unable to parse JSON returned from ECS metadata services"
                    logger.debug('%s:%s', error_msg, response_text)
                    raise MetadataRetrievalError(error_msg=error_msg)

        except RETRYABLE_HTTP_ERRORS as e:
            error_msg = (
                "Received error when attempting to retrieve "
                "ECS metadata: %s" % e
            )
            raise MetadataRetrievalError(error_msg=error_msg)