Why Gemfury? Push, build, and install  RubyGems npm packages Python packages Maven artifacts PHP packages Go Modules Debian packages RPM packages NuGet packages

Repository URL to install this package:

Details    
postgis-import / pgimport.py
Size: Mime:
import logging

import pandas as pd
import geopandas as gpd
import psycopg2
from cachetools import cached, TTLCache
from fiona.crs import from_epsg
from geoalchemy2 import WKTElement, Geometry
from shapely.geometry import Point
import datetime

from .navinfo_service import NavInfoService
from .config import Config as cfg

logger = logging.getLogger(__name__)


class DbService(object):
    """This class handles database connections and requests to the DB"""

    def __init__(self):
        self.db_string = f'postgresql://{cfg.DB_USER}:{cfg.DB_PASSWD}@{cfg.DB_HOST}:5432/{cfg.DB_NAME}'

        self.conn = None
        self.cur = None

        self.temp_tables = []
        self.logger = logging.getLogger(__name__)

    def connect(self):
        """Opens DB connection.
        All queries are performed through ony one cursor. It is safe to call this
        method several times, no new connections opened.
        :return:
        """
        if not self.conn or self.conn.closed:
            self.conn = psycopg2.connect(self.db_string)
        if not self.cur or self.cur.closed:
            self.cur = self.conn.cursor()

    def disconnect(self):
        """Closes open db-connections
        Save to call anytime, will do nothing if no connections are open
        """
        if self.cur:
            self.cur.close()
        if self.conn:
            self.conn.close()

    def rollback(self):
        """Attempts to rollback recent changes. Use in case of emergencies"""
        try:
            if self.conn:
                self.conn.rollback()
        except psycopg2.InterfaceError:
            pass

    def commit(self):
        """Attempts to commit recent changes. """
        try:
            if self.conn:
                self.conn.commit()
        except psycopg2.InterfaceError:
            pass

    @cached(cache=TTLCache(maxsize=1024, ttl=1 * 60))
    def get_platform_id(self, pf_shortname):
        """Queries DB for platform id. Results are cached for one minute."""
        self.connect()
        try:
            self.cur.execute(f"SELECT id FROM {cfg.DB_SCHEMA}.platforms WHERE shortname=%s", (pf_shortname,))
            result = self.cur.fetchone()
            if result:
                return result[0]
            else:
                return None
        except Exception as exc:
            self.disconnect()
            raise exc
        finally:
            self.disconnect()

    def _get_all_tables(self):
        """Returns list of tables found in DB. Used to clear tmp tables."""
        the_query = f"""SELECT tablename FROM pg_catalog.pg_tables
                        WHERE schemaname = '{cfg.DB_SCHEMA}';
                    """
        self.connect()

        self.cur.execute(the_query)
        result = self.cur.fetchall()

        self.disconnect()

        return [r[0] for r in result]

    def get_tmp_tables(self):
        """Returns names of all tmp tables found in DB"""
        tables = self._get_all_tables()
        return [t for t in tables if '_tmp_' in t]

    def drop_temp_table(self, tmp_tbl_name):
        """Drops a table from the database. Used to drop temporary tables after they have been mergen"""
        assert '_tmp_' in tmp_tbl_name, \
            f'Use this to drop *temporary* tables only! Table {tmp_tbl_name} does not sound like a temp table :/'
        the_query = f"""
            DROP TABLE IF EXISTS {cfg.DB_SCHEMA}.{tmp_tbl_name};
            """
        self.cur.execute(the_query)

    def remove_all_tmp_tables(self):
        """Clears all tmp tables from db. For occasional housekeeping, *not for casual use*!"""
        tmp_tables = self.get_tmp_tables()
        self.connect()
        for tbl in tmp_tables:
            self.drop_temp_table(tbl)
        self.commit()
        self.disconnect()


