Repository URL to install this package:
|
Version:
2.0.0 ▾
|
from math import ceil, floor
import torch
from torch.utils.data import DataLoader
from torch_harmonics.examples import ShallowWaterSolver
def load_spherical_swe(
n_train,
n_tests,
batch_size,
test_batch_sizes,
train_resolution=(256, 512),
test_resolutions=[(256, 512)],
device=torch.device("cpu"),
):
"""Load the Spherical Shallow Water equations Dataloader"""
print(f"Loading train dataloader at resolution {train_resolution} with {n_train} samples and batch-size={batch_size}")
train_dataset = SphericalSWEDataset(dims=train_resolution, num_examples=n_train, device=device)
train_loader = DataLoader(
train_dataset,
batch_size=batch_size,
shuffle=True,
num_workers=0,
persistent_workers=False,
)
test_loaders = dict()
for res, n_test, test_batch_size in zip(test_resolutions, n_tests, test_batch_sizes):
print(f"Loading test dataloader at resolution {res} with {n_test} samples and batch-size={test_batch_size}")
test_dataset = SphericalSWEDataset(dims=res, num_examples=n_test, device=device)
test_loader = DataLoader(
test_dataset,
batch_size=test_batch_size,
shuffle=True,
num_workers=0,
persistent_workers=False,
)
test_loaders[res] = test_loader
return train_loader, test_loaders
class SphericalSWEDataset(torch.utils.data.Dataset):
"""Custom Dataset class for PDE training data"""
def __init__(
self,
dt=3600,
dims=(256, 512),
initial_condition="random",
num_examples=32,
device=torch.device("cpu"),
normalize=True,
stream=None,
):
# Caution: this is a heuristic which can break and lead to diverging results
dt_min = 256 / dims[0] * 150
nsteps = int(floor(dt / dt_min))
self.num_examples = num_examples
self.device = device
self.stream = stream
self.nlat = dims[0]
self.nlon = dims[1]
# number of solver steps used to compute the target
self.nsteps = nsteps
self.normalize = normalize
lmax = ceil(self.nlat / 3)
mmax = lmax
dt_solver = dt / float(self.nsteps)
self.solver = (
ShallowWaterSolver(
self.nlat,
self.nlon,
dt_solver,
lmax=lmax,
mmax=mmax,
grid="equiangular"
)
.to(self.device).float()
)
self.set_initial_condition(ictype=initial_condition)
if self.normalize:
inp0, _ = self._get_sample()
self.inp_mean = torch.mean(inp0, dim=(-1, -2)).reshape(-1, 1, 1)
self.inp_var = torch.var(inp0, dim=(-1, -2)).reshape(-1, 1, 1)
def __len__(self):
length = self.num_examples if self.ictype == "random" else 1
return length
def set_initial_condition(self, ictype="random"):
self.ictype = ictype
def set_num_examples(self, num_examples=32):
self.num_examples = num_examples
def _get_sample(self):
if self.ictype == "random":
inp = self.solver.random_initial_condition(mach=0.2)
elif self.ictype == "galewsky":
inp = self.solver.galewsky_initial_condition()
# solve pde for n steps to return the target
tar = self.solver.timestep(inp, self.nsteps)
inp = self.solver.spec2grid(inp)
tar = self.solver.spec2grid(tar)
return inp, tar
def __getitem__(self, index):
with torch.inference_mode():
with torch.no_grad():
inp, tar = self._get_sample()
if self.normalize:
inp = (inp - self.inp_mean) / torch.sqrt(self.inp_var)
tar = (tar - self.inp_mean) / torch.sqrt(self.inp_var)
return {"x": inp.clone(), "y": tar.clone()}