Why Gemfury? Push, build, and install  RubyGems npm packages Python packages Maven artifacts PHP packages Go Modules Bower components Debian packages RPM packages NuGet packages

arrow-adbc-nightlies / adbc-driver-manager   python

Repository URL to install this package:

Version: 1.5.0 

/ dbapi.py

# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements.  See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership.  The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License.  You may obtain a copy of the License at
#
#   http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied.  See the License for the
# specific language governing permissions and limitations
# under the License.

"""PEP 249 (DB-API 2.0) API wrapper for the ADBC Driver Manager.

Resource Management
===================

You must ``close()`` Connection and Cursor objects, or else driver
resources may be leaked.  A ``__del__`` is implemented as a fallback,
but Python does not guarantee the timing of when this is called.  For
development, ``__del__`` will raise a ResourceWarning when running
under pytest, or when the environment variable
``_ADBC_DRIVER_MANAGER_WARN_UNCLOSED_RESOURCE`` is set to ``1``.

"""

import abc
import datetime
import os
import threading
import time
import typing
import warnings
from typing import Any, Dict, List, Literal, Optional, Tuple, Union

try:
    import pyarrow
except ImportError as e:
    raise ImportError("PyArrow is required for the DBAPI-compatible interface") from e

try:
    import pyarrow.dataset
except ImportError:
    _pya_dataset = ()
    _pya_scanner = ()
else:
    _pya_dataset = (pyarrow.dataset.Dataset,)
    _pya_scanner = (pyarrow.dataset.Scanner,)

import adbc_driver_manager

from . import _lib, _reader
from ._lib import _blocking_call

if typing.TYPE_CHECKING:
    import pandas
    from typing_extensions import Self

# ----------------------------------------------------------
# Globals

#: The DB-API API level (2.0).
apilevel = "2.0"
#: The thread safety level (connections may not be shared).
threadsafety = 1
#: The parameter style (qmark). This is hardcoded, but actually
#: depends on the driver.
paramstyle = "qmark"

Warning = _lib.Warning
Error = _lib.Error
InterfaceError = _lib.InterfaceError
DatabaseError = _lib.DatabaseError
DataError = _lib.DataError
OperationalError = _lib.OperationalError
IntegrityError = _lib.IntegrityError
InternalError = _lib.InternalError
ProgrammingError = _lib.ProgrammingError
NotSupportedError = _lib.NotSupportedError

_KNOWN_INFO_VALUES = {
    0: "vendor_name",
    1: "vendor_version",
    2: "vendor_arrow_version",
    100: "driver_name",
    101: "driver_version",
    102: "driver_arrow_version",
    103: "driver_adbc_version",
}

# ----------------------------------------------------------
# Types

#: The type for date values.
Date = datetime.date
#: The type for time values.
Time = datetime.time
#: The type for timestamp values.
Timestamp = datetime.datetime


def DateFromTicks(ticks: int) -> Date:
    """Construct a date value from a count of seconds."""
    # Standard implementations from PEP 249 itself
    return Date(*time.localtime(ticks)[:3])


def TimeFromTicks(ticks: int) -> Time:
    """Construct a time value from a count of seconds."""
    return Time(*time.localtime(ticks)[3:6])


def TimestampFromTicks(ticks: int) -> Timestamp:
    """Construct a timestamp value from a count of seconds."""
    return Timestamp(*time.localtime(ticks)[:6])


class _TypeSet(frozenset):
    """A set of PyArrow type IDs that compares equal to subsets of self."""

    def __eq__(self, other: Any) -> bool:
        if isinstance(other, _TypeSet):
            return not (other - self)
        elif isinstance(other, pyarrow.DataType):
            return other.id in self
        return False


#: The type of binary columns.
BINARY = _TypeSet({pyarrow.binary().id, pyarrow.large_binary().id})
#: The type of datetime columns.
DATETIME = _TypeSet(
    [
        pyarrow.date32().id,
        pyarrow.date64().id,
        pyarrow.time32("s").id,
        pyarrow.time64("ns").id,
        pyarrow.timestamp("s").id,
    ]
)
#: The type of numeric columns.
NUMBER = _TypeSet(
    [
        pyarrow.int8().id,
        pyarrow.int16().id,
        pyarrow.int32().id,
        pyarrow.int64().id,
        pyarrow.uint8().id,
        pyarrow.uint16().id,
        pyarrow.uint32().id,
        pyarrow.uint64().id,
        pyarrow.float32().id,
        pyarrow.float64().id,
    ]
)
#: The type of "row ID" columns.
ROWID = _TypeSet([pyarrow.int64().id])
#: The type of string columns.
STRING = _TypeSet([pyarrow.string().id, pyarrow.large_string().id])