def process_data(df, db_service):
    """Converts data types, gets platform id's from DB, drops unnecessary rows and columns, removes positions from future, de-duplicates data, creates GeoDataframe."""
    df = _add_missing_columns(df)
    df = _correct_dtypes(df)
    df = _make_geodataframe(df)
    df = _get_pf_ids(df, db_service)
    df = _de_dup_data(df)
    df = _reject_future_positions(df)
    df = _prune_data(df)

    return df

def _reject_future_positions(df):
    """Removes rows with timestamps > now+1h"""
#   # we need to decide if to use timestamp with or w/o timezone for comparison
    # check 1st ts assuming one for all
    if df['obs_timestamp'].iloc[0].tzinfo:  # compare w/ tz-aware ts
        now = datetime.datetime.now(datetime.timezone.utc)
    else:  # compare w/ naive ts
        now = datetime.datetime.utcnow()

    to_drop = df[df['obs_timestamp'] > now+datetime.timedelta(hours=1)]
    if len(to_drop)>0:
        logger.warning(f'REJECTING POSITIONS with obs_timestamp in future: {to_drop}')

    return df.drop(to_drop.index, axis=0)

def _add_missing_columns(df):
    """Adds 'heading' and 'speed_over_ground' columns if they are missing """
    if 'heading' not in df.columns:
        df['heading'] = ""
    if 'speed_over_ground' not in df.columns:
        df['speed_over_ground'] = ""
    if 'additional_data' not in df.columns:
        df['additional_data'] = ""
    return df


def _get_pf_ids(gdf, db_service):
    """Adds column 'platform_id' queried from db using 'platform_shortname'"""
    gdf['platform_id'] = gdf['platform_shortname'].apply(lambda sn: db_service.get_platform_id(sn))
    return gdf


def _correct_dtypes(df):
    """Makes sure that all dtypes are correct for db import"""
    must_be_numeric = ['lat', 'lon', 'heading', 'speed_over_ground']
    for column in must_be_numeric:
        if not pd.api.types.is_numeric_dtype(df[column]):
            logging.warning(f'dtype of column {column} is {df[column].dtype}, changing to numeric type!')
            df[column] = pd.to_numeric(df[column], errors='coerce')
    return df


def _make_geodataframe(df):
    """
    Creates a new GeoDataFrame from passed pandas Dataframe
    :param df:
    :return:
    """
    gdf = gpd.GeoDataFrame(df)
    gdf['geom'] = [Point(xy) for xy in zip(gdf['lon'], gdf['lat'])]
    gdf['geom'] = gdf['geom'].apply(lambda x: WKTElement(x.wkt, srid=4326))

    gdf.crs = from_epsg(4326)

    return gdf


def _prune_data(geo_df):
    """Removes all columns not handled in postgis db table"""
    to_keep = ['obs_timestamp', 'platform_id', 'geom', 'speed_over_ground', 'heading', 'additional_data']
    to_drop = [col for col in geo_df.columns if col not in to_keep]
    geo_df.drop(to_drop, axis=1, inplace=True)

    return geo_df


def _de_dup_data(df):
    """Removes duplicate lines from data frame; For every line with the same obs_timestamp and platform_id, only
       a single record is kept
    """
    if 'platform_shortname' in df.columns:
        to_drop = df[df.duplicated(subset=['obs_timestamp', 'platform_shortname'], keep='last')]
    elif 'platform_id' in df.columns:
        to_drop = df[df.duplicated(subset=['obs_timestamp', 'platform_id'], keep='last')]
    else:
        logging.warning('No platform column found! There is a problem somewhere before this.')

    return df.drop(to_drop.index, axis=0)


