Why Gemfury? Push, build, and install  RubyGems npm packages Python packages Maven artifacts PHP packages Go Modules Debian packages RPM packages NuGet packages

Repository URL to install this package:

Details    
Size: Mime:
"""
Transaction context managers returned by Connection.transaction()
"""

# Copyright (C) 2020 The Psycopg Team

import logging

from types import TracebackType
from typing import Generic, Iterator, Optional, Type, Union, TypeVar, TYPE_CHECKING

from . import pq
from . import sql
from . import errors as e
from .abc import ConnectionType, PQGen
from .pq.misc import connection_summary

if TYPE_CHECKING:
    from typing import Any
    from .connection import Connection
    from .connection_async import AsyncConnection

IDLE = pq.TransactionStatus.IDLE

OK = pq.ConnStatus.OK

logger = logging.getLogger(__name__)


class Rollback(Exception):
    """
    Exit the current `Transaction` context immediately and rollback any changes
    made within this context.

    If a transaction context is specified in the constructor, rollback
    enclosing transactions contexts up to and including the one specified.
    """

    __module__ = "psycopg"

    def __init__(
        self,
        transaction: Union["Transaction", "AsyncTransaction", None] = None,
    ):
        self.transaction = transaction

    def __repr__(self) -> str:
        return f"{self.__class__.__qualname__}({self.transaction!r})"


class OutOfOrderTransactionNesting(e.ProgrammingError):
    """Out-of-order transaction nesting detected"""


class BaseTransaction(Generic[ConnectionType]):
    def __init__(
        self,
        connection: ConnectionType,
        savepoint_name: Optional[str] = None,
        force_rollback: bool = False,
    ):
        self._conn = connection
        self.pgconn = self._conn.pgconn
        self._savepoint_name = savepoint_name or ""
        self.force_rollback = force_rollback
        self._entered = self._exited = False
        self._outer_transaction = False
        self._stack_index = -1

    @property
    def savepoint_name(self) -> Optional[str]:
        """
        The name of the savepoint; `!None` if handling the main transaction.
        """
        # Yes, it may change on __enter__. No, I don't care, because the
        # un-entered state is outside the public interface.
        return self._savepoint_name

    def __repr__(self) -> str:
        cls = f"{self.__class__.__module__}.{self.__class__.__qualname__}"
        info = connection_summary(self.pgconn)
        if not self._entered:
            status = "inactive"
        elif not self._exited:
            status = "active"
        else:
            status = "terminated"

        sp = f"{self.savepoint_name!r} " if self.savepoint_name else ""
        return f"<{cls} {sp}({status}) {info} at 0x{id(self):x}>"

    def _enter_gen(self) -> PQGen[None]:
        if self._entered:
            raise TypeError("transaction blocks can be used only once")
        self._entered = True

        self._push_savepoint()
        for command in self._get_enter_commands():
            yield from self._conn._exec_command(command)

    def _exit_gen(
        self,
        exc_type: Optional[Type[BaseException]],
        exc_val: Optional[BaseException],
        exc_tb: Optional[TracebackType],
    ) -> PQGen[bool]:
        if not exc_val and not self.force_rollback:
            yield from self._commit_gen()
            return False
        else:
            # try to rollback, but if there are problems (connection in a bad
            # state) just warn without clobbering the exception bubbling up.
            try:
                return (yield from self._rollback_gen(exc_val))
            except OutOfOrderTransactionNesting:
                # Clobber an exception happened in the block with the exception
                # caused by out-of-order transaction detected, so make the
                # behaviour consistent with _commit_gen and to make sure the
                # user fixes this condition, which is unrelated from
                # operational error that might arise in the block.
                raise
            except Exception as exc2:
                logger.warning("error ignored in rollback of %s: %s", self, exc2)
                return False

    def _commit_gen(self) -> PQGen[None]:
        ex = self._pop_savepoint("commit")
        self._exited = True
        if ex:
            raise ex

        for command in self._get_commit_commands():
            yield from self._conn._exec_command(command)

    def _rollback_gen(self, exc_val: Optional[BaseException]) -> PQGen[bool]:
        if isinstance(exc_val, Rollback):
            logger.debug(f"{self._conn}: Explicit rollback from: ", exc_info=True)

        ex = self._pop_savepoint("rollback")
        self._exited = True
        if ex:
            raise ex

        for command in self._get_rollback_commands():
            yield from self._conn._exec_command(command)

        if isinstance(exc_val, Rollback):
            if not exc_val.transaction or exc_val.transaction is self:
                return True  # Swallow the exception

        return False

    def _get_enter_commands(self) -> Iterator[bytes]:
        if self._outer_transaction:
            yield self._conn._get_tx_start_command()

        if self._savepoint_name:
            yield (
                sql.SQL("SAVEPOINT {}")
                .format(sql.Identifier(self._savepoint_name))
                .as_bytes(self._conn)
            )

    def _get_commit_commands(self) -> Iterator[bytes]:
        if self._savepoint_name and not self._outer_transaction:
            yield (
                sql.SQL("RELEASE {}")
                .format(sql.Identifier(self._savepoint_name))
                .as_bytes(self._conn)
            )

        if self._outer_transaction:
            assert not self._conn._num_transactions
            yield b"COMMIT"

    def _get_rollback_commands(self) -> Iterator[bytes]:
        if self._savepoint_name and not self._outer_transaction:
            yield (
                sql.SQL("ROLLBACK TO {n}")
                .format(n=sql.Identifier(self._savepoint_name))
                .as_bytes(self._conn)
            )
            yield (
                sql.SQL("RELEASE {n}")
                .format(n=sql.Identifier(self._savepoint_name))
                .as_bytes(self._conn)
            )

        if self._outer_transaction:
            assert not self._conn._num_transactions
            yield b"ROLLBACK"

        # Also clear the prepared statements cache.
        if self._conn._prepared.clear():
            yield from self._conn._prepared.get_maintenance_commands()

    def _push_savepoint(self) -> None:
        """
        Push the transaction on the connection transactions stack.

        Also set the internal state of the object and verify consistency.
        """
        self._outer_transaction = self.pgconn.transaction_status == IDLE
        if self._outer_transaction:
            # outer transaction: if no name it's only a begin, else
            # there will be an additional savepoint
            assert not self._conn._num_transactions
        else:
            # inner transaction: it always has a name
            if not self._savepoint_name:
                self._savepoint_name = f"_pg3_{self._conn._num_transactions + 1}"

        self._stack_index = self._conn._num_transactions
        self._conn._num_transactions += 1

    def _pop_savepoint(self, action: str) -> Optional[Exception]:
        """
        Pop the transaction from the connection transactions stack.

        Also verify the state consistency.
        """
        self._conn._num_transactions -= 1
        if self._conn._num_transactions == self._stack_index:
            return None

        return OutOfOrderTransactionNesting(
            f"transaction {action} at the wrong nesting level: {self}"
        )


