# Copyright (C) 2016-present the asyncpg authors and contributors
# <see AUTHORS file>
#
# This module is part of asyncpg and is released under
# the Apache 2.0 License: http://www.apache.org/licenses/LICENSE-2.0
# cython: language_level=3
cimport cython
cimport cpython
import asyncio
import builtins
import codecs
import collections
import socket
import time
import weakref
from asyncpg.pgproto.pgproto cimport (
WriteBuffer,
ReadBuffer,
FRBuffer,
frb_init,
frb_read,
frb_read_all,
frb_slice_from,
frb_check,
frb_set_len,
frb_get_len,
)
from asyncpg.pgproto cimport pgproto
from asyncpg.protocol cimport cpythonx
from asyncpg.protocol cimport record
from libc.stdint cimport int8_t, uint8_t, int16_t, uint16_t, \
int32_t, uint32_t, int64_t, uint64_t, \
UINT32_MAX
from asyncpg.exceptions import _base as apg_exc_base
from asyncpg import compat
from asyncpg import types as apg_types
from asyncpg import exceptions as apg_exc
from asyncpg.pgproto cimport hton
include "consts.pxi"
include "pgtypes.pxi"
include "encodings.pyx"
include "settings.pyx"
include "codecs/base.pyx"
include "codecs/textutils.pyx"
# register codecs provided by pgproto
include "codecs/pgproto.pyx"
# nonscalar
include "codecs/array.pyx"
include "codecs/range.pyx"
include "codecs/record.pyx"
include "coreproto.pyx"
include "prepared_stmt.pyx"
NO_TIMEOUT = object()
cdef class BaseProtocol(CoreProtocol):
def __init__(self, addr, connected_fut, con_params, loop):
# type of `con_params` is `_ConnectionParameters`
CoreProtocol.__init__(self, con_params)
self.loop = loop
self.transport = None
self.waiter = connected_fut
self.cancel_waiter = None
self.cancel_sent_waiter = None
self.address = addr
self.settings = ConnectionSettings((self.address, con_params.database))
self.statement = None
self.return_extra = False
self.last_query = None
self.closing = False
self.is_reading = True
self.writing_allowed = asyncio.Event(loop=self.loop)
self.writing_allowed.set()
self.timeout_handle = None
self.timeout_callback = self._on_timeout
self.completed_callback = self._on_waiter_completed
self.queries_count = 0
try:
self.create_future = loop.create_future
except AttributeError:
self.create_future = self._create_future_fallback
def set_connection(self, connection):
self.conref = weakref.ref(connection)
cdef get_connection(self):
if self.conref is not None:
return self.conref()
else:
return None
def get_server_pid(self):
return self.backend_pid
def get_settings(self):
return self.settings
def is_in_transaction(self):
# PQTRANS_INTRANS = idle, within transaction block
# PQTRANS_INERROR = idle, within failed transaction
return self.xact_status in (PQTRANS_INTRANS, PQTRANS_INERROR)
cdef inline resume_reading(self):
if not self.is_reading:
self.is_reading = True
self.transport.resume_reading()
cdef inline pause_reading(self):
if self.is_reading:
self.is_reading = False
self.transport.pause_reading()
@cython.iterable_coroutine
async def prepare(self, stmt_name, query, timeout,
PreparedStatementState state=None):
if self.cancel_waiter is not None:
await self.cancel_waiter
if self.cancel_sent_waiter is not None:
await self.cancel_sent_waiter
self.cancel_sent_waiter = None
self._check_state()
timeout = self._get_timeout_impl(timeout)
waiter = self._new_waiter(timeout)
try:
self._prepare(stmt_name, query) # network op
self.last_query = query
if state is None:
state = PreparedStatementState(stmt_name, query, self)
self.statement = state
except Exception as ex:
waiter.set_exception(ex)
self._coreproto_error()
finally:
return await waiter
@cython.iterable_coroutine
async def bind_execute(self, PreparedStatementState state, args,
str portal_name, int limit, return_extra,
timeout):
if self.cancel_waiter is not None:
await self.cancel_waiter
if self.cancel_sent_waiter is not None:
await self.cancel_sent_waiter
self.cancel_sent_waiter = None
self._check_state()
timeout = self._get_timeout_impl(timeout)
args_buf = state._encode_bind_msg(args)
waiter = self._new_waiter(timeout)
try:
self._bind_execute(
portal_name,
state.name,
args_buf,
limit) # network op
self.last_query = state.query
self.statement = state
self.return_extra = return_extra
self.queries_count += 1
except Exception as ex:
waiter.set_exception(ex)
self._coreproto_error()
finally:
return await waiter
@cython.iterable_coroutine
async def bind_execute_many(self, PreparedStatementState state, args,
str portal_name, timeout):
if self.cancel_waiter is not None:
await self.cancel_waiter
if self.cancel_sent_waiter is not None:
await self.cancel_sent_waiter
self.cancel_sent_waiter = None
self._check_state()
timeout = self._get_timeout_impl(timeout)
# Make sure the argument sequence is encoded lazily with
# this generator expression to keep the memory pressure under
# control.
data_gen = (state._encode_bind_msg(b) for b in args)
arg_bufs = iter(data_gen)
waiter = self._new_waiter(timeout)
try:
self._bind_execute_many(
portal_name,
state.name,
arg_bufs) # network op
self.last_query = state.query
self.statement = state
self.return_extra = False
self.queries_count += 1
except Exception as ex:
waiter.set_exception(ex)
self._coreproto_error()
finally:
return await waiter
@cython.iterable_coroutine
async def bind(self, PreparedStatementState state, args,
str portal_name, timeout):
if self.cancel_waiter is not None:
await self.cancel_waiter
if self.cancel_sent_waiter is not None:
await self.cancel_sent_waiter
self.cancel_sent_waiter = None
self._check_state()
timeout = self._get_timeout_impl(timeout)
args_buf = state._encode_bind_msg(args)
waiter = self._new_waiter(timeout)
try:
self._bind(
portal_name,
state.name,
args_buf) # network op
self.last_query = state.query
self.statement = state
except Exception as ex:
waiter.set_exception(ex)
self._coreproto_error()
finally:
return await waiter
@cython.iterable_coroutine
async def execute(self, PreparedStatementState state,
str portal_name, int limit, return_extra,
timeout):
if self.cancel_waiter is not None:
await self.cancel_waiter
if self.cancel_sent_waiter is not None:
await self.cancel_sent_waiter
self.cancel_sent_waiter = None
self._check_state()
timeout = self._get_timeout_impl(timeout)
waiter = self._new_waiter(timeout)
try:
self._execute(
portal_name,
limit) # network op
self.last_query = state.query
self.statement = state
self.return_extra = return_extra
self.queries_count += 1
except Exception as ex:
waiter.set_exception(ex)
self._coreproto_error()
finally:
return await waiter
@cython.iterable_coroutine
async def query(self, query, timeout):
if self.cancel_waiter is not None:
await self.cancel_waiter
if self.cancel_sent_waiter is not None:
await self.cancel_sent_waiter
self.cancel_sent_waiter = None
self._check_state()
# query() needs to call _get_timeout instead of _get_timeout_impl
# for consistent validation, as it is called differently from
# prepare/bind/execute methods.
timeout = self._get_timeout(timeout)
waiter = self._new_waiter(timeout)
try:
self._simple_query(query) # network op
self.last_query = query
self.queries_count += 1
except Exception as ex:
waiter.set_exception(ex)
self._coreproto_error()
finally:
return await waiter
@cython.iterable_coroutine
async def copy_out(self, copy_stmt, sink, timeout):
if self.cancel_waiter is not None:
await self.cancel_waiter
if self.cancel_sent_waiter is not None:
await self.cancel_sent_waiter
self.cancel_sent_waiter = None
self._check_state()
timeout = self._get_timeout_impl(timeout)
timer = Timer(timeout)
# The copy operation is guarded by a single timeout
# on the top level.
waiter = self._new_waiter(timer.get_remaining_budget())
self._copy_out(copy_stmt)
try:
while True:
self.resume_reading()
with timer:
buffer, done, status_msg = await waiter
# buffer will be empty if CopyDone was received apart from
# the last CopyData message.
Loading ...