Why Gemfury? Push, build, and install  RubyGems npm packages Python packages Maven artifacts PHP packages Go Modules Debian packages RPM packages NuGet packages

Repository URL to install this package:

Details    
tdw-catalog / tdw_catalog / utils.py
Size: Mime:
from dataclasses import dataclass
from enum import Enum, IntEnum
import math
import gzip
from io import BytesIO
import sys
import typing
import zlib

from aiohttp import ClientResponse, ClientSession
if sys.version_info >= (3, 11):
    from enum import StrEnum
else:
    from backports.strenum import StrEnum
from typing import BinaryIO, Dict, Optional, TYPE_CHECKING, Callable, List, Union
import collections
from sys import getsizeof
from datetime import datetime
import dateutil
from tdw_catalog.errors import CatalogException

if TYPE_CHECKING:
    from urllib.request import Request


class _ExportFormat(IntEnum):
    CSV = 0
    PARQUET = 1
    CSV_GZIP = 2


class FilterSortOrder(Enum):
    ASC = 1
    DESC = 2


class ConnectionPortalType(StrEnum):
    GS = 'Gs'
    S3 = 'S3'
    UNITY = 'Unity'
    FTP = 'Ftp'
    SFTP = 'Sftp'
    EXTERNAL = 'External'
    NULL = 'Null'
    IMPORT_LITE = 'ImportLite'
    HTTP = 'Http'
    CATALOG = 'Namara'


class ConnectionMode(IntEnum):
    """
    The mode/type of connection to filter by.
    """
    SOURCE = 0
    DESTINATION = 1
    ALL = 10


# TODO remove and replace with constants from namara-go generated code
class MetadataFieldType(IntEnum):
    """
    The different possible data types for values stored in MetadataFields and default values stored in MetadataTemplateFields
    """
    FT_STRING = 0,
    FT_INTEGER = 1,
    FT_DECIMAL = 2,
    FT_DATE = 3,
    FT_DATETIME = 4,
    FT_DATASET = 5,
    FT_URL = 6,
    FT_USER = 7,
    FT_ATTACHMENT = 8,
    FT_LIST = 9,
    FT_CURRENCY = 10,
    FT_TEAM = 11
    FT_ALIAS = 12


class ColumnType(StrEnum):
    """
    The different possible data types for :class:`.Column`\ s within a :class:`.DataDictionary`
    """
    BOOLEAN = 'boolean'
    DATE = 'date'
    DATETIME = 'datetime'
    INTEGER = 'integer'
    DECIMAL = 'decimal'
    PERCENT = 'percent'
    CURRENCY = 'currency'
    STRING = 'string'
    TEXT = 'text'
    GEOMETRY = 'geometry'
    GEOJSON = 'geojson'


class ImportState(StrEnum):
    """
    The different possible states an imported dataset might occupy. Virtualized datasets
    will always show state ``IMPORTED``.
    """
    IMPORTED = 'imported'
    IMPORTING = 'importing'
    QUEUED = 'queued'
    FAILED = 'failed'


class IngestPipeline(IntEnum):
    """
    The different possible ingest pipelines for references.
    """
    JABBA = 0
    EXTERNAL_WAREHOUSE = 1
    CREATE_AS_QUERY = 2
    CREATE_AS_QUERY_BATCH = 3
    SENZING_ER = 4


@dataclass
class CurrencyFieldValue():
    """
    :class:`.CurrencyFieldValue` models the value of a currency field

    Attributes
    ----------
    value : float
        The currency value
    currency : str
        The specific currency to which the value belongs
    """
    value: float
    currency: str


@dataclass
class FilterSort:
    """
    :class:`.FilterSort` describes a desired sort field and order for results.

    Attributes
    ----------
    field : str
        The field to sort by
    order : FilterSortOrder, optional
        The order to sort in (`FilterSortOrder.ASC` by default)
    """
    field: str
    order: FilterSortOrder = FilterSortOrder.ASC


@dataclass
class LegacyFilter:
    """
    :class:`.LegacyFilter` describes the ways in which results should be filtered and/or
    paginated

    Attributes
    ----------
    limit : int, optional
        Limits the number of results. Useful for pagination. (`None` by default)
    offset : int, optional
        Offsets the result list by the given number of results. Useful for pagination. (`None` by default)
    """
    limit: int = None
    offset: int = None

    def serialize(self):
        new_filter = {}
        if self.limit is not None:
            new_filter["limit"] = {"value": self.limit}

        if self.offset is not None:
            new_filter["offset"] = {"value": self.offset}
        return new_filter


class Filter(LegacyFilter):
    """
    :class:`.Filter` describes the ways in which results should be filtered and/or
    paginated. It is serialized in a new way vs :class:`.LegacyFilter`

    Attributes
    ----------
    limit : int, optional
        Limits the number of results. Useful for pagination. (`None` by default)
    offset : int, optional
        Offsets the result list by the given number of results. Useful for pagination. (`None` by default)
    """

    def serialize(self):
        new_filter = super().serialize()
        new_filter["offset"] = {"offset": self.offset}
        return new_filter


@dataclass
class SortableFilter(LegacyFilter):
    """
    :class:`.SortableFilter` describes the ways in which results should be filtered,
    paginated and/or sorted.

    Attributes
    ----------
    limit : int, optional
        Limits the number of results. Useful for pagination. (`None` by default)
    offset : int, optional
        Offsets the result list by the given number of results. Useful for pagination. (`None` by default)
    sort : FilterSort, optional
        Specifies a desired sort field and order for results (`None` by default).
    """
    sort: FilterSort = None

    def serialize(self):
        new_filter = super().serialize()
        if self.sort is not None:
            new_filter["sort"] = {
                "value":
                self.sort.field,
                "order":
                "ASC" if self.sort.order == FilterSortOrder.ASC else "DESC"
            }
        return new_filter


