Why Gemfury? Push, build, and install  RubyGems npm packages Python packages Maven artifacts PHP packages Go Modules Bower components Debian packages RPM packages NuGet packages

edgify / grpcio   python

Repository URL to install this package:

Version: 1.29.0 

/ experimental / aio / _call.py

# Copyright 2019 gRPC authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Invocation-side implementation of gRPC Asyncio Python."""

import asyncio
import enum
import inspect
import logging
from functools import partial
from typing import AsyncIterable, Optional, Tuple

import grpc
from grpc import _common
from grpc._cython import cygrpc

from . import _base_call
from ._typing import (DeserializingFunction, DoneCallbackType, MetadataType,
                      MetadatumType, RequestIterableType, RequestType,
                      ResponseType, SerializingFunction)

__all__ = 'AioRpcError', 'Call', 'UnaryUnaryCall', 'UnaryStreamCall'

_LOCAL_CANCELLATION_DETAILS = 'Locally cancelled by application!'
_GC_CANCELLATION_DETAILS = 'Cancelled upon garbage collection!'
_RPC_ALREADY_FINISHED_DETAILS = 'RPC already finished.'
_RPC_HALF_CLOSED_DETAILS = 'RPC is half closed after calling "done_writing".'

_OK_CALL_REPRESENTATION = ('<{} of RPC that terminated with:\n'
                           '\tstatus = {}\n'
                           '\tdetails = "{}"\n'
                           '>')

_NON_OK_CALL_REPRESENTATION = ('<{} of RPC that terminated with:\n'
                               '\tstatus = {}\n'
                               '\tdetails = "{}"\n'
                               '\tdebug_error_string = "{}"\n'
                               '>')

_LOGGER = logging.getLogger(__name__)


class AioRpcError(grpc.RpcError):
    """An implementation of RpcError to be used by the asynchronous API.

    Raised RpcError is a snapshot of the final status of the RPC, values are
    determined. Hence, its methods no longer needs to be coroutines.
    """

    # TODO(https://github.com/grpc/grpc/issues/20144) Metadata
    # type returned by `initial_metadata` and `trailing_metadata`
    # and also taken in the constructor needs to be revisit and make
    # it more specific.

    _code: grpc.StatusCode
    _details: Optional[str]
    _initial_metadata: Optional[MetadataType]
    _trailing_metadata: Optional[MetadataType]
    _debug_error_string: Optional[str]

    def __init__(self,
                 code: grpc.StatusCode,
                 details: Optional[str] = None,
                 initial_metadata: Optional[MetadataType] = None,
                 trailing_metadata: Optional[MetadataType] = None,
                 debug_error_string: Optional[str] = None) -> None:
        """Constructor.

        Args:
          code: The status code with which the RPC has been finalized.
          details: Optional details explaining the reason of the error.
          initial_metadata: Optional initial metadata that could be sent by the
            Server.
          trailing_metadata: Optional metadata that could be sent by the Server.
        """

        super().__init__(self)
        self._code = code
        self._details = details
        self._initial_metadata = initial_metadata
        self._trailing_metadata = trailing_metadata
        self._debug_error_string = debug_error_string

    def code(self) -> grpc.StatusCode:
        """Accesses the status code sent by the server.

        Returns:
          The `grpc.StatusCode` status code.
        """
        return self._code

    def details(self) -> Optional[str]:
        """Accesses the details sent by the server.

        Returns:
          The description of the error.
        """
        return self._details

    def initial_metadata(self) -> Optional[MetadataType]:
        """Accesses the initial metadata sent by the server.

        Returns:
          The initial metadata received.
        """
        return self._initial_metadata

    def trailing_metadata(self) -> Optional[MetadataType]:
        """Accesses the trailing metadata sent by the server.

        Returns:
          The trailing metadata received.
        """
        return self._trailing_metadata

    def debug_error_string(self) -> str:
        """Accesses the debug error string sent by the server.

        Returns:
          The debug error string received.
        """
        return self._debug_error_string

    def _repr(self) -> str:
        """Assembles the error string for the RPC error."""
        return _NON_OK_CALL_REPRESENTATION.format(self.__class__.__name__,
                                                  self._code, self._details,
                                                  self._debug_error_string)

    def __repr__(self) -> str:
        return self._repr()

    def __str__(self) -> str:
        return self._repr()