def _merge_positions_tables(table_a, table_b, db_service):
    """Merges tables a and b by either insert or update of b to a """
    #
    the_query = f"""
    SET TIMEZONE = 'UTC';
    INSERT INTO {cfg.DB_SCHEMA}.{table_a}
        (obs_timestamp, geom, platform_id, speed_over_ground, heading, additional_data)

            SELECT new_data.obs_timestamp, new_data.geom, new_data.platform_id,  
                    new_data.speed_over_ground, new_data.heading::numeric, new_data.additional_data
            FROM {cfg.DB_SCHEMA}.{table_b} AS new_data

        ON CONFLICT (obs_timestamp, platform_id) 
            DO UPDATE SET geom = excluded.geom, obs_timestamp = excluded.obs_timestamp;
    """
    db_service.cur.execute(the_query)
    db_service.conn.commit()


def add_navinfo():
    """Calculates heading and speed (speed not implemented yet) from known positions and persists them to DB
        only rows without info supplied by platforms are affected
    """
    dbs = DbService()
    df = NavInfoService(dbs).get_data_previous_position(only_missing=True)
    df['heading_calculated'] = df.apply(NavInfoService.calc_heading, axis=1)

    query = f"""
            UPDATE {cfg.DB_SCHEMA}.{cfg.DB_POSITIONS_TABLE}
            SET heading = %(heading_calculated)s
            WHERE obs_timestamp = %(obs_timestamp)s
                AND platform_id = %(platform_id)s
                AND heading IS NULL
            ;
            """

    query_to_null = f"""
                        UPDATE {cfg.DB_SCHEMA}.{cfg.DB_POSITIONS_TABLE} 
                        SET heading = NULL WHERE heading = 'NaN';
                    """
    dbs.connect()
    try:
        for i, row in df.iterrows():
            dbs.cur.execute(query, row.to_dict())
        # since heading col is of type 'float', 'NaN' is persisted to DB. Change to NULL.
        dbs.cur.execute(query_to_null)
        dbs.commit()
    except Exception as exc:
        dbs.rollback()
        dbs.disconnect()
        raise exc
    finally:
        dbs.disconnect()


def write_to_postgis(data_df, calc_missing_navinfo=False):
    """Writes data to postgis DB"""
    try:
        check_data_integrity(data_df)
    except AssertionError:
        logger.exception('Invalid input data.')
        return

    if len(data_df) < 1:
        # empty data frames can happen when devices are not currently emitting
        return

    dbs = DbService()
    dbs.connect()
    try:
        gdf = process_data(data_df, dbs)

        assert not gdf['platform_id'].isnull().any(), 'Unknown short names found!'
        # create temp table, make name unique
        tmp_tbl_name = cfg.DB_POSITIONS_TABLE + '_tmp_' + str(int(round(datetime.datetime.now().timestamp() * 1000)))
        gdf.to_sql(tmp_tbl_name, dbs.db_string, schema=cfg.DB_SCHEMA, if_exists='fail', index=False,
                   dtype={'geom': Geometry('POINT', srid=4326)})
        dbs.commit()
        dbs.connect()  # might have been closed by gdf.to_sql
        _merge_positions_tables(cfg.DB_POSITIONS_TABLE, tmp_tbl_name, dbs)
        dbs.connect()
        dbs.drop_temp_table(tmp_tbl_name)
        dbs.commit()
        if calc_missing_navinfo:
            add_navinfo()

    except Exception as exc:
        dbs.rollback()
        dbs.disconnect()
        raise exc
    finally:
        dbs.disconnect()


def check_data_integrity(data_df):
    """Checks if passed data is understood"""

    must_haves = ['platform_shortname', 'obs_timestamp', 'lat', 'lon']
    for item in must_haves:
        assert item in data_df.columns, f'Column {item} not found!'
        assert not data_df[item].isnull().any(), f'Empty value in column {item} not allowed!'

    if len(data_df) > 0:
        assert pd.core.dtypes.common.is_datetime_or_timedelta_dtype(data_df['obs_timestamp']) \
               or pd.core.dtypes.common.is_datetime64_any_dtype(data_df['obs_timestamp']), 'Expected timestamp' \
                                                                                           ' data in obs_timestamp col'
    else:
        logger.warning('Passed empty dataframe to importer')