Learn more  » Push, build, and install  RubyGems npm packages Python packages Maven artifacts PHP packages Go Modules Bower components Debian packages RPM packages NuGet packages

arrow-adbc-nightlies / adbc-driver-manager   python

Repository URL to install this package:

Version: 1.3.0 

/ dbapi.py

# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements.  See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership.  The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License.  You may obtain a copy of the License at
#
#   http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied.  See the License for the
# specific language governing permissions and limitations
# under the License.

"""PEP 249 (DB-API 2.0) API wrapper for the ADBC Driver Manager.

Resource Management
===================

You must ``close()`` Connection and Cursor objects, or else driver
resources may be leaked.  A ``__del__`` is implemented as a fallback,
but Python does not guarantee the timing of when this is called.  For
development, ``__del__`` will raise a ResourceWarning when running
under pytest, or when the environment variable
``_ADBC_DRIVER_MANAGER_WARN_UNCLOSED_RESOURCE`` is set to ``1``.

"""

import abc
import datetime
import os
import threading
import time
import typing
import warnings
from typing import Any, Dict, List, Literal, Optional, Tuple, Union

try:
    import pyarrow
except ImportError as e:
    raise ImportError("PyArrow is required for the DBAPI-compatible interface") from e

try:
    import pyarrow.dataset
except ImportError:
    _pya_dataset = ()
    _pya_scanner = ()
else:
    _pya_dataset = (pyarrow.dataset.Dataset,)
    _pya_scanner = (pyarrow.dataset.Scanner,)

import adbc_driver_manager

from . import _lib, _reader
from ._lib import _blocking_call

if typing.TYPE_CHECKING:
    import pandas
    from typing_extensions import Self

# ----------------------------------------------------------
# Globals

#: The DB-API API level (2.0).
apilevel = "2.0"
#: The thread safety level (connections may not be shared).
threadsafety = 1
#: The parameter style (qmark). This is hardcoded, but actually
#: depends on the driver.
paramstyle = "qmark"

Warning = _lib.Warning
Error = _lib.Error
InterfaceError = _lib.InterfaceError
DatabaseError = _lib.DatabaseError
DataError = _lib.DataError
OperationalError = _lib.OperationalError
IntegrityError = _lib.IntegrityError
InternalError = _lib.InternalError
ProgrammingError = _lib.ProgrammingError
NotSupportedError = _lib.NotSupportedError

_KNOWN_INFO_VALUES = {
    0: "vendor_name",
    1: "vendor_version",
    2: "vendor_arrow_version",
    100: "driver_name",
    101: "driver_version",
    102: "driver_arrow_version",
    103: "driver_adbc_version",
}

# ----------------------------------------------------------
# Types

#: The type for date values.
Date = datetime.date
#: The type for time values.
Time = datetime.time
#: The type for timestamp values.
Timestamp = datetime.datetime


def DateFromTicks(ticks: int) -> Date:
    """Construct a date value from a count of seconds."""
    # Standard implementations from PEP 249 itself
    return Date(*time.localtime(ticks)[:3])


def TimeFromTicks(ticks: int) -> Time:
    """Construct a time value from a count of seconds."""
    return Time(*time.localtime(ticks)[3:6])


def TimestampFromTicks(ticks: int) -> Timestamp:
    """Construct a timestamp value from a count of seconds."""
    return Timestamp(*time.localtime(ticks)[:6])


class _TypeSet(frozenset):
    """A set of PyArrow type IDs that compares equal to subsets of self."""

    def __eq__(self, other: Any) -> bool:
        if isinstance(other, _TypeSet):
            return not (other - self)
        elif isinstance(other, pyarrow.DataType):
            return other.id in self
        return False


#: The type of binary columns.
BINARY = _TypeSet({pyarrow.binary().id, pyarrow.large_binary().id})
#: The type of datetime columns.
DATETIME = _TypeSet(
    [
        pyarrow.date32().id,
        pyarrow.date64().id,
        pyarrow.time32("s").id,
        pyarrow.time64("ns").id,
        pyarrow.timestamp("s").id,
    ]
)
#: The type of numeric columns.
NUMBER = _TypeSet(
    [
        pyarrow.int8().id,
        pyarrow.int16().id,
        pyarrow.int32().id,
        pyarrow.int64().id,
        pyarrow.uint8().id,
        pyarrow.uint16().id,
        pyarrow.uint32().id,
        pyarrow.uint64().id,
        pyarrow.float32().id,
        pyarrow.float64().id,
    ]
)
#: The type of "row ID" columns.
ROWID = _TypeSet([pyarrow.int64().id])
#: The type of string columns.
STRING = _TypeSet([pyarrow.string().id, pyarrow.large_string().id])

