Why Gemfury? Push, build, and install  RubyGems npm packages Python packages Maven artifacts PHP packages Go Modules Debian packages RPM packages NuGet packages

Repository URL to install this package:

Details    
namara-python / client.py
Size: Mime:
import requests, os
from dotenv import load_dotenv
from urllib import parse

from namara_python.query import Query
from namara_python.version import __version__
import namara_python.cache as cache

from namara_python.utils import to_camel_case

from namara_python.rpc.catalog_twirp import CatalogServiceClient
from namara_python.rpc.category_twirp import CategoryServiceClient
from namara_python.rpc.dataspec_twirp import DataspecServiceClient
from namara_python.rpc.exports_twirp import ExportsServiceClient
from namara_python.rpc.grants_twirp import GrantsServiceClient
from namara_python.rpc.organizations_twirp import OrganizationsServiceClient
from namara_python.rpc.query_twirp import QueryServiceClient
from namara_python.rpc.references_twirp import ReferencesServiceClient
from namara_python.rpc.sources_twirp import SourcesServiceClient
from namara_python.rpc.warehouses_twirp import WarehousesServiceClient

# must import these so that they are available to the symbol database which
# registers all the rpc objects, which is required before method generation begins
import namara_python.rpc.catalog_pb2 as catalog_pb2
import namara_python.rpc.category_pb2 as category_pb2
import namara_python.rpc.dataspec_pb2 as dataspec_pb2
import namara_python.rpc.exports_pb2 as exports_pb2
import namara_python.rpc.grants_pb2 as grants_pb2
import namara_python.rpc.organizations_pb2 as organizations_pb2
import namara_python.rpc.query_pb2 as query_pb2
import namara_python.rpc.references_pb2 as references_pb2
import namara_python.rpc.sources_pb2 as sources_pb2
import namara_python.rpc.warehouses_pb2 as warehouses_pb2

import inspect
import json
import requests
from time import sleep
from inflection import underscore
from io import StringIO
from pandas import read_csv
from mimetypes import guess_type
from tempfile import NamedTemporaryFile

from twirp.context import Context

from typing import Callable, Any
from google.protobuf.json_format import MessageToDict
from google.protobuf import symbol_database as _symbol_database

_sym_db = _symbol_database.Default()

SERVICE_DESCRIPTORS = {
    organizations_pb2._ORGANIZATIONSSERVICE: OrganizationsServiceClient,
    catalog_pb2._CATALOGSERVICE: CatalogServiceClient,
    category_pb2._CATEGORYSERVICE: CategoryServiceClient,
    dataspec_pb2._DATASPECSERVICE: DataspecServiceClient,
    exports_pb2._EXPORTSSERVICE: ExportsServiceClient,
    grants_pb2._GRANTSSERVICE: GrantsServiceClient,
    query_pb2._QUERYSERVICE: QueryServiceClient,
    references_pb2._REFERENCESSERVICE: ReferencesServiceClient,
    sources_pb2._SOURCESSERVICE: SourcesServiceClient,
    warehouses_pb2._WAREHOUSESSERVICE: WarehousesServiceClient
}

# Gets used on client creation to store initialized RPC Service clients (ie.
# the values of SERVICE_DESCRIPTORS)
RPC_CLIENTS_CACHE = {}

# Will change the rpc method name to value given here when generating the methods on
# Client (ie. 'list_datasets' for Category service will be called 'list_topic_datasets' instead).
# This is to avoid known conlicting names from other services (in this example
# this 'list_datasets' exists for Catalog service as well)
RPC_METHOD_CONFLICTS_RESOLVER = {
    'category.CategoryService.ListDatasets': 'ListTopicDatasets'
}

TWIRP_FUNCTION = Callable[[Any, Any], Any]

load_dotenv()
cache.load_cache_index()

LOGIN_PATH = 'users/jwt.json'


class ExportFailedError(Exception):
    """
    this error is raised when an export fails
    """
    def __init__(self, export_id):
        super().__init__(f"Export {export_id} failed: Please contact admin")


