Repository URL to install this package:
|
Version:
2.0.0 ▾
|
from functools import partialmethod
from pathlib import Path
from typing import List, Union, Optional
import torch
from .tensor_dataset import TensorDataset
from ..transforms.data_processors import DefaultDataProcessor
from ..transforms.normalizers import UnitGaussianNormalizer
class PTDataset:
"""PTDataset is a base Dataset class for our library.
PTDatasets contain input-output pairs a(x), u(x) and may also
contain additional information, e.g. function parameters,
input geometry or output query points.
datasets may implement a download flag at init, which provides
access to a number of premade datasets for sample problems provided
in our Zenodo archive.
Parameters
----------
root_dir : Union[Path, str]
root at which to download data files
dataset_name : str
prefix of pt data files to store/access
n_train : int
number of train instances
n_tests : List[int]
number of test instances per test dataset
batch_size : int
batch size of training set
test_batch_sizes : List[int]
batch size of test sets
train_resolution : int
resolution of data for training set
test_resolutions : List[int]
resolution of data for testing sets
encode_input : bool, optional
whether to normalize inputs in provided DataProcessor,
by default False
encode_output : bool, optional
whether to normalize outputs in provided DataProcessor,
by default True
encoding : str, optional
parameter for input/output normalization. Whether
to normalize by channel ("channel-wise") or
by pixel ("pixel-wise"), default "channel-wise"
input_subsampling_rate : int or List[int], optional
rate at which to subsample each input dimension, by default None
output_subsampling_rate : int or List[int], optional
rate at which to subsample each output dimension, by default None
channel_dim : int, optional
dimension of saved tensors to index data channels, by default 1
channels_squeezed : bool, optional
If the channels dim is 1, whether that is explicitly kept in the saved tensor.
If not, we need to unsqueeze it to explicitly have a channel dim.
Only applies when there is only one data channel, as in our example problems
Defaults to True
All datasets are required to expose the following attributes after init:
train_db: torch.utils.data.Dataset of training examples
test_db: torch.utils.data.Dataset of test examples
data_processor: neuralop.data.transforms.DataProcessor to process data examples
optional, default is None
"""
def __init__(
self,
root_dir: Union[Path, str],
dataset_name: str,
n_train: int,
n_tests: List[int],
batch_size: int,
test_batch_sizes: List[int],
train_resolution: int,
test_resolutions: List[int],
encode_input: bool = False,
encode_output: bool = True,
encoding="channel-wise",
input_subsampling_rate=None,
output_subsampling_rate=None,
channel_dim=1,
channels_squeezed=True,
):
"""Initialize the PTDataset.
See class docstring for detailed parameter descriptions.
"""
if isinstance(root_dir, str):
root_dir = Path(root_dir)
self.root_dir = root_dir
# save dataloader properties for later
self.batch_size = batch_size
self.test_resolutions = test_resolutions
self.test_batch_sizes = test_batch_sizes
# Load train data
data = torch.load(
Path(root_dir).joinpath(f"{dataset_name}_train_{train_resolution}.pt").as_posix()
)
x_train = data["x"].type(torch.float32).clone()
if channels_squeezed:
x_train = x_train.unsqueeze(channel_dim)
# optionally subsample along data indices
## Input subsampling
input_data_dims = data["x"].ndim - 2 # batch and channels
# convert None and 0 to 1
if not input_subsampling_rate:
input_subsampling_rate = 1
if not isinstance(input_subsampling_rate, list):
# expand subsampling rate along dims if one per dim is not provided
input_subsampling_rate = [input_subsampling_rate] * input_data_dims
# make sure there is one subsampling rate per data dim
assert (
len(input_subsampling_rate) == input_data_dims
), f"Error: length mismatch between input_subsampling_rate and dimensions of data.\
input_subsampling_rate must be one int shared across all dims, or an iterable of\
length {len(input_data_dims)}, got {input_subsampling_rate}"
# Construct full indices along which to grab X
train_input_indices = [slice(0, n_train, None)] + [slice(None, None, rate) for rate in input_subsampling_rate]
train_input_indices.insert(channel_dim, slice(None))
train_input_indices = tuple(train_input_indices)
x_train = x_train[train_input_indices]
y_train = data["y"].clone()
if channels_squeezed:
y_train = y_train.unsqueeze(channel_dim)
## Output subsampling
output_data_dims = data["y"].ndim - 2
# convert None and 0 to 1
if not input_subsampling_rate:
output_subsampling_rate = 1
if not isinstance(output_subsampling_rate, list):
# expand subsampling rate along dims if one per dim is not provided
output_subsampling_rate = [output_subsampling_rate] * output_data_dims
# make sure there is one subsampling rate per data dim
assert (
len(output_subsampling_rate) == output_data_dims
), f"Error: length mismatch between output_subsampling_rate and dimensions of data.\
input_subsampling_rate must be one int shared across all dims, or an iterable of\
length {len(output_data_dims)}, got {output_subsampling_rate}"
# Construct full indices along which to grab Y
train_output_indices = [slice(0, n_train, None)] + [slice(None, None, rate) for rate in output_subsampling_rate]
train_output_indices.insert(channel_dim, slice(None))
train_output_indices = tuple(train_output_indices)
y_train = y_train[train_output_indices]
del data
# Fit optional encoders to train data
# Actual encoding happens within DataProcessor
if encode_input:
if encoding == "channel-wise":
reduce_dims = list(range(x_train.ndim))
# preserve mean for each channel
reduce_dims.pop(channel_dim)
elif encoding == "pixel-wise":
reduce_dims = [0]
input_encoder = UnitGaussianNormalizer(dim=reduce_dims)
input_encoder.fit(x_train)
else:
input_encoder = None
if encode_output:
if encoding == "channel-wise":
reduce_dims = list(range(y_train.ndim))
# preserve mean for each channel
reduce_dims.pop(channel_dim)
elif encoding == "pixel-wise":
reduce_dims = [0]
output_encoder = UnitGaussianNormalizer(dim=reduce_dims)
output_encoder.fit(y_train)
else:
output_encoder = None
# Save train dataset
self._train_db = TensorDataset(
x_train,
y_train,
)
# create DataProcessor
self._data_processor = DefaultDataProcessor(
in_normalizer=input_encoder, out_normalizer=output_encoder
)
# load test data
self._test_dbs = {}
for res, n_test in zip(test_resolutions, n_tests):
print(f"Loading test db for resolution {res} with {n_test} samples ")
data = torch.load(Path(root_dir).joinpath(f"{dataset_name}_test_{res}.pt").as_posix())
x_test = data["x"].type(torch.float32).clone()
if channels_squeezed:
x_test = x_test.unsqueeze(channel_dim)
# optionally subsample along data indices
test_input_indices = [slice(0, n_test, None)] + [slice(None, None, rate) for rate in input_subsampling_rate]
test_input_indices.insert(channel_dim, slice(None))
test_input_indices = tuple(test_input_indices)
x_test = x_test[test_input_indices]
y_test = data["y"].clone()
if channels_squeezed:
y_test = y_test.unsqueeze(channel_dim)
test_output_indices = [slice(0, n_test, None)] + [slice(None, None, rate) for rate in output_subsampling_rate]
test_output_indices.insert(channel_dim, slice(None))
test_output_indices = tuple(test_output_indices)
y_test = y_test[test_output_indices]
del data
test_db = TensorDataset(
x_test,
y_test,
)
self._test_dbs[res] = test_db
@property
def data_processor(self):
return self._data_processor
@property
def train_db(self):
return self._train_db
@property
def test_dbs(self):
return self._test_dbs