Why Gemfury? Push, build, and install  RubyGems npm packages Python packages Maven artifacts PHP packages Go Modules Debian packages RPM packages NuGet packages

Repository URL to install this package:

Details    
tdw-catalog / tdw_catalog / _client.py
Size: Mime:
import os
from typing import Any, Callable, Optional
from urllib import parse

import requests
from dotenv import load_dotenv
from google.protobuf import symbol_database as _symbol_database
from google.protobuf.json_format import MessageToDict
from inflection import underscore
from twirp.context import Context

# must import these so that they are available to the symbol database which
# registers all the rpc objects, which is required before method generation begins
import tdw_catalog.rpc.catalog.catalog_pb2 as catalog_pb2
import tdw_catalog.rpc.category.category_pb2 as category_pb2
import tdw_catalog.rpc.dataspec.dataspec_pb2 as dataspec_pb2
import tdw_catalog.rpc.exports.exports_pb2 as exports_pb2
import tdw_catalog.rpc.grants.grants_pb2 as grants_pb2
import tdw_catalog.rpc.organizations.organizations_pb2 as organizations_pb2
import tdw_catalog.rpc.query.query_pb2 as query_pb2
import tdw_catalog.rpc.references.references_pb2 as references_pb2
import tdw_catalog.rpc.sources.sources_pb2 as sources_pb2
from tdw_catalog.rpc.uploads import uploads_pb2
from tdw_catalog.rpc.uploads.uploads_twirp import UploadsServiceClient
import tdw_catalog.rpc.warehouses.warehouses_pb2 as warehouses_pb2
from tdw_catalog.errors import (CatalogException,
                                CatalogUnauthenticatedException,
                                _convert_error)
from tdw_catalog.rpc.catalog.catalog_twirp import CatalogServiceClient
from tdw_catalog.rpc.category.category_twirp import CategoryServiceClient
from tdw_catalog.rpc.dataspec.dataspec_twirp import DataspecV2ServiceClient
from tdw_catalog.rpc.exports.exports_twirp import ExportsServiceClient
from tdw_catalog.rpc.grants.grants_twirp import GrantsServiceClient
from tdw_catalog.rpc.organizations.organizations_twirp import OrganizationsServiceClient
from tdw_catalog.rpc.query.query_twirp import QueryServiceClient
from tdw_catalog.rpc.references.references_twirp import ReferencesServiceClient
from tdw_catalog.rpc.sources.sources_twirp import SourcesServiceClient
from tdw_catalog.rpc.warehouses.warehouses_twirp import WarehousesServiceClient
from tdw_catalog.utils import (LegacyFilter)

_sym_db = _symbol_database.Default()

SERVICE_DESCRIPTORS = {
    organizations_pb2._ORGANIZATIONSSERVICE: OrganizationsServiceClient,
    catalog_pb2._CATALOGSERVICE: CatalogServiceClient,
    category_pb2._CATEGORYSERVICE: CategoryServiceClient,
    dataspec_pb2._DATASPECV2SERVICE: DataspecV2ServiceClient,
    exports_pb2._EXPORTSSERVICE: ExportsServiceClient,
    grants_pb2._GRANTSSERVICE: GrantsServiceClient,
    query_pb2._QUERYSERVICE: QueryServiceClient,
    references_pb2._REFERENCESSERVICE: ReferencesServiceClient,
    sources_pb2._SOURCESSERVICE: SourcesServiceClient,
    warehouses_pb2._WAREHOUSESSERVICE: WarehousesServiceClient,
    uploads_pb2._UPLOADSSERVICE: UploadsServiceClient
}

# Gets used on client creation to store initialized RPC Service clients (ie.
# the values of SERVICE_DESCRIPTORS)
RPC_CLIENTS_CACHE = {}

# Will change the rpc method name to value given here when generating the methods on
# Client (ie. 'list_datasets' for Category service will be called 'list_topic_datasets' instead).
# This is to avoid known conlicting names from other services (in this example
# this 'list_datasets' exists for Catalog service as well)
RPC_METHOD_CONFLICTS_RESOLVER = {
    'category.CategoryService.ListDatasets': 'ListTopicDatasets',
    'sources.SourcesService.ListWarehouses': 'SourcesListWarehouses',
}

TWIRP_FUNCTION = Callable[[Any, Any], Any]

load_dotenv()

LOGIN_PATH = 'users/jwt.json'


