Why Gemfury? Push, build, and install  RubyGems npm packages Python packages Maven artifacts PHP packages Go Modules Debian packages RPM packages NuGet packages

Repository URL to install this package:

Details    
sarus_dataset / sarus_dataset / read_dataset.py
Size: Mime:
from urllib.parse import urlparse
import pandas as pd
import numpy as np
import pyarrow.dataset as ds
import tensorflow as tf
import gcsfs
from .column_name_encoding import decode


def dataset_to_pandas(parquet_dataset, rows_number=None, convert_nan=True):
    """
    Takes a pyarrow.ParquetDataset pointing to a dataset in Sarus format and
    loads it as a pandas DataFrame

    Args:
        parquet_dataset (pyarrow.ParquetDataset): a pyarrow.ParquetDataset
        rows_number (int): limits the number of rows returned

    Returns:
        pd.DataFrame: the dataset in pandas
    """
    print("MODULE SARUS DATASET WORKS")
    df = pd.DataFrame()
    for piece in parquet_dataset.pieces:
        df = df.append(piece.read().to_pandas())
        if rows_number is not None and len(df) > rows_number:
            df = df[:rows_number]
            break

    # remove the id column, decode names
    df.drop(columns=["id"], inplace=True)
    df.rename(columns={c: decode(c) for c in df.columns}, inplace=True)
    # replace min values = -9e+18 by pd.NA and change the data types
    if convert_nan:
        for column in df.columns:
            if df[column].dtype in [
                np.dtype("int64"),
                np.dtype("int32"),
                np.dtype("int16"),
            ]:
                df[column] = df[column].astype("Int64")
                df.loc[df[column] <= np.iinfo(np.int64).min, column] = pd.NA
    return df


def get_path_fs(source):
    url = urlparse(source)

    if url.scheme in ("gs", "gcs"):
        fs = gcsfs.GCSFileSystem()
        path = f"{url.netloc}{url.path}"
    else:
        # We let arrow deal with the parsing of the fs
        # otherwise
        fs = None
        path = source

    return path, fs


def make_tf_dataset(source, sarus_schema):
    """
    Creates a tensorflow Dataset consuming the Parquet data.

    Args:
        source (str): URI pointing to the root directory of the Parquet Dataset.

    Returns:
        tf.data.Dataset: infinite dataset looping over the whole dataset sequentially
    """
    path, fs = get_path_fs(source)
    dataset = ds.dataset(path, filesystem=fs, format="parquet")
    generator_factory = get_generator_factory_from_dataset(
        dataset, sarus_schema
    )
    tf_schema = sarus_schema.transform(add_id=True).get_tf_schema()
    output_signature = {
        name: tf.TensorSpec(
            shape=(None, *signature.shape), dtype=signature.dtype
        )
        for name, signature in tf_schema.items()
    }

    return tf.data.Dataset.from_generator(
        generator_factory,
        output_signature=output_signature,
    )


def get_generator_factory_from_dataset(dataset, sarus_schema):
    def generator():
        while True:
            for batch in dataset.to_batches():
                # convert the batch into a pydict
                pydict = {}
                for col_name, column in zip(batch.schema.names, batch.columns):
                    raw_np_value = column.to_numpy(zero_copy_only=False)

                    if col_name == "id":
                        # id column is not in the Sarus Schema
                        pydict[col_name] = raw_np_value
                    else:
                        feature = sarus_schema.features[col_name]
                        pydict[col_name] = feature.raw_np_to_tf(raw_np_value)

                yield pydict

    return generator