Why Gemfury? Push, build, and install  RubyGems npm packages Python packages Maven artifacts PHP packages Go Modules Debian packages RPM packages NuGet packages

Repository URL to install this package:

Details    
namara-python / client.py
Size: Mime:
import requests, os
from dotenv import load_dotenv
from urllib import parse

from namara_python.query import Query
from namara_python.version import __version__
import namara_python.cache as cache

from namara_python.utils import to_camel_case
from rpc.catalog.catalog_twirp import CatalogServiceClient
from rpc.category.category_twirp import CategoryServiceClient
from rpc.dataspec.dataspec_twirp import DataspecV2ServiceClient
from rpc.exports.exports_twirp import ExportsServiceClient
from rpc.grants.grants_twirp import GrantsServiceClient
from rpc.organizations.organizations_twirp import OrganizationsServiceClient
from rpc.query.query_twirp import QueryServiceClient
from rpc.references.references_twirp import ReferencesServiceClient
from rpc.sources.sources_twirp import SourcesServiceClient
from rpc.warehouses.warehouses_twirp import WarehousesServiceClient

# must import these so that they are available to the symbol database which
# registers all the rpc objects, which is required before method generation begins
import rpc.catalog.catalog_pb2 as catalog_pb2
import rpc.category.category_pb2 as category_pb2
import rpc.dataspec.dataspec_pb2 as dataspec_pb2
import rpc.exports.exports_pb2 as exports_pb2
import rpc.grants.grants_pb2 as grants_pb2
import rpc.organizations.organizations_pb2 as organizations_pb2
import rpc.query.query_pb2 as query_pb2
import rpc.references.references_pb2 as references_pb2
import rpc.sources.sources_pb2 as sources_pb2
import rpc.warehouses.warehouses_pb2 as warehouses_pb2

import inspect
import json
import requests
from time import sleep
from inflection import underscore
from io import StringIO
from pandas import read_csv
from mimetypes import guess_type
from tempfile import NamedTemporaryFile

from twirp.context import Context

from typing import Callable, Any
from google.protobuf.json_format import MessageToDict
from google.protobuf import symbol_database as _symbol_database
from google.protobuf.timestamp_pb2 import Timestamp

_sym_db = _symbol_database.Default()

SERVICE_DESCRIPTORS = {
    organizations_pb2._ORGANIZATIONSSERVICE: OrganizationsServiceClient,
    catalog_pb2._CATALOGSERVICE: CatalogServiceClient,
    category_pb2._CATEGORYSERVICE: CategoryServiceClient,
    dataspec_pb2._DATASPECV2SERVICE: DataspecV2ServiceClient,
    exports_pb2._EXPORTSSERVICE: ExportsServiceClient,
    grants_pb2._GRANTSSERVICE: GrantsServiceClient,
    query_pb2._QUERYSERVICE: QueryServiceClient,
    references_pb2._REFERENCESSERVICE: ReferencesServiceClient,
    sources_pb2._SOURCESSERVICE: SourcesServiceClient,
    warehouses_pb2._WAREHOUSESSERVICE: WarehousesServiceClient
}

# Gets used on client creation to store initialized RPC Service clients (ie.
# the values of SERVICE_DESCRIPTORS)
RPC_CLIENTS_CACHE = {}

# Will change the rpc method name to value given here when generating the methods on
# Client (ie. 'list_datasets' for Category service will be called 'list_topic_datasets' instead).
# This is to avoid known conlicting names from other services (in this example
# this 'list_datasets' exists for Catalog service as well)
RPC_METHOD_CONFLICTS_RESOLVER = {
    'category.CategoryService.ListDatasets': 'ListTopicDatasets'
}

TWIRP_FUNCTION = Callable[[Any, Any], Any]

load_dotenv()
cache.load_cache_index()

LOGIN_PATH = 'users/jwt.json'


class ExportFailedError(Exception):
    """
    this error is raised when an export fails
    """
    def __init__(self, export_id):
        super().__init__(f"Export {export_id} failed: Please contact admin")