@dataclass
class ListOrganizationsFilter(LegacyFilter):
    """
    :class:`.ListOrganizationsFilter` filters :class:`.Organization` results according to a set of provided ids

    Attributes
    ----------
    organization_ids : str[], optional
        Filters results according to a set of provided ids
    """
    organization_ids: Optional[List[str]] = None

    def serialize(self):
        new_filter = super().serialize()
        if self.organization_ids is not None:
            new_filter["organization_ids"] = self.organization_ids
        return new_filter


@dataclass
class ListSourcesFilter(LegacyFilter):
    """
    :class:`.ListSourcesFilter` filters results according to :class:`.Source` fields

    Attributes
    ----------
    labels : Optional[str]
        Filters results by label.  This will match label substrings.
    """
    labels: Optional[str] = None

    def serialize(self):
        new_filter = super().serialize()
        if self.labels is not None:
            new_filter["labels"] = list(self.labels)

        return new_filter


@dataclass
class ListConnectionsFilter(LegacyFilter):
    """
    :class:`.ListConnectionsFilter` filters results according to Connection fields

    Attributes
    ----------
    organization_id : Optional[str]
        Filters results by `organization_id`
    source_ids : Optional[List[str]]
        Filters results to the given `source_id(s)`
    portals : Optional[List[ConnectionPortalType]]
        Filters results to the given :class:`.ConnectionPortalType`\\ (s)
    mode : Optional[ConnectionMode]
        Filters results by connection mode (SOURCE, DESTINATION, or ALL)
    """
    organization_id: Optional[str] = None
    source_ids: Optional[List[str]] = None
    portals: Optional[List[ConnectionPortalType]] = None
    mode: Optional['ConnectionMode'] = None

    def serialize(self):
        new_filter = super().serialize()
        if self.organization_id is not None:
            new_filter["organization_id"] = self.organization_id
        if self.source_ids is not None:
            new_filter["source_ids"] = self.source_ids
        if self.portals is not None:
            new_filter["portals"] = self.portals
        if self.mode is not None:
            new_filter["mode"] = int(self.mode)

        return new_filter


@dataclass
class ListGlossaryTermsFilter(Filter):
    """
    :class:`.ListGlossaryTermsFilter` filters results according to :class:`.GlossaryTerm` ids

    Attributes
    ----------
    glossary_term_ids : Optional[List[str]]
        Filters results to the given `glossary_term_id(s)`
    """

    glossary_term_ids: Optional[List[str]] = None

    def serialize(self):
        new_filter = super().serialize()
        if self.glossary_term_ids is not None:
            new_filter["glossary_term_ids"] = self.glossary_term_ids

        return new_filter


@dataclass
class QueryFilter(SortableFilter):
    """
    :class:`.QueryFilter` filters results according to a NiQL query

    Attributes
    ----------
    query : str, optional
        Filters results according to a NiQL query
    """
    query: Optional[str] = None

    def serialize(self):
        new_filter = SortableFilter.serialize(self)
        if self.query is not None:
            new_filter["query"] = {"value": self.query}
        return new_filter


def _parse_timestamp(timestamp: Union[str, datetime, dict]) -> datetime:
    if isinstance(timestamp, datetime):
        return timestamp
    # handle NullableTimestamp
    elif isinstance(timestamp, dict) and ('timestamp' in timestamp
                                          or 'is_null' in timestamp):
        if 'is_null' in timestamp and timestamp['is_null'] == True:
            return None
        return dateutil.parser.parse(timestamp["timestamp"])
    else:
        return dateutil.parser.parse(timestamp)


def _convert_datetime_to_nullable_timestamp(date: datetime):
    timestamp = date.timestamp()
    seconds = int(timestamp)
    # getting the trailing numbers and converting to nanoseconds
    nanos = int(((timestamp % 1) * 1000) * 10000)

    return {
        "is_null": False,
        "timestamp": {
            "seconds": seconds,
            "nanos": nanos,
        }
    }


# with help from https://stackoverflow.com/questions/56832881/check-if-a-field-is-typing-optional
def _type_is_optional(field):
    return typing.get_origin(field) is Union and \
           type(None) in typing.get_args(field)


async def _download_export(
        download_url: str,
        format: Optional[_ExportFormat] = None,
        f_out: Optional[BinaryIO] = None) -> Optional[Union[str, BinaryIO]]:
    async with ClientSession() as session:
        # TODO support gunzip
        async with session.get(download_url, allow_redirects=True) as response:
            if f_out is not None:
                decompressor = zlib.decompressobj(
                    16 + zlib.MAX_WBITS
                ) if format == _ExportFormat.CSV_GZIP else None
                try:
                    async for chunk in response.content.iter_chunked(4096):
                        if decompressor is not None:
                            dchunk = decompressor.decompress(chunk)
                            decompressor.flush()
                            f_out.write(dchunk)
                        else:
                            f_out.write(chunk)
                    return None
                finally:
                    f_out.flush()
            else:
                async with session.get(download_url,
                                       allow_redirects=True) as response:
                    if format is _ExportFormat.PARQUET:
                        return BytesIO(await response.read())
                    elif format is _ExportFormat.CSV_GZIP:
                        return gzip.decompress(await
                                               response.read()).decode("utf-8")
                    else:
                        return await response.text("utf-8")