def _create_rpc_error(initial_metadata: Optional[MetadataType],
                      status: cygrpc.AioRpcStatus) -> AioRpcError:
    return AioRpcError(
        _common.CYGRPC_STATUS_CODE_TO_STATUS_CODE[status.code()],
        status.details(),
        initial_metadata,
        status.trailing_metadata(),
        status.debug_error_string(),
    )


class Call:
    """Base implementation of client RPC Call object.

    Implements logic around final status, metadata and cancellation.
    """
    _loop: asyncio.AbstractEventLoop
    _code: grpc.StatusCode
    _cython_call: cygrpc._AioCall
    _metadata: Tuple[MetadatumType]
    _request_serializer: SerializingFunction
    _response_deserializer: DeserializingFunction

    def __init__(self, cython_call: cygrpc._AioCall, metadata: MetadataType,
                 request_serializer: SerializingFunction,
                 response_deserializer: DeserializingFunction,
                 loop: asyncio.AbstractEventLoop) -> None:
        self._loop = loop
        self._cython_call = cython_call
        self._metadata = tuple(metadata)
        self._request_serializer = request_serializer
        self._response_deserializer = response_deserializer

    def __del__(self) -> None:
        # The '_cython_call' object might be destructed before Call object
        if hasattr(self, '_cython_call'):
            if not self._cython_call.done():
                self._cancel(_GC_CANCELLATION_DETAILS)

    def cancelled(self) -> bool:
        return self._cython_call.cancelled()

    def _cancel(self, details: str) -> bool:
        """Forwards the application cancellation reasoning."""
        if not self._cython_call.done():
            self._cython_call.cancel(details)
            return True
        else:
            return False

    def cancel(self) -> bool:
        return self._cancel(_LOCAL_CANCELLATION_DETAILS)

    def done(self) -> bool:
        return self._cython_call.done()

    def add_done_callback(self, callback: DoneCallbackType) -> None:
        cb = partial(callback, self)
        self._cython_call.add_done_callback(cb)

    def time_remaining(self) -> Optional[float]:
        return self._cython_call.time_remaining()

    async def initial_metadata(self) -> MetadataType:
        return await self._cython_call.initial_metadata()

    async def trailing_metadata(self) -> MetadataType:
        return (await self._cython_call.status()).trailing_metadata()

    async def code(self) -> grpc.StatusCode:
        cygrpc_code = (await self._cython_call.status()).code()
        return _common.CYGRPC_STATUS_CODE_TO_STATUS_CODE[cygrpc_code]

    async def details(self) -> str:
        return (await self._cython_call.status()).details()

    async def debug_error_string(self) -> str:
        return (await self._cython_call.status()).debug_error_string()

    async def _raise_for_status(self) -> None:
        if self._cython_call.is_locally_cancelled():
            raise asyncio.CancelledError()
        code = await self.code()
        if code != grpc.StatusCode.OK:
            raise _create_rpc_error(await self.initial_metadata(), await
                                    self._cython_call.status())

    def _repr(self) -> str:
        return repr(self._cython_call)

    def __repr__(self) -> str:
        return self._repr()

    def __str__(self) -> str:
        return self._repr()


class _APIStyle(enum.IntEnum):
    UNKNOWN = 0
    ASYNC_GENERATOR = 1
    READER_WRITER = 2