class Client:
    def __init__(self, *args, **kwargs):
        self.api_key = kwargs.get('api_key', os.getenv('NAMARA_API_KEY'))

        if not self.api_key:
            raise AttributeError('Missing api_key')

        self.base_auth_url = kwargs.get('auth_url', os.getenv('NAMARA_AUTH_URL', 'https://account.ee.namara.io'))
        self.base_api_url = kwargs.get('api_url', os.getenv('NAMARA_API_URL', 'https://api.ee.namara.io'))

        self._generate_clients()
        self._generate_client_methods()

        headers = {
                'X-API-Key': self.api_key, 
                'x-namara-client': 'python', 
                'x-namara-client-version': __version__,
                'credentials': 'include'
                }
        res = requests.get('%s/users/profile.json' %self.base_auth_url, headers=headers)
        res.raise_for_status()

        self.profile = res.json()


    def query(self, *args, **kwargs):
        return Query(self).query(*args, **kwargs)

    def get(self, query:str):
        ''' Given a niql query, export the result from Namara and return a
        Pandas dataframe with the loaded data.

        Since export is an async operation, we have to poll Namara to see when
        the exported file content is available for download. This will then
        read the file content directly into a Dataframe (without writing it to
        disk)
        '''
        df = cache.dataset_from_cache(query)
        if df is not None:
            return df

        export = self.create_export(query=query)
        print("Starting export")

        # poll for when the export is done
        export_in_progress = True
        while export_in_progress:
            finished_export_details = None

            # when the export was already been done before, this 'export' key
            # will be available right away.
            if 'export' in export:
                finished_export_details = self.get_export(id=export['export']['id'])
            else:
                finished_export_details = self.get_export(id=export['id'])

            if 'export' in finished_export_details:
                state = finished_export_details['export']['state']
            else:
                state = finished_export_details['state']

            if state == 'finished':
                break
            elif state == 'failed':
                raise ExportFailedError(finished_export_details['id'])

            sleep(2)
            print("Export in progress...")

        url = finished_export_details['file_url']
        res = requests.get(url, allow_redirects=True)

        data = str(res.content, 'utf-8')

        cache.cache_dataset(query, data, res.headers['ETag'], url)

        df = read_csv(StringIO(data))

        return df

    def put(self, data_frame, dataset_name:str, organization_id:str, warehouse_name:str=None, reference_id:str=None):
        ''' given a dataframe and an organization_id, attempt to upload
        the file to the Namara google bucket, from which it will be ingested
        into the Namara data warehouse.



        if a warehouse name is provided it will be used otherwise the first warehouse
        from a users warehouse list will be used. This will only be useful if a reference
        id has not been provided, otherwise the ingest will run according to
        the references warehouse.

        if a reference is provided, the dataset for the given reference will be
        updated to the current file, otherwise a new reference and dataset will be
        created

        Since ingest is an async operation we will have to poll it until the file
        has finished ingesting onto namara platform

        return the created or updated dataset reference
        '''

        #we will load the dataframe into a temporary file, then attempt to upload the temporary
        #file to namara
        tf = NamedTemporaryFile(suffix=".csv", mode='w', encoding='utf-8')
        data_frame.to_csv(tf.name, index=False)
        size = str(os.stat(tf.name).st_size)
        mime_type = guess_type(tf.name)[0]

        sign_url_resp = self.get_import_lite_signed_url(filename=dataset_name, content_length=size, content_type=mime_type, organization_id=organization_id)

        # here we create a google signed url through the sources service
        goog_headers = {"x-goog-meta-original-file": dataset_name, "x-goog-resumable": "start"}
        post_res = requests.post(url=sign_url_resp.get('signed_url'), headers=goog_headers)

        # now we actually load the file into the google bucket via the signed url
        f = open(tf.name, mode='rb')

        location = post_res.headers.get('location')
        put_response = requests.put(location, f)

        # if we have a reference id retrieve the reference, otherwise create a new reference
        if reference_id is not None:
            ref = self.get_reference(organization_id=organization_id, id=reference_id)

            # updates the file url reference to the new file
            self.update_reference(reference={
                'id': reference_id,
                'organization_id': organization_id,
                'source_id': sign_url_resp.get('source_id'),
                'url': sign_url_resp.get('data_reference_url'),
                'user_id': self.profile.get('id'),
                'warehouse': ref.get('warehouse')
            })
        else:
            if warehouse_name != None:
                wh = warehouse_name
            else:
                whs = self.list_warehouses()
                wh = whs[0].get('name')

            ref = {
                'organization_id': organization_id,
                'user_id': self.profile.get('id'),
                'source_id': sign_url_resp.get('source_id'),
                'url': sign_url_resp.get('data_reference_url'),
                'warehouse': wh
                }

            ref = self.create_reference(reference=ref)

        #begin the ingest reference process
        job = self.ingest_reference(id=ref.get('id'), organization_id=ref.get('organization_id'))

        # poll the reference until it is either imported or failed
        while job.get('state') != 'imported' and job.get('state') != 'failed':
            job = self.get_reference(organization_id=organization_id, id=ref.get('id'))
            print('Reference state: %s' %job.get('state'))
            sleep(1)

        # destroy the temporary file used to upload the dataframe
        tf.close()

        # returns latest dataset reference object
        return job

    def _construct_auth_url(self):
        return parse.urljoin(self.base_auth_url, LOGIN_PATH)


    def _generate_clients(self):
        ''' Creates new RPC clients bound to the current Client token and caches them for later use
        '''
        for _, client in SERVICE_DESCRIPTORS.items():
            RPC_CLIENTS_CACHE[client.__name__] = client(self.base_api_url)

    def _generate_client_methods(self):
        ''' Dynamically generates RPC methods for each service on to this object
        '''
        for service_descriptor, service_client in SERVICE_DESCRIPTORS.items():
            for method_descriptor in service_descriptor.methods:
                func_name = method_descriptor.name
                twirp_func = getattr(service_client, func_name)

                # Skip generating the Query method since we have our own implementation of it
                if func_name == "Query" and service_descriptor.name == 'QueryService':
                    continue

                func_name = RPC_METHOD_CONFLICTS_RESOLVER.get(method_descriptor.full_name, func_name)
                func_name = underscore(func_name)

                proto_request_name = method_descriptor.input_type.full_name
                # create decorators around the existing twirp functions to handle input
                # and output in a pythonic way (ie. using python objects rather than Twirp objects)
                new_func = _deserialize_output(_serialize_input(twirp_func, proto_request_name))

                setattr(self.__class__, func_name, new_func)


