Why Gemfury? Push, build, and install  RubyGems npm packages Python packages Maven artifacts PHP packages Go Modules Debian packages RPM packages NuGet packages

Repository URL to install this package:

Details    
snowflake-connector-python / src / snowflake / connector / file_transfer_agent.py
Size: Mime:
#!/usr/bin/env python
#
# Copyright (c) 2012-2021 Snowflake Computing Inc. All rights reserved.
#

from __future__ import annotations

import binascii
import glob
import mimetypes
import os
import sys
import threading
from concurrent.futures.thread import ThreadPoolExecutor
from dataclasses import dataclass
from functools import partial
from logging import getLogger
from time import time
from typing import IO, TYPE_CHECKING, Any, Callable, TypeVar

from .azure_storage_client import SnowflakeAzureRestClient
from .compat import GET_CWD, IS_WINDOWS
from .constants import (
    AZURE_FS,
    CMD_TYPE_DOWNLOAD,
    CMD_TYPE_UPLOAD,
    GCS_FS,
    LOCAL_FS,
    S3_FS,
    ResultStatus,
    megabyte,
)
from .converter_snowsql import SnowflakeConverterSnowSQL
from .errorcode import (
    ER_COMPRESSION_NOT_SUPPORTED,
    ER_FAILED_TO_DOWNLOAD_FROM_STAGE,
    ER_FAILED_TO_UPLOAD_TO_STAGE,
    ER_FILE_NOT_EXISTS,
    ER_INTERNAL_NOT_MATCH_ENCRYPT_MATERIAL,
    ER_INVALID_STAGE_FS,
    ER_INVALID_STAGE_LOCATION,
    ER_LOCAL_PATH_NOT_DIRECTORY,
)
from .errors import (
    DatabaseError,
    Error,
    InternalError,
    OperationalError,
    ProgrammingError,
)
from .file_compression_type import CompressionTypes, lookup_by_mime_sub_type
from .gcs_storage_client import SnowflakeGCSRestClient
from .local_storage_client import SnowflakeLocalStorageClient
from .s3_storage_client import SnowflakeS3RestClient
from .storage_client import SnowflakeFileEncryptionMaterial, SnowflakeStorageClient

if TYPE_CHECKING:  # pragma: no cover
    from .connection import SnowflakeConnection
    from .cursor import SnowflakeCursor
    from .file_compression_type import CompressionType

VALID_STORAGE = [LOCAL_FS, S3_FS, AZURE_FS, GCS_FS]

INJECT_WAIT_IN_PUT = 0

logger = getLogger(__name__)


def result_text_column_desc(name):
    return {
        "name": name,
        "type": "text",
        "length": 16777216,
        "precision": None,
        "scale": None,
        "nullable": False,
    }


def result_fixed_column_desc(name):
    return {
        "name": name,
        "type": "fixed",
        "length": 5,
        "precision": 0,
        "scale": 0,
        "nullable": False,
    }


# TODO: rewrite, we use this class to store information about file transfers
#  It'd make more sense to define a new object, like FileTransferMeta that then has
#  more FileMetas inside of it. This would help in some cases, where for example
#  consider the case where we run into an unrecoverable error for the whole transfer
#  job and we need to convey an error to the main thread and that error needs to be
#  raised by the main thread. Where should this go? Currently the answer could be
#  all of the current FileMetas. Hmmm...
@dataclass
class SnowflakeFileMeta:
    """Class to keep track of information necessary for file operations."""

    name: str
    src_file_name: str
    stage_location_type: str
    result_status: ResultStatus | None = None

    self: SnowflakeFileTransferAgent | None = None
    put_callback: type[SnowflakeProgressPercentage] | None = None
    put_azure_callback: type[SnowflakeProgressPercentage] | None = None
    put_callback_output_stream: IO[str] | None = None
    get_callback: type[SnowflakeProgressPercentage] | None = None
    get_azure_callback: type[SnowflakeProgressPercentage] | None = None
    get_callback_output_stream: IO[str] | None = None
    show_progress_bar: bool = False
    multipart_threshold: int = 67108864  # Historical value
    presigned_url: str | None = None
    overwrite: bool = False
    sha256_digest: str | None = None
    upload_size: int | None = None
    real_src_file_name: str | None = None
    error_details: Exception | None = None
    last_error: Exception | None = None
    no_sleeping_time: bool = False
    gcs_file_header_digest: str | None = None
    gcs_file_header_content_length: int | None = None
    gcs_file_header_encryption_metadata: dict[str, Any] | None = None

    encryption_material: SnowflakeFileEncryptionMaterial | None = None
    # Specific to Uploads only
    src_file_size: int = 0
    src_compression_type: CompressionType | None = None
    dst_compression_type: CompressionType = None
    require_compress: bool = False
    dst_file_name: str | None = None
    dst_file_size: int = -1
    intermediate_stream: IO[bytes] | None = None
    src_stream: IO[bytes] | None = None
    # Specific to Downloads only
    local_location: str | None = None