class Client:
    def __init__(self, *args, **kwargs):
        self.api_key = kwargs.get('api_key', os.getenv('NAMARA_API_KEY'))

        if not self.api_key:
            raise AttributeError('Missing api_key')

        self.base_auth_url = kwargs.get('auth_url', os.getenv('NAMARA_AUTH_URL', 'https://account.ee.namara.io'))
        self.base_api_url = kwargs.get('api_url', os.getenv('NAMARA_API_URL', 'https://api.ee.namara.io'))

        self._generate_clients()
        self._generate_client_methods()

        headers = {
                'X-API-Key': self.api_key,
                'x-namara-client': 'python',
                'x-namara-client-version': __version__,
                'credentials': 'include'
                }
        res = requests.get('%s/users/profile.json' %self.base_auth_url, headers=headers)
        res.raise_for_status()

        self.profile = res.json()


    def query(self, *args, **kwargs):
        return Query(self).query(*args, **kwargs)

    def export_dataset(self, query:str):
        ''' Given a niql query, export the result from Namara and
        return the url of downloadable file.
        '''

        export = self.create_export(query=query)
        print("Starting export")

        # poll for when the export is done
        export_in_progress = True
        while export_in_progress:
            finished_export_details = None

            # when the export was already been done before, this 'export' key
            # will be available right away.
            if 'export' in export:
                finished_export_details = self.get_export(id=export['export']['id'])
            else:
                finished_export_details = self.get_export(id=export['id'])

            if 'export' in finished_export_details:
                state = finished_export_details['export']['state']
            else:
                state = finished_export_details['state']

            if state == 'finished':
                break
            elif state == 'failed':
                raise ExportFailedError(finished_export_details['id'])

            sleep(2)
            print("Export in progress...")

        url = finished_export_details['file_url']
        return url

    def get(self, query:str):
        ''' Given a niql query, export the result from Namara and return a
        Pandas dataframe with the loaded data.

        Since export is an async operation, we have to poll Namara to see when
        the exported file content is available for download. This will then
        read the file content directly into a Dataframe (without writing it to
        disk)
        '''
        df = cache.dataset_from_cache(query)
        if df is not None:
            return df

        url = self.export_dataset(query)
        res = requests.get(url, allow_redirects=True)

        data = str(res.content, 'utf-8')

        cache.cache_dataset(query, data, res.headers['ETag'], url)

        df = read_csv(StringIO(data))

        return df

    def put(self, data_frame, dataset_name:str, organization_id:str, warehouse_name:str=None, reference_id:str=None):
        ''' given a dataframe and an organization_id, attempt to upload
        the file to the Namara google bucket, from which it will be ingested
        into the Namara data warehouse.



        if a warehouse name is provided it will be used otherwise the first warehouse
        from a users warehouse list will be used. This will only be useful if a reference
        id has not been provided, otherwise the ingest will run according to
        the references warehouse.

        if a reference is provided, the dataset for the given reference will be
        updated to the current file, otherwise a new reference and dataset will be
        created

        Since ingest is an async operation we will have to poll it until the file
        has finished ingesting onto namara platform

        return the created or updated dataset reference
        '''

        #we will load the dataframe into a temporary file, then attempt to upload the temporary
        #file to namara
        tf = NamedTemporaryFile(suffix=".csv", mode='w', encoding='utf-8')
        data_frame.to_csv(tf.name, index=False)
        size = str(os.stat(tf.name).st_size)
        mime_type = guess_type(tf.name)[0]

        sign_url_resp = self.get_import_lite_signed_url(filename=dataset_name, content_length=size, content_type=mime_type, organization_id=organization_id)

        # here we create a google signed url through the sources service
        goog_headers = {"x-goog-meta-original-file": dataset_name, "x-goog-resumable": "start"}
        post_res = requests.post(url=sign_url_resp.get('signed_url'), headers=goog_headers)

        # now we actually load the file into the google bucket via the signed url
        f = open(tf.name, mode='rb')

        location = post_res.headers.get('location')
        put_response = requests.put(location, f)

        # if we have a reference id retrieve the reference, otherwise create a new reference
        if reference_id is not None:
            ref = self.get_reference(organization_id=organization_id, id=reference_id)
            ref['source_id'] = sign_url_resp.get('source_id')
            ref['url'] = sign_url_resp.get('data_reference_url')

            # removing timestamp fields manually. we dont support
            # updating these anyways but there's an issue currently serializing
            # these into the proper protobuf NullableTimestamp. We'll fix this
            # in another issue (platform/2558)
            ref.pop('created_at', None)
            ref.pop('updated_at', None)
            ref.pop('imported_at', None)
            ref.pop('failed_at', None)
            ref.pop('marked_for_destruction_at',None)
            ref.pop('next_ingest', None)
            ref.pop('started_at', None)

            # updates the file url reference to the new file
            self.update_reference(reference=ref)
        else:
            if warehouse_name != None:
                wh = warehouse_name
            else:
                whs = self.list_warehouses()
                wh = whs[0].get('name')

            ref = {
                'organization_id': organization_id,
                'user_id': self.profile.get('id'),
                'source_id': sign_url_resp.get('source_id'),
                'url': sign_url_resp.get('data_reference_url'),
                'warehouse': wh
                }

            ref = self.create_reference(reference=ref)

        #begin the ingest reference process
        job = self.ingest_reference(id=ref.get('id'), organization_id=ref.get('organization_id'))

        # poll the reference until it is either imported or failed
        while job.get('state') != 'imported' and job.get('state') != 'failed':
            job = self.get_reference(organization_id=organization_id, id=ref.get('id'))
            print('Reference state: %s' %job.get('state'))
            sleep(1)

        # destroy the temporary file used to upload the dataframe
        tf.close()

        # returns latest dataset reference object
        return job

    def _construct_auth_url(self):
        return parse.urljoin(self.base_auth_url, LOGIN_PATH)


    def _generate_clients(self):
        ''' Creates new RPC clients bound to the current Client token and caches them for later use
        '''
        for _, client in SERVICE_DESCRIPTORS.items():
            RPC_CLIENTS_CACHE[client.__name__] = client(self.base_api_url)

    def _generate_client_methods(self):
        ''' Dynamically generates RPC methods for each service on to this object
        '''
        for service_descriptor, service_client in SERVICE_DESCRIPTORS.items():
            for method_descriptor in service_descriptor.methods:
                func_name = method_descriptor.name
                twirp_func = getattr(service_client, func_name)

                # Skip generating the Query method since we have our own implementation of it
                if func_name == "Query" and service_descriptor.name == 'QueryService':
                    continue

                func_name = RPC_METHOD_CONFLICTS_RESOLVER.get(method_descriptor.full_name, func_name)
                func_name = underscore(func_name)

                proto_request_name = method_descriptor.input_type.full_name
                # create decorators around the existing twirp functions to handle input
                # and output in a pythonic way (ie. using python objects rather than Twirp objects)
                new_func = _deserialize_output(_serialize_input(twirp_func, proto_request_name))

                setattr(self.__class__, func_name, new_func)