# ----------------------------------------------------------
# Functions


def connect(
    *,
    driver: str,
    entrypoint: Optional[str] = None,
    db_kwargs: Optional[Dict[str, str]] = None,
    conn_kwargs: Optional[Dict[str, str]] = None,
    autocommit=False,
) -> "Connection":
    """
    Connect to a database via ADBC.

    Parameters
    ----------
    driver
        The driver name. For example, "adbc_driver_sqlite" will
        attempt to load libadbc_driver_sqlite.so on Linux systems,
        libadbc_driver_sqlite.dylib on MacOS, and
        adbc_driver_sqlite.dll on Windows. This may also be a path to
        the library to load.
    entrypoint
        The driver-specific entrypoint, if different than the default.
    db_kwargs
        Key-value parameters to pass to the driver to initialize the
        database.
    conn_kwargs
        Key-value parameters to pass to the driver to initialize the
        connection.
    autocommit
        Whether to enable autocommit.  For compliance with DB-API,
        this is disabled by default.  A warning will be emitted if it
        cannot be disabled.
    """
    db = None
    conn = None

    db_kwargs = dict(db_kwargs or {})
    db_kwargs["driver"] = driver
    if entrypoint:
        db_kwargs["entrypoint"] = entrypoint
    if conn_kwargs is None:
        conn_kwargs = {}

    try:
        db = _lib.AdbcDatabase(**db_kwargs)
        conn = _lib.AdbcConnection(db, **conn_kwargs)
        return Connection(db, conn, conn_kwargs, autocommit=autocommit)
    except Exception:
        if conn:
            conn.close()
        if db:
            db.close()
        raise


# ----------------------------------------------------------
# Classes


class _Closeable(abc.ABC):
    """Base class providing context manager interface."""

    def __enter__(self) -> "Self":
        return self

    def __exit__(self, exc_type, exc_val, exc_tb) -> None:
        self.close()

    @abc.abstractmethod
    def close(self) -> None: ...


class _SharedDatabase(_Closeable):
    """A holder for a shared AdbcDatabase."""

    def __init__(self, db: _lib.AdbcDatabase) -> None:
        self._db = db
        self._lock = threading.Lock()
        self._refcount = 1

    def _inc(self) -> None:
        with self._lock:
            self._refcount += 1

    def _dec(self) -> int:
        with self._lock:
            self._refcount -= 1
            return self._refcount

    def clone(self) -> "Self":
        self._inc()
        return self

    def close(self) -> None:
        if self._dec() == 0:
            self._db.close()


