Why Gemfury? Push, build, and install  RubyGems npm packages Python packages Maven artifacts PHP packages Go Modules Debian packages RPM packages NuGet packages

Repository URL to install this package:

Details    
getdaft / unity_catalog / unity_catalog.py
Size: Mime:
from __future__ import annotations

import dataclasses
import warnings
from typing import TYPE_CHECKING, Callable, Literal
from urllib.parse import urlparse

import unitycatalog

from daft.io import AzureConfig, IOConfig, S3Config

if TYPE_CHECKING:
    from unitycatalog.types import TableInfo


@dataclasses.dataclass(frozen=True)
class UnityCatalogTable:
    table_info: TableInfo
    table_uri: str
    io_config: IOConfig | None


class UnityCatalog:
    """Client to access the Unity Catalog.

    Unity Catalog is an open-sourced data catalog that can be self-hosted, or hosted by Databricks.

    Example of reading a dataframe from a table in Unity Catalog hosted by Databricks:

    >>> cat = UnityCatalog("https://<databricks_workspace_id>.cloud.databricks.com", token="my-token")
    >>> table = cat.load_table("my_catalog.my_schema.my_table")
    >>> df = daft.read_deltalake(table)
    """

    def __init__(self, endpoint: str, token: str | None = None):
        self._client = unitycatalog.Unitycatalog(
            base_url=endpoint.rstrip("/") + "/api/2.1/unity-catalog/",
            default_headers={"Authorization": f"Bearer {token}"},
        )

    def _paginate_to_completion(
        self,
        client_func_call: Callable[[unitycatalog.Unitycatalog, str | None], tuple[list[str] | None, str | None]],
    ) -> list[str]:
        results = []

        # Make first request
        new_results, next_page_token = client_func_call(self._client, None)
        if new_results is not None:
            results.extend(new_results)

        # Exhaust pages
        while next_page_token is not None and next_page_token != "":
            new_results, next_page_token = client_func_call(self._client, next_page_token)
            if new_results is not None:
                results.extend(new_results)

        return results

    def list_catalogs(self) -> list[str]:
        def _paginated_list_catalogs(client: unitycatalog.Unitycatalog, page_token: str | None):
            response = client.catalogs.list(page_token=page_token)
            next_page_token = response.next_page_token
            if response.catalogs is None:
                return None, next_page_token
            return [c.name for c in response.catalogs], next_page_token

        return self._paginate_to_completion(_paginated_list_catalogs)

    def list_schemas(self, catalog_name: str) -> list[str]:
        def _paginated_list_schemas(client: unitycatalog.Unitycatalog, page_token: str | None):
            response = client.schemas.list(catalog_name=catalog_name, page_token=page_token)
            next_page_token = response.next_page_token
            if response.schemas is None:
                return None, next_page_token
            return [s.full_name for s in response.schemas], next_page_token

        return self._paginate_to_completion(_paginated_list_schemas)

    def list_tables(self, schema_name: str):
        if schema_name.count(".") != 1:
            raise ValueError(
                f"Expected fully-qualified schema name with format `catalog_name`.`schema_name`, but received: {schema_name}"
            )

        catalog_name, schema_name = schema_name.split(".")

        def _paginated_list_tables(client: unitycatalog.Unitycatalog, page_token: str | None):
            response = client.tables.list(catalog_name=catalog_name, schema_name=schema_name, page_token=page_token)
            next_page_token = response.next_page_token
            if response.tables is None:
                return None, next_page_token
            return [f"{t.catalog_name}.{t.schema_name}.{t.name}" for t in response.tables], next_page_token

        return self._paginate_to_completion(_paginated_list_tables)

    def load_table(
        self,
        table_name: str,
        new_table_storage_path: str | None = None,
        operation: Literal["READ" | "READ_WRITE"] = "READ_WRITE",
    ) -> UnityCatalogTable:
        """Loads an existing Unity Catalog table. If the table is not found, and information is provided in the method to create a new table, a new table will be attempted to be registered.

        Args:
            table_name (str): Name of the table in Unity Catalog in the form of dot-separated, 3-level namespace
            new_table_storage_path (str, optional): Cloud storage path URI to register a new external table using this path. Unity Catalog will validate if the path is valid and authorized for the principal, else will raise an exception.
            operation ("READ" or "READ_WRITE", optional): The intended use of the table, which impacts authorization. Defaults to "READ_WRITE".

        Returns:
            UnityCatalogTable
        """
        # Load the table ID
        try:
            table_info = self._client.tables.retrieve(table_name)
            if new_table_storage_path:
                warnings.warn(
                    f"Table {table_name} is an existing storage table with a valid storage path. The 'new_table_storage_path' argument provided will be ignored."
                )
        except unitycatalog.NotFoundError:
            if not new_table_storage_path:
                raise ValueError(
                    f"Table {table_name} is not an existing table. If a new table needs to be created, provide 'new_table_storage_path' value."
                )
            try:
                three_part_namesplit = table_name.split(".")
                if len(three_part_namesplit) != 3 or not all(three_part_namesplit):
                    raise ValueError(
                        f"Expected table name to be in the format of 'catalog.schema.table', received: {table_name}"
                    )

                params = {
                    "catalog_name": three_part_namesplit[0],
                    "schema_name": three_part_namesplit[1],
                    "name": three_part_namesplit[2],
                    "columns": None,
                    "data_source_format": "DELTA",
                    "table_type": "EXTERNAL",
                    "storage_location": new_table_storage_path,
                    "comment": None,
                }

                table_info = self._client.tables.create(**params)
            except Exception as e:
                raise Exception(f"An error occurred while registering the table in Unity Catalog: {e}")

        table_id = table_info.table_id
        storage_location = table_info.storage_location
        # Grab credentials from Unity catalog and place it into the Table
        temp_table_credentials = self._client.temporary_table_credentials.create(operation=operation, table_id=table_id)

        scheme = urlparse(storage_location).scheme
        if scheme == "s3" or scheme == "s3a":
            aws_temp_credentials = temp_table_credentials.aws_temp_credentials
            io_config = (
                IOConfig(
                    s3=S3Config(
                        key_id=aws_temp_credentials.access_key_id,
                        access_key=aws_temp_credentials.secret_access_key,
                        session_token=aws_temp_credentials.session_token,
                    )
                )
                if aws_temp_credentials is not None
                else None
            )
        elif scheme == "gcs" or scheme == "gs":
            # TO-DO: gather GCS credential vending assets from Unity and construct 'io_config``
            warnings.warn("GCS credential vending from Unity Catalog is not yet supported.")
            io_config = None
        elif scheme == "az" or scheme == "abfs" or scheme == "abfss":
            io_config = IOConfig(
                azure=AzureConfig(sas_token=temp_table_credentials.azure_user_delegation_sas.get("sas_token"))
            )
        else:
            warnings.warn(f"Credentials for scheme {scheme} are not yet supported.")
            io_config = None

        return UnityCatalogTable(
            table_info=table_info,
            table_uri=storage_location,
            io_config=io_config,
        )