def _serialize_input(twirp_func:TWIRP_FUNCTION, proto_request_name: str) -> TWIRP_FUNCTION:
    def wrapper(self, *args, **kwargs):
        client_name = _get_rpc_client_name(twirp_func)
        client = RPC_CLIENTS_CACHE[client_name]

        kwargs = _serialize_filter(kwargs)
        serialized_input_obj = _sym_db.GetSymbol(proto_request_name)(**kwargs)

        # twirp Context objects allow header overrides
        ctx = Context()
        ctx.set_header('x-api-key', self.api_key)
        ctx.set_header('x-namara-client', 'python')
        ctx.set_header('x-namara-client-version', __version__)

        return twirp_func(client, ctx=ctx, request=serialized_input_obj)

    return wrapper


def _deserialize_output(twirp_func:TWIRP_FUNCTION) -> TWIRP_FUNCTION:
    def wrapper(self, *args, **kwargs):
        resp = twirp_func(self, *args, **kwargs)

        result = MessageToDict(resp, preserving_proto_field_name=True)

        # a lot of rpc responses have a single top level key, so for
        # convenience sake, return just its value
        if len(result.keys()) == 1:
            result = result[list(result.keys())[0]]

        return result

    return wrapper

def _get_rpc_client_name(twirp_func):
    # __qualname__ will be of the form <rpc clien cls name>.<twirp_func>
    # ie. 'CatalogServiceClient.list_datasets'
    return twirp_func.__qualname__.split('.')[0]

def _serialize_filter(params):
    if "filter" not in params:
        return params

    new_filter = {}
    f = params["filter"]

    if "limit" in f:
        new_filter["limit"] = {"value": f["limit"]}

    if "offset" in f:
        new_filter["offset"] = {"value": f["offset"]}

    if "sort_field" in f:
        new_filter["sort"] = {
            "value": f["sort_field"],
            "order": f.get("sort_order", "ASC")
        }

    if "query" in f:
        new_filter["query"] = {"value": f["query"]}

    if "user_id" in f:
        new_filter["user_id"] = {"value": f["user_id"]}

    params["filter"].update(new_filter)
    return  params