class _Client:
    _org_context: Optional[str]

    def __init__(self, *args, **kwargs):
        self.api_key = kwargs.get('api_key', os.getenv('CATALOG_API_KEY'))

        if not self.api_key:
            raise AttributeError('Missing api_key')

        self.base_auth_url = (kwargs.get('auth_url') or os.getenv(
            'CATALOG_AUTH_URL', 'https://account.ee.namara.io')).rstrip('/')
        self.base_api_url = (kwargs.get('api_url') or os.getenv(
            'CATALOG_API_URL', 'https://api.ee.namara.io')).rstrip('/')
        if 'org_context' in kwargs:
            self._org_context = kwargs.get('org_context')
        else:
            self._org_context = None

        self.__set_client_profile()

        self.__generate_clients()
        self.__generate_client_methods()

    def __clone_with_org_context(self, org_id: str) -> '_Client':
        c = _Client(api_key=self.api_key,
                    auth_url=self.base_auth_url,
                    api_url=self.base_api_url,
                    org_context=org_id)
        return c

    # TODO revisit before release. should involve an Account/Profile class?
    def __set_client_profile(self):
        try:
            user_profile_url = parse.urljoin(self.base_auth_url,
                                             '/users/profile.json')
            res = self._make_catalog_request(user_profile_url)
            if (res.status_code == 401):
                raise CatalogUnauthenticatedException(
                    message=
                    f'Unable to fetch client profile for given API key: {res.reason}'
                )
            elif (res.status_code != 200):
                raise CatalogException(
                    message=
                    f'Unable to fetch client profile for given API key: {res.reason}'
                )
            self.profile = res.json()
        except Exception as e:
            raise _convert_error(e)

    def _make_catalog_request(self, url):
        headers = {
            'X-API-Key': self.api_key,
            'X-NAMARA-ACTIVITY-CHANNEL': 'catalog-python',
            'X-NAMARA-ACTIVITY-FLAVOR': 'USER',
            'credentials': 'include'
        }
        if self._org_context is not None:
            headers['X-NAMARA-ORGANIZATION-CONTEXT'] = self._org_context
        return requests.get(url, headers=headers)

    def __construct_auth_url(self):
        return parse.urljoin(self.base_auth_url, LOGIN_PATH)

    def __generate_clients(self):
        """
        Creates new RPC clients bound to the current Client token and caches them for later use
        """
        for _, client in SERVICE_DESCRIPTORS.items():
            RPC_CLIENTS_CACHE[client.__name__] = client(self.base_api_url)

    def __generate_client_methods(self):
        """
        Dynamically generates RPC methods for each service on to this object
        """
        for service_descriptor, service_client in SERVICE_DESCRIPTORS.items():
            for method_descriptor in service_descriptor.methods:
                func_name = method_descriptor.name
                twirp_func = getattr(service_client, func_name)

                func_name = RPC_METHOD_CONFLICTS_RESOLVER.get(
                    method_descriptor.full_name, func_name)
                func_name = '_' + underscore(func_name)

                proto_request_name = method_descriptor.input_type.full_name
                # create decorators around the existing twirp functions to handle input
                # and output in a pythonic way (ie. using python objects rather than Twirp objects)
                new_func = _deserialize_output(
                    _serialize_input(twirp_func, proto_request_name,
                                     self._org_context))

                setattr(self.__class__, func_name, new_func)


def _serialize_input(twirp_func: TWIRP_FUNCTION,
                     proto_request_name: str,
                     org_context: Optional[str] = None) -> TWIRP_FUNCTION:

    def wrapper(self, *args, **kwargs):
        client_name = _get_rpc_client_name(twirp_func)
        client = RPC_CLIENTS_CACHE[client_name]

        kwargs = _serialize_filter(kwargs)
        serialized_input_obj = _sym_db.GetSymbol(proto_request_name)(**kwargs)

        # twirp Context objects allow header overrides
        ctx = Context()
        ctx.set_header('x-api-key', self.api_key)
        ctx.set_header('X-NAMARA-ACTIVITY-CHANNEL', 'SDK')
        ctx.set_header('X-NAMARA-ACTIVITY-FLAVOR', 'USER')
        if org_context is not None:
            ctx.set_header('X-NAMARA-ORGANIZATION-CONTEXT', org_context)
        # timeout all RPC calls after 60 seconds, or 60 minutes for Query
        timeout = (60 * 60) if client_name == "QueryServiceClient" else 60
        return twirp_func(client,
                          ctx=ctx,
                          request=serialized_input_obj,
                          timeout=timeout)

    return wrapper


def _deserialize_output(twirp_func: TWIRP_FUNCTION) -> TWIRP_FUNCTION:

    def wrapper(self, *args, **kwargs):
        try:
            resp = twirp_func(self, *args, **kwargs)

            result = MessageToDict(resp, preserving_proto_field_name=True)

            # a lot of rpc responses have a single top level key, so for
            # convenience sake, return just its value
            if len(result.keys()) == 1:
                result = result[list(result.keys())[0]]

            return result
        except CatalogException as ce:
            raise ce
        except Exception as e:
            raise _convert_error(e)

    return wrapper


def _get_rpc_client_name(twirp_func):
    # __qualname__ will be of the form <rpc clien cls name>.<twirp_func>
    # ie. 'CatalogServiceClient.list_datasets'
    return twirp_func.__qualname__.split('.')[0]


def _serialize_filter(params):
    if "filter" not in params:
        return params
    f = params["filter"]
    if isinstance(f, LegacyFilter):
        params["filter"] = f.serialize()
    return params