Why Gemfury? Push, build, and install  RubyGems npm packages Python packages Maven artifacts PHP packages Go Modules Debian packages RPM packages NuGet packages

Repository URL to install this package:

Details    
namara-python / client.py
Size: Mime:
import requests, os
from dotenv import load_dotenv
from urllib import parse

from namara_python.query import Query

from namara_python.utils import gen_request_decorator
from namara_python.utils import to_camel_case

from namara_python.rpc.catalog_pb2_twirp import CatalogServiceClient
from namara_python.rpc.category_pb2_twirp import CategoryServiceClient
from namara_python.rpc.dataspec_pb2_twirp import DataspecServiceClient
from namara_python.rpc.exports_pb2_twirp import ExportsServiceClient
from namara_python.rpc.grants_pb2_twirp import GrantsServiceClient
from namara_python.rpc.organizations_pb2_twirp import OrganizationsServiceClient
from namara_python.rpc.query_pb2_twirp import QueryServiceClient
from namara_python.rpc.references_pb2_twirp import ReferencesServiceClient
from namara_python.rpc.sources_pb2_twirp import SourcesServiceClient
from namara_python.rpc.warehouses_pb2_twirp import WarehousesServiceClient

# must import these so that they are available to the symbol database which
# registers all the rpc objects, which is required before method generation begins
import namara_python.rpc.catalog_pb2
import namara_python.rpc.category_pb2
import namara_python.rpc.dataspec_pb2
import namara_python.rpc.exports_pb2
import namara_python.rpc.grants_pb2
import namara_python.rpc.organizations_pb2
import namara_python.rpc.query_pb2
import namara_python.rpc.references_pb2
import namara_python.rpc.sources_pb2
import namara_python.rpc.warehouses_pb2

import inspect
import json
import requests
from time import sleep
from io import StringIO
from pandas import read_csv
from mimetypes import guess_type
from tempfile import NamedTemporaryFile

from typing import Callable, Any
from google.protobuf.json_format import MessageToDict
from google.protobuf import symbol_database as _symbol_database

_sym_db = _symbol_database.Default()

RPC_CLIENTS= [
    CatalogServiceClient,
    CategoryServiceClient,
    DataspecServiceClient,
    ExportsServiceClient,
    GrantsServiceClient,
    OrganizationsServiceClient,
    QueryServiceClient,
    ReferencesServiceClient,
    SourcesServiceClient,
    WarehousesServiceClient,
]

# Gets used on client creation to store initialized RPC Service clients (ie.
# instances in RPC_CLIENTS)
RPC_CLIENTS_CACHE = {}

# Will change the rpc method name to value given here when generating the methods on
# Client (ie. 'list_datasets' for Category service will be called 'list_topic_datasets' instead).
# This is to avoid known conlicting names from other services (in this example
# this 'list_datasets' exists for Catalog service as well)
RPC_METHOD_CONFLICTS_RESOLVER = {
    "CategoryServiceClient.list_datasets": "list_topic_datasets"
}

TWIRP_FUNCTION = Callable[[Any, Any], Any]

load_dotenv()

LOGIN_PATH = 'users/jwt.json'