# ----------------------------------------------------------
# Functions


def connect(
    *,
    driver: str,
    entrypoint: Optional[str] = None,
    db_kwargs: Optional[Dict[str, str]] = None,
    conn_kwargs: Optional[Dict[str, str]] = None,
    autocommit=False,
) -> "Connection":
    """
    Connect to a database via ADBC.

    Parameters
    ----------
    driver
        The driver name. For example, "adbc_driver_sqlite" will
        attempt to load libadbc_driver_sqlite.so on Linux systems,
        libadbc_driver_sqlite.dylib on MacOS, and
        adbc_driver_sqlite.dll on Windows. This may also be a path to
        the library to load.
    entrypoint
        The driver-specific entrypoint, if different than the default.
    db_kwargs
        Key-value parameters to pass to the driver to initialize the
        database.
    conn_kwargs
        Key-value parameters to pass to the driver to initialize the
        connection.
    autocommit
        Whether to enable autocommit.  For compliance with DB-API,
        this is disabled by default.  A warning will be emitted if it
        cannot be disabled.
    """
    db = None
    conn = None

    db_kwargs = dict(db_kwargs or {})
    db_kwargs["driver"] = driver
    if entrypoint:
        db_kwargs["entrypoint"] = entrypoint
    if conn_kwargs is None:
        conn_kwargs = {}

    try:
        db = _lib.AdbcDatabase(**db_kwargs)
        conn = _lib.AdbcConnection(db, **conn_kwargs)
        return Connection(db, conn, conn_kwargs, autocommit=autocommit)
    except Exception:
        if conn:
            conn.close()
        if db:
            db.close()
        raise


# ----------------------------------------------------------
# Classes


class _Closeable(abc.ABC):
    """Base class providing context manager interface."""

    def __enter__(self) -> "Self":
        return self

    def __exit__(self, exc_type, exc_val, exc_tb) -> None:
        self.close()

    @abc.abstractmethod
    def close(self) -> None: ...


class _SharedDatabase(_Closeable):
    """A holder for a shared AdbcDatabase."""

    def __init__(self, db: _lib.AdbcDatabase) -> None:
        self._db = db
        self._lock = threading.Lock()
        self._refcount = 1

    def _inc(self) -> None:
        with self._lock:
            self._refcount += 1

    def _dec(self) -> int:
        with self._lock:
            self._refcount -= 1
            return self._refcount

    def clone(self) -> "Self":
        self._inc()
        return self

    def close(self) -> None:
        if self._dec() == 0:
            self._db.close()


class Connection(_Closeable):
    """
    A DB-API 2.0 (PEP 249) connection.

    Do not create this object directly; use connect().
    """

    # Optional extension: expose exception classes on Connection
    Warning = _lib.Warning
    Error = _lib.Error
    InterfaceError = _lib.InterfaceError
    DatabaseError = _lib.DatabaseError
    DataError = _lib.DataError
    OperationalError = _lib.OperationalError
    IntegrityError = _lib.IntegrityError
    InternalError = _lib.InternalError
    ProgrammingError = _lib.ProgrammingError
    NotSupportedError = _lib.NotSupportedError

    def __init__(
        self,
        db: Union[_lib.AdbcDatabase, _SharedDatabase],
        conn: _lib.AdbcConnection,
        conn_kwargs: Optional[Dict[str, str]] = None,
        *,
        autocommit=False,
    ) -> None:
        self._closed = False
        if isinstance(db, _SharedDatabase):
            self._db = db.clone()
        else:
            self._db = _SharedDatabase(db)
        self._conn = conn
        self._conn_kwargs = conn_kwargs

        try:
            self._conn.set_autocommit(False)
        except _lib.NotSupportedError:
            self._commit_supported = False
            if not autocommit:
                warnings.warn(
                    "Cannot disable autocommit; conn will not be DB-API 2.0 compliant",
                    category=Warning,
                )
            self._autocommit = True
        else:
            self._autocommit = False
            self._commit_supported = True

        if autocommit and self._commit_supported:
            self._conn.set_autocommit(True)
            self._autocommit = True

    def close(self) -> None:
        """
        Close the connection.

        Warnings
        --------
        Failure to close a connection may leak memory or database
        connections.
        """
        if self._closed:
            return

        self._conn.close()
        self._db.close()
        self._closed = True

    def commit(self) -> None:
        """Explicitly commit."""
        if self._commit_supported:
            self._conn.commit()

    def cursor(self) -> "Cursor":
        """Create a new cursor for querying the database."""
        return Cursor(self)

    def rollback(self) -> None:
Loading ...