Repository URL to install this package:
|
Version:
1.0.9 ▾
|
namara-er
/
client.py
|
|---|
from time import time, sleep
from pandas import DataFrame, Series
import pandas as pd
from urllib.error import URLError, HTTPError
from rpc.er_service_twirp import ErServiceClient
from rpc.er_service_pb2 import (
GetEntitiesRequest, GetBatchJobRequest,
LinkEntitiesRequest, LinkEntitiesResponse, GetEntitiesResponse
)
from namara_er.utils import gen_request_decorator
from typing import Any, List
from uuid import uuid4
from twirp.context import Context
# BATCH_SIZE indicates the max number of entities sent per request
# if len(entries) > BATCH_SIZE, we divide the request into batches
BATCH_SIZE = 100000
API_URL = "https://er.namara.io/twirp/er_service.ErService/"
SUPPORTED_ENTITY_TYPES = ["default", "business", "number", "address", "people"]
POLLING_INTERVAL = 6
POLLING = True
RETRIES = 3
# NOTE: to be implemented
RETRY_BATCH = False
class Client:
def __init__(self, *args, **kwargs) -> None:
self.api_key = kwargs.get('api_key')
if not self.api_key:
raise ApiKeyRequiredError()
self.api_url = kwargs.get('api_url', API_URL)
self.auth_url = kwargs.get('auth_url') # to be supported in the future
self.rpc_client = ErServiceClient(self.api_url, timeout=3800)
def get_entities(
self,
entity_type: str,
entities: List[str],
threshold: float = None,
seed: bool = False,
dataset_id: str = None,
use_master_db: bool = True
) -> Any:
"""
returns a list of entities matched against the input entries.
if entries are not matched against any entities, these are going to
be created.
"""
ctx = Context()
ctx.set_header('x-api-key', self.api_key)
self._validate(entities)
response = {
"entities": []
}
if not dataset_id or dataset_id is None:
dataset_id = self._generate_er_dataset_id()
last_batch = False
for i in range(self._get_batch_number(len(entities))):
chunk = i * BATCH_SIZE
query_entities = entities[chunk:BATCH_SIZE + chunk]
# check for last batch
if (BATCH_SIZE + chunk) >= len(entities):
last_batch = True
if threshold:
entities_request = GetEntitiesRequest(
dataset_id=dataset_id,
entity_type=entity_type,
queries=query_entities,
seed=seed,
threshold=threshold,
use_master_db=use_master_db,
last_batch=last_batch,
)
else:
entities_request = GetEntitiesRequest(
dataset_id=dataset_id,
entity_type=entity_type,
queries=query_entities,
seed=seed,
use_master_db=use_master_db,
last_batch=last_batch,
)
# Async Seeding
if seed and POLLING:
for i in range(RETRIES):
try:
create_batch_job_resp = self.rpc_client.CreateBatchJob(
ctx=ctx, request=entities_request
)
except (URLError, HTTPError):
if i < RETRIES - 1:
sleep(POLLING_INTERVAL)
continue
else:
raise
break
batch_job_id = create_batch_job_resp.batch_job_id
batch_in_progress = True
while batch_in_progress:
sleep(POLLING_INTERVAL)
batch_job_request = GetBatchJobRequest(batch_job_id=batch_job_id)
for i in range(RETRIES):
try:
resp = self.rpc_client.GetBatchJob(ctx=ctx, request=batch_job_request)
except (URLError, HTTPError):
if i < RETRIES - 1:
sleep(POLLING_INTERVAL)
continue
else:
raise
break
if resp.status == "finished":
break
if resp.status == "failed":
raise SeedingFailed(
f"batch id: {batch_job_id} failed to seed", None
)
else:
resp = self.rpc_client.GetEntities(ctx=ctx, request=entities_request)
response['dataset_id'] = resp.dataset_id
response['entity_type'] = resp.entity_type
for ent in resp.entities:
response['entities'].append(self._reformat_entity(ent))
return response
def link_entities(
self,
master_df,
batch_df,
master_index_col:str,
batch_index_col:str,
entity_type: str,
threshold: float = None,
dataset_id: str = None,
feature: str = None,
merge_direction: str = 'left'
) -> Any:
"""
This method receives 2 dataframes along with the target columns, and performs er linking on them
Returns master df with codes, batch df with codes, entity_pairs, and merged_df.
"""
ctx = Context()
ctx.set_header('x-api-key', self.api_key)
if entity_type not in SUPPORTED_ENTITY_TYPES:
raise InvalidEntryError(f"Entity type not supported: '{entity_type}'.", None)
if not isinstance(master_df, DataFrame):
raise InvalidEntryError(f"Type {type(df)} not supported. Please enter a valid pandas DataFrame.", None)
if not isinstance(batch_df, DataFrame):
raise InvalidEntryError(f"Type {type(df)} not supported. Please enter a valid pandas DataFrame.", None)
master_entities = master_df[master_index_col]
batch_entities = batch_df[batch_index_col]
self._validate(master_entities)
self._validate(batch_entities)
response = {
"master_entities": [],
"batch_entities": [],
"entity_pairs": []
}
if not dataset_id or dataset_id == None:
dataset_id = self._generate_er_dataset_id()
last_batch = True
entities_request = LinkEntitiesRequest(
dataset_id=dataset_id,
entity_type=entity_type,
master_entities=master_entities,
batch_entities=batch_entities,
threshold=threshold,
last_batch=last_batch,
feature=feature,
)
resp = self.rpc_client.LinkEntities(ctx=ctx, request=entities_request)
response['dataset_id'] = resp.dataset_id
response['entity_type'] = resp.entity_type
for ent in resp.master_entities:
response['master_entities'].append(self._reformat_entity(ent))
for ent in resp.batch_entities:
response['batch_entities'].append(self._reformat_entity(ent))
for entity_pair in resp.entity_pairs:
response['entity_pairs'].append(self._reformat_entity_pairs(entity_pair))
master_er_df = pd.DataFrame(response['master_entities'])
batch_er_df = pd.DataFrame(response['batch_entities'])
entity_pairs = pd.DataFrame(response['entity_pairs'])
# Merge er dfs with orig dfs to get the rest of the columns
# When merge_direction is 'right' we want to keep everything from master table
if merge_direction == 'left':
batch_er_df = batch_er_df.merge(batch_df, left_index=True, right_index=True).drop(columns=['input_value'])
master_er_df = master_er_df.merge(master_df, left_index=True, right_index=True).drop(columns=['input_value'])
master_pairs = master_er_df.merge(entity_pairs, left_on=master_index_col, right_on='entity_value', how='inner').drop(columns=['similarity_score', 'code'])
batch_pairs = batch_er_df.merge(entity_pairs, left_on=batch_index_col, right_on='matched_entity_value', how='left')
merged_df = batch_pairs.merge(master_pairs, on = ['entity_value', 'matched_entity_value'], how=merge_direction, suffixes=['_batch', '_master'])
merged_df.drop(columns=['entity_value', 'matched_entity_value'], inplace=True)
elif merge_direction == 'right':
batch_er_df = batch_er_df.merge(batch_df, left_index=True, right_index=True).drop(columns=['input_value'])
master_er_df = master_er_df.merge(master_df, left_index=True, right_index=True).drop(columns=['input_value'])
master_pairs = master_er_df.merge(entity_pairs, left_on=master_index_col, right_on='entity_value', how='left').drop(columns=['similarity_score', 'code'])
batch_pairs = batch_er_df.merge(entity_pairs, left_on=batch_index_col, right_on='matched_entity_value', how='inner')
merged_df = batch_pairs.merge(master_pairs, on = ['entity_value', 'matched_entity_value'], how=merge_direction, suffixes=['_batch', '_master'])
merged_df.drop(columns=['entity_value', 'matched_entity_value'], inplace=True)
else:
print("merge_direction must be either 'left' or 'right'")
return master_er_df, batch_er_df, entity_pairs, merged_df
def _reformat_entity(self, entity_response):
"""
In case service side can't find codes for a certain entry, it won't return those fields
this method simply adds these fields with `None` values for convenience.
See the warnings: [] for a possible reason this entity isn't getting linked.
"""
if not hasattr(entity_response, 'code'):
entity_response.code = None
if not hasattr(entity_response, 'input_value'):
entity_response.input_value = None
return {
"input_value": entity_response.input_value,
"code": entity_response.code,
}
def _reformat_entity_pairs(self, entity_pairs_response):
"""
In case service side can't find codes for a certain entry, it won't return those fields
this method simply adds these fields with `None` values for convenience.
See the warnings: [] for a possible reason this entity isn't getting linked.
"""
if not hasattr(entity_pairs_response, 'entity_value'):
entity_pairs_response.entity_value = None
if not hasattr(entity_pairs_response, 'matched_entity_value'):
entity_pairs_response.matched_entity_value = None
if not hasattr(entity_pairs_response, 'similarity_score'):
entity_pairs_response.similarity_score = None
return {
"entity_value": entity_pairs_response.entity_value,
"matched_entity_value": entity_pairs_response.matched_entity_value,
"similarity_score": entity_pairs_response.similarity_score,
}
def _get_batch_number(self, n_entities):
if n_entities % BATCH_SIZE >= 1:
return int((n_entities / BATCH_SIZE)) + 1
else:
return int(n_entities / BATCH_SIZE)
def _validate(self, entries: List[str]):
"""
given a list of entries List[str], loops through them and
raises in case a invalid entry is found
"""
for idx, e in enumerate(entries):
if str(e) == 'nan':
raise InvalidEntryError("found invalid entry: 'nan' in index: {}".format(idx), None)
def _generate_er_dataset_id(self):
"""
generate a random UUID prefixed with 'ER_' to show that
the dataset_id was generated internally
"""
dataset_id = f"ER_{uuid4()}"
return dataset_id
def append_codes(
self,
df,
col_name: str,
entity_type: str = "business",
threshold: float = None,
seed: bool = False,
dataset_id: str = None,
use_master_db: bool = True
) -> Any:
"""
Extracts a list of entries from the given dataset to assign codes,
then appends the column of codes to the original dataset.
:param df: input DataFrame
:param col_name: Name of column containing the entries
:param entity_type: Type of entity of entries in col_name, defaults to "business"
:return: DataFrame with the appended column of codes
"""
if not isinstance(df, DataFrame):
raise InvalidEntryError(f"Type {type(df)} not supported. Please enter a valid pandas DataFrame.", None)
if entity_type not in SUPPORTED_ENTITY_TYPES:
raise InvalidEntryError(f"Entity type not supported: '{entity_type}'.", None)
df = df.fillna('')
entities = df[col_name].tolist()
resolved_entities = self.get_entities(
entity_type=entity_type,
entities=entities,
threshold=threshold,
seed=seed,
dataset_id=dataset_id,
use_master_db=use_master_db
)
codes = []
for e in resolved_entities["entities"]:
codes.append(e["code"])
# append a column of name '{col_name}_id' with codes
new_col_name = col_name + "_id"
df[new_col_name] = Series(codes).values
return df
class ApiKeyRequiredError(Exception):
def __init__(self, message, errors):
super().__init__(message)
self.errors = errors
class InvalidEntryError(Exception):
"""
custom error class used for handling errors in the client input
"""
def __init__(self, message, errors):
super().__init__(message)
self.errors = errors
class SeedingFailed(Exception):
"""
Error raised when the API fails running some seeding batch
"""
def __init__(self, message, errors):
super().__init__(message)
self.errors = errors