class Connection(_Closeable):
    """
    A DB-API 2.0 (PEP 249) connection.

    Do not create this object directly; use connect().
    """

    # Optional extension: expose exception classes on Connection
    Warning = _lib.Warning
    Error = _lib.Error
    InterfaceError = _lib.InterfaceError
    DatabaseError = _lib.DatabaseError
    DataError = _lib.DataError
    OperationalError = _lib.OperationalError
    IntegrityError = _lib.IntegrityError
    InternalError = _lib.InternalError
    ProgrammingError = _lib.ProgrammingError
    NotSupportedError = _lib.NotSupportedError

    def __init__(
        self,
        db: Union[_lib.AdbcDatabase, _SharedDatabase],
        conn: _lib.AdbcConnection,
        conn_kwargs: Optional[Dict[str, str]] = None,
        *,
        autocommit=False,
    ) -> None:
        self._closed = False
        if isinstance(db, _SharedDatabase):
            self._db = db.clone()
        else:
            self._db = _SharedDatabase(db)
        self._conn = conn
        self._conn_kwargs = conn_kwargs

        try:
            self._conn.set_autocommit(False)
        except _lib.NotSupportedError:
            self._commit_supported = False
            if not autocommit:
                warnings.warn(
                    "Cannot disable autocommit; conn will not be DB-API 2.0 compliant",
                    category=Warning,
                )
            self._autocommit = True
        else:
            self._autocommit = False
            self._commit_supported = True

        if autocommit and self._commit_supported:
            self._conn.set_autocommit(True)
            self._autocommit = True

    def close(self) -> None:
        """
        Close the connection.

        Warnings
        --------
        Failure to close a connection may leak memory or database
        connections.
        """
        if self._closed:
            return

        self._conn.close()
        self._db.close()
        self._closed = True

    def commit(self) -> None:
        """Explicitly commit."""
        if self._commit_supported:
            self._conn.commit()

    def cursor(self) -> "Cursor":
        """Create a new cursor for querying the database."""
        return Cursor(self)

    def rollback(self) -> None:
        """Explicitly rollback."""
        if self._commit_supported:
            self._conn.rollback()

    def __del__(self) -> None:
        if not self._closed:
            self.close()
            _warn_unclosed("adbc_driver_manager.dbapi.Connection")

    # ------------------------------------------------------------
    # API Extensions
    # ------------------------------------------------------------

    def adbc_cancel(self) -> None:
        """
        Cancel any ongoing operations on this connection.

        Notes
        -----
        This is an extension and not part of the DBAPI standard.
        """
        self._conn.cancel()

    def adbc_clone(self) -> "Connection":
        """
        Create a new Connection sharing the same underlying database.

        Notes
        -----
        This is an extension and not part of the DBAPI standard.
        """
        conn = _lib.AdbcConnection(self._db._db, **(self._conn_kwargs or {}))
        return Connection(self._db, conn)

    def adbc_get_info(self) -> Dict[Union[str, int], Any]:
        """
        Get metadata about the database and driver.

        Notes
        -----
        This is an extension and not part of the DBAPI standard.
        """
        handle = _blocking_call(self._conn.get_info, (), {}, self._conn.cancel)
        reader = pyarrow.RecordBatchReader._import_from_c(handle.address)
        table = _blocking_call(reader.read_all, (), {}, self._conn.cancel)
        info = table.to_pylist()
        return dict(
            {
                _KNOWN_INFO_VALUES.get(row["info_name"], row["info_name"]): row[
                    "info_value"
                ]
                for row in info
            }
        )

    def adbc_get_objects(
        self,
        *,
        depth: Literal["all", "catalogs", "db_schemas", "tables", "columns"] = "all",
        catalog_filter: Optional[str] = None,
        db_schema_filter: Optional[str] = None,
        table_name_filter: Optional[str] = None,
        table_types_filter: Optional[List[str]] = None,
        column_name_filter: Optional[str] = None,
    ) -> pyarrow.RecordBatchReader:
        """
        List catalogs, schemas, tables, etc. in the database.

        Parameters
        ----------
        depth
            What objects to return info on.
        catalog_filter
            An optional filter on the catalog names returned.
        db_schema_filter
            An optional filter on the database schema names returned.
        table_name_filter
            An optional filter on the table names returned.
        table_types_filter
            An optional list of types of tables returned.
        column_name_filter
            An optional filter on the column names returned.

        Notes
        -----
        This is an extension and not part of the DBAPI standard.
        """
        if depth in ("all", "columns"):
            c_depth = _lib.GetObjectsDepth.ALL
        elif depth == "catalogs":
            c_depth = _lib.GetObjectsDepth.CATALOGS
        elif depth == "db_schemas":
            c_depth = _lib.GetObjectsDepth.DB_SCHEMAS
        elif depth == "tables":
            c_depth = _lib.GetObjectsDepth.TABLES
        else:
            raise ValueError(f"Invalid value for 'depth': {depth}")
        handle = _blocking_call(
            self._conn.get_objects,
            (c_depth,),
            dict(
                catalog=catalog_filter,
                db_schema=db_schema_filter,
                table_name=table_name_filter,
                table_types=table_types_filter,
                column_name=column_name_filter,
            ),
            self._conn.cancel,
        )
        return pyarrow.RecordBatchReader._import_from_c(handle.address)

    def adbc_get_table_schema(
        self,
        table_name: str,
        *,
        catalog_filter: Optional[str] = None,
        db_schema_filter: Optional[str] = None,
    ) -> pyarrow.Schema:
        """
        Get the Arrow schema of a table by name.

        Parameters
        ----------
        table_name
            The table to get the schema of.
        catalog_filter
            An optional filter on the catalog name of the table.
        db_schema_filter
            An optional filter on the database schema name of the table.

        Notes
        -----
        This is an extension and not part of the DBAPI standard.
        """
        handle = _blocking_call(
            self._conn.get_table_schema,
            (
                catalog_filter,
                db_schema_filter,
                table_name,
            ),
            {},
            self._conn.cancel,
        )
        return pyarrow.Schema._import_from_c(handle.address)

    def adbc_get_table_types(self) -> List[str]:
        """
        List the types of tables that the server knows about.

        Notes
        -----
        This is an extension and not part of the DBAPI standard.
        """
        handle = _blocking_call(
            self._conn.get_table_types,
            (),
            {},
            self._conn.cancel,
        )
        reader = pyarrow.RecordBatchReader._import_from_c(handle.address)
        table = reader.read_all()
        return table[0].to_pylist()

    @property
    def adbc_connection(self) -> _lib.AdbcConnection:
        """
        Get the underlying ADBC connection.

        Notes
        -----
        This is an extension and not part of the DBAPI standard.
        """
        return self._conn

    @property
    def adbc_current_catalog(self) -> str:
        """
        The name of the current catalog.

        Notes
        -----
        This is an extension and not part of the DBAPI standard.
        """
        key = adbc_driver_manager.ConnectionOptions.CURRENT_CATALOG.value
        return self._conn.get_option(key)

    @adbc_current_catalog.setter
    def adbc_current_catalog(self, catalog: str) -> None:
        key = adbc_driver_manager.ConnectionOptions.CURRENT_CATALOG.value
        self._conn.set_options(**{key: catalog})

    @property
    def adbc_current_db_schema(self) -> str:
        """
        The name of the current schema.

        Notes
        -----
        This is an extension and not part of the DBAPI standard.
        """
        key = adbc_driver_manager.ConnectionOptions.CURRENT_DB_SCHEMA.value
        return self._conn.get_option(key)

    @adbc_current_db_schema.setter
    def adbc_current_db_schema(self, db_schema: str) -> None:
        key = adbc_driver_manager.ConnectionOptions.CURRENT_DB_SCHEMA.value
        self._conn.set_options(**{key: db_schema})

    @property
    def adbc_database(self) -> _lib.AdbcDatabase:
        """
        Get the underlying ADBC database.

        Notes
        -----
        This is an extension and not part of the DBAPI standard.
        """
        return self._db._db


