Why Gemfury? Push, build, and install  RubyGems npm packages Python packages Maven artifacts PHP packages Go Modules Debian packages RPM packages NuGet packages

Repository URL to install this package:

Details    
Size: Mime:
#!/usr/bin/env python
#
# Copyright (c) 2012-2021 Snowflake Computing Inc. All rights reserved.
#

from __future__ import annotations

import logging
import re
import signal
import sys
import time
import uuid
from enum import Enum
from logging import getLogger
from threading import Lock, Timer
from typing import (
    IO,
    TYPE_CHECKING,
    Any,
    Callable,
    Generator,
    Iterator,
    NamedTuple,
    NoReturn,
    Sequence,
)

from snowflake.connector.result_batch import create_batches_from_response
from snowflake.connector.result_set import ResultSet

from .bind_upload_agent import BindUploadAgent, BindUploadError
from .compat import BASE_EXCEPTION_CLASS
from .constants import (
    FIELD_NAME_TO_ID,
    PARAMETER_PYTHON_CONNECTOR_QUERY_RESULT_FORMAT,
    FileTransferType,
    QueryStatus,
)
from .description import CLIENT_NAME
from .errorcode import (
    ER_CURSOR_IS_CLOSED,
    ER_FAILED_PROCESSING_PYFORMAT,
    ER_FAILED_TO_REWRITE_MULTI_ROW_INSERT,
    ER_INVALID_VALUE,
    ER_NO_ARROW_RESULT,
    ER_NO_PYARROW,
    ER_NO_PYARROW_SNOWSQL,
    ER_NOT_POSITIVE_SIZE,
    ER_UNSUPPORTED_METHOD,
)
from .errors import (
    DatabaseError,
    Error,
    IntegrityError,
    InterfaceError,
    NotSupportedError,
    ProgrammingError,
)
from .file_transfer_agent import SnowflakeFileTransferAgent
from .options import installed_pandas, pandas
from .sqlstate import SQLSTATE_FEATURE_NOT_SUPPORTED
from .telemetry import TelemetryData, TelemetryField
from .time_util import get_time_millis

if TYPE_CHECKING:  # pragma: no cover
    from .connection import SnowflakeConnection
    from .file_transfer_agent import SnowflakeProgressPercentage
    from .result_batch import ResultBatch


logger = getLogger(__name__)

if installed_pandas:
    from pyarrow import Table
else:
    logger.debug("Failed to import pyarrow. Cannot use pandas fetch API")
    Table = None

try:
    from .arrow_iterator import PyArrowIterator  # NOQA

    CAN_USE_ARROW_RESULT_FORMAT = True
except ImportError as e:  # pragma: no cover
    logger.debug(
        f"Failed to import ArrowResult. No Apache Arrow result set format can be used. ImportError: {e}",
    )
    CAN_USE_ARROW_RESULT_FORMAT = False

STATEMENT_TYPE_ID_DML = 0x3000
STATEMENT_TYPE_ID_INSERT = STATEMENT_TYPE_ID_DML + 0x100
STATEMENT_TYPE_ID_UPDATE = STATEMENT_TYPE_ID_DML + 0x200
STATEMENT_TYPE_ID_DELETE = STATEMENT_TYPE_ID_DML + 0x300
STATEMENT_TYPE_ID_MERGE = STATEMENT_TYPE_ID_DML + 0x400
STATEMENT_TYPE_ID_MULTI_TABLE_INSERT = STATEMENT_TYPE_ID_DML + 0x500

STATEMENT_TYPE_ID_DML_SET = frozenset(
    [
        STATEMENT_TYPE_ID_DML,
        STATEMENT_TYPE_ID_INSERT,
        STATEMENT_TYPE_ID_UPDATE,
        STATEMENT_TYPE_ID_DELETE,
        STATEMENT_TYPE_ID_MERGE,
        STATEMENT_TYPE_ID_MULTI_TABLE_INSERT,
    ]
)

DESC_TABLE_RE = re.compile(r"desc(?:ribe)?\s+([\w_]+)\s*;?\s*$", flags=re.IGNORECASE)

LOG_MAX_QUERY_LENGTH = 80

ASYNC_NO_DATA_MAX_RETRY = 24
ASYNC_RETRY_PATTERN = [1, 1, 2, 3, 4, 8, 10]


class ResultMetadata(NamedTuple):
    name: str
    type_code: int
    display_size: int
    internal_size: int
    precision: int
    scale: int
    is_nullable: bool

    @classmethod
    def from_column(cls, col: dict[str, Any]):
        """Initializes a ResultMetadata object from the column description in the query response."""
        return cls(
            col["name"],
            FIELD_NAME_TO_ID[
                col["extTypeName"].upper()
                if "extTypeName" in col
                else col["type"].upper()
            ],
            None,
            col["length"],
            col["precision"],
            col["scale"],
            col["nullable"],
        )


def exit_handler(*_) -> NoReturn:
    """Handler for signal. When called, it will raise SystemExit with exit code FORCE_EXIT."""
    print("\nForce exit")
    logger.info("Force exit")
    sys.exit(1)


class ResultState(Enum):
    DEFAULT = 1
    VALID = 2
    RESET = 3