class Client:
    def __init__(self, *args, **kwargs):
        self.api_key = kwargs.get('api_key', os.getenv('NAMARA_API_KEY'))

        if not self.api_key:
            raise AttributeError('Missing api_key')

        self.base_auth_url = kwargs.get('auth_url', os.getenv('NAMARA_AUTH_URL'))
        self.base_api_url = kwargs.get('api_url', os.getenv('NAMARA_API_URL'))

        if (not self.base_auth_url) or (not self.base_api_url):
            raise AttributeError('Missing NAMARA_AUTH_URL or NAMRA_API_URL environment variable')

        url = self._construct_auth_url()

        response = requests.post(url, headers={'X-API-Key': self.api_key})

        if response.status_code != requests.codes.ok:
            response.raise_for_status()

        token = response.json().get('jwt')

        if not token:
            response.status_code = 401
            response.raise_for_status()

        self.token = token

        self.generate_clients(token)
        self.generate_client_methods()
        self.profile = requests.get('%s/users/profile.json' %self.base_auth_url, headers={'X-API-Key': self.api_key, 'credentials': 'include'}).json()

    def generate_clients(self, token):
        ''' Creates new RPC clients bound to the current Client token and caches them for later use
        '''
        request_decorator = gen_request_decorator(token)
        for client in RPC_CLIENTS:
            RPC_CLIENTS_CACHE[client.__name__] = client(self.base_api_url, request_decorator)

    def generate_client_methods(self):
        ''' Dynamically generates RPC methods for each service on to this object
        '''
        for rpc_client in RPC_CLIENTS:
            for func_name, twirp_func in rpc_client.__dict__.items():
                if func_name.startswith('_'):
                    continue

                # We override this method to parse the results in a more user friendly way
                if func_name == "query" and rpc_client == QueryServiceClient:
                    continue

                func_name = RPC_METHOD_CONFLICTS_RESOLVER.get(twirp_func.__qualname__, func_name)

                # create decorators around the existing twirp functions to handle input
                # and output in a pythonic way (ie. using python objects rather than Twirp objects)
                new_func = deserialize_output(serialize_input(twirp_func))
                setattr(self.__class__, func_name, new_func)

    def query(self, *args, **kwargs):
        return Query(self).query(*args, **kwargs)

    def get(self, query:str):
        ''' Given a niql query, export the result from Namara and return a
        Pandas dataframe with the loaded data.

        Since export is an async operation, we have to poll Namara to see when
        the exported file content is available for download. This will then
        read the file content directly into a Dataframe (without writing it to
        disk)
        '''
        export = self.create_export(query=query)

        # poll for when the export is done
        export_in_progress = True
        while export_in_progress:
            finished_export_details = self.get_export(id=export['export']['id'])

            if finished_export_details != None and finished_export_details['export'] != None and finished_export_details['export']['state'] == 'finished':
                export_in_progress = False

            sleep(2)

        res = requests.get(finished_export_details['file_url'], allow_redirects=True)

        data = StringIO(str(res.content, 'utf-8'))
        df = read_csv(data)

        return df

    def put(self, data_frame, dataset_name:str, organization_id:str, warehouse_name:str=None, reference_id:str=None):
        ''' given a dataframe and an organization_id, attempt to upload
        the file to the Namara google bucket, from which it will be ingested
        into the Namara data warehouse.



        if a warehouse name is provided it will be used otherwise the first warehouse
        from a users warehouse list will be used. This will only be useful if a reference
        id has not been provided, otherwise the ingest will run according to
        the references warehouse.

        if a reference is provided, the dataset for the given reference will be
        updated to the current file, otherwise a new reference and dataset will be
        created

        Since ingest is an async operation we will have to poll it until the file
        has finished ingesting onto namara platform

        return the created or updated dataset reference
        '''

        #we will load the dataframe into a temporary file, then attempt to upload the temporary
        #file to namara
        tf = NamedTemporaryFile(suffix=".csv")
        data_frame.to_csv(tf.name, index=False)
        size = str(os.stat(tf.name).st_size)
        mime_type = guess_type(tf.name)[0]

        sign_url_resp = self.get_import_lite_signed_url(filename=dataset_name, content_length=size, content_type=mime_type, organization_id=organization_id)

        # here we create a google signed url through the sources service
        goog_headers = {"x-goog-meta-original-file": dataset_name, "x-goog-resumable": "start"}
        post_res = requests.post(url=sign_url_resp.get('signed_url'), headers=goog_headers)

        # now we actually load the file into the google bucket via the signed url
        f = open(tf.name)
        location = post_res.headers.get('location')
        put_response = requests.put(location, f)

        # if we have a reference id retrieve the reference, otherwise create a new reference
        if reference_id != None:
            ref = self.get_reference(organization_id=organization_id, id=reference_id)
        else:
            if warehouse_name != None:
                wh = warehouse_name
            else:
                whs = self.list_warehouses()
                wh = whs[0].get('name')

            ref = {
                'organization_id': organization_id,
                'user_id': self.profile.get('id'),
                'source_id': sign_url_resp.get('source_id'),
                'url': sign_url_resp.get('data_reference_url'),
                'warehouse': wh
                }

            ref = self.create_reference(reference=ref)

        #begin the ingest reference process
        self.ingest_reference(id=ref.get('id'), organization_id=ref.get('organization_id'))

        # poll the reference until it is either imported or failed
        while ref.get('state') != 'imported' and ref.get('state') != 'failed':
            ref = self.get_reference(organization_id=organization_id, id=ref.get('id'))
            print('Reference state: %s' %ref.get('state'))
            sleep(1)

        # destroy the temporary file used to upload the dataframe
        tf.close()
        return ref

    def _construct_auth_url(self):
        return parse.urljoin(self.base_auth_url, LOGIN_PATH)


