# Copyright 2016 gRPC authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Service-side implementation of gRPC Python."""
import collections
import enum
import logging
import threading
import time
from concurrent import futures
import six
import grpc
from grpc import _common
from grpc import _compression
from grpc import _interceptor
from grpc._cython import cygrpc
_LOGGER = logging.getLogger(__name__)
_SHUTDOWN_TAG = 'shutdown'
_REQUEST_CALL_TAG = 'request_call'
_RECEIVE_CLOSE_ON_SERVER_TOKEN = 'receive_close_on_server'
_SEND_INITIAL_METADATA_TOKEN = 'send_initial_metadata'
_RECEIVE_MESSAGE_TOKEN = 'receive_message'
_SEND_MESSAGE_TOKEN = 'send_message'
_SEND_INITIAL_METADATA_AND_SEND_MESSAGE_TOKEN = (
'send_initial_metadata * send_message')
_SEND_STATUS_FROM_SERVER_TOKEN = 'send_status_from_server'
_SEND_INITIAL_METADATA_AND_SEND_STATUS_FROM_SERVER_TOKEN = (
'send_initial_metadata * send_status_from_server')
_OPEN = 'open'
_CLOSED = 'closed'
_CANCELLED = 'cancelled'
_EMPTY_FLAGS = 0
_DEALLOCATED_SERVER_CHECK_PERIOD_S = 1.0
_INF_TIMEOUT = 1e9
def _serialized_request(request_event):
return request_event.batch_operations[0].message()
def _application_code(code):
cygrpc_code = _common.STATUS_CODE_TO_CYGRPC_STATUS_CODE.get(code)
return cygrpc.StatusCode.unknown if cygrpc_code is None else cygrpc_code
def _completion_code(state):
if state.code is None:
return cygrpc.StatusCode.ok
else:
return _application_code(state.code)
def _abortion_code(state, code):
if state.code is None:
return code
else:
return _application_code(state.code)
def _details(state):
return b'' if state.details is None else state.details
class _HandlerCallDetails(
collections.namedtuple('_HandlerCallDetails', (
'method',
'invocation_metadata',
)), grpc.HandlerCallDetails):
pass
class _RPCState(object):
def __init__(self):
self.condition = threading.Condition()
self.due = set()
self.request = None
self.client = _OPEN
self.initial_metadata_allowed = True
self.compression_algorithm = None
self.disable_next_compression = False
self.trailing_metadata = None
self.code = None
self.details = None
self.statused = False
self.rpc_errors = []
self.callbacks = []
self.aborted = False
def _raise_rpc_error(state):
rpc_error = grpc.RpcError()
state.rpc_errors.append(rpc_error)
raise rpc_error
def _possibly_finish_call(state, token):
state.due.remove(token)
if not _is_rpc_state_active(state) and not state.due:
callbacks = state.callbacks
state.callbacks = None
return state, callbacks
else:
return None, ()
def _send_status_from_server(state, token):
def send_status_from_server(unused_send_status_from_server_event):
with state.condition:
return _possibly_finish_call(state, token)
return send_status_from_server
def _get_initial_metadata(state, metadata):
with state.condition:
if state.compression_algorithm:
compression_metadata = (
_compression.compression_algorithm_to_metadata(
state.compression_algorithm),)
if metadata is None:
return compression_metadata
else:
return compression_metadata + tuple(metadata)
else:
return metadata
def _get_initial_metadata_operation(state, metadata):
operation = cygrpc.SendInitialMetadataOperation(
_get_initial_metadata(state, metadata), _EMPTY_FLAGS)
return operation
def _abort(state, call, code, details):
if state.client is not _CANCELLED:
effective_code = _abortion_code(state, code)
effective_details = details if state.details is None else state.details
if state.initial_metadata_allowed:
operations = (
_get_initial_metadata_operation(state, None),
cygrpc.SendStatusFromServerOperation(state.trailing_metadata,
effective_code,
effective_details,
_EMPTY_FLAGS),
)
token = _SEND_INITIAL_METADATA_AND_SEND_STATUS_FROM_SERVER_TOKEN
else:
operations = (cygrpc.SendStatusFromServerOperation(
state.trailing_metadata, effective_code, effective_details,
_EMPTY_FLAGS),)
token = _SEND_STATUS_FROM_SERVER_TOKEN
call.start_server_batch(operations,
_send_status_from_server(state, token))
state.statused = True
state.due.add(token)
def _receive_close_on_server(state):
def receive_close_on_server(receive_close_on_server_event):
with state.condition:
if receive_close_on_server_event.batch_operations[0].cancelled():
state.client = _CANCELLED
elif state.client is _OPEN:
state.client = _CLOSED
state.condition.notify_all()
return _possibly_finish_call(state, _RECEIVE_CLOSE_ON_SERVER_TOKEN)
return receive_close_on_server
def _receive_message(state, call, request_deserializer):
def receive_message(receive_message_event):
serialized_request = _serialized_request(receive_message_event)
if serialized_request is None:
with state.condition:
if state.client is _OPEN:
state.client = _CLOSED
state.condition.notify_all()
return _possibly_finish_call(state, _RECEIVE_MESSAGE_TOKEN)
else:
request = _common.deserialize(serialized_request,
request_deserializer)
with state.condition:
if request is None:
_abort(state, call, cygrpc.StatusCode.internal,
b'Exception deserializing request!')
else:
state.request = request
state.condition.notify_all()
return _possibly_finish_call(state, _RECEIVE_MESSAGE_TOKEN)
return receive_message
def _send_initial_metadata(state):
def send_initial_metadata(unused_send_initial_metadata_event):
with state.condition:
return _possibly_finish_call(state, _SEND_INITIAL_METADATA_TOKEN)
return send_initial_metadata
def _send_message(state, token):
def send_message(unused_send_message_event):
with state.condition:
state.condition.notify_all()
return _possibly_finish_call(state, token)
return send_message
class _Context(grpc.ServicerContext):
def __init__(self, rpc_event, state, request_deserializer):
self._rpc_event = rpc_event
self._state = state
self._request_deserializer = request_deserializer
def is_active(self):
with self._state.condition:
return _is_rpc_state_active(self._state)
def time_remaining(self):
return max(self._rpc_event.call_details.deadline - time.time(), 0)
def cancel(self):
self._rpc_event.call.cancel()
def add_callback(self, callback):
with self._state.condition:
if self._state.callbacks is None:
return False
else:
self._state.callbacks.append(callback)
return True
def disable_next_message_compression(self):
with self._state.condition:
self._state.disable_next_compression = True
def invocation_metadata(self):
return self._rpc_event.invocation_metadata
def peer(self):
return _common.decode(self._rpc_event.call.peer())
def peer_identities(self):
return cygrpc.peer_identities(self._rpc_event.call)
def peer_identity_key(self):
id_key = cygrpc.peer_identity_key(self._rpc_event.call)
return id_key if id_key is None else _common.decode(id_key)
def auth_context(self):
return {
_common.decode(key): value for key, value in six.iteritems(
cygrpc.auth_context(self._rpc_event.call))
}
def set_compression(self, compression):
with self._state.condition:
self._state.compression_algorithm = compression
def send_initial_metadata(self, initial_metadata):
with self._state.condition:
if self._state.client is _CANCELLED:
_raise_rpc_error(self._state)
else:
if self._state.initial_metadata_allowed:
operation = _get_initial_metadata_operation(
self._state, initial_metadata)
self._rpc_event.call.start_server_batch(
(operation,), _send_initial_metadata(self._state))
self._state.initial_metadata_allowed = False
self._state.due.add(_SEND_INITIAL_METADATA_TOKEN)
else:
raise ValueError('Initial metadata no longer allowed!')
def set_trailing_metadata(self, trailing_metadata):
with self._state.condition:
self._state.trailing_metadata = trailing_metadata
def abort(self, code, details):
# treat OK like other invalid arguments: fail the RPC
if code == grpc.StatusCode.OK:
_LOGGER.error(
'abort() called with StatusCode.OK; returning UNKNOWN')
code = grpc.StatusCode.UNKNOWN
details = ''
with self._state.condition:
self._state.code = code
self._state.details = _common.encode(details)
self._state.aborted = True
raise Exception()
def abort_with_status(self, status):
self._state.trailing_metadata = status.trailing_metadata
self.abort(status.code, status.details)
def set_code(self, code):
with self._state.condition:
self._state.code = code
def set_details(self, details):
with self._state.condition:
self._state.details = _common.encode(details)
def _finalize_state(self):
pass
class _RequestIterator(object):
def __init__(self, state, call, request_deserializer):
self._state = state
self._call = call
self._request_deserializer = request_deserializer
def _raise_or_start_receive_message(self):
if self._state.client is _CANCELLED:
Loading ...