class _UnaryResponseMixin(Call):
    _call_response: asyncio.Task

    def _init_unary_response_mixin(self, response_task: asyncio.Task):
        self._call_response = response_task

    def cancel(self) -> bool:
        if super().cancel():
            self._call_response.cancel()
            return True
        else:
            return False

    def __await__(self) -> ResponseType:
        """Wait till the ongoing RPC request finishes."""
        try:
            response = yield from self._call_response
        except asyncio.CancelledError:
            # Even if we caught all other CancelledError, there is still
            # this corner case. If the application cancels immediately after
            # the Call object is created, we will observe this
            # `CancelledError`.
            if not self.cancelled():
                self.cancel()
            raise

        # NOTE(lidiz) If we raise RpcError in the task, and users doesn't
        # 'await' on it. AsyncIO will log 'Task exception was never retrieved'.
        # Instead, if we move the exception raising here, the spam stops.
        # Unfortunately, there can only be one 'yield from' in '__await__'. So,
        # we need to access the private instance variable.
        if response is cygrpc.EOF:
            if self._cython_call.is_locally_cancelled():
                raise asyncio.CancelledError()
            else:
                raise _create_rpc_error(self._cython_call._initial_metadata,
                                        self._cython_call._status)
        else:
            return response


class _StreamResponseMixin(Call):
    _message_aiter: AsyncIterable[ResponseType]
    _preparation: asyncio.Task
    _response_style: _APIStyle

    def _init_stream_response_mixin(self, preparation: asyncio.Task):
        self._message_aiter = None
        self._preparation = preparation
        self._response_style = _APIStyle.UNKNOWN

    def _update_response_style(self, style: _APIStyle):
        if self._response_style is _APIStyle.UNKNOWN:
            self._response_style = style
        elif self._response_style is not style:
            raise cygrpc.UsageError(
                'Please don\'t mix two styles of API for streaming responses')

    def cancel(self) -> bool:
        if super().cancel():
            self._preparation.cancel()
            return True
        else:
            return False

    async def _fetch_stream_responses(self) -> ResponseType:
        message = await self._read()
        while message is not cygrpc.EOF:
            yield message
            message = await self._read()

    def __aiter__(self) -> AsyncIterable[ResponseType]:
        self._update_response_style(_APIStyle.ASYNC_GENERATOR)
        if self._message_aiter is None:
            self._message_aiter = self._fetch_stream_responses()
        return self._message_aiter

    async def _read(self) -> ResponseType:
        # Wait for the request being sent
        await self._preparation

        # Reads response message from Core
        try:
            raw_response = await self._cython_call.receive_serialized_message()
        except asyncio.CancelledError:
            if not self.cancelled():
                self.cancel()
            await self._raise_for_status()

        if raw_response is cygrpc.EOF:
            return cygrpc.EOF
        else:
            return _common.deserialize(raw_response,
                                       self._response_deserializer)

    async def read(self) -> ResponseType:
        if self.done():
            await self._raise_for_status()
            return cygrpc.EOF
        self._update_response_style(_APIStyle.READER_WRITER)

        response_message = await self._read()

        if response_message is cygrpc.EOF:
            # If the read operation failed, Core should explain why.
            await self._raise_for_status()
        return response_message