class SnowflakeCursor:
    """Implementation of Cursor object that is returned from Connection.cursor() method.

    Attributes:
        description: A list of namedtuples about metadata for all columns.
        rowcount: The number of records updated or selected. If not clear, -1 is returned.
        rownumber: The current 0-based index of the cursor in the result set or None if the index cannot be
            determined.
        sfqid: Snowflake query id in UUID form. Include this in the problem report to the customer support.
        sqlstate: Snowflake SQL State code.
        timestamp_output_format: Snowflake timestamp_output_format for timestamps.
        timestamp_ltz_output_format: Snowflake output format for LTZ timestamps.
        timestamp_tz_output_format: Snowflake output format for TZ timestamps.
        timestamp_ntz_output_format: Snowflake output format for NTZ timestamps.
        date_output_format: Snowflake output format for dates.
        time_output_format: Snowflake output format for times.
        timezone: Snowflake timezone.
        binary_output_format: Snowflake output format for binary fields.
        arraysize: The default number of rows fetched by fetchmany.
        connection: The connection object by which the cursor was created.
        errorhandle: The class that handles error handling.
        is_file_transfer: Whether, or not the current command is a put, or get.

    TODO:
        Most of these attributes have no reason to be properties, we could just store them in public variables.
        Calling a function is expensive in Python and most of these getters are unnecessary.
    """

    PUT_SQL_RE = re.compile(r"^(?:/\*.*\*/\s*)*put\s+", flags=re.IGNORECASE)
    GET_SQL_RE = re.compile(r"^(?:/\*.*\*/\s*)*get\s+", flags=re.IGNORECASE)
    INSERT_SQL_RE = re.compile(r"^insert\s+into", flags=re.IGNORECASE)
    COMMENT_SQL_RE = re.compile(r"/\*.*\*/")
    INSERT_SQL_VALUES_RE = re.compile(
        r".*VALUES\s*(\(.*\)).*", re.IGNORECASE | re.MULTILINE | re.DOTALL
    )
    ALTER_SESSION_RE = re.compile(
        r"alter\s+session\s+set\s+(.*)=\'?([^\']+)\'?\s*;",
        flags=re.IGNORECASE | re.MULTILINE | re.DOTALL,
    )

    @staticmethod
    def get_file_transfer_type(sql: str) -> FileTransferType | None:
        """Decide whether a SQL is a file transfer and return its type.

        None is returned if the SQL isn't a file transfer so that this function can be
        used in an if-statement.
        """
        if SnowflakeCursor.PUT_SQL_RE.match(sql):
            return FileTransferType.PUT
        elif SnowflakeCursor.GET_SQL_RE.match(sql):
            return FileTransferType.GET
        return None

    def __init__(
        self,
        connection: SnowflakeConnection,
        use_dict_result: bool = False,
    ) -> None:
        """Inits a SnowflakeCursor with a connection.

        Args:
            connection: The connection that created this cursor.
            use_dict_result: Decides whether to use dict result or not.
        """
        self._connection: SnowflakeConnection = connection

        self._errorhandler: Callable[
            [SnowflakeConnection, SnowflakeCursor, type[Error], dict[str, str]],
            None,
        ] = Error.default_errorhandler
        self.messages: list[
            tuple[type[Error] | type[Exception], dict[str, str | bool]]
        ] = []
        self._timebomb: Timer | None = None  # must be here for abort_exit method
        self._description: list[ResultMetadata] | None = None
        self._column_idx_to_name = None
        self._sfqid = None
        self._sqlstate = None
        self._total_rowcount = -1
        self._sequence_counter = -1
        self._request_id = None
        self._is_file_transfer = False

        self._timestamp_output_format = None
        self._timestamp_ltz_output_format = None
        self._timestamp_ntz_output_format = None
        self._timestamp_tz_output_format = None
        self._date_output_format = None
        self._time_output_format = None
        self._timezone = None
        self._binary_output_format = None
        self._result: Iterator[tuple] | Iterator[dict] | None = None
        self._result_set: ResultSet | None = None
        self._result_state: ResultState = ResultState.DEFAULT
        self._use_dict_result = use_dict_result
        self.query: str | None = None
        # TODO: self._query_result_format could be defined as an enum
        self._query_result_format: str | None = None

        self._arraysize = 1  # PEP-0249: defaults to 1

        self._lock_canceling = Lock()

        self._first_chunk_time = None

        self._log_max_query_length = connection.log_max_query_length
        self._inner_cursor: SnowflakeCursor | None = None
        self._prefetch_hook = None
        self._rownumber: int | None = None

        self.reset()

    def __del__(self) -> None:  # pragma: no cover
        try:
            self.close()
        except BASE_EXCEPTION_CLASS as e:
            if logger.getEffectiveLevel() <= logging.INFO:
                logger.info(e)

    @property
    def description(self) -> list[ResultMetadata]:
        return self._description

    @property
    def rowcount(self):
        return self._total_rowcount if self._total_rowcount >= 0 else None

    @property
    def rownumber(self):
        return self._rownumber if self._rownumber >= 0 else None

    @property
    def sfqid(self):
        return self._sfqid

    @property
    def sqlstate(self):
        return self._sqlstate

    @property
    def timestamp_output_format(self):
        return self._timestamp_output_format

    @property
    def timestamp_ltz_output_format(self):
        return (
            self._timestamp_ltz_output_format
            if self._timestamp_ltz_output_format
            else self._timestamp_output_format
        )

    @property
    def timestamp_tz_output_format(self):
        return (
            self._timestamp_tz_output_format
            if self._timestamp_tz_output_format
            else self._timestamp_output_format
        )

    @property
    def timestamp_ntz_output_format(self):
        return (
            self._timestamp_ntz_output_format
            if self._timestamp_ntz_output_format
            else self._timestamp_output_format
        )

    @property
    def date_output_format(self):
        return self._date_output_format

    @property
    def time_output_format(self):
        return self._time_output_format

    @property
    def timezone(self):
        return self._timezone

    @property
    def binary_output_format(self):
        return self._binary_output_format

    @property
    def arraysize(self):
        return self._arraysize

    @arraysize.setter
    def arraysize(self, value):
        self._arraysize = int(value)

    @property
    def connection(self):
        return self._connection

    @property
    def errorhandler(self):
        return self._errorhandler

    @errorhandler.setter
    def errorhandler(self, value):
        logger.debug("setting errorhandler: %s", value)
        if value is None:
            raise ProgrammingError("Invalid errorhandler is specified")
        self._errorhandler = value

    @property
    def is_file_transfer(self):
        """Whether the command is PUT or GET."""
        return hasattr(self, "_is_file_transfer") and self._is_file_transfer

    def callproc(self, procname, args=()):
        """Not supported."""
        Error.errorhandler_wrapper(
            self.connection,
            self,
            NotSupportedError,
            {
                "msg": "callproc is not supported.",
                "errno": ER_UNSUPPORTED_METHOD,
                "sqlstate": SQLSTATE_FEATURE_NOT_SUPPORTED,
            },
        )

    def close(self) -> bool | None:
        """Closes the cursor object.

        Returns whether the cursor was closed during this call.
        """
        try:
            if self.is_closed():
                return False

            with self._lock_canceling:
                self.reset()
                self._connection = None
                del self.messages[:]
                return True
        except Exception:
            pass

    def is_closed(self):
        return self._connection is None or self._connection.is_closed()

    def _execute_helper(
        self,
        query: str,
        timeout: int = 0,
        statement_params: dict[str, str] | None = None,
        binding_params: tuple | dict[str, dict[str, str]] = None,
        binding_stage: str | None = None,
        is_internal: bool = False,
        describe_only: bool = False,
        _no_results: bool = False,
        _is_put_get=None,
    ):
        del self.messages[:]

        if statement_params is not None and not isinstance(statement_params, dict):
            Error.errorhandler_wrapper(
                self.connection,
                self,
                ProgrammingError,
                {
                    "msg": "The data type of statement params is invalid. It must be dict.",
                    "errno": ER_INVALID_VALUE,
                },
            )

        # check if current installation include arrow extension or not,
        # if not, we set statement level query result format to be JSON
        if not CAN_USE_ARROW_RESULT_FORMAT:
            logger.debug("Cannot use arrow result format, fallback to json format")
            if statement_params is None:
                statement_params = {
                    PARAMETER_PYTHON_CONNECTOR_QUERY_RESULT_FORMAT: "JSON"
                }
            else:
                result_format_val = statement_params.get(
                    PARAMETER_PYTHON_CONNECTOR_QUERY_RESULT_FORMAT
                )
                if str(result_format_val).upper() == "ARROW":
                    self.check_can_use_arrow_resultset()
                elif result_format_val is None:
                    statement_params[
                        PARAMETER_PYTHON_CONNECTOR_QUERY_RESULT_FORMAT
                    ] = "JSON"

        self._sequence_counter = self._connection._next_sequence_counter()
        self._request_id = uuid.uuid4()

        logger.debug(f"Request id: {self._request_id}")

        if logger.getEffectiveLevel() <= logging.DEBUG:
            logger.debug("running query [%s]", self._format_query_for_log(query))
        if _is_put_get is not None:
            # if told the query is PUT or GET, use the information
            self._is_file_transfer = _is_put_get
        else:
            # or detect it.
            self._is_file_transfer = self.PUT_SQL_RE.match(
                query
            ) or self.GET_SQL_RE.match(query)
        logger.debug("is_file_transfer: %s", self._is_file_transfer is not None)

        real_timeout = (
            timeout if timeout and timeout > 0 else self._connection.network_timeout
        )

        if real_timeout is not None:
            self._timebomb = Timer(real_timeout, self.__cancel_query, [query])
            self._timebomb.start()
            logger.debug("started timebomb in %ss", real_timeout)
        else:
            self._timebomb = None

        original_sigint = signal.getsignal(signal.SIGINT)

        def interrupt_handler(*_):  # pragma: no cover
            try:
                signal.signal(signal.SIGINT, exit_handler)
            except (ValueError, TypeError):
                # ignore failures
                pass
            try:
                if self._timebomb is not None:
                    self._timebomb.cancel()
                    logger.debug("cancelled timebomb in finally")
                    self._timebomb = None
                self.__cancel_query(query)
            finally:
                if original_sigint:
                    try:
                        signal.signal(signal.SIGINT, original_sigint)
                    except (ValueError, TypeError):
                        # ignore failures
                        pass
            raise KeyboardInterrupt

        try:
            if not original_sigint == exit_handler:
                signal.signal(signal.SIGINT, interrupt_handler)
        except ValueError:  # pragma: no cover
            logger.debug(
                "Failed to set SIGINT handler. " "Not in main thread. Ignored..."
            )
        ret = {"data": {}}
        try:
            ret = self._connection.cmd_query(
                query,
                self._sequence_counter,
                self._request_id,
                binding_params=binding_params,
                binding_stage=binding_stage,
                is_file_transfer=bool(self._is_file_transfer),
                statement_params=statement_params,
                is_internal=is_internal,
                describe_only=describe_only,
                _no_results=_no_results,
            )
        finally:
            try:
                if original_sigint:
                    signal.signal(signal.SIGINT, original_sigint)
            except (ValueError, TypeError):  # pragma: no cover
                logger.debug(
                    "Failed to reset SIGINT handler. Not in main " "thread. Ignored..."
                )
            if self._timebomb is not None:
                self._timebomb.cancel()
                logger.debug("cancelled timebomb in finally")

        if "data" in ret and "parameters" in ret["data"]:
            parameters = ret["data"]["parameters"]
            # Set session parameters for cursor object
            for kv in parameters:
                if "TIMESTAMP_OUTPUT_FORMAT" in kv["name"]:
                    self._timestamp_output_format = kv["value"]
                if "TIMESTAMP_NTZ_OUTPUT_FORMAT" in kv["name"]:
                    self._timestamp_ntz_output_format = kv["value"]
                if "TIMESTAMP_LTZ_OUTPUT_FORMAT" in kv["name"]:
                    self._timestamp_ltz_output_format = kv["value"]
                if "TIMESTAMP_TZ_OUTPUT_FORMAT" in kv["name"]:
                    self._timestamp_tz_output_format = kv["value"]
                if "DATE_OUTPUT_FORMAT" in kv["name"]:
                    self._date_output_format = kv["value"]
                if "TIME_OUTPUT_FORMAT" in kv["name"]:
                    self._time_output_format = kv["value"]
                if "TIMEZONE" in kv["name"]:
                    self._timezone = kv["value"]
                if "BINARY_OUTPUT_FORMAT" in kv["name"]:
                    self._binary_output_format = kv["value"]
            # Set session parameters for connection object
            self._connection._update_parameters(
                {p["name"]: p["value"] for p in parameters}
            )

        self.query = query
        self._sequence_counter = -1
        return ret

    def execute(
        self,
        command: str,
        params: Sequence[Any] | dict[Any, Any] | None = None,
        _bind_stage: str | None = None,
        timeout: int | None = None,
        _exec_async: bool = False,
        _do_reset: bool = True,
        _put_callback: SnowflakeProgressPercentage = None,
        _put_azure_callback: SnowflakeProgressPercentage = None,
        _put_callback_output_stream: IO[str] = sys.stdout,
        _get_callback: SnowflakeProgressPercentage = None,
        _get_azure_callback: SnowflakeProgressPercentage = None,
        _get_callback_output_stream: IO[str] = sys.stdout,
        _show_progress_bar: bool = True,
        _statement_params: dict[str, str] | None = None,
        _is_internal: bool = False,
        _describe_only: bool = False,
        _no_results: bool = False,
        _use_ijson: bool = False,
        _is_put_get: bool | None = None,
        _raise_put_get_error: bool = True,
        _force_put_overwrite: bool = False,
        file_stream: IO[bytes] | None = None,
    ) -> SnowflakeCursor | None:
        """Executes a command/query.

        Args:
            command: The SQL command to be executed.
            params: Parameters to be bound into the SQL statement.
            _bind_stage: Path in temporary stage where binding parameters are uploaded as CSV files.
            timeout: Number of seconds after which to abort the query.
            _exec_async: Whether to execute this query asynchronously.
            _do_reset: Whether or not the result set needs to be reset before executing query.
            _put_callback: Function to which GET command should call back to.
            _put_azure_callback: Function to which an Azure GET command should call back to.
            _put_callback_output_stream: The output stream a PUT command's callback should report on.
            _get_callback: Function to which GET command should call back to.
            _get_azure_callback: Function to which an Azure GET command should call back to.
            _get_callback_output_stream: The output stream a GET command's callback should report on.
            _show_progress_bar: Whether or not to show progress bar.
            _statement_params: Extra information that should be sent to Snowflake with query.
            _is_internal: This flag indicates whether the query is issued internally by the connector.
            _describe_only: If true, the query will not be executed but return the schema/description of this query.
            _no_results: This flag tells the back-end to not return the result, just fire the query and return the
                response returned by Snowflake's server.
            _use_ijson: This flag doesn't do anything as ijson support has ended.
            _is_put_get: Force decision of this SQL query being a PUT, or GET command. This is detected otherwise.
            _raise_put_get_error: Whether to raise PUT and GET errors.
            _force_put_overwrite: If the SQL query is a PUT, then this flag can force overwriting of an already
                existing file on stage.
            file_stream: File-like object to be uploaded with PUT

        Returns:
            The cursor itself, or None if some error happened, or the response returned
            by Snowflake if the _no_results flag is on.
        """
        if _exec_async:
            _no_results = True
        logger.debug("executing SQL/command")
        if self.is_closed():
            Error.errorhandler_wrapper(
                self.connection,
                self,
                InterfaceError,
                {"msg": "Cursor is closed in execute.", "errno": ER_CURSOR_IS_CLOSED},
            )

        if _do_reset:
            self.reset()
        command = command.strip(" \t\n\r") if command else None
        if not command:
            logger.warning("execute: no query is given to execute")
            return

        kwargs = {
            "timeout": timeout,
            "statement_params": _statement_params,
            "is_internal": _is_internal,
            "describe_only": _describe_only,
            "_no_results": _no_results,
            "_is_put_get": _is_put_get,
        }

        if self._connection.is_pyformat:
            # pyformat/format paramstyle
            # client side binding
            processed_params = self._connection._process_params_pyformat(params, self)
            # SNOW-513061 collect telemetry for empty sequence usage before we make the breaking change announcement
            if params is not None and len(params) == 0:
                self._log_telemetry_job_data(
                    TelemetryField.EMPTY_SEQ_INTERPOLATION,
                    TelemetryData.TRUE
                    if self.connection._interpolate_empty_sequences
                    else TelemetryData.FALSE,
                )
            if logger.getEffectiveLevel() <= logging.DEBUG:
                logger.debug(
                    f"binding: [{self._format_query_for_log(command)}] "
                    f"with input=[{params}], "
                    f"processed=[{processed_params}]",
                )
            if (
                self.connection._interpolate_empty_sequences
                and processed_params is not None
            ) or (
                not self.connection._interpolate_empty_sequences
                and len(processed_params) > 0
            ):
                query = command % processed_params
            else:
                query = command
        else:
            # qmark and numeric paramstyle
            query = command
            if _bind_stage:
                kwargs["binding_stage"] = _bind_stage
            else:
                if params is not None and not isinstance(params, (list, tuple)):
                    errorvalue = {
                        "msg": f"Binding parameters must be a list: {params}",
                        "errno": ER_FAILED_PROCESSING_PYFORMAT,
                    }
                    Error.errorhandler_wrapper(
                        self.connection, self, ProgrammingError, errorvalue
                    )

                kwargs["binding_params"] = self._connection._process_params_qmarks(
                    params, self
                )

        m = DESC_TABLE_RE.match(query)
        if m:
            query1 = f"describe table {m.group(1)}"
            if logger.getEffectiveLevel() <= logging.WARNING:
                logger.info(
                    "query was rewritten: org=%s, new=%s",
                    " ".join(line.strip() for line in query.split("\n")),
                    query1,
                )
            query = query1

        if logger.getEffectiveLevel() <= logging.INFO:
            logger.info("query: [%s]", self._format_query_for_log(query))
        ret = self._execute_helper(query, **kwargs)
        self._sfqid = (
            ret["data"]["queryId"]
            if "data" in ret and "queryId" in ret["data"]
            else None
        )
        self._sqlstate = (
            ret["data"]["sqlState"]
            if "data" in ret and "sqlState" in ret["data"]
            else None
        )
        self._first_chunk_time = get_time_millis()

        # if server gives a send time, log the time it took to arrive
        if "data" in ret and "sendResultTime" in ret["data"]:
            time_consume_first_result = (
                self._first_chunk_time - ret["data"]["sendResultTime"]
            )
            self._log_telemetry_job_data(
                TelemetryField.TIME_CONSUME_FIRST_RESULT, time_consume_first_result
            )
        logger.debug("sfqid: %s", self.sfqid)

        logger.info("query execution done")
        if ret["success"]:
            logger.debug("SUCCESS")
            data = ret["data"]

            logger.debug("PUT OR GET: %s", self.is_file_transfer)
            if self.is_file_transfer:
                # Decide whether to use the old, or new code path
                sf_file_transfer_agent = SnowflakeFileTransferAgent(
                    self,
                    query,
                    ret,
                    put_callback=_put_callback,
                    put_azure_callback=_put_azure_callback,
                    put_callback_output_stream=_put_callback_output_stream,
                    get_callback=_get_callback,
                    get_azure_callback=_get_azure_callback,
                    get_callback_output_stream=_get_callback_output_stream,
                    show_progress_bar=_show_progress_bar,
                    raise_put_get_error=_raise_put_get_error,
                    force_put_overwrite=_force_put_overwrite
                    or data.get("overwrite", False),
                    source_from_stream=file_stream,
                    multipart_threshold=data.get("threshold"),
                    use_s3_regional_url=self._connection.enable_stage_s3_privatelink_for_us_east_1,
                )
                sf_file_transfer_agent.execute()
                data = sf_file_transfer_agent.result()
                self._total_rowcount = len(data["rowset"]) if "rowset" in data else -1
            m = self.ALTER_SESSION_RE.match(query)
            if m:
                # session parameters
                param = m.group(1).upper()
                value = m.group(2)
                self._connection.converter.set_parameter(param, value)

            if _exec_async:
                self.connection._async_sfqids.add(self._sfqid)
            if _no_results:
                self._total_rowcount = (
                    ret["data"]["total"]
                    if "data" in ret and "total" in ret["data"]
                    else -1
                )
                return data
            self._init_result_and_meta(data)
        else:
            self._total_rowcount = (
                ret["data"]["total"] if "data" in ret and "total" in ret["data"] else -1
            )
            logger.debug(ret)
            err = ret["message"]
            code = ret.get("code", -1)
            if "data" in ret:
                err += ret["data"].get("errorMessage", "")
            errvalue = {
                "msg": err,
                "errno": int(code),
                "sqlstate": self._sqlstate,
                "sfqid": self._sfqid,
            }
            is_integrity_error = (
                code == "100072"
            )  # NULL result in a non-nullable column
            error_class = IntegrityError if is_integrity_error else ProgrammingError
            Error.errorhandler_wrapper(self.connection, self, error_class, errvalue)
        return self

    def execute_async(self, *args, **kwargs):
        """Convenience function to execute a query without waiting for results (asynchronously).

        This function takes the same arguments as execute, please refer to that function
        for documentation.
        """
        kwargs["_exec_async"] = True
        return self.execute(*args, **kwargs)

    def describe(self, *args, **kwargs) -> list[ResultMetadata]:
        """Obtain the schema of the result without executing the query.

        This function takes the same arguments as execute, please refer to that function
        for documentation.

        Returns:
            The schema of the result.
        """
        kwargs["_describe_only"] = kwargs["_is_internal"] = True
        self.execute(*args, **kwargs)
        return self._description

    def _format_query_for_log(self, query):
        return self._connection._format_query_for_log(query)

    def _is_dml(self, data):
        return (
            "statementTypeId" in data
            and int(data["statementTypeId"]) in STATEMENT_TYPE_ID_DML_SET
        )

    def _init_result_and_meta(self, data):
        is_dml = self._is_dml(data)
        self._query_result_format = data.get("queryResultFormat", "json")
        logger.debug("Query result format: %s", self._query_result_format)

        if self._total_rowcount == -1 and not is_dml and data.get("total") is not None:
            self._total_rowcount = data["total"]

        self._description: list[ResultMetadata] = [
            ResultMetadata.from_column(col) for col in data["rowtype"]
        ]

        result_chunks = create_batches_from_response(
            self, self._query_result_format, data, self._description
        )

        self._result_set = ResultSet(
            self,
            result_chunks,
            self._connection.client_prefetch_threads,
        )
        self._rownumber = -1
        self._result_state = ResultState.VALID

        # don't update the row count when the result is returned from `describe` method
        if is_dml and "rowset" in data and len(data["rowset"]) > 0:
            updated_rows = 0
            for idx, desc in enumerate(self._description):
                if (
                    desc[0]
                    in (
                        "number of rows updated",
                        "number of multi-joined rows updated",
                        "number of rows deleted",
                    )
                    or desc[0].startswith("number of rows inserted")
                ):
                    updated_rows += int(data["rowset"][0][idx])
            if self._total_rowcount == -1:
                self._total_rowcount = updated_rows
            else:
                self._total_rowcount += updated_rows

    def check_can_use_arrow_resultset(self):
        global CAN_USE_ARROW_RESULT_FORMAT

        if not CAN_USE_ARROW_RESULT_FORMAT:
            if self._connection.application == "SnowSQL":
                msg = "Currently SnowSQL doesn't support the result set in Apache Arrow format."
                errno = ER_NO_PYARROW_SNOWSQL
            else:
                msg = "The result set in Apache Arrow format is not supported for the platform."
                errno = ER_NO_ARROW_RESULT

            Error.errorhandler_wrapper(
                self.connection,
                self,
                ProgrammingError,
                {
                    "msg": msg,
                    "errno": errno,
                },
            )

    def check_can_use_pandas(self):
        if not installed_pandas:
            msg = (
                "Optional dependency: 'pyarrow' is not installed, please see the following link for install "
                "instructions: https://docs.snowflake.com/en/user-guide/python-connector-pandas.html#installation"
            )
            errno = ER_NO_PYARROW

            Error.errorhandler_wrapper(
                self.connection,
                self,
                ProgrammingError,
                {
                    "msg": msg,
                    "errno": errno,
                },
            )

    def query_result(self, qid):
        url = f"/queries/{qid}/result"
        ret = self._connection.rest.request(url=url, method="get")
        self._sfqid = (
            ret["data"]["queryId"]
            if "data" in ret and "queryId" in ret["data"]
            else None
        )
        self._sqlstate = (
            ret["data"]["sqlState"]
            if "data" in ret and "sqlState" in ret["data"]
            else None
        )
        logger.debug("sfqid=%s", self._sfqid)

        if ret.get("success"):
            data = ret.get("data")
            self._init_result_and_meta(data)
        else:
            logger.info("failed")
            logger.debug(ret)
            err = ret["message"]
            code = ret.get("code", -1)
            if "data" in ret:
                err += ret["data"].get("errorMessage", "")
            errvalue = {
                "msg": err,
                "errno": int(code),
                "sqlstate": self._sqlstate,
                "sfqid": self._sfqid,
            }
            Error.errorhandler_wrapper(
                self.connection, self, ProgrammingError, errvalue
            )
        return self

    def fetch_arrow_batches(self) -> Iterator[Table]:
        self.check_can_use_arrow_resultset()
        if self._query_result_format != "arrow":
            raise NotSupportedError
        self._log_telemetry_job_data(
            TelemetryField.ARROW_FETCH_BATCHES, TelemetryData.TRUE
        )
        return self._result_set._fetch_arrow_batches()

    def fetch_arrow_all(self) -> Table | None:
        self.check_can_use_arrow_resultset()
        if self._query_result_format != "arrow":
            raise NotSupportedError
        self._log_telemetry_job_data(TelemetryField.ARROW_FETCH_ALL, TelemetryData.TRUE)
        return self._result_set._fetch_arrow_all()

    def fetch_pandas_batches(self, **kwargs) -> Iterator[pandas.DataFrame]:
        """Fetches a single Arrow Table."""
        self.check_can_use_pandas()
        if self._prefetch_hook is not None:
            self._prefetch_hook()
            raise NotSupportedError
        self._log_telemetry_job_data(
            TelemetryField.PANDAS_FETCH_BATCHES, TelemetryData.TRUE
        )
        return self._result_set._fetch_pandas_batches(**kwargs)

    def fetch_pandas_all(self, **kwargs) -> pandas.DataFrame:
        """Fetch Pandas dataframes in batches, where 'batch' refers to Snowflake Chunk."""
        self.check_can_use_pandas()
        if self._prefetch_hook is not None:
            self._prefetch_hook()
        if self._query_result_format != "arrow":
            raise NotSupportedError
        self._log_telemetry_job_data(
            TelemetryField.PANDAS_FETCH_ALL, TelemetryData.TRUE
        )
        return self._result_set._fetch_pandas_all(**kwargs)

    def abort_query(self, qid):
        url = f"/queries/{qid}/abort-request"
        ret = self._connection.rest.request(url=url, method="post")
        return ret.get("success")

    def executemany(
        self,
        command: str,
        seqparams: Sequence[Any] | dict[str, Any],
    ) -> SnowflakeCursor:
        """Executes a command/query with the given set of parameters sequentially."""
        logger.debug("executing many SQLs/commands")
        command = command.strip(" \t\n\r") if command else None

        if not seqparams:
            return self

        if self.INSERT_SQL_RE.match(command):
            if self._connection.is_pyformat:
                logger.debug("rewriting INSERT query")
                command_wo_comments = re.sub(self.COMMENT_SQL_RE, "", command)
                m = self.INSERT_SQL_VALUES_RE.match(command_wo_comments)
                if not m:
                    Error.errorhandler_wrapper(
                        self.connection,
                        self,
                        InterfaceError,
                        {
                            "msg": "Failed to rewrite multi-row insert",
                            "errno": ER_FAILED_TO_REWRITE_MULTI_ROW_INSERT,
                        },
                    )

                fmt = m.group(1)
                values = []
                for param in seqparams:
                    logger.debug(f"parameter: {param}")
                    values.append(
                        fmt % self._connection._process_params_pyformat(param, self)
                    )
                command = command.replace(fmt, ",".join(values), 1)
                self.execute(command)
                return self
            else:
                logger.debug("bulk insert")
                # sanity check
                row_size = len(seqparams[0])
                for row in seqparams:
                    if len(row) != row_size:
                        error_value = {
                            "msg": f"Bulk data size don't match. expected: {row_size}, "
                            f"got: {len(row)}, command: {command}",
                            "errno": ER_INVALID_VALUE,
                        }
                        Error.errorhandler_wrapper(
                            self.connection, self, InterfaceError, error_value
                        )
                        return self
                bind_size = len(seqparams) * row_size
                bind_stage = None
                if (
                    bind_size
                    > self.connection._session_parameters[
                        "CLIENT_STAGE_ARRAY_BINDING_THRESHOLD"
                    ]
                    > 0
                ):
                    # bind stage optimization
                    try:
                        rows = self.connection._write_params_to_byte_rows(seqparams)
                        bind_uploader = BindUploadAgent(self, rows)
                        bind_uploader.upload()
                        bind_stage = bind_uploader.stage_path
                    except BindUploadError:
                        logger.debug(
                            "Failed to upload binds to stage, sending binds to "
                            "Snowflake instead."
                        )
                    except Exception as exc:
                        if not isinstance(exc, INCIDENT_BLACKLIST):
                            self.connection.incident.report_incident()
                        raise
                binding_param = (
                    None if bind_stage else list(map(list, zip(*seqparams)))
                )  # transpose
                self.execute(command, params=binding_param, _bind_stage=bind_stage)
                return self

        self.reset()
        for param in seqparams:
            self.execute(command, param, _do_reset=False)
        return self

    def _result_iterator(
        self,
    ) -> Generator[dict, None, None] | Generator[tuple, None, None]:
        """Yields the elements from _result and raises an exception when appropriate."""
        try:
            for _next in self._result:
                if isinstance(_next, Exception):
                    Error.errorhandler_wrapper_from_ready_exception(
                        self._connection,
                        self,
                        _next,
                    )
                self._rownumber += 1
                yield _next
        except TypeError as err:
            if self._result_state == ResultState.DEFAULT:
                raise err
            else:
                yield None

    def fetchone(self) -> dict | tuple | None:
        """Fetches one row."""
        if self._prefetch_hook is not None:
            self._prefetch_hook()
        if self._result is None and self._result_set is not None:
            self._result = iter(self._result_set)
            self._result_state = ResultState.VALID
        try:
            return next(self._result_iterator())
        except StopIteration:
            return None

    def fetchmany(self, size=None):
        """Fetches the number of specified rows."""
        if size is None:
            size = self.arraysize

        if size < 0:
            errorvalue = {
                "msg": (
                    "The number of rows is not zero or " "positive number: {}"
                ).format(size),
                "errno": ER_NOT_POSITIVE_SIZE,
            }
            Error.errorhandler_wrapper(
                self.connection, self, ProgrammingError, errorvalue
            )
        ret = []
        while size > 0:
            row = self.fetchone()
            if row is None:
                break
            ret.append(row)
            if size is not None:
                size -= 1

        return ret

    def fetchall(self) -> list[tuple] | list[dict]:
        """Fetches all of the results."""
        ret = []
        while True:
            row = self.fetchone()
            if row is None:
                break
            ret.append(row)
        return ret

    def nextset(self):
        """Not supported."""
        logger.debug("nop")
        return None

    def setinputsizes(self, _):
        """Not supported."""
        logger.debug("nop")

    def setoutputsize(self, _, column=None):
        """Not supported."""
        del column
        logger.debug("nop")

    def scroll(self, value, mode="relative"):
        Error.errorhandler_wrapper(
            self.connection,
            self,
            NotSupportedError,
            {
                "msg": "scroll is not supported.",
                "errno": ER_UNSUPPORTED_METHOD,
                "sqlstate": SQLSTATE_FEATURE_NOT_SUPPORTED,
            },
        )

    def reset(self):
        """Resets the result set."""
        self._total_rowcount = -1  # reset the rowcount
        if self._result_state != ResultState.DEFAULT:
            self._result_state = ResultState.RESET
        if self._result is not None:
            self._result = None
        if self._inner_cursor is not None:
            self._inner_cursor.reset()
            self._result = None
            self._inner_cursor = None
        self._prefetch_hook = None
        if not self.connection._reuse_results:
            self._result_set = None

    def __iter__(self) -> Iterator[dict] | Iterator[tuple]:
        """Iteration over the result set."""
        # set _result if _result_set is not None
        if self._result is None and self._result_set is not None:
            self._result = iter(self._result_set)
            self._result_state = ResultState.VALID
        return self._result_iterator()

    def __cancel_query(self, query):
        if self._sequence_counter >= 0 and not self.is_closed():
            logger.debug("canceled. %s, request_id: %s", query, self._request_id)
            with self._lock_canceling:
                self._connection._cancel_query(query, self._request_id)

    def _log_telemetry_job_data(
        self, telemetry_field: TelemetryField, value: Any
    ) -> None:
        """Builds an instance of TelemetryData with the given field and logs it."""
        obj = {
            "type": telemetry_field.value,
            "source": self._connection.application if self._connection else CLIENT_NAME,
            "query_id": self._sfqid,
            "value": int(value),
        }
        ts = get_time_millis()
        try:
            self._connection._log_telemetry(TelemetryData(obj, ts))
        except AttributeError:
            logger.warning(
                "Cursor failed to log to telemetry. Connection object may be None.",
                exc_info=True,
            )

    def __enter__(self):
        """Context manager."""
        return self

    def __exit__(self, exc_type, exc_val, exc_tb):
        """Context manager with commit or rollback."""
        self.close()

    def get_results_from_sfqid(self, sfqid: str):
        """Gets the results from previously ran query."""

        def wait_until_ready():
            """Makes sure query has finished executing and once it has retrieves results."""
            no_data_counter = 0
            retry_pattern_pos = 0
            while True:
                status = self.connection.get_query_status(sfqid)
                if not self.connection.is_still_running(status):
                    break
                if status == QueryStatus.NO_DATA:  # pragma: no cover
                    no_data_counter += 1
                    if no_data_counter > ASYNC_NO_DATA_MAX_RETRY:
                        raise DatabaseError(
                            "Cannot retrieve data on the status of this query. No information returned "
                            "from server for query '{}'"
                        )
                time.sleep(
                    0.5 * ASYNC_RETRY_PATTERN[retry_pattern_pos]
                )  # Same wait as JDBC
                # If we can advance in ASYNC_RETRY_PATTERN then do so
                if retry_pattern_pos < (len(ASYNC_RETRY_PATTERN) - 1):
                    retry_pattern_pos += 1
            if status != QueryStatus.SUCCESS:
                raise DatabaseError(
                    "Status of query '{}' is {}, results are unavailable".format(
                        sfqid, status.name
                    )
                )
            self._inner_cursor.execute(f"select * from table(result_scan('{sfqid}'))")
            self._result = self._inner_cursor._result
            self._query_result_format = self._inner_cursor._query_result_format
            self._total_rowcount = self._inner_cursor._total_rowcount
            self._description = self._inner_cursor._description
            self._result_set = self._inner_cursor._result_set
            self._result_state = ResultState.VALID
            self._rownumber = 0
            # Unset this function, so that we don't block anymore
            self._prefetch_hook = None

        self.connection.get_query_status_throw_if_error(
            sfqid
        )  # Trigger an exception if query failed
        klass = self.__class__
        self._inner_cursor = klass(self.connection)
        self._sfqid = sfqid
        self._prefetch_hook = wait_until_ready

    def get_result_batches(self) -> list[ResultBatch] | None:
        """Get the previously executed query's ``ResultBatch`` s if available.

        If they are unavailable, in case nothing has been executed yet None will
        be returned.

        For a detailed description of ``ResultBatch`` s please see the docstring of:
        ``snowflake.connector.result_batches.ResultBatch``
        """
        if self._result_set is None:
            return None
        self._log_telemetry_job_data(
            TelemetryField.GET_PARTITIONS_USED, TelemetryData.TRUE
        )
        return self._result_set.batches


class DictCursor(SnowflakeCursor):
    """Cursor returning results in a dictionary."""

    def __init__(self, connection):
        super().__init__(
            connection,
            use_dict_result=True,
        )