class Cursor(_Closeable):
    """
    A DB-API 2.0 (PEP 249) cursor.

    Do not create this object directly; use Connection.cursor().
    """

    def __init__(self, conn: Connection) -> None:
        # Must be at top in case __init__ is interrupted and then __del__ is called
        self._closed = True
        self._conn = conn
        self._stmt = _lib.AdbcStatement(conn._conn)
        self._closed = False

        self._last_query: Optional[Union[str, bytes]] = None
        self._results: Optional["_RowIterator"] = None
        self._arraysize = 1
        self._rowcount = -1

    @property
    def arraysize(self) -> int:
        """The number of rows to fetch at a time with fetchmany()."""
        return self._arraysize

    @arraysize.setter
    def arraysize(self, size: int) -> None:
        self._arraysize = size

    @property
    def connection(self) -> Connection:
        """
        Get the connection associated with this cursor.

        This is an optional DB-API extension.
        """
        return self._conn

    @property
    def description(self) -> Optional[List[tuple]]:
        """The schema of the result set."""
        if self._results is None:
            return None
        return self._results.description

    @property
    def rowcount(self) -> int:
        """
        Get the row count of the result set, or -1 if not known.
        """
        return self._rowcount

    @property
    def rownumber(self) -> Optional[int]:
        """Get the current row number, or None if not applicable."""
        if self._results is not None:
            return self._results.rownumber
        return None

    def callproc(self, procname, parameters):
        """Call a stored procedure (not supported)."""
        raise NotSupportedError("Cursor.callproc")

    def close(self):
        """Close the cursor and free resources."""
        if self._closed:
            return

        if self._results is not None:
            self._results.close()
            self._results = None
        self._stmt.close()
        self._closed = True

    def _bind(self, parameters) -> None:
        if hasattr(parameters, "__arrow_c_array__"):
            self._stmt.bind(parameters)
        elif hasattr(parameters, "__arrow_c_stream__"):
            self._stmt.bind_stream(parameters)
        elif isinstance(parameters, pyarrow.RecordBatch):
            arr_handle = _lib.ArrowArrayHandle()
            sch_handle = _lib.ArrowSchemaHandle()
            parameters._export_to_c(arr_handle.address, sch_handle.address)
            self._stmt.bind(arr_handle, sch_handle)
        else:
            if isinstance(parameters, pyarrow.Table):
                parameters = parameters.to_reader()
            stream_handle = _lib.ArrowArrayStreamHandle()
            parameters._export_to_c(stream_handle.address)
            self._stmt.bind_stream(stream_handle)

    def _prepare_execute(self, operation, parameters=None) -> None:
        self._results = None
        if operation != self._last_query:
            self._last_query = operation
            if isinstance(operation, bytes):
                # Serialized Substrait plan
                self._stmt.set_substrait_plan(operation)
            else:
                self._stmt.set_sql_query(operation)
            try:
                _blocking_call(self._stmt.prepare, (), {}, self._stmt.cancel)
            except NotSupportedError:
                # Not all drivers support it
                pass

        if _is_arrow_data(parameters):
            self._bind(parameters)
        elif parameters:
            rb = pyarrow.record_batch(
                [[param_value] for param_value in parameters],
                names=[str(i) for i in range(len(parameters))],
            )
            self._bind(rb)

    def execute(self, operation: Union[bytes, str], parameters=None) -> None:
        """
        Execute a query.

        Parameters
        ----------
        operation : bytes or str
            The query to execute.  Pass SQL queries as strings,
            (serialized) Substrait plans as bytes.
        parameters
            Parameters to bind.  Can be a Python sequence (to provide
            a single set of parameters), or an Arrow record batch,
            table, or record batch reader (to provide multiple
            parameters, which will each be bound in turn).
        """
        self._prepare_execute(operation, parameters)

        handle, self._rowcount = _blocking_call(
            self._stmt.execute_query, (), {}, self._stmt.cancel
        )
        self._results = _RowIterator(
            self._stmt, _reader.AdbcRecordBatchReader._import_from_c(handle.address)
        )

    def executemany(self, operation: Union[bytes, str], seq_of_parameters) -> None:
        """
        Execute a query with multiple parameter sets.

        This method does not generate a result set.

        Parameters
        ----------
        operation : bytes or str
            The query to execute.  Pass SQL queries as strings,
            (serialized) Substrait plans as bytes.
        seq_of_parameters
            Parameters to bind.  Can be a list of Python sequences, or
            an Arrow record batch, table, or record batch reader.  If
            None, then the query will be executed once, else it will
            be executed once per row.
        """
        self._results = None
        if operation != self._last_query:
            self._last_query = operation
            self._stmt.set_sql_query(operation)
            self._stmt.prepare()

        if _is_arrow_data(seq_of_parameters):
            arrow_parameters = seq_of_parameters
        elif seq_of_parameters:
            arrow_parameters = pyarrow.RecordBatch.from_pydict(
                {
                    str(col_idx): pyarrow.array(x)
                    for col_idx, x in enumerate(map(list, zip(*seq_of_parameters)))
                },
            )
        else:
            arrow_parameters = pyarrow.record_batch([])

        self._bind(arrow_parameters)
        self._rowcount = _blocking_call(
            self._stmt.execute_update, (), {}, self._stmt.cancel
        )

    def fetchone(self) -> Optional[tuple]:
        """Fetch one row of the result."""
        if self._results is None:
            raise ProgrammingError(
                "Cannot fetchone() before execute()",
                status_code=_lib.AdbcStatusCode.INVALID_STATE,
            )
        return self._results.fetchone()

    def fetchmany(self, size: Optional[int] = None) -> List[tuple]:
        """Fetch some rows of the result."""
        if self._results is None:
            raise ProgrammingError(
                "Cannot fetchmany() before execute()",
                status_code=_lib.AdbcStatusCode.INVALID_STATE,
            )
        if size is None:
            size = self.arraysize
        return self._results.fetchmany(size)

    def fetchall(self) -> List[tuple]:
        """Fetch all rows of the result."""
        if self._results is None:
            raise ProgrammingError(
                "Cannot fetchall() before execute()",
                status_code=_lib.AdbcStatusCode.INVALID_STATE,
            )
        return self._results.fetchall()

    def next(self):
        """Fetch the next row, or raise StopIteration."""
        row = self.fetchone()
        if row is None:
            raise StopIteration
        return row

    def nextset(self):
        """Move to the next available result set (not supported)."""
        raise NotSupportedError("Cursor.nextset")

    def setinputsizes(self, sizes):
        """Preallocate memory for the parameters (no-op)."""
        pass

    def setoutputsize(self, size, column=None):
        """Preallocate memory for the result set (no-op)."""
        pass

    def __del__(self) -> None:
        if not self._closed:
            self.close()
            _warn_unclosed("adbc_driver_manager.dbapi.Cursor")

    def __iter__(self):
        return self

    def __next__(self):
        return self.next()

    # ------------------------------------------------------------
    # API Extensions
    # ------------------------------------------------------------

    def adbc_cancel(self) -> None:
        """
        Cancel any ongoing operations on this statement.

        Notes
        -----
        This is an extension and not part of the DBAPI standard.
        """
        self._stmt.cancel()

    def adbc_ingest(
        self,
        table_name: str,
        data: Union[pyarrow.RecordBatch, pyarrow.Table, pyarrow.RecordBatchReader],
        mode: Literal["append", "create", "replace", "create_append"] = "create",
        *,
        catalog_name: Optional[str] = None,
        db_schema_name: Optional[str] = None,
        temporary: bool = False,
    ) -> int:
        """Ingest Arrow data into a database table.

        Depending on the driver, this can avoid per-row overhead that
        would result from a typical prepare-bind-insert loop.

        Parameters
        ----------
        table_name
            The table to insert into.
        data
            The Arrow data to insert. This can be a pyarrow RecordBatch, Table
            or RecordBatchReader, or any Arrow-compatible data that implements
            the Arrow PyCapsule Protocol (i.e. has an ``__arrow_c_array__``
            or ``__arrow_c_stream__`` method).
        mode
            How to deal with existing data:

            - 'append': append to a table (error if table does not exist)
            - 'create': create a table and insert (error if table exists)
            - 'create_append': create a table (if not exists) and insert
            - 'replace': drop existing table (if any), then same as 'create'
        catalog_name
            If given, the catalog to create/locate the table in.
            **This API is EXPERIMENTAL.**
        db_schema_name
            If given, the schema to create/locate the table in.
            **This API is EXPERIMENTAL.**
        temporary
            Whether to ingest to a temporary table or not.  Most drivers will
            not support setting this along with catalog_name and/or
            db_schema_name.
            **This API is EXPERIMENTAL.**

        Returns
        -------
        int
            The number of rows inserted, or -1 if the driver cannot
            provide this information.

        Notes
        -----
        This is an extension and not part of the DBAPI standard.

        """
        if mode == "append":
            c_mode = _lib.INGEST_OPTION_MODE_APPEND
        elif mode == "create":
            c_mode = _lib.INGEST_OPTION_MODE_CREATE
        elif mode == "create_append":
            c_mode = _lib.INGEST_OPTION_MODE_CREATE_APPEND
        elif mode == "replace":
            c_mode = _lib.INGEST_OPTION_MODE_REPLACE
        else:
            raise ValueError(f"Invalid value for 'mode': {mode}")

        options = {
            _lib.INGEST_OPTION_TARGET_TABLE: table_name,
            _lib.INGEST_OPTION_MODE: c_mode,
        }
        if catalog_name is not None:
            options[
                adbc_driver_manager.StatementOptions.INGEST_TARGET_CATALOG.value
            ] = catalog_name
        if db_schema_name is not None:
            options[
                adbc_driver_manager.StatementOptions.INGEST_TARGET_DB_SCHEMA.value
            ] = db_schema_name
        self._stmt.set_options(**options)

        if temporary:
            self._stmt.set_options(
                **{
                    adbc_driver_manager.StatementOptions.INGEST_TEMPORARY.value: "true",
                }
            )
        else:
            # Need to explicitly clear it, but not all drivers support this
            options = {
                adbc_driver_manager.StatementOptions.INGEST_TEMPORARY.value: "false",
            }
            try:
                self._stmt.set_options(**options)
            except NotSupportedError:
                pass

        if hasattr(data, "__arrow_c_array__"):
            self._stmt.bind(data)
        elif hasattr(data, "__arrow_c_stream__"):
            self._stmt.bind_stream(data)
        elif isinstance(data, pyarrow.RecordBatch):
            array = _lib.ArrowArrayHandle()
            schema = _lib.ArrowSchemaHandle()
            data._export_to_c(array.address, schema.address)
            self._stmt.bind(array, schema)
        else:
            if isinstance(data, pyarrow.Table):
                data = data.to_reader()
            elif isinstance(data, pyarrow.dataset.Dataset):
                data = data.scanner().to_reader()
            elif isinstance(data, pyarrow.dataset.Scanner):
                data = data.to_reader()
            elif not hasattr(data, "_export_to_c"):
                data = pyarrow.Table.from_batches(data)
                data = data.to_reader()
            handle = _lib.ArrowArrayStreamHandle()
            data._export_to_c(handle.address)
            self._stmt.bind_stream(handle)

        self._last_query = None
        return _blocking_call(self._stmt.execute_update, (), {}, self._stmt.cancel)

    def adbc_execute_partitions(
        self,
        operation,
        parameters=None,
    ) -> Tuple[List[bytes], pyarrow.Schema]:
        """
        Execute a query and get the partitions of a distributed result set.

        Returns
        -------
        partitions : list of byte
            A list of partition descriptors, which can be read with
            read_partition.
        schema : pyarrow.Schema or None
            The schema of the result set.  May be None if incremental query
            execution is enabled and the server has not returned a schema.

        Notes
        -----
        This is an extension and not part of the DBAPI standard.
        """
        self._prepare_execute(operation, parameters)
        partitions, schema_handle, self._rowcount = _blocking_call(
            self._stmt.execute_partitions, (), {}, self._stmt.cancel
        )
        if schema_handle and schema_handle.address:
            schema = pyarrow.Schema._import_from_c(schema_handle.address)
        else:
            schema = None
        return partitions, schema

    def adbc_execute_schema(self, operation, parameters=None) -> pyarrow.Schema:
        """
        Get the schema of the result set of a query without executing it.

        Returns
        -------
        pyarrow.Schema
            The schema of the result set.

        Notes
        -----
        This is an extension and not part of the DBAPI standard.
        """
        self._prepare_execute(operation, parameters)
        schema = _blocking_call(self._stmt.execute_schema, (), {}, self._stmt.cancel)
        return pyarrow.Schema._import_from_c(schema.address)

    def adbc_prepare(self, operation: Union[bytes, str]) -> Optional[pyarrow.Schema]:
        """
        Prepare a query without executing it.

        To execute the query afterwards, call :meth:`execute` or
        :meth:`executemany` with the same query.  This will not
        prepare the query a second time.

        Returns
        -------
        pyarrow.Schema or None
            The schema of the bind parameters, or None if the schema
            could not be determined.

        Notes
        -----
        This is an extension and not part of the DBAPI standard.
        """
        self._prepare_execute(operation)

        try:
            handle = _blocking_call(
                self._stmt.get_parameter_schema, (), {}, self._stmt.cancel
            )
        except NotSupportedError:
            return None
        return pyarrow.Schema._import_from_c(handle.address)

    def adbc_read_partition(self, partition: bytes) -> None:
        """
        Read a partition of a distributed result set.

        Notes
        -----
        This is an extension and not part of the DBAPI standard.
        """
        self._results = None
        handle = _blocking_call(
            self._conn._conn.read_partition, (partition,), {}, self._stmt.cancel
        )
        self._rowcount = -1
        self._results = _RowIterator(
            self._stmt, pyarrow.RecordBatchReader._import_from_c(handle.address)
        )

    @property
    def adbc_statement(self) -> _lib.AdbcStatement:
        """
        Get the underlying ADBC statement.

        Notes
        -----
        This is an extension and not part of the DBAPI standard.
        """
        return self._stmt

    def executescript(self, operation: str) -> None:
        """
        Execute multiple statements.

        If there is a pending transaction, commits first.

        Notes
        -----
        This is an extension and not part of the DBAPI standard.
        """
        if not self._conn._autocommit:
            self._conn.commit()

        self._last_query = None
        self._results = None
        self._stmt.set_sql_query(operation)
        _blocking_call(self._stmt.execute_update, (), {}, self._stmt.cancel)

    def fetchallarrow(self) -> pyarrow.Table:
        """
        Fetch all rows of the result as a PyArrow Table.

        This implements a similar API as turbodbc.

        Notes
        -----
        This is an extension and not part of the DBAPI standard.
        """
        return self.fetch_arrow_table()

    def fetch_arrow_table(self) -> pyarrow.Table:
        """
        Fetch all rows of the result as a PyArrow Table.

        This implements a similar API as DuckDB.

        Notes
        -----
        This is an extension and not part of the DBAPI standard.
        """
        if self._results is None:
            raise ProgrammingError(
                "Cannot fetch_arrow_table() before execute()",
                status_code=_lib.AdbcStatusCode.INVALID_STATE,
            )
        return self._results.fetch_arrow_table()

    def fetch_df(self) -> "pandas.DataFrame":
        """
        Fetch all rows of the result as a Pandas DataFrame.

        This implements a similar API as DuckDB.

        Notes
        -----
        This is an extension and not part of the DBAPI standard.
        """
        if self._results is None:
            raise ProgrammingError(
                "Cannot fetch_df() before execute()",
                status_code=_lib.AdbcStatusCode.INVALID_STATE,
            )
        return self._results.fetch_df()

    def fetch_record_batch(self) -> pyarrow.RecordBatchReader:
        """
        Fetch the result as a PyArrow RecordBatchReader.

        This implements a similar API as DuckDB:
        https://duckdb.org/docs/guides/python/export_arrow.html#export-as-a-recordbatchreader

        Notes
        -----
        This is an extension and not part of the DBAPI standard.
        """
        if self._results is None:
            raise ProgrammingError(
                "Cannot fetch_record_batch() before execute()",
                status_code=_lib.AdbcStatusCode.INVALID_STATE,
            )
        # XXX(https://github.com/apache/arrow-adbc/issues/1523): return the
        # "real" PyArrow reader since PyArrow may try to poke the internal C++
        # reader pointer
        return self._results._reader._reader