class _StreamRequestMixin(Call):
    _metadata_sent: asyncio.Event
    _done_writing_flag: bool
    _async_request_poller: Optional[asyncio.Task]
    _request_style: _APIStyle

    def _init_stream_request_mixin(
            self, request_iterator: Optional[RequestIterableType]):
        self._metadata_sent = asyncio.Event(loop=self._loop)
        self._done_writing_flag = False

        # If user passes in an async iterator, create a consumer Task.
        if request_iterator is not None:
            self._async_request_poller = self._loop.create_task(
                self._consume_request_iterator(request_iterator))
            self._request_style = _APIStyle.ASYNC_GENERATOR
        else:
            self._async_request_poller = None
            self._request_style = _APIStyle.READER_WRITER

    def _raise_for_different_style(self, style: _APIStyle):
        if self._request_style is not style:
            raise cygrpc.UsageError(
                'Please don\'t mix two styles of API for streaming requests')

    def cancel(self) -> bool:
        if super().cancel():
            if self._async_request_poller is not None:
                self._async_request_poller.cancel()
            return True
        else:
            return False

    def _metadata_sent_observer(self):
        self._metadata_sent.set()

    async def _consume_request_iterator(self,
                                        request_iterator: RequestIterableType
                                       ) -> None:
        try:
            if inspect.isasyncgen(request_iterator):
                async for request in request_iterator:
                    await self._write(request)
            else:
                for request in request_iterator:
                    await self._write(request)

            await self._done_writing()
        except AioRpcError as rpc_error:
            # Rpc status should be exposed through other API. Exceptions raised
            # within this Task won't be retrieved by another coroutine. It's
            # better to suppress the error than spamming users' screen.
            _LOGGER.debug('Exception while consuming the request_iterator: %s',
                          rpc_error)

    async def _write(self, request: RequestType) -> None:
        if self.done():
            raise asyncio.InvalidStateError(_RPC_ALREADY_FINISHED_DETAILS)
        if self._done_writing_flag:
            raise asyncio.InvalidStateError(_RPC_HALF_CLOSED_DETAILS)
        if not self._metadata_sent.is_set():
            await self._metadata_sent.wait()
            if self.done():
                await self._raise_for_status()

        serialized_request = _common.serialize(request,
                                               self._request_serializer)

        try:
            await self._cython_call.send_serialized_message(serialized_request)
        except asyncio.CancelledError:
            if not self.cancelled():
                self.cancel()
            await self._raise_for_status()

    async def _done_writing(self) -> None:
        if self.done():
            # If the RPC is finished, do nothing.
            return
        if not self._done_writing_flag:
            # If the done writing is not sent before, try to send it.
            self._done_writing_flag = True
            try:
                await self._cython_call.send_receive_close()
            except asyncio.CancelledError:
                if not self.cancelled():
                    self.cancel()
                await self._raise_for_status()

    async def write(self, request: RequestType) -> None:
        self._raise_for_different_style(_APIStyle.READER_WRITER)
        await self._write(request)

    async def done_writing(self) -> None:
        """Signal peer that client is done writing.

        This method is idempotent.
        """
        self._raise_for_different_style(_APIStyle.READER_WRITER)
        await self._done_writing()

    async def wait_for_connection(self) -> None:
        await self._metadata_sent.wait()
        if self.done():
            await self._raise_for_status()


class UnaryUnaryCall(_UnaryResponseMixin, Call, _base_call.UnaryUnaryCall):
    """Object for managing unary-unary RPC calls.

    Returned when an instance of `UnaryUnaryMultiCallable` object is called.
    """
    _request: RequestType
    _invocation_task: asyncio.Task

    # pylint: disable=too-many-arguments
    def __init__(self, request: RequestType, deadline: Optional[float],
                 metadata: MetadataType,
                 credentials: Optional[grpc.CallCredentials],
                 wait_for_ready: Optional[bool], channel: cygrpc.AioChannel,
                 method: bytes, request_serializer: SerializingFunction,
                 response_deserializer: DeserializingFunction,
                 loop: asyncio.AbstractEventLoop) -> None:
        super().__init__(
            channel.call(method, deadline, credentials, wait_for_ready),
            metadata, request_serializer, response_deserializer, loop)
        self._request = request
        self._invocation_task = loop.create_task(self._invoke())
        self._init_unary_response_mixin(self._invocation_task)

    async def _invoke(self) -> ResponseType:
        serialized_request = _common.serialize(self._request,
                                               self._request_serializer)

        # NOTE(lidiz) asyncio.CancelledError is not a good transport for status,
        # because the asyncio.Task class do not cache the exception object.
        # https://github.com/python/cpython/blob/edad4d89e357c92f70c0324b937845d652b20afd/Lib/asyncio/tasks.py#L785
        try:
            serialized_response = await self._cython_call.unary_unary(
                serialized_request, self._metadata)
        except asyncio.CancelledError:
            if not self.cancelled():
                self.cancel()

        if self._cython_call.is_ok():
            return _common.deserialize(serialized_response,
                                       self._response_deserializer)
        else:
            return cygrpc.EOF

    async def wait_for_connection(self) -> None:
        await self._invocation_task
        if self.done():
            await self._raise_for_status()


