# Copyright (C) 2016-present the asyncpg authors and contributors
# <see AUTHORS file>
#
# This module is part of asyncpg and is released under
# the Apache 2.0 License: http://www.apache.org/licenses/LICENSE-2.0
import asyncio
import asyncpg
import collections
import collections.abc
import itertools
import struct
import sys
import time
import traceback
import warnings
from . import compat
from . import connect_utils
from . import cursor
from . import exceptions
from . import introspection
from . import prepared_stmt
from . import protocol
from . import serverversion
from . import transaction
from . import utils
class ConnectionMeta(type):
def __instancecheck__(cls, instance):
mro = type(instance).__mro__
return Connection in mro or _ConnectionProxy in mro
class Connection(metaclass=ConnectionMeta):
"""A representation of a database session.
Connections are created by calling :func:`~asyncpg.connection.connect`.
"""
__slots__ = ('_protocol', '_transport', '_loop',
'_top_xact', '_aborted',
'_pool_release_ctr', '_stmt_cache', '_stmts_to_close',
'_listeners', '_server_version', '_server_caps',
'_intro_query', '_reset_query', '_proxy',
'_stmt_exclusive_section', '_config', '_params', '_addr',
'_log_listeners', '_cancellations', '_source_traceback',
'__weakref__')
def __init__(self, protocol, transport, loop,
addr: (str, int) or str,
config: connect_utils._ClientConfiguration,
params: connect_utils._ConnectionParameters):
self._protocol = protocol
self._transport = transport
self._loop = loop
self._top_xact = None
self._aborted = False
# Incremented every time the connection is released back to a pool.
# Used to catch invalid references to connection-related resources
# post-release (e.g. explicit prepared statements).
self._pool_release_ctr = 0
self._addr = addr
self._config = config
self._params = params
self._stmt_cache = _StatementCache(
loop=loop,
max_size=config.statement_cache_size,
on_remove=self._maybe_gc_stmt,
max_lifetime=config.max_cached_statement_lifetime)
self._stmts_to_close = set()
self._listeners = {}
self._log_listeners = set()
self._cancellations = set()
settings = self._protocol.get_settings()
ver_string = settings.server_version
self._server_version = \
serverversion.split_server_version_string(ver_string)
self._server_caps = _detect_server_capabilities(
self._server_version, settings)
self._intro_query = introspection.INTRO_LOOKUP_TYPES
self._reset_query = None
self._proxy = None
# Used to serialize operations that might involve anonymous
# statements. Specifically, we want to make the following
# operation atomic:
# ("prepare an anonymous statement", "use the statement")
#
# Used for `con.fetchval()`, `con.fetch()`, `con.fetchrow()`,
# `con.execute()`, and `con.executemany()`.
self._stmt_exclusive_section = _Atomic()
if loop.get_debug():
self._source_traceback = _extract_stack()
else:
self._source_traceback = None
def __del__(self):
if not self.is_closed() and self._protocol is not None:
if self._source_traceback:
msg = "unclosed connection {!r}; created at:\n {}".format(
self, self._source_traceback)
else:
msg = (
"unclosed connection {!r}; run in asyncio debug "
"mode to show the traceback of connection "
"origin".format(self)
)
warnings.warn(msg, ResourceWarning)
if not self._loop.is_closed():
self.terminate()
async def add_listener(self, channel, callback):
"""Add a listener for Postgres notifications.
:param str channel: Channel to listen on.
:param callable callback:
A callable receiving the following arguments:
**connection**: a Connection the callback is registered with;
**pid**: PID of the Postgres server that sent the notification;
**channel**: name of the channel the notification was sent to;
**payload**: the payload.
"""
self._check_open()
if channel not in self._listeners:
await self.fetch('LISTEN {}'.format(utils._quote_ident(channel)))
self._listeners[channel] = set()
self._listeners[channel].add(callback)
async def remove_listener(self, channel, callback):
"""Remove a listening callback on the specified channel."""
if self.is_closed():
return
if channel not in self._listeners:
return
if callback not in self._listeners[channel]:
return
self._listeners[channel].remove(callback)
if not self._listeners[channel]:
del self._listeners[channel]
await self.fetch('UNLISTEN {}'.format(utils._quote_ident(channel)))
def add_log_listener(self, callback):
"""Add a listener for Postgres log messages.
It will be called when asyncronous NoticeResponse is received
from the connection. Possible message types are: WARNING, NOTICE,
DEBUG, INFO, or LOG.
:param callable callback:
A callable receiving the following arguments:
**connection**: a Connection the callback is registered with;
**message**: the `exceptions.PostgresLogMessage` message.
.. versionadded:: 0.12.0
"""
if self.is_closed():
raise exceptions.InterfaceError('connection is closed')
self._log_listeners.add(callback)
def remove_log_listener(self, callback):
"""Remove a listening callback for log messages.
.. versionadded:: 0.12.0
"""
self._log_listeners.discard(callback)
def get_server_pid(self):
"""Return the PID of the Postgres server the connection is bound to."""
return self._protocol.get_server_pid()
def get_server_version(self):
"""Return the version of the connected PostgreSQL server.
The returned value is a named tuple similar to that in
``sys.version_info``:
.. code-block:: pycon
>>> con.get_server_version()
ServerVersion(major=9, minor=6, micro=1,
releaselevel='final', serial=0)
.. versionadded:: 0.8.0
"""
return self._server_version
def get_settings(self):
"""Return connection settings.
:return: :class:`~asyncpg.ConnectionSettings`.
"""
return self._protocol.get_settings()
def transaction(self, *, isolation='read_committed', readonly=False,
deferrable=False):
"""Create a :class:`~transaction.Transaction` object.
Refer to `PostgreSQL documentation`_ on the meaning of transaction
parameters.
:param isolation: Transaction isolation mode, can be one of:
`'serializable'`, `'repeatable_read'`,
`'read_committed'`.
:param readonly: Specifies whether or not this transaction is
read-only.
:param deferrable: Specifies whether or not this transaction is
deferrable.
.. _`PostgreSQL documentation`:
https://www.postgresql.org/docs/
current/static/sql-set-transaction.html
"""
self._check_open()
return transaction.Transaction(self, isolation, readonly, deferrable)
def is_in_transaction(self):
"""Return True if Connection is currently inside a transaction.
:return bool: True if inside transaction, False otherwise.
.. versionadded:: 0.16.0
"""
return self._protocol.is_in_transaction()
async def execute(self, query: str, *args, timeout: float=None) -> str:
"""Execute an SQL command (or commands).
This method can execute many SQL commands at once, when no arguments
are provided.
Example:
.. code-block:: pycon
>>> await con.execute('''
... CREATE TABLE mytab (a int);
... INSERT INTO mytab (a) VALUES (100), (200), (300);
... ''')
INSERT 0 3
>>> await con.execute('''
... INSERT INTO mytab (a) VALUES ($1), ($2)
... ''', 10, 20)
INSERT 0 2
:param args: Query arguments.
:param float timeout: Optional timeout value in seconds.
:return str: Status of the last SQL command.
.. versionchanged:: 0.5.4
Made it possible to pass query arguments.
"""
self._check_open()
if not args:
return await self._protocol.query(query, timeout)
_, status, _ = await self._execute(query, args, 0, timeout, True)
return status.decode()
async def executemany(self, command: str, args, *, timeout: float=None):
"""Execute an SQL *command* for each sequence of arguments in *args*.
Example:
.. code-block:: pycon
>>> await con.executemany('''
... INSERT INTO mytab (a) VALUES ($1, $2, $3);
... ''', [(1, 2, 3), (4, 5, 6)])
:param command: Command to execute.
:param args: An iterable containing sequences of arguments.
:param float timeout: Optional timeout value in seconds.
:return None: This method discards the results of the operations.
.. note::
When inserting a large number of rows,
use :meth:`Connection.copy_records_to_table()` instead,
it is much more efficient for this purpose.
.. versionadded:: 0.7.0
.. versionchanged:: 0.11.0
`timeout` became a keyword-only parameter.
"""
self._check_open()
return await self._executemany(command, args, timeout)
async def _get_statement(self, query, timeout, *, named: bool=False,
use_cache: bool=True):
if use_cache:
statement = self._stmt_cache.get(query)
if statement is not None:
return statement
# Only use the cache when:
# * `statement_cache_size` is greater than 0;
# * query size is less than `max_cacheable_statement_size`.
use_cache = self._stmt_cache.get_max_size() > 0
if (use_cache and
self._config.max_cacheable_statement_size and
len(query) > self._config.max_cacheable_statement_size):
use_cache = False
if use_cache or named:
stmt_name = self._get_unique_id('stmt')
else:
stmt_name = ''
statement = await self._protocol.prepare(stmt_name, query, timeout)
need_reprepare = False
types_with_missing_codecs = statement._init_types()
tries = 0
while types_with_missing_codecs:
settings = self._protocol.get_settings()
# Introspect newly seen types and populate the
# codec cache.
types, intro_stmt = await self._introspect_types(
types_with_missing_codecs, timeout)
settings.register_data_types(types)
# The introspection query has used an anonymous statement,
# which has blown away the anonymous statement we've prepared
# for the query, so we need to re-prepare it.
Loading ...