Repository URL to install this package:
|
Version:
2.0.0rc1 ▾
|
# flake8: noqa
# fmt: off
# __import_lightning_begin__
import math
import torch
import pytorch_lightning as pl
from filelock import FileLock
from torch.utils.data import DataLoader, random_split
from torch.nn import functional as F
from torchvision.datasets import MNIST
from torchvision import transforms
import os
# __import_lightning_end__
# __import_tune_begin__
from pytorch_lightning.loggers import TensorBoardLogger
from ray import air, tune
from ray.tune import CLIReporter
from ray.tune.schedulers import ASHAScheduler, PopulationBasedTraining
from ray.tune.integration.pytorch_lightning import TuneReportCallback, \
TuneReportCheckpointCallback
# __import_tune_end__
# __lightning_begin__
class LightningMNISTClassifier(pl.LightningModule):
"""
This has been adapted from
https://towardsdatascience.com/from-pytorch-to-pytorch-lightning-a-gentle-introduction-b371b7caaf09
"""
def __init__(self, config, data_dir=None):
super(LightningMNISTClassifier, self).__init__()
self.data_dir = data_dir or os.getcwd()
self.layer_1_size = config["layer_1_size"]
self.layer_2_size = config["layer_2_size"]
self.lr = config["lr"]
self.batch_size = config["batch_size"]
# mnist images are (1, 28, 28) (channels, width, height)
self.layer_1 = torch.nn.Linear(28 * 28, self.layer_1_size)
self.layer_2 = torch.nn.Linear(self.layer_1_size, self.layer_2_size)
self.layer_3 = torch.nn.Linear(self.layer_2_size, 10)
def forward(self, x):
batch_size, channels, width, height = x.size()
x = x.view(batch_size, -1)
x = self.layer_1(x)
x = torch.relu(x)
x = self.layer_2(x)
x = torch.relu(x)
x = self.layer_3(x)
x = torch.log_softmax(x, dim=1)
return x
def cross_entropy_loss(self, logits, labels):
return F.nll_loss(logits, labels)
def accuracy(self, logits, labels):
_, predicted = torch.max(logits.data, 1)
correct = (predicted == labels).sum().item()
accuracy = correct / len(labels)
return torch.tensor(accuracy)
def training_step(self, train_batch, batch_idx):
x, y = train_batch
logits = self.forward(x)
loss = self.cross_entropy_loss(logits, y)
accuracy = self.accuracy(logits, y)
self.log("ptl/train_loss", loss)
self.log("ptl/train_accuracy", accuracy)
return loss
def validation_step(self, val_batch, batch_idx):
x, y = val_batch
logits = self.forward(x)
loss = self.cross_entropy_loss(logits, y)
accuracy = self.accuracy(logits, y)
return {"val_loss": loss, "val_accuracy": accuracy}
def validation_epoch_end(self, outputs):
avg_loss = torch.stack([x["val_loss"] for x in outputs]).mean()
avg_acc = torch.stack([x["val_accuracy"] for x in outputs]).mean()
self.log("ptl/val_loss", avg_loss)
self.log("ptl/val_accuracy", avg_acc)
@staticmethod
def download_data(data_dir):
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.1307, ), (0.3081, ))
])
with FileLock(os.path.expanduser("~/.data.lock")):
return MNIST(data_dir, train=True, download=True, transform=transform)
def prepare_data(self):
mnist_train = self.download_data(self.data_dir)
self.mnist_train, self.mnist_val = random_split(
mnist_train, [55000, 5000])
def train_dataloader(self):
return DataLoader(self.mnist_train, batch_size=int(self.batch_size))
def val_dataloader(self):
return DataLoader(self.mnist_val, batch_size=int(self.batch_size))
def configure_optimizers(self):
optimizer = torch.optim.Adam(self.parameters(), lr=self.lr)
return optimizer
def train_mnist(config):
model = LightningMNISTClassifier(config)
trainer = pl.Trainer(max_epochs=10, enable_progress_bar=False)
trainer.fit(model)
# __lightning_end__
# __no_tune_train_begin__
def train_mnist_no_tune():
config = {
"layer_1_size": 128,
"layer_2_size": 256,
"lr": 1e-3,
"batch_size": 64
}
train_mnist(config)
# __no_tune_train_end__
# __tune_train_begin__
def train_mnist_tune(config, num_epochs=10, num_gpus=0, data_dir="~/data"):
data_dir = os.path.expanduser(data_dir)
model = LightningMNISTClassifier(config, data_dir)
trainer = pl.Trainer(
max_epochs=num_epochs,
# If fractional GPUs passed in, convert to int.
gpus=math.ceil(num_gpus),
logger=TensorBoardLogger(
save_dir=os.getcwd(), name="", version="."),
enable_progress_bar=False,
callbacks=[
TuneReportCallback(
{
"loss": "ptl/val_loss",
"mean_accuracy": "ptl/val_accuracy"
},
on="validation_end")
])
trainer.fit(model)
# __tune_train_end__
# __tune_train_checkpoint_begin__
def train_mnist_tune_checkpoint(config,
checkpoint_dir=None,
num_epochs=10,
num_gpus=0,
data_dir="~/data"):
data_dir = os.path.expanduser(data_dir)
kwargs = {
"max_epochs": num_epochs,
# If fractional GPUs passed in, convert to int.
"gpus": math.ceil(num_gpus),
"logger": TensorBoardLogger(
save_dir=os.getcwd(), name="", version="."),
"enable_progress_bar": False,
"callbacks": [
TuneReportCheckpointCallback(
metrics={
"loss": "ptl/val_loss",
"mean_accuracy": "ptl/val_accuracy"
},
filename="checkpoint",
on="validation_end")
]
}
if checkpoint_dir:
kwargs["resume_from_checkpoint"] = os.path.join(
checkpoint_dir, "checkpoint")
model = LightningMNISTClassifier(config=config, data_dir=data_dir)
trainer = pl.Trainer(**kwargs)
trainer.fit(model)
# __tune_train_checkpoint_end__
# __tune_asha_begin__
def tune_mnist_asha(num_samples=10, num_epochs=10, gpus_per_trial=0, data_dir="~/data"):
config = {
"layer_1_size": tune.choice([32, 64, 128]),
"layer_2_size": tune.choice([64, 128, 256]),
"lr": tune.loguniform(1e-4, 1e-1),
"batch_size": tune.choice([32, 64, 128]),
}
scheduler = ASHAScheduler(
max_t=num_epochs,
grace_period=1,
reduction_factor=2)
reporter = CLIReporter(
parameter_columns=["layer_1_size", "layer_2_size", "lr", "batch_size"],
metric_columns=["loss", "mean_accuracy", "training_iteration"])
train_fn_with_parameters = tune.with_parameters(train_mnist_tune,
num_epochs=num_epochs,
num_gpus=gpus_per_trial,
data_dir=data_dir)
resources_per_trial = {"cpu": 1, "gpu": gpus_per_trial}
tuner = tune.Tuner(tune.with_resources(train_fn_with_parameters, resources=resources_per_trial),
tune_config=tune.TuneConfig(metric="loss",
mode="min", num_samples=num_samples,
scheduler=scheduler,),
run_config=air.RunConfig(name="tune_mnist_asha", progress_reporter=reporter,),
param_space=config,
)
results = tuner.fit()
print("Best hyperparameters found were: ", results.get_best_result().config)
# __tune_asha_end__
# __tune_pbt_begin__
def tune_mnist_pbt(num_samples=10, num_epochs=10, gpus_per_trial=0, data_dir="~/data"):
config = {
"layer_1_size": tune.choice([32, 64, 128]),
"layer_2_size": tune.choice([64, 128, 256]),
"lr": 1e-3,
"batch_size": 64,
}
scheduler = PopulationBasedTraining(
perturbation_interval=4,
hyperparam_mutations={
"lr": tune.loguniform(1e-4, 1e-1),
"batch_size": [32, 64, 128]
})
reporter = CLIReporter(
parameter_columns=["layer_1_size", "layer_2_size", "lr", "batch_size"],
metric_columns=["loss", "mean_accuracy", "training_iteration"])
tuner = tune.Tuner(
tune.with_resources(
tune.with_parameters(
train_mnist_tune_checkpoint,
num_epochs=num_epochs,
num_gpus=gpus_per_trial,
data_dir=data_dir),
resources={
"cpu": 1,
"gpu": gpus_per_trial
}),
tune_config=tune.TuneConfig(metric="loss",
mode="min", num_samples=num_samples,
scheduler=scheduler,),
run_config=air.RunConfig(name="tune_mnist_pbt", progress_reporter=reporter,),
param_space=config,
)
results = tuner.fit()
print("Best hyperparameters found were: ", results.get_best_result().config)
# __tune_pbt_end__
if __name__ == "__main__":
import argparse
parser = argparse.ArgumentParser()
parser.add_argument(
"--smoke-test", action="store_true", help="Finish quickly for testing")
parser.add_argument(
"--data-dir",
type=str,
default="~/data/",
help="Set the path of the dataset."
)
args, _ = parser.parse_known_args()
data_dir = args.data_dir
if args.smoke_test:
tune_mnist_asha(num_samples=1, num_epochs=6, gpus_per_trial=0, data_dir=data_dir)
tune_mnist_pbt(num_samples=1, num_epochs=6, gpus_per_trial=0, data_dir=data_dir)
else:
# ASHA scheduler
tune_mnist_asha(num_samples=10, num_epochs=10, gpus_per_trial=0, data_dir=data_dir)
# Population based training
tune_mnist_pbt(num_samples=10, num_epochs=10, gpus_per_trial=0, data_dir=data_dir)