def serialize_input(twirp_func:TWIRP_FUNCTION) -> TWIRP_FUNCTION:
    def wrapper(self, *args, **kwargs):
        client_name = _get_rpc_client_name(twirp_func)
        client = RPC_CLIENTS_CACHE[client_name]
        input_obj_name = _get_input_object_name(twirp_func, client)

        kwargs = _serialize_filter(kwargs)
        serialized_input_obj = _sym_db.GetSymbol(input_obj_name)(**kwargs)

        # here self is an instance of NamaraPython.<class>, which all have
        # `rpc_client` as an instance variable which points to the Twirp generated
        # client object
        return twirp_func(client, serialized_input_obj)

    return wrapper


def deserialize_output(twirp_func:TWIRP_FUNCTION) -> TWIRP_FUNCTION:
    def wrapper(self, *args, **kwargs):
        resp = twirp_func(self, *args, **kwargs)

        result = MessageToDict(resp, preserving_proto_field_name=True)

        # a lot of rpc responses have a single top level key, so for
        # convenience sake, return just its value
        if len(result.keys()) == 1:
            result = result[list(result.keys())[0]]

        return result

    return wrapper

def _get_rpc_client_name(twirp_func):
    # __qualname__ will be of the form <rpc clien cls name>.<twirp_func>
    # ie. 'CatalogServiceClient.list_datasets'
    return twirp_func.__qualname__.split('.')[0]


def _get_input_object_name(twirp_func, rpc_client):
    ''' This funtion will return a string of the form `<service_name>.<function_input_object>`
        ie. `organizations.GetOrganizationReqeust`

        It  relies on _name_ of the second parameter of `twirp_func`.
        This depends on parameter names being the same as the actual protofbuf objects
        i.e get_organization_request, which camelized will give
        GetOrganizationRequest, which is the correct input proto object for
        `get_organization` method.

        This will work as long the python twirp client generates the function
        signatures in this manner (which is stable and shouldn't be a problem)
    '''

    input_obj_name = inspect.getargspec(twirp_func).args[1]
    proper_input_name = to_camel_case(input_obj_name)

    #  cls_name = type(namara_lib_instance.rpc_client).__name__
    cls_name = rpc_client.__class__.__name__

    internal_var_name = '_' + cls_name + '__service_name'
    service_name = getattr(rpc_client, internal_var_name).split('.')[0]

    return service_name + '.' + proper_input_name


def _serialize_filter(params):
    if "filter" not in params:
        return params

    new_filter = {}
    f = params["filter"]

    if "limit" in f:
        new_filter["limit"] = {"value": f["limit"]}

    if "offset" in f:
        new_filter["limit"] = {"value": f["offset"]}

    if "sort_field" in f:
        new_filter["sort"] = {
            "value": f["sort_field"],
            "order": f.get("sort_order", "ASC")
        }

    if "query" in f:
        new_filter["query"] = {"value": f["query"]}

    if "user_id" in f:
        new_filter["user_id"] = {"value": f["user_id"]}

    params["filter"].update(new_filter)
    return  params