Why Gemfury? Push, build, and install  RubyGems npm packages Python packages Maven artifacts PHP packages Go Modules Debian packages RPM packages NuGet packages

Repository URL to install this package:

Details    
namara-er / client.py
Size: Mime:
from time import time, sleep
from pandas import DataFrame, Series
import pandas as pd
from urllib.error import URLError, HTTPError
from rpc.er_service_twirp import ErServiceClient
from rpc.er_service_pb2 import (
    GetEntitiesRequest, GetBatchJobRequest, 
    LinkEntitiesRequest, LinkEntitiesResponse, GetEntitiesResponse
)
from namara_er.utils import gen_request_decorator
from typing import Any, List
from uuid import uuid4
from twirp.context import Context

# BATCH_SIZE indicates the max number of entities sent per request
# if len(entries) > BATCH_SIZE, we divide the request into batches
BATCH_SIZE = 100000

API_URL = "https://er.namara.io/twirp/er_service.ErService/"
SUPPORTED_ENTITY_TYPES = ["default", "business", "number", "address", "people"]
POLLING_INTERVAL = 6
POLLING = True
RETRIES = 3

# NOTE: to be implemented
RETRY_BATCH = False


class Client:
    def __init__(self, *args, **kwargs) -> None:
        self.api_key = kwargs.get('api_key')
        if not self.api_key:
            raise ApiKeyRequiredError()
        self.api_url = kwargs.get('api_url', API_URL)
        self.auth_url = kwargs.get('auth_url') # to be supported in the future
        self.rpc_client = ErServiceClient(self.api_url, timeout=3800)

    def get_entities(
        self,
        entity_type: str,
        entities: List[str],
        threshold: float = None,
        seed: bool = False,
        dataset_id: str = None,
        use_master_db: bool = True
    ) -> Any:
        """
        returns a list of entities matched against the input entries.
        if entries are not matched against any entities, these are going to
        be created.
        """
        ctx = Context()
        ctx.set_header('x-api-key', self.api_key)
        self._validate(entities)

        response = {
            "entities": []
        }

        if not dataset_id or dataset_id is None:
            dataset_id = self._generate_er_dataset_id()


        last_batch = False
        for i in range(self._get_batch_number(len(entities))):
            chunk = i * BATCH_SIZE
            query_entities = entities[chunk:BATCH_SIZE + chunk]

            # check for last batch
            if (BATCH_SIZE + chunk) >= len(entities):
                last_batch = True

            if threshold:
                entities_request = GetEntitiesRequest(
                    dataset_id=dataset_id,
                    entity_type=entity_type,
                    queries=query_entities,
                    seed=seed,
                    threshold=threshold,
                    use_master_db=use_master_db,
                    last_batch=last_batch,
                )
            else:
                entities_request = GetEntitiesRequest(
                    dataset_id=dataset_id,
                    entity_type=entity_type,
                    queries=query_entities,
                    seed=seed,
                    use_master_db=use_master_db,
                    last_batch=last_batch,
                )
            # Async Seeding
            if seed and POLLING:
                for i in range(RETRIES):
                    try:
                        create_batch_job_resp = self.rpc_client.CreateBatchJob(
                            ctx=ctx, request=entities_request
                        )
                    except (URLError, HTTPError):
                        if i < RETRIES - 1:
                            sleep(POLLING_INTERVAL)
                            continue
                        else:
                            raise
                    break

                batch_job_id = create_batch_job_resp.batch_job_id
                batch_in_progress = True
                while batch_in_progress:
                    sleep(POLLING_INTERVAL)
                    batch_job_request = GetBatchJobRequest(batch_job_id=batch_job_id)

                    for i in range(RETRIES):
                        try:
                            resp = self.rpc_client.GetBatchJob(ctx=ctx, request=batch_job_request)
                        except (URLError, HTTPError):
                            if i < RETRIES - 1:
                                sleep(POLLING_INTERVAL)
                                continue
                            else:
                                raise
                        break
                    if resp.status == "finished":
                        break
                    if resp.status == "failed":
                        raise SeedingFailed(
                            f"batch id: {batch_job_id} failed to seed", None
                        )
            else:
                resp = self.rpc_client.GetEntities(ctx=ctx, request=entities_request)

            response['dataset_id'] = resp.dataset_id
            response['entity_type'] = resp.entity_type
            for ent in resp.entities:
                response['entities'].append(self._reformat_entity(ent))

        return response


    def link_entities(
        self,
        master_df,
        batch_df,
        master_index_col:str,
        batch_index_col:str,
        entity_type: str,
        threshold: float = None,
        dataset_id: str = None,
        feature: str = None,
        merge_direction: str = 'left'
        ) -> Any:
        """
        This method receives 2 dataframes along with the target columns, and performs er linking on them
        Returns master df with codes, batch df with codes, entity_pairs, and merged_df.
        """
        ctx = Context()
        ctx.set_header('x-api-key', self.api_key)

        if entity_type not in SUPPORTED_ENTITY_TYPES:
            raise InvalidEntryError(f"Entity type not supported: '{entity_type}'.", None)
        if not isinstance(master_df, DataFrame):
            raise InvalidEntryError(f"Type {type(df)} not supported. Please enter a valid pandas DataFrame.", None)
        if not isinstance(batch_df, DataFrame):
            raise InvalidEntryError(f"Type {type(df)} not supported. Please enter a valid pandas DataFrame.", None)

        master_entities = master_df[master_index_col]
        batch_entities = batch_df[batch_index_col]
        self._validate(master_entities)
        self._validate(batch_entities)

        response = {
            "master_entities": [],
            "batch_entities": [],
            "entity_pairs": []
        }

        if not dataset_id or dataset_id == None:
            dataset_id = self._generate_er_dataset_id()
        
        last_batch = True

        entities_request = LinkEntitiesRequest(
                dataset_id=dataset_id,
                entity_type=entity_type,
                master_entities=master_entities,
                batch_entities=batch_entities,
                threshold=threshold,
                last_batch=last_batch,
                feature=feature,
        )

        resp = self.rpc_client.LinkEntities(ctx=ctx, request=entities_request)
        response['dataset_id'] = resp.dataset_id
        response['entity_type'] = resp.entity_type

        for ent in resp.master_entities:
            response['master_entities'].append(self._reformat_entity(ent))

        for ent in resp.batch_entities:
            response['batch_entities'].append(self._reformat_entity(ent))

        for entity_pair in resp.entity_pairs:
            response['entity_pairs'].append(self._reformat_entity_pairs(entity_pair))

        master_er_df = pd.DataFrame(response['master_entities'])
        batch_er_df = pd.DataFrame(response['batch_entities'])
        entity_pairs = pd.DataFrame(response['entity_pairs'])

        # Merge er dfs with orig dfs to get the rest of the columns
        # When merge_direction is 'right' we want to keep everything from master table
        if merge_direction == 'left':
            batch_er_df = batch_er_df.merge(batch_df, left_index=True, right_index=True).drop(columns=['input_value'])
            master_er_df = master_er_df.merge(master_df, left_index=True, right_index=True).drop(columns=['input_value'])

            master_pairs = master_er_df.merge(entity_pairs, left_on=master_index_col, right_on='entity_value', how='inner').drop(columns=['similarity_score', 'code'])
            batch_pairs = batch_er_df.merge(entity_pairs, left_on=batch_index_col, right_on='matched_entity_value', how='left')
            merged_df = batch_pairs.merge(master_pairs, on = ['entity_value', 'matched_entity_value'], how=merge_direction, suffixes=['_batch', '_master'])
            merged_df.drop(columns=['entity_value', 'matched_entity_value'], inplace=True)
        elif merge_direction == 'right':
            batch_er_df = batch_er_df.merge(batch_df, left_index=True, right_index=True).drop(columns=['input_value'])
            master_er_df = master_er_df.merge(master_df, left_index=True, right_index=True).drop(columns=['input_value'])

            master_pairs = master_er_df.merge(entity_pairs, left_on=master_index_col, right_on='entity_value', how='left').drop(columns=['similarity_score', 'code'])
            batch_pairs = batch_er_df.merge(entity_pairs, left_on=batch_index_col, right_on='matched_entity_value', how='inner')
            merged_df = batch_pairs.merge(master_pairs, on = ['entity_value', 'matched_entity_value'], how=merge_direction, suffixes=['_batch', '_master'])
            merged_df.drop(columns=['entity_value', 'matched_entity_value'], inplace=True)            
        else:
            print("merge_direction must be either 'left' or 'right'")

        return master_er_df, batch_er_df, entity_pairs, merged_df


    def _reformat_entity(self, entity_response):
        """
        In case service side can't find codes for a certain entry, it won't return those fields
        this method simply adds these fields with `None` values for convenience.
        See the warnings: [] for a possible reason this entity isn't getting linked.
        """
        if not hasattr(entity_response, 'code'):
            entity_response.code = None
        if not hasattr(entity_response, 'input_value'):
            entity_response.input_value = None
        return {
            "input_value": entity_response.input_value,
            "code": entity_response.code,
        }

    def _reformat_entity_pairs(self, entity_pairs_response):
        """
        In case service side can't find codes for a certain entry, it won't return those fields
        this method simply adds these fields with `None` values for convenience.
        See the warnings: [] for a possible reason this entity isn't getting linked.
        """
        if not hasattr(entity_pairs_response, 'entity_value'):
            entity_pairs_response.entity_value = None
        if not hasattr(entity_pairs_response, 'matched_entity_value'):
            entity_pairs_response.matched_entity_value = None
        if not hasattr(entity_pairs_response, 'similarity_score'):
            entity_pairs_response.similarity_score = None
        return {
            "entity_value": entity_pairs_response.entity_value,
            "matched_entity_value": entity_pairs_response.matched_entity_value,
            "similarity_score": entity_pairs_response.similarity_score,
        }

    def _get_batch_number(self, n_entities):
        if n_entities % BATCH_SIZE >= 1:
            return int((n_entities / BATCH_SIZE)) + 1
        else:
            return int(n_entities / BATCH_SIZE)

    def _validate(self, entries: List[str]):
        """
        given a list of entries List[str], loops through them and
        raises in case a invalid entry is found
        """
        for idx, e in enumerate(entries):
            if str(e) == 'nan':
                raise InvalidEntryError("found invalid entry: 'nan' in index: {}".format(idx), None)

    def _generate_er_dataset_id(self):
        """
        generate a random UUID prefixed with 'ER_' to show that
        the dataset_id was generated internally
        """
        dataset_id = f"ER_{uuid4()}"
        return dataset_id

    def append_codes(
        self,
        df,
        col_name: str,
        entity_type: str = "business",
        threshold: float = None,
        seed: bool = False,
        dataset_id: str = None,
        use_master_db: bool = True
        ) -> Any:
        """
        Extracts a list of entries from the given dataset to assign codes,
        then appends the column of codes to the original dataset.

        :param df: input DataFrame
        :param col_name: Name of column containing the entries
        :param entity_type: Type of entity of entries in col_name, defaults to "business"
        :return: DataFrame with the appended column of codes
        """
        if not isinstance(df, DataFrame):
            raise InvalidEntryError(f"Type {type(df)} not supported. Please enter a valid pandas DataFrame.", None)
        if entity_type not in SUPPORTED_ENTITY_TYPES:
            raise InvalidEntryError(f"Entity type not supported: '{entity_type}'.", None)

        df = df.fillna('')
        entities = df[col_name].tolist()

        resolved_entities = self.get_entities(
            entity_type=entity_type,
            entities=entities,
            threshold=threshold,
            seed=seed,
            dataset_id=dataset_id,
            use_master_db=use_master_db
        )

        codes = []
        for e in resolved_entities["entities"]:
            codes.append(e["code"])

        # append a column of name '{col_name}_id' with codes
        new_col_name = col_name + "_id"
        df[new_col_name] = Series(codes).values

        return df


class ApiKeyRequiredError(Exception):
    def __init__(self, message, errors):
        super().__init__(message)
        self.errors = errors


class InvalidEntryError(Exception):
    """
    custom error class used for handling errors in the client input
    """
    def __init__(self, message, errors):
        super().__init__(message)
        self.errors = errors


class SeedingFailed(Exception):
    """
    Error raised when the API fails running some seeding batch
    """
    def __init__(self, message, errors):
        super().__init__(message)
        self.errors = errors