# ----------------------------------------------------------
# Utilities


class _RowIterator(_Closeable):
    """Track state needed to iterate over the result set."""

    def __init__(self, stmt, reader: pyarrow.RecordBatchReader) -> None:
        self._stmt = stmt
        self._reader = reader
        self._current_batch = None
        self._next_row = 0
        self._finished = False
        self.rownumber = 0

    def close(self) -> None:
        if hasattr(self._reader, "close"):
            # Only in recent PyArrow
            self._reader.close()

    @property
    def description(self) -> List[tuple]:
        return [
            (field.name, field.type, None, None, None, None, None)
            for field in self._reader.schema
        ]

    def fetchone(self) -> Optional[tuple]:
        if self._current_batch is None or self._next_row >= len(self._current_batch):
            try:
                while True:
                    self._current_batch = _blocking_call(
                        self._reader.read_next_batch, (), {}, self._stmt.cancel
                    )
                    if self._current_batch.num_rows > 0:
                        break
                self._next_row = 0
            except StopIteration:
                self._current_batch = None
                self._finished = True

        if self._finished or self._current_batch is None:
            return None

        row = tuple(arr[self._next_row].as_py() for arr in self._current_batch.columns)
        self._next_row += 1
        self.rownumber += 1
        return row

    def fetchmany(self, size: int) -> List[tuple]:
        rows = []
        for _ in range(size):
            row = self.fetchone()
            if row is None:
                break
            rows.append(row)
        return rows

    def fetchall(self) -> List[tuple]:
        rows = []
        while True:
            row = self.fetchone()
            if row is None:
                break
            rows.append(row)
        return rows

    def fetch_arrow_table(self) -> pyarrow.Table:
        return _blocking_call(self._reader.read_all, (), {}, self._stmt.cancel)

    def fetch_df(self) -> "pandas.DataFrame":
        return _blocking_call(self._reader.read_pandas, (), {}, self._stmt.cancel)


_PYTEST_ENV_VAR = "PYTEST_CURRENT_TEST"
_ADBC_ENV_VAR = "_ADBC_DRIVER_MANAGER_WARN_UNCLOSED_RESOURCE"


def _warn_unclosed(name):
    if _PYTEST_ENV_VAR in os.environ or os.environ.get(_ADBC_ENV_VAR) == "1":
        warnings.warn(
            f"A {name} was not explicitly close()d, which may leak "
            f"driver resources. (This warning is only emitted if "
            f"{_PYTEST_ENV_VAR} or {_ADBC_ENV_VAR} are set.)",
            category=ResourceWarning,
            stacklevel=2,
        )


def _is_arrow_data(data):
    return (
        hasattr(data, "__arrow_c_array__")
        or hasattr(data, "__arrow_c_stream__")
        or isinstance(
            data, (pyarrow.RecordBatch, pyarrow.Table, pyarrow.RecordBatchReader)
        )
    )