class Transaction(BaseTransaction["Connection[Any]"]):
    """
    Returned by `Connection.transaction()` to handle a transaction block.
    """

    __module__ = "psycopg"

    _Self = TypeVar("_Self", bound="Transaction")

    @property
    def connection(self) -> "Connection[Any]":
        """The connection the object is managing."""
        return self._conn

    def __enter__(self: _Self) -> _Self:
        with self._conn.lock:
            self._conn.wait(self._enter_gen())
        return self

    def __exit__(
        self,
        exc_type: Optional[Type[BaseException]],
        exc_val: Optional[BaseException],
        exc_tb: Optional[TracebackType],
    ) -> bool:
        if self.pgconn.status == OK:
            with self._conn.lock:
                return self._conn.wait(self._exit_gen(exc_type, exc_val, exc_tb))
        else:
            return False


class AsyncTransaction(BaseTransaction["AsyncConnection[Any]"]):
    """
    Returned by `AsyncConnection.transaction()` to handle a transaction block.
    """

    __module__ = "psycopg"

    _Self = TypeVar("_Self", bound="AsyncTransaction")

    @property
    def connection(self) -> "AsyncConnection[Any]":
        return self._conn

    async def __aenter__(self: _Self) -> _Self:
        async with self._conn.lock:
            await self._conn.wait(self._enter_gen())
        return self

    async def __aexit__(
        self,
        exc_type: Optional[Type[BaseException]],
        exc_val: Optional[BaseException],
        exc_tb: Optional[TracebackType],
    ) -> bool:
        if self.pgconn.status == OK:
            async with self._conn.lock:
                return await self._conn.wait(self._exit_gen(exc_type, exc_val, exc_tb))
        else:
            return False