Repository URL to install this package:
|
Version:
0.9.8 ▾
|
postgis-import
/
pgimport.py
|
|---|
import logging
import pandas as pd
import geopandas as gpd
import psycopg2
from cachetools import cached, TTLCache
from fiona.crs import from_epsg
from geoalchemy2 import WKTElement, Geometry
from shapely.geometry import Point
import datetime
from .navinfo_service import NavInfoService
from .config import Config as cfg
logger = logging.getLogger(__name__)
class DbService(object):
"""This class handles database connections and requests to the DB"""
def __init__(self):
self.db_string = f'postgresql://{cfg.DB_USER}:{cfg.DB_PASSWD}@{cfg.DB_HOST}:5432/{cfg.DB_NAME}'
self.conn = None
self.cur = None
self.temp_tables = []
self.logger = logging.getLogger(__name__)
def connect(self):
"""Opens DB connection.
All queries are performed through ony one cursor. It is safe to call this
method several times, no new connections opened.
:return:
"""
if not self.conn or self.conn.closed:
self.conn = psycopg2.connect(self.db_string)
if not self.cur or self.cur.closed:
self.cur = self.conn.cursor()
def disconnect(self):
"""Closes open db-connections
Save to call anytime, will do nothing if no connections are open
"""
if self.cur:
self.cur.close()
if self.conn:
self.conn.close()
def rollback(self):
"""Attempts to rollback recent changes. Use in case of emergencies"""
try:
if self.conn:
self.conn.rollback()
except psycopg2.InterfaceError:
pass
def commit(self):
"""Attempts to commit recent changes. """
try:
if self.conn:
self.conn.commit()
except psycopg2.InterfaceError:
pass
@cached(cache=TTLCache(maxsize=1024, ttl=1 * 60))
def get_platform_id(self, pf_shortname):
"""Queries DB for platform id. Results are cached for one minute."""
self.connect()
try:
self.cur.execute(f"SELECT id FROM {cfg.DB_SCHEMA}.platforms WHERE shortname=%s", (pf_shortname,))
result = self.cur.fetchone()
if result:
return result[0]
else:
return None
except Exception as exc:
self.disconnect()
raise exc
finally:
self.disconnect()
def _get_all_tables(self):
"""Returns list of tables found in DB. Used to clear tmp tables."""
the_query = f"""SELECT tablename FROM pg_catalog.pg_tables
WHERE schemaname = '{cfg.DB_SCHEMA}';
"""
self.connect()
self.cur.execute(the_query)
result = self.cur.fetchall()
self.disconnect()
return [r[0] for r in result]
def get_tmp_tables(self):
"""Returns names of all tmp tables found in DB"""
tables = self._get_all_tables()
return [t for t in tables if '_tmp_' in t]
def drop_temp_table(self, tmp_tbl_name):
"""Drops a table from the database. Used to drop temporary tables after they have been mergen"""
assert '_tmp_' in tmp_tbl_name, \
f'Use this to drop *temporary* tables only! Table {tmp_tbl_name} does not sound like a temp table :/'
the_query = f"""
DROP TABLE IF EXISTS {cfg.DB_SCHEMA}.{tmp_tbl_name};
"""
self.cur.execute(the_query)
def remove_all_tmp_tables(self):
"""Clears all tmp tables from db. For occasional housekeeping, *not for casual use*!"""
tmp_tables = self.get_tmp_tables()
self.connect()
for tbl in tmp_tables:
self.drop_temp_table(tbl)
self.commit()
self.disconnect()
def process_data(df, db_service):
"""Converts data types, gets platform id's from DB, drops unnecessary rows and columns, removes positions from future, de-duplicates data, creates GeoDataframe."""
df = _add_missing_columns(df)
df = _correct_dtypes(df)
df = _make_geodataframe(df)
df = _get_pf_ids(df, db_service)
df = _de_dup_data(df)
df = _reject_future_positions(df)
df = _prune_data(df)
return df
def _reject_future_positions(df):
"""Removes rows with timestamps > now+1h"""
# # we need to decide if to use timestamp with or w/o timezone for comparison
# check 1st ts assuming one for all
if df['obs_timestamp'].iloc[0].tzinfo: # compare w/ tz-aware ts
now = datetime.datetime.now(datetime.timezone.utc)
else: # compare w/ naive ts
now = datetime.datetime.utcnow()
to_drop = df[df['obs_timestamp'] > now+datetime.timedelta(hours=1)]
if len(to_drop)>0:
logger.warning(f'REJECTING POSITIONS with obs_timestamp in future: {to_drop}')
return df.drop(to_drop.index, axis=0)
def _add_missing_columns(df):
"""Adds 'heading' and 'speed_over_ground' columns if they are missing """
if 'heading' not in df.columns:
df['heading'] = ""
if 'speed_over_ground' not in df.columns:
df['speed_over_ground'] = ""
if 'additional_data' not in df.columns:
df['additional_data'] = ""
return df
def _get_pf_ids(gdf, db_service):
"""Adds column 'platform_id' queried from db using 'platform_shortname'"""
gdf['platform_id'] = gdf['platform_shortname'].apply(lambda sn: db_service.get_platform_id(sn))
return gdf
def _correct_dtypes(df):
"""Makes sure that all dtypes are correct for db import"""
must_be_numeric = ['lat', 'lon', 'heading', 'speed_over_ground']
for column in must_be_numeric:
if not pd.api.types.is_numeric_dtype(df[column]):
logging.warning(f'dtype of column {column} is {df[column].dtype}, changing to numeric type!')
df[column] = pd.to_numeric(df[column], errors='coerce')
return df
def _make_geodataframe(df):
"""
Creates a new GeoDataFrame from passed pandas Dataframe
:param df:
:return:
"""
gdf = gpd.GeoDataFrame(df)
gdf['geom'] = [Point(xy) for xy in zip(gdf['lon'], gdf['lat'])]
gdf['geom'] = gdf['geom'].apply(lambda x: WKTElement(x.wkt, srid=4326))
gdf.crs = from_epsg(4326)
return gdf
def _prune_data(geo_df):
"""Removes all columns not handled in postgis db table"""
to_keep = ['obs_timestamp', 'platform_id', 'geom', 'speed_over_ground', 'heading', 'additional_data']
to_drop = [col for col in geo_df.columns if col not in to_keep]
geo_df.drop(to_drop, axis=1, inplace=True)
return geo_df
def _de_dup_data(df):
"""Removes duplicate lines from data frame; For every line with the same obs_timestamp and platform_id, only
a single record is kept
"""
if 'platform_shortname' in df.columns:
to_drop = df[df.duplicated(subset=['obs_timestamp', 'platform_shortname'], keep='last')]
elif 'platform_id' in df.columns:
to_drop = df[df.duplicated(subset=['obs_timestamp', 'platform_id'], keep='last')]
else:
logging.warning('No platform column found! There is a problem somewhere before this.')
return df.drop(to_drop.index, axis=0)
def _merge_positions_tables(table_a, table_b, db_service):
"""Merges tables a and b by either insert or update of b to a """
#
the_query = f"""
SET TIMEZONE = 'UTC';
INSERT INTO {cfg.DB_SCHEMA}.{table_a}
(obs_timestamp, geom, platform_id, speed_over_ground, heading, additional_data)
SELECT new_data.obs_timestamp, new_data.geom, new_data.platform_id,
new_data.speed_over_ground, new_data.heading::numeric, new_data.additional_data
FROM {cfg.DB_SCHEMA}.{table_b} AS new_data
ON CONFLICT (obs_timestamp, platform_id)
DO UPDATE SET geom = excluded.geom, obs_timestamp = excluded.obs_timestamp;
"""
db_service.cur.execute(the_query)
db_service.conn.commit()
def add_navinfo():
"""Calculates heading and speed (speed not implemented yet) from known positions and persists them to DB
only rows without info supplied by platforms are affected
"""
dbs = DbService()
df = NavInfoService(dbs).get_data_previous_position(only_missing=True)
df['heading_calculated'] = df.apply(NavInfoService.calc_heading, axis=1)
query = f"""
UPDATE {cfg.DB_SCHEMA}.{cfg.DB_POSITIONS_TABLE}
SET heading = %(heading_calculated)s
WHERE obs_timestamp = %(obs_timestamp)s
AND platform_id = %(platform_id)s
AND heading IS NULL
;
"""
query_to_null = f"""
UPDATE {cfg.DB_SCHEMA}.{cfg.DB_POSITIONS_TABLE}
SET heading = NULL WHERE heading = 'NaN';
"""
dbs.connect()
try:
for i, row in df.iterrows():
dbs.cur.execute(query, row.to_dict())
# since heading col is of type 'float', 'NaN' is persisted to DB. Change to NULL.
dbs.cur.execute(query_to_null)
dbs.commit()
except Exception as exc:
dbs.rollback()
dbs.disconnect()
raise exc
finally:
dbs.disconnect()
def write_to_postgis(data_df, calc_missing_navinfo=False):
"""Writes data to postgis DB"""
try:
check_data_integrity(data_df)
except AssertionError:
logger.exception('Invalid input data.')
return
if len(data_df) < 1:
# empty data frames can happen when devices are not currently emitting
return
dbs = DbService()
dbs.connect()
try:
gdf = process_data(data_df, dbs)
assert not gdf['platform_id'].isnull().any(), 'Unknown short names found!'
# create temp table, make name unique
tmp_tbl_name = cfg.DB_POSITIONS_TABLE + '_tmp_' + str(int(round(datetime.datetime.now().timestamp() * 1000)))
gdf.to_sql(tmp_tbl_name, dbs.db_string, schema=cfg.DB_SCHEMA, if_exists='fail', index=False,
dtype={'geom': Geometry('POINT', srid=4326)})
dbs.commit()
dbs.connect() # might have been closed by gdf.to_sql
_merge_positions_tables(cfg.DB_POSITIONS_TABLE, tmp_tbl_name, dbs)
dbs.connect()
dbs.drop_temp_table(tmp_tbl_name)
dbs.commit()
if calc_missing_navinfo:
add_navinfo()
except Exception as exc:
dbs.rollback()
dbs.disconnect()
raise exc
finally:
dbs.disconnect()
def check_data_integrity(data_df):
"""Checks if passed data is understood"""
must_haves = ['platform_shortname', 'obs_timestamp', 'lat', 'lon']
for item in must_haves:
assert item in data_df.columns, f'Column {item} not found!'
assert not data_df[item].isnull().any(), f'Empty value in column {item} not allowed!'
if len(data_df) > 0:
assert pd.core.dtypes.common.is_datetime_or_timedelta_dtype(data_df['obs_timestamp']) \
or pd.core.dtypes.common.is_datetime64_any_dtype(data_df['obs_timestamp']), 'Expected timestamp' \
' data in obs_timestamp col'
else:
logger.warning('Passed empty dataframe to importer')