Why Gemfury? Push, build, and install  RubyGems npm packages Python packages Maven artifacts PHP packages Go Modules Debian packages RPM packages NuGet packages

Repository URL to install this package:

Details    
getdaft / delta_lake / delta_lake_scan.py
Size: Mime:
from __future__ import annotations

import logging
import os
from typing import TYPE_CHECKING
from urllib.parse import urlparse

from deltalake.table import DeltaTable

import daft
import daft.exceptions
from daft.daft import (
    FileFormatConfig,
    ParquetSourceConfig,
    Pushdowns,
    S3Config,
    ScanTask,
    StorageConfig,
)
from daft.io.aws_config import boto3_client_from_s3_config
from daft.io.object_store_options import io_config_to_storage_options
from daft.io.scan import PartitionField, ScanOperator
from daft.logical.schema import Schema

if TYPE_CHECKING:
    from collections.abc import Iterator
    from datetime import datetime

logger = logging.getLogger(__name__)


class DeltaLakeScanOperator(ScanOperator):
    def __init__(
        self, table_uri: str, storage_config: StorageConfig, version: int | str | datetime | None = None
    ) -> None:
        super().__init__()

        # Unfortunately delta-rs doesn't do very good inference of credentials for S3. Thus the current Daft behavior of passing
        # in `None` for credentials will cause issues when instantiating the DeltaTable without credentials.
        #
        # Thus, if we don't detect any credentials being available, we attempt to detect it from the environment using our Daft credentials chain.
        #
        # See: https://github.com/delta-io/delta-rs/issues/2117
        deltalake_sdk_io_config = storage_config.io_config
        scheme = urlparse(table_uri).scheme
        if scheme == "s3" or scheme == "s3a":
            # Try to get region from boto3
            if deltalake_sdk_io_config.s3.region_name is None:
                from botocore.exceptions import BotoCoreError

                try:
                    client = boto3_client_from_s3_config("s3", deltalake_sdk_io_config.s3)
                    response = client.get_bucket_location(Bucket=urlparse(table_uri).netloc)
                except BotoCoreError as e:
                    logger.warning(
                        "Failed to get the S3 bucket region using existing storage config, will attempt to get it from the environment instead. Error from boto3: %s",
                        e,
                    )
                else:
                    deltalake_sdk_io_config = deltalake_sdk_io_config.replace(
                        s3=deltalake_sdk_io_config.s3.replace(region_name=response["LocationConstraint"])
                    )

            # Try to get config from the environment
            if any([deltalake_sdk_io_config.s3.key_id is None, deltalake_sdk_io_config.s3.region_name is None]):
                try:
                    s3_config_from_env = S3Config.from_env()
                # Sometimes S3Config.from_env throws an error, for example on CI machines with weird metadata servers.
                except daft.exceptions.DaftCoreException:
                    pass
                else:
                    if (
                        deltalake_sdk_io_config.s3.key_id is None
                        and deltalake_sdk_io_config.s3.access_key is None
                        and deltalake_sdk_io_config.s3.session_token is None
                    ):
                        deltalake_sdk_io_config = deltalake_sdk_io_config.replace(
                            s3=deltalake_sdk_io_config.s3.replace(
                                key_id=s3_config_from_env.key_id,
                                access_key=s3_config_from_env.access_key,
                                session_token=s3_config_from_env.session_token,
                            )
                        )
                    if deltalake_sdk_io_config.s3.region_name is None:
                        deltalake_sdk_io_config = deltalake_sdk_io_config.replace(
                            s3=deltalake_sdk_io_config.s3.replace(
                                region_name=s3_config_from_env.region_name,
                            )
                        )
        elif scheme == "gcs" or scheme == "gs":
            # TO-DO: Handle any key-value replacements in `io_config` if there are missing elements
            pass
        elif scheme == "az" or scheme == "abfs" or scheme == "abfss":
            # TO-DO: Handle any key-value replacements in `io_config` if there are missing elements
            pass

        self._table = DeltaTable(
            table_uri, storage_options=io_config_to_storage_options(deltalake_sdk_io_config, table_uri)
        )

        if version is not None:
            self._table.load_as_version(version)

        self._storage_config = storage_config
        self._schema = Schema.from_pyarrow_schema(self._table.schema().to_pyarrow())
        partition_columns = set(self._table.metadata().partition_columns)
        self._partition_keys = [
            PartitionField(field._field) for field in self._schema if field.name in partition_columns
        ]

    def schema(self) -> Schema:
        return self._schema

    def name(self) -> str:
        return "DeltaLakeScanOperator"

    def display_name(self) -> str:
        return f"DeltaLakeScanOperator({self._table.metadata().name})"

    def partitioning_keys(self) -> list[PartitionField]:
        return self._partition_keys

    def multiline_display(self) -> list[str]:
        return [
            self.display_name(),
            f"Schema = {self._schema}",
            f"Partitioning keys = {self.partitioning_keys()}",
            # TODO(Clark): Improve repr of storage config here.
            f"Storage config = {self._storage_config}",
        ]

    def to_scan_tasks(self, pushdowns: Pushdowns) -> Iterator[ScanTask]:
        import pyarrow as pa

        # TODO(Clark): Push limit and filter expressions into deltalake action fetch, to prune the files returned.
        # Issue: https://github.com/Eventual-Inc/Daft/issues/1953
        add_actions: pa.RecordBatch = self._table.get_add_actions()

        if len(self.partitioning_keys()) > 0 and pushdowns.partition_filters is None:
            logger.warning(
                "%s has partitioning keys = %s, but no partition filter was specified. This will result in a full table scan.",
                self.display_name(),
                self.partitioning_keys(),
            )

        # TODO(Clark): Add support for deletion vectors.
        # Issue: https://github.com/Eventual-Inc/Daft/issues/1954
        if "deletionVector" in add_actions.schema.names:
            raise NotImplementedError(
                "Delta Lake deletion vectors are not yet supported; please let the Daft team know if you'd like to see this feature!\n"
                "Deletion records can be dropped from this table to allow it to be read with Daft: https://docs.delta.io/latest/delta-drop-feature.html"
            )

        # TODO(Clark): Add support for column mappings.
        # Issue: https://github.com/Eventual-Inc/Daft/issues/1955

        limit_files = pushdowns.limit is not None and pushdowns.filters is None and pushdowns.partition_filters is None
        rows_left = pushdowns.limit if pushdowns.limit is not None else 0
        scan_tasks = []
        is_partitioned = (
            "partition_values" in add_actions.schema.names
            and add_actions.schema.field("partition_values").type.num_fields > 0
        )
        for task_idx in range(add_actions.num_rows):
            if limit_files and rows_left <= 0:
                break

            # NOTE: The paths in the transaction log consist of the post-table-uri suffix.
            path = os.path.join(self._table.table_uri, add_actions["path"][task_idx].as_py())

            try:
                record_count = add_actions["num_records"][task_idx].as_py()
            except KeyError:
                record_count = None

            try:
                size_bytes = add_actions["size_bytes"][task_idx].as_py()
            except KeyError:
                size_bytes = None
            file_format_config = FileFormatConfig.from_parquet_config(ParquetSourceConfig())

            if is_partitioned:
                dtype = add_actions.schema.field("partition_values").type
                part_values = add_actions["partition_values"][task_idx]
                arrays = {}
                for field_idx in range(dtype.num_fields):
                    field_name = dtype.field(field_idx).name
                    try:
                        arrow_arr = pa.array([part_values[field_name]], type=dtype.field(field_idx).type)
                    except (pa.ArrowInvalid, pa.ArrowTypeError, pa.ArrowNotImplementedError):
                        # pyarrow < 13.0.0 doesn't accept pyarrow scalars in the array constructor.
                        arrow_arr = pa.array([part_values[field_name].as_py()], type=dtype.field(field_idx).type)
                    arrays[field_name] = daft.Series.from_arrow(arrow_arr, field_name)
                partition_values = daft.recordbatch.RecordBatch.from_pydict(arrays)._table
            else:
                partition_values = None

            # Populate scan task with column-wise stats.
            schema_names = add_actions.schema.names
            if "min" in schema_names and "max" in schema_names:
                dtype = add_actions.schema.field("min").type
                min_values = add_actions["min"][task_idx]
                max_values = add_actions["max"][task_idx]
                # TODO(Clark): Add support for tracking null counts in column stats.
                # null_counts = add_actions["null_count"][task_idx]
                arrays = {}
                for field_idx in range(dtype.num_fields):
                    field_name = dtype.field(field_idx).name
                    try:
                        arrow_arr = pa.array(
                            [min_values[field_name], max_values[field_name]], type=dtype.field(field_idx).type
                        )
                    except (pa.ArrowInvalid, pa.ArrowTypeError, pa.ArrowNotImplementedError):
                        # pyarrow < 13.0.0 doesn't accept pyarrow scalars in the array constructor.
                        arrow_arr = pa.array(
                            [min_values[field_name].as_py(), max_values[field_name].as_py()],
                            type=dtype.field(field_idx).type,
                        )
                    arrays[field_name] = daft.Series.from_arrow(arrow_arr, field_name)
                stats = daft.recordbatch.RecordBatch.from_pydict(arrays)
            else:
                stats = None
            st = ScanTask.catalog_scan_task(
                file=path,
                file_format=file_format_config,
                schema=self._schema._schema,
                num_rows=record_count,
                storage_config=self._storage_config,
                size_bytes=size_bytes,
                iceberg_delete_files=None,
                pushdowns=pushdowns,
                partition_values=partition_values,
                stats=stats._table if stats is not None else None,
            )
            if st is None:
                continue
            if record_count is not None:
                rows_left -= record_count
            scan_tasks.append(st)
        return iter(scan_tasks)

    def can_absorb_filter(self) -> bool:
        return False

    def can_absorb_limit(self) -> bool:
        return False

    def can_absorb_select(self) -> bool:
        return True