def _serialize_input(twirp_func:TWIRP_FUNCTION, proto_request_name: str) -> TWIRP_FUNCTION:
    def wrapper(self, *args, **kwargs):
        client_name = _get_rpc_client_name(twirp_func)
        client = RPC_CLIENTS_CACHE[client_name]

        kwargs = _serialize_filter(kwargs)
        kwargs = _serialize_timestamps(kwargs)

        serialized_input_obj = _sym_db.GetSymbol(proto_request_name)(**kwargs)

        # twirp Context objects allow header overrides
        ctx = Context()
        ctx.set_header('x-api-key', self.api_key)
        ctx.set_header('x-namara-client', 'python')
        ctx.set_header('x-namara-client-version', __version__)

        return twirp_func(client, ctx=ctx, request=serialized_input_obj)

    return wrapper


def _deserialize_output(twirp_func:TWIRP_FUNCTION) -> TWIRP_FUNCTION:
    def wrapper(self, *args, **kwargs):
        resp = twirp_func(self, *args, **kwargs)

        result = MessageToDict(resp, preserving_proto_field_name=True)

        # a lot of rpc responses have a single top level key, so for
        # convenience sake, return just its value
        if len(result.keys()) == 1:
            result = result[list(result.keys())[0]]

        return result

    return wrapper

def _get_rpc_client_name(twirp_func):
    # __qualname__ will be of the form <rpc clien cls name>.<twirp_func>
    # ie. 'CatalogServiceClient.list_datasets'
    return twirp_func.__qualname__.split('.')[0]

def _serialize_filter(params):
    if "filter" not in params:
        return params

    new_filter = {}
    f = params["filter"]

    if "limit" in f:
        new_filter["limit"] = {"value": f["limit"]}

    if "offset" in f:
        new_filter["offset"] = {"value": f["offset"]}

    if "sort_field" in f:
        new_filter["sort"] = {
            "value": f["sort_field"],
            "order": f.get("sort_order", "ASC")
        }

    if "query" in f:
        new_filter["query"] = {"value": f["query"]}

    if "user_id" in f:
        new_filter["user_id"] = {"value": f["user_id"]}

    params["filter"].update(new_filter)
    return  params

def _serialize_timestamps(params):
    ''' Convert any google Timestamps fields to their proper type from string format'

    This assumes a structure similar to:
        {
            "dataset": {
                "created_at": {
                    "timestamp": "date string"
                }
            }
        }

    The timestamp field must be at the second level of nesting!

    '''
    if type(params) is not dict:
        return params

    for val in params.values():
        if type(val) is not dict:
            continue

        for key, sub_val in val.items():
            if type(sub_val) is not dict:
                continue

            ts = sub_val.get('timestamp', None)
            if (not ts):
                continue

            # if its already a Timestamp, not need to convert
            if type(ts) is Timestamp:
                continue

            t = Timestamp()

            try:
                t.FromJsonString(ts)
            except AttributeError as e:
                raise AttributeError(e, "failed on %s"%key)

            sub_val['timestamp'] = t

            val[key] = sub_val

    return params