def _update_progress(
    file_name: str,
    start_time: float,
    total_size: float,
    progress: float | int,
    output_stream: IO | None = sys.stdout,
    show_progress_bar: bool | None = True,
) -> bool:
    bar_length = 10  # Modify this to change the length of the progress bar
    total_size /= megabyte
    status = ""
    elapsed_time = time() - start_time
    throughput = (total_size / elapsed_time) if elapsed_time != 0.0 else 0.0
    if isinstance(progress, int):
        progress = float(progress)
    if not isinstance(progress, float):
        progress = 0
        status = "error: progress var must be float\r\n"
    if progress < 0:
        progress = 0
        status = "Halt...\r\n"
    if progress >= 1:
        progress = 1
        status = f"Done ({elapsed_time:.3f}s, {throughput:.2f}MB/s).\r\n"
    if not status and show_progress_bar:
        status = f"({elapsed_time:.3f}s, {throughput:.2f}MB/s)"
    if status:
        block = int(round(bar_length * progress))

        text = (
            f"\r{file_name}({total_size:.2f}MB): "
            f"[{'#' * block + '-' * (bar_length - block)}] "
            f"{progress * 100.0:.2f}% {status}"
        )
        output_stream.write(text)
        output_stream.flush()
    logger.debug(
        f"filename: {file_name}, start_time: {start_time}, total_size: {total_size}, "
        f"progress: {progress}, show_progress_bar: {show_progress_bar}"
    )
    return progress == 1.0


def percent(seen_so_far: int, size: float) -> float:
    return 1.0 if seen_so_far >= size or size <= 0 else float(seen_so_far / size)


class SnowflakeProgressPercentage:
    """Built-in Progress bar for PUT commands."""

    def __init__(
        self,
        filename: str,
        filesize: int | float,
        output_stream: IO | None = sys.stdout,
        show_progress_bar: bool | None = True,
    ):
        last_pound_char = filename.rfind("#")
        if last_pound_char < 0:
            last_pound_char = len(filename)
        self._filename = os.path.basename(filename[0:last_pound_char])
        self._output_stream = output_stream
        self._show_progress_bar = show_progress_bar
        self._size = float(filesize)
        self._seen_so_far = 0
        self._done = False
        self._start_time = time()
        self._lock = threading.Lock()

    def __call__(self, bytes_amount: int):
        raise NotImplementedError


class SnowflakeS3ProgressPercentage(SnowflakeProgressPercentage):
    def __init__(
        self,
        filename: str,
        filesize: int | float,
        output_stream: IO | None = sys.stdout,
        show_progress_bar: bool | None = True,
    ):
        super().__init__(
            filename,
            filesize,
            output_stream=output_stream,
            show_progress_bar=show_progress_bar,
        )

    def __call__(self, bytes_amount: int):
        with self._lock:
            if self._output_stream:
                self._seen_so_far += bytes_amount
                percentage = percent(self._seen_so_far, self._size)
                if not self._done:
                    self._done = _update_progress(
                        self._filename,
                        self._start_time,
                        self._size,
                        percentage,
                        output_stream=self._output_stream,
                        show_progress_bar=self._show_progress_bar,
                    )


class SnowflakeAzureProgressPercentage(SnowflakeProgressPercentage):
    def __init__(
        self,
        filename: str,
        filesize: int | float,
        output_stream: IO | None = sys.stdout,
        show_progress_bar: bool | None = True,
    ):
        super().__init__(
            filename,
            filesize,
            output_stream=output_stream,
            show_progress_bar=show_progress_bar,
        )

    def __call__(self, current: int):
        with self._lock:
            if self._output_stream:
                self._seen_so_far = current
                percentage = percent(self._seen_so_far, self._size)
                if not self._done:
                    self._done = _update_progress(
                        self._filename,
                        self._start_time,
                        self._size,
                        percentage,
                        output_stream=self._output_stream,
                        show_progress_bar=self._show_progress_bar,
                    )


class StorageCredential:
    def __init__(
        self,
        credentials: dict[str, Any],
        connection: SnowflakeConnection,
        command: str,
    ):
        self.creds = credentials
        self.timestamp = time()
        self.lock = threading.Lock()
        self.connection = connection
        self._command = command

    def update(self, cur_timestamp):
        with self.lock:
            if cur_timestamp < self.timestamp:
                return
            logger.debug("Renewing expired storage token.")
            ret = self.connection.cursor()._execute_helper(self._command)
            self.creds = ret["data"]["stageInfo"]["creds"]
            self.timestamp = time()