class UnaryStreamCall(_StreamResponseMixin, Call, _base_call.UnaryStreamCall):
    """Object for managing unary-stream RPC calls.

    Returned when an instance of `UnaryStreamMultiCallable` object is called.
    """
    _request: RequestType
    _send_unary_request_task: asyncio.Task

    # pylint: disable=too-many-arguments
    def __init__(self, request: RequestType, deadline: Optional[float],
                 metadata: MetadataType,
                 credentials: Optional[grpc.CallCredentials],
                 wait_for_ready: Optional[bool], channel: cygrpc.AioChannel,
                 method: bytes, request_serializer: SerializingFunction,
                 response_deserializer: DeserializingFunction,
                 loop: asyncio.AbstractEventLoop) -> None:
        super().__init__(
            channel.call(method, deadline, credentials, wait_for_ready),
            metadata, request_serializer, response_deserializer, loop)
        self._request = request
        self._send_unary_request_task = loop.create_task(
            self._send_unary_request())
        self._init_stream_response_mixin(self._send_unary_request_task)

    async def _send_unary_request(self) -> ResponseType:
        serialized_request = _common.serialize(self._request,
                                               self._request_serializer)
        try:
            await self._cython_call.initiate_unary_stream(
                serialized_request, self._metadata)
        except asyncio.CancelledError:
            if not self.cancelled():
                self.cancel()
            raise

    async def wait_for_connection(self) -> None:
        await self._send_unary_request_task
        if self.done():
            await self._raise_for_status()


class StreamUnaryCall(_StreamRequestMixin, _UnaryResponseMixin, Call,
                      _base_call.StreamUnaryCall):
    """Object for managing stream-unary RPC calls.

    Returned when an instance of `StreamUnaryMultiCallable` object is called.
    """

    # pylint: disable=too-many-arguments
    def __init__(self, request_iterator: Optional[RequestIterableType],
                 deadline: Optional[float], metadata: MetadataType,
                 credentials: Optional[grpc.CallCredentials],
                 wait_for_ready: Optional[bool], channel: cygrpc.AioChannel,
                 method: bytes, request_serializer: SerializingFunction,
                 response_deserializer: DeserializingFunction,
                 loop: asyncio.AbstractEventLoop) -> None:
        super().__init__(
            channel.call(method, deadline, credentials, wait_for_ready),
            metadata, request_serializer, response_deserializer, loop)

        self._init_stream_request_mixin(request_iterator)
        self._init_unary_response_mixin(loop.create_task(self._conduct_rpc()))

    async def _conduct_rpc(self) -> ResponseType:
        try:
            serialized_response = await self._cython_call.stream_unary(
                self._metadata, self._metadata_sent_observer)
        except asyncio.CancelledError:
            if not self.cancelled():
                self.cancel()

        if self._cython_call.is_ok():
            return _common.deserialize(serialized_response,
                                       self._response_deserializer)
        else:
            return cygrpc.EOF


class StreamStreamCall(_StreamRequestMixin, _StreamResponseMixin, Call,
                       _base_call.StreamStreamCall):
    """Object for managing stream-stream RPC calls.

    Returned when an instance of `StreamStreamMultiCallable` object is called.
    """
    _initializer: asyncio.Task

    # pylint: disable=too-many-arguments
    def __init__(self, request_iterator: Optional[RequestIterableType],
                 deadline: Optional[float], metadata: MetadataType,
                 credentials: Optional[grpc.CallCredentials],
                 wait_for_ready: Optional[bool], channel: cygrpc.AioChannel,
                 method: bytes, request_serializer: SerializingFunction,
                 response_deserializer: DeserializingFunction,
                 loop: asyncio.AbstractEventLoop) -> None:
        super().__init__(
            channel.call(method, deadline, credentials, wait_for_ready),
            metadata, request_serializer, response_deserializer, loop)
        self._initializer = self._loop.create_task(self._prepare_rpc())
        self._init_stream_request_mixin(request_iterator)
        self._init_stream_response_mixin(self._initializer)

    async def _prepare_rpc(self):
        """This method prepares the RPC for receiving/sending messages.

        All other operations around the stream should only happen after the
        completion of this method.
        """
        try:
            await self._cython_call.initiate_stream_stream(
                self._metadata, self._metadata_sent_observer)
        except asyncio.CancelledError:
            if not self.cancelled():
                self.cancel()
            # No need to raise RpcError here, because no one will `await` this task.