@dataclass
class TransferMetadata:
    num_files_started: int = 0
    num_files_completed: int = 0
    chunks_in_queue: int = 0


class SnowflakeFileTransferAgent:
    """Snowflake File Transfer Agent provides cloud provider independent implementation for putting/getting files."""

    def __init__(
        self,
        cursor: SnowflakeCursor,
        command: str,
        ret: dict[str, Any],
        put_callback: type[SnowflakeProgressPercentage] | None = None,
        put_azure_callback: type[SnowflakeProgressPercentage] | None = None,
        put_callback_output_stream: IO[str] = sys.stdout,
        get_callback: type[SnowflakeProgressPercentage] | None = None,
        get_azure_callback: type[SnowflakeProgressPercentage] | None = None,
        get_callback_output_stream: IO[str] = sys.stdout,
        show_progress_bar: bool = True,
        raise_put_get_error: bool = True,
        force_put_overwrite: bool = True,
        multipart_threshold: int | None = None,
        source_from_stream: IO[bytes] | None = None,
        use_s3_regional_url: bool = False,
    ):
        self._cursor = cursor
        self._command = command
        self._ret = ret
        self._put_callback = put_callback
        self._put_azure_callback = (
            put_azure_callback if put_azure_callback else put_callback
        )
        self._put_callback_output_stream = put_callback_output_stream
        self._get_callback = get_callback
        self._get_azure_callback = (
            get_azure_callback if get_azure_callback else get_callback
        )
        self._get_callback_output_stream = get_callback_output_stream
        self._use_accelerate_endpoint = False
        self._raise_put_get_error = raise_put_get_error
        self._show_progress_bar = show_progress_bar
        self._force_put_overwrite = force_put_overwrite
        self._source_from_stream = source_from_stream
        # The list of self-sufficient file metas that are sent to
        # remote storage clients to get operated on.
        self._file_metadata: list[SnowflakeFileMeta] = []
        self._results: list[SnowflakeFileMeta] = []
        self._multipart_threshold = multipart_threshold or 67108864  # Historical value
        self._use_s3_regional_url = use_s3_regional_url
        self._credentials: StorageCredential | None = None

    def execute(self) -> None:
        self._parse_command()
        self._init_file_metadata()

        if self._command_type == CMD_TYPE_UPLOAD:
            self._process_file_compression_type()

        for m in self._file_metadata:
            m.self = self

        self._transfer_accelerate_config()

        if self._command_type == CMD_TYPE_DOWNLOAD:
            if not os.path.isdir(self._local_location):
                os.makedirs(self._local_location)

        if self._stage_location_type == LOCAL_FS:
            if not os.path.isdir(self._stage_info["location"]):
                os.makedirs(self._stage_info["location"])

        for m in self._file_metadata:
            m.overwrite = self._overwrite
            m.self = self
            if self._stage_location_type != LOCAL_FS:
                m.put_callback = self._put_callback
                m.put_azure_callback = self._put_azure_callback
                m.put_callback_output_stream = self._put_callback_output_stream
                m.get_callback = self._get_callback
                m.get_azure_callback = self._get_azure_callback
                m.get_callback_output_stream = self._get_callback_output_stream
                m.show_progress_bar = self._show_progress_bar

                # multichunk threshold
                m.multipart_threshold = self._multipart_threshold

        logger.debug(f"parallel=[{self._parallel}]")
        if self._raise_put_get_error and not self._file_metadata:
            Error.errorhandler_wrapper(
                self._cursor.connection,
                self._cursor,
                OperationalError,
                {
                    "msg": "While getting file(s) there was an error: "
                    "the file does not exist.",
                    "errno": ER_FILE_NOT_EXISTS,
                },
            )
        self.transfer(self._file_metadata)

        # turn enum to string, in order to have backward compatible interface
        for result in self._results:
            result.result_status = result.result_status.value

    def transfer(self, metas: list[SnowflakeFileMeta]) -> None:
        max_concurrency = self._parallel
        network_tpe = ThreadPoolExecutor(max_concurrency)
        preprocess_tpe = ThreadPoolExecutor(min(len(metas), os.cpu_count()))
        postprocess_tpe = ThreadPoolExecutor(min(len(metas), os.cpu_count()))
        logger.debug(f"Chunk ThreadPoolExecutor size: {max_concurrency}")
        cv_main_thread = threading.Condition()  # to signal the main thread
        cv_chunk_process = (
            threading.Condition()
        )  # to get more chunks into the chunk_tpe
        files = [self._create_file_transfer_client(m) for m in metas]
        num_total_files = len(metas)
        transfer_metadata = TransferMetadata()  # this is protected by cv_chunk_process
        is_upload = self._command_type == CMD_TYPE_UPLOAD
        exception_caught_in_callback: Exception | None = None

        def notify_file_completed():
            # Increment the number of completed files, then notify the main thread.
            with cv_main_thread:
                transfer_metadata.num_files_completed += 1
                cv_main_thread.notify()

        def preprocess_done_cb(
            success: bool,
            result: Any,
            file_meta: SnowflakeFileMeta,
            done_client: SnowflakeStorageClient,
        ):
            if not success:
                logger.debug(f"Failed to prepare {done_client.meta.name}.")
                if is_upload:
                    done_client.finish_upload()
                    done_client.delete_client_data()
                else:
                    done_client.finish_download()
                notify_file_completed()
            elif done_client.meta.result_status == ResultStatus.SKIPPED:
                # this case applies to upload only
                notify_file_completed()
            else:
                logger.debug(f"Finished preparing file {done_client.meta.name}")
                with cv_chunk_process:
                    while transfer_metadata.chunks_in_queue > 2 * max_concurrency:
                        logger.debug(
                            "Chunk queue busy, waiting in file done callback..."
                        )
                        cv_chunk_process.wait()
                    for _chunk_id in range(done_client.num_of_chunks):
                        _callback = partial(
                            transfer_done_cb,
                            done_client=done_client,
                            chunk_id=_chunk_id,
                        )
                        if is_upload:
                            network_tpe.submit(
                                function_and_callback_wrapper,
                                # Work fn
                                done_client.upload_chunk,
                                # Callback fn
                                _callback,
                                file_meta,
                                # Arguments for work fn
                                _chunk_id,
                            )
                        else:
                            network_tpe.submit(
                                function_and_callback_wrapper,
                                # Work fn
                                done_client.download_chunk,
                                # Callback fn
                                _callback,
                                file_meta,
                                # Arguments for work fn
                                _chunk_id,
                            )
                        transfer_metadata.chunks_in_queue += 1
                    cv_chunk_process.notify()

        def transfer_done_cb(
            success: bool,
            result: Any,
            file_meta: SnowflakeFileMeta,
            done_client: SnowflakeStorageClient,
            chunk_id: int,
        ):
            # Note: chunk_id is 0 based while num_of_chunks is count
            logger.debug(
                f"Chunk {chunk_id}/{done_client.num_of_chunks} of file {done_client.meta.name} reached callback"
            )
            with cv_chunk_process:
                transfer_metadata.chunks_in_queue -= 1
                cv_chunk_process.notify()

            with done_client.lock:
                if not success:
                    # TODO: Cancel other chunks?
                    done_client.failed_transfers += 1
                    logger.debug(
                        f"Chunk {chunk_id} of file {done_client.meta.name} failed to transfer for unexpected exception {result}"
                    )
                else:
                    done_client.successful_transfers += 1
                logger.debug(
                    f"Chunk progress: {done_client.meta.name}: completed: {done_client.successful_transfers} failed: {done_client.failed_transfers} total: {done_client.num_of_chunks}"
                )
                if (
                    done_client.successful_transfers + done_client.failed_transfers
                    == done_client.num_of_chunks
                ):
                    if is_upload:
                        done_client.finish_upload()
                        done_client.delete_client_data()
                        notify_file_completed()
                    else:
                        postprocess_tpe.submit(
                            function_and_callback_wrapper,
                            # Work fn
                            done_client.finish_download,
                            # Callback fn
                            partial(postprocess_done_cb, done_client=done_client),
                            transfer_metadata,
                        )
                        logger.debug(
                            f"submitting {done_client.meta.name} to done_postprocess"
                        )

        def postprocess_done_cb(
            success: bool,
            result: Any,
            file_meta: SnowflakeFileMeta,
            done_client: SnowflakeStorageClient,
        ):
            logger.debug(f"File {done_client.meta.name} reached postprocess callback")

            with done_client.lock:
                if not success:
                    done_client.failed_transfers += 1
                    logger.debug(
                        f"File {done_client.meta.name} failed to transfer for unexpected exception {result}"
                    )
                # Whether there was an exception or not, we're done the file.
                notify_file_completed()

        _T = TypeVar("_T")

        def function_and_callback_wrapper(
            work: Callable[..., _T],
            _callback: Callable[[bool, _T | Exception, SnowflakeFileMeta], None],
            file_meta: SnowflakeFileMeta,
            *args: Any,
            **kwargs: Any,
        ) -> None:
            """This wrapper makes sure that callbacks are called from the TPEs.

            If the main thread adds a callback to a future that has already been
            fulfilled then the callback is executed by the main thread. This can
            lead to unexpected slowdowns and behavior.
            """
            try:
                result: tuple[bool, _T | Exception] = (
                    True,
                    work(*args, **kwargs),
                )
            except Exception as e:
                logger.error(f"An exception was raised in {repr(work)}", exc_info=True)
                file_meta.error_details = e
                result = (False, e)
            try:
                _callback(*result, file_meta)
            except Exception as e:
                # TODO: if an exception happens in a callback, the exception will not
                #  propagate to the main thread. We need to save these Exceptions
                #  somewhere and then re-raise by the main thread. For now let's log
                #  this exception, but for a long term solution see my
                #  TODO comment for SnowflakeFileMeta
                with cv_main_thread:
                    nonlocal exception_caught_in_callback
                    exception_caught_in_callback = e
                    cv_main_thread.notify()
                if not result[0]:
                    # Re-raising the exception from the work function, it would already
                    #  be logged at this point
                    logger.error(
                        f"An exception was raised in {repr(callback)}", exc_info=True
                    )

        for file_client in files:
            callback = partial(preprocess_done_cb, done_client=file_client)
            if is_upload:
                preprocess_tpe.submit(
                    function_and_callback_wrapper,
                    # Work fn
                    file_client.prepare_upload,
                    # Callback fn
                    callback,
                    file_client.meta,
                )
            else:
                preprocess_tpe.submit(
                    function_and_callback_wrapper,
                    # Work fn
                    file_client.prepare_download,
                    # Callback fn
                    callback,
                    file_client.meta,
                )
            transfer_metadata.num_files_started += 1  # TODO: do we need this?

        with cv_main_thread:
            while transfer_metadata.num_files_completed < num_total_files:
                cv_main_thread.wait()
                if exception_caught_in_callback is not None:
                    raise exception_caught_in_callback

        self._results = metas

    def _create_file_transfer_client(
        self, meta: SnowflakeFileMeta
    ) -> SnowflakeStorageClient:
        from .constants import AZURE_CHUNK_SIZE, S3_CHUNK_SIZE

        if self._stage_location_type == LOCAL_FS:
            return SnowflakeLocalStorageClient(
                meta,
                self._stage_info,
                4 * megabyte,
                use_s3_regional_url=self._use_s3_regional_url,
            )
        elif self._stage_location_type == AZURE_FS:
            return SnowflakeAzureRestClient(
                meta,
                self._credentials,
                AZURE_CHUNK_SIZE,
                self._stage_info,
                use_s3_regional_url=self._use_s3_regional_url,
            )
        elif self._stage_location_type == S3_FS:
            return SnowflakeS3RestClient(
                meta,
                self._credentials,
                self._stage_info,
                S3_CHUNK_SIZE,
                use_s3_regional_url=self._use_s3_regional_url,
            )
        elif self._stage_location_type == GCS_FS:
            return SnowflakeGCSRestClient(
                meta,
                self._credentials,
                self._stage_info,
                self._cursor._connection,
                self._command,
                use_s3_regional_url=self._use_s3_regional_url,
            )
        raise Exception(f"{self._stage_location_type} is an unknown stage type")

    def _transfer_accelerate_config(self) -> None:
        if self._stage_location_type == S3_FS and self._file_metadata:
            client = self._create_file_transfer_client(self._file_metadata[0])
            self._use_accelerate_endpoint = client.transfer_accelerate_config()

    def result(self):
        converter_class = self._cursor._connection.converter_class
        rowset = []
        if self._command_type == CMD_TYPE_UPLOAD:
            if hasattr(self, "_results"):
                for meta in self._results:
                    if meta.src_compression_type is not None:
                        src_compression_type = meta.src_compression_type.name
                    else:
                        src_compression_type = "NONE"

                    if meta.dst_compression_type is not None:
                        dst_compression_type = meta.dst_compression_type.name
                    else:
                        dst_compression_type = "NONE"

                    error_details: str = (
                        repr(meta.error_details)
                        if meta.error_details is not None
                        else ""
                    )

                    src_file_size = (
                        meta.src_file_size
                        if converter_class != SnowflakeConverterSnowSQL
                        else str(meta.src_file_size)
                    )

                    dst_file_size = (
                        meta.dst_file_size
                        if converter_class != SnowflakeConverterSnowSQL
                        else str(meta.dst_file_size)
                    )

                    logger.debug(
                        f"raise_put_get_error: {self._raise_put_get_error}, "
                        f"{meta.result_status}, {type(meta.result_status)}, "
                        f"{ResultStatus.ERROR}, {type(ResultStatus.ERROR)}",
                    )
                    if self._raise_put_get_error and error_details:
                        Error.errorhandler_wrapper(
                            self._cursor.connection,
                            self._cursor,
                            OperationalError,
                            {
                                "msg": "While putting file(s) there was an error: "
                                f"'{error_details}', this might be caused by "
                                f"your access to the blob storage provider, "
                                f"or by Snowflake.",
                                "errno": ER_FAILED_TO_UPLOAD_TO_STAGE,
                            },
                        )
                    rowset.append(
                        [
                            meta.name,
                            meta.dst_file_name,
                            src_file_size,
                            dst_file_size,
                            src_compression_type,
                            dst_compression_type,
                            meta.result_status,
                            error_details,
                        ]
                    )
            return {
                "rowtype": [
                    result_text_column_desc("source"),
                    result_text_column_desc("target"),
                    result_fixed_column_desc("source_size"),
                    result_fixed_column_desc("target_size"),
                    result_text_column_desc("source_compression"),
                    result_text_column_desc("target_compression"),
                    result_text_column_desc("status"),
                    result_text_column_desc("message"),
                ],
                "rowset": sorted(rowset),
            }
        else:  # DOWNLOAD
            if hasattr(self, "_results"):
                for meta in self._results:
                    dst_file_size = (
                        meta.dst_file_size
                        if converter_class != SnowflakeConverterSnowSQL
                        else str(meta.dst_file_size)
                    )

                    error_details: str = (
                        repr(meta.error_details)
                        if meta.error_details is not None
                        else ""
                    )

                    if self._raise_put_get_error and error_details:
                        Error.errorhandler_wrapper(
                            self._cursor.connection,
                            self._cursor,
                            OperationalError,
                            {
                                "msg": "While getting file(s) there was an error: "
                                f"'{error_details}', this might be caused by "
                                f"your access to the blob storage provider, "
                                f"or by Snowflake.",
                                "errno": ER_FAILED_TO_DOWNLOAD_FROM_STAGE,
                            },
                        )

                    rowset.append(
                        [
                            meta.dst_file_name,
                            dst_file_size,
                            meta.result_status,
                            error_details,
                        ]
                    )
            return {
                "rowtype": [
                    result_text_column_desc("file"),
                    result_fixed_column_desc("size"),
                    result_text_column_desc("status"),
                    result_text_column_desc("message"),
                ],
                "rowset": sorted(rowset),
            }

    def _expand_filenames(self, locations):
        canonical_locations = []
        for file_name in locations:
            if self._command_type == CMD_TYPE_UPLOAD:
                file_name = os.path.expanduser(file_name)
                if not os.path.isabs(file_name):
                    file_name = os.path.join(GET_CWD(), file_name)
                if (
                    IS_WINDOWS
                    and len(file_name) > 2
                    and file_name[0] == "/"
                    and file_name[2] == ":"
                ):
                    # Windows path: /C:/data/file1.txt where it starts with slash
                    # followed by a drive letter and colon.
                    file_name = file_name[1:]
                files = glob.glob(file_name)
                canonical_locations += files
            else:
                canonical_locations.append(file_name)

        return canonical_locations

    def _init_encryption_material(self) -> None:
        self._encryption_material = []
        if self._ret["data"].get("encryptionMaterial") is not None:
            root_node = self._ret["data"]["encryptionMaterial"]
            logger.debug(self._command_type)

            if self._command_type == CMD_TYPE_UPLOAD:
                self._encryption_material.append(
                    SnowflakeFileEncryptionMaterial(
                        query_stage_master_key=root_node["queryStageMasterKey"],
                        query_id=root_node["queryId"],
                        smk_id=root_node["smkId"],
                    )
                )
            else:
                for elem in root_node:
                    if elem is not None:
                        self._encryption_material.append(
                            SnowflakeFileEncryptionMaterial(
                                query_stage_master_key=elem["queryStageMasterKey"],
                                query_id=elem["queryId"],
                                smk_id=elem["smkId"],
                            )
                        )

    def _parse_command(self) -> None:

        if "data" not in self._ret:
            Error.errorhandler_wrapper(
                self._cursor.connection,
                self._cursor,
                DatabaseError,
                {
                    "msg": "Failed to parse server's response",
                    "errno": ER_INVALID_STAGE_LOCATION,
                },
            )

        response = self._ret["data"]

        self._command_type = response["command"]

        self._init_encryption_material()

        if not isinstance(response.get("src_locations"), list):
            Error.errorhandler_wrapper(
                self._cursor.connection,
                self._cursor,
                DatabaseError,
                {
                    "msg": "Failed to parse the location",
                    "errno": ER_INVALID_STAGE_LOCATION,
                },
            )

        self._src_locations = response["src_locations"]

        self._parallel = response.get("parallel", 1)
        self._overwrite = self._force_put_overwrite or response.get("overwrite", False)
        self._stage_location_type = response["stageInfo"]["locationType"].upper()

        if self._stage_location_type not in VALID_STORAGE:
            Error.errorhandler_wrapper(
                self._cursor.connection,
                self._cursor,
                OperationalError,
                {
                    "msg": f"Destination location type is not valid: {self._stage_location_type}",
                    "errno": ER_INVALID_STAGE_FS,
                },
            )

        self._stage_location = response["stageInfo"]["location"]
        self._stage_info = response["stageInfo"]
        self._credentials = StorageCredential(
            self._stage_info["creds"], self._cursor.connection, self._command
        )
        self._presigned_urls = self._ret["data"].get("presignedUrls")

        if self._command_type == CMD_TYPE_UPLOAD:
            if self._source_from_stream:
                self._src_files = self._src_locations
            else:
                self._src_files = list(self._expand_filenames(self._src_locations))
            self._auto_compress = (
                "autoCompress" not in response or response["autoCompress"]
            )
            self._source_compression = (
                response["sourceCompression"].lower()
                if "sourceCompression" in response
                else ""
            )
        else:
            self._src_files = list(self._src_locations)
            self._src_file_to_encryption_material = {}
            if len(response["src_locations"]) == len(self._encryption_material):
                for idx, src_file in enumerate(self._src_files):
                    logger.debug(src_file)
                    self._src_file_to_encryption_material[
                        src_file
                    ] = self._encryption_material[idx]
            elif len(self._encryption_material) != 0:
                # some encryption material exists. Zero means no encryption
                Error.errorhandler_wrapper(
                    self._cursor.connection,
                    self._cursor,
                    InternalError,
                    {
                        "msg": (
                            "The number of downloading files doesn't match "
                            f"the encryption materials: files={len(response['src_locations'])}, "
                            f"encmat={len(self._encryption_material)}"
                        ),
                        "errno": ER_INTERNAL_NOT_MATCH_ENCRYPT_MATERIAL,
                    },
                )

            self._local_location = os.path.expanduser(response["localLocation"])
            if not os.path.isdir(self._local_location):
                # NOTE: isdir follows the symlink
                Error.errorhandler_wrapper(
                    self._cursor.connection,
                    self._cursor,
                    ProgrammingError,
                    {
                        "msg": f"The local path is not a directory: {self._local_location}",
                        "errno": ER_LOCAL_PATH_NOT_DIRECTORY,
                    },
                )

    def _init_file_metadata(self) -> None:
        logger.debug(f"command type: {self._command_type}")

        if self._command_type == CMD_TYPE_UPLOAD:
            if not self._src_files:
                file_name = (
                    self._ret["data"]["src_locations"]
                    if "data" in self._ret and "src_locations" in self._ret["data"]
                    else "None"
                )
                Error.errorhandler_wrapper(
                    self._cursor.connection,
                    self._cursor,
                    ProgrammingError,
                    {
                        "msg": f"File doesn't exist: {file_name}",
                        "errno": ER_FILE_NOT_EXISTS,
                    },
                )
            if self._source_from_stream:
                self._file_metadata.append(
                    SnowflakeFileMeta(
                        name=os.path.basename(self._src_files[0]),
                        src_file_name=self._src_files[0],
                        intermediate_stream=self._source_from_stream,
                        src_file_size=self._source_from_stream.seek(0, os.SEEK_END),
                        stage_location_type=self._stage_location_type,
                        encryption_material=self._encryption_material[0]
                        if len(self._encryption_material) > 0
                        else None,
                    )
                )
                self._source_from_stream.seek(0)
            else:
                for file_name in self._src_files:
                    if not os.path.exists(file_name):
                        Error.errorhandler_wrapper(
                            self._cursor.connection,
                            self._cursor,
                            ProgrammingError,
                            {
                                "msg": f"File doesn't exist: {file_name}",
                                "errno": ER_FILE_NOT_EXISTS,
                            },
                        )
                    elif os.path.isdir(file_name):
                        Error.errorhandler_wrapper(
                            self._cursor.connection,
                            self._cursor,
                            ProgrammingError,
                            {
                                "msg": f"Not a file but a directory: {file_name}",
                                "errno": ER_FILE_NOT_EXISTS,
                            },
                        )
                    statinfo = os.stat(file_name)
                    self._file_metadata.append(
                        SnowflakeFileMeta(
                            name=os.path.basename(file_name),
                            src_file_name=file_name,
                            src_file_size=statinfo.st_size,
                            stage_location_type=self._stage_location_type,
                            encryption_material=self._encryption_material[0]
                            if len(self._encryption_material) > 0
                            else None,
                        )
                    )
        elif self._command_type == CMD_TYPE_DOWNLOAD:
            for idx, file_name in enumerate(self._src_files):
                if not file_name:
                    continue
                first_path_sep = file_name.find("/")
                dst_file_name = (
                    file_name[first_path_sep + 1 :]
                    if first_path_sep >= 0
                    else file_name
                )
                url = None
                if self._presigned_urls and idx < len(self._presigned_urls):
                    url = self._presigned_urls[idx]
                self._file_metadata.append(
                    SnowflakeFileMeta(
                        name=os.path.basename(file_name),
                        src_file_name=file_name,
                        dst_file_name=dst_file_name,
                        stage_location_type=self._stage_location_type,
                        local_location=self._local_location,
                        presigned_url=url,
                        encryption_material=self._src_file_to_encryption_material[
                            file_name
                        ]
                        if file_name in self._src_file_to_encryption_material
                        else None,
                    )
                )

    def _process_file_compression_type(self) -> None:
        user_specified_source_compression = None
        if self._source_compression == "auto_detect":
            auto_detect = True
        elif self._source_compression == "none":
            auto_detect = False
        else:
            user_specified_source_compression = lookup_by_mime_sub_type(
                self._source_compression
            )
            if (
                user_specified_source_compression is None
                or not user_specified_source_compression.is_supported
            ):
                Error.errorhandler_wrapper(
                    self._cursor.connection,
                    self._cursor,
                    ProgrammingError,
                    {
                        "msg": f"Feature is not supported: {user_specified_source_compression}",
                        "errno": ER_COMPRESSION_NOT_SUPPORTED,
                    },
                )

            auto_detect = False

        for m in self._file_metadata:
            file_name = m.src_file_name

            current_file_compression_type: CompressionType | None = None
            if auto_detect:
                mimetypes.init()
                _, encoding = mimetypes.guess_type(file_name)

                if encoding is None:
                    test = None
                    if not self._source_from_stream:
                        with open(file_name, "rb") as f:
                            test = f.read(4)
                    else:
                        test = self._source_from_stream.read(4)
                        self._source_from_stream.seek(0)
                    if file_name.endswith(".br"):
                        encoding = "br"
                    elif test and test[:3] == b"ORC":
                        encoding = "orc"
                    elif test and test == b"PAR1":
                        encoding = "parquet"
                    elif test and (int(binascii.hexlify(test), 16) == 0x28B52FFD):
                        encoding = "zstd"

                if encoding is not None:
                    logger.debug(
                        f"detected the encoding {encoding}: file={file_name}",
                    )
                    current_file_compression_type = lookup_by_mime_sub_type(encoding)
                else:
                    logger.debug(f"no file encoding was detected: file={file_name}")

                if (
                    current_file_compression_type is not None
                    and not current_file_compression_type.is_supported
                ):
                    Error.errorhandler_wrapper(
                        self._cursor.connection,
                        self._cursor,
                        ProgrammingError,
                        {
                            "msg": f"Feature is not supported: {current_file_compression_type}",
                            "errno": ER_COMPRESSION_NOT_SUPPORTED,
                        },
                    )
            else:
                current_file_compression_type = user_specified_source_compression

            if current_file_compression_type is not None:
                m.src_compression_type = current_file_compression_type
                if current_file_compression_type.is_supported:
                    m.dst_compression_type = current_file_compression_type
                    m.require_compress = False
                    m.dst_file_name = m.name
                else:
                    Error.errorhandler_wrapper(
                        self._cursor.connection,
                        self._cursor,
                        ProgrammingError,
                        {
                            "msg": f"Feature is not supported: {current_file_compression_type}",
                            "errno": ER_COMPRESSION_NOT_SUPPORTED,
                        },
                    )
            else:
                # src is not compressed but the destination want to be
                # compressed unless the users disable it
                m.require_compress = self._auto_compress
                m.src_compression_type = None
                if self._auto_compress:
                    m.dst_file_name = m.name + CompressionTypes["GZIP"].file_extension
                    m.dst_compression_type = CompressionTypes["GZIP"]
                else:
                    m.dst_file_name = m.name
                    m.dst_compression_type = None