Transfer Learning from Supervised and Self-Supervised Pretraining using PyTorch Lightning
Credit to original author William Falcon, and also to Alfredo Canziani for posting the video presentation: Supervised and self-supervised transfer learning (with PyTorch Lightning)
In the video presentation, they compare transfer learning from pretrained:
- supervised
- self-supervised
However, I would like to point out that the comparison is not entirely fair for the case of supervised pretraining. The reason is that they do not replace the last fully-connected layer of the supervised pretrained backbone model with the new finetuning layer. Instead, they stack the new finetuning layer on top of the pretrained model (including its last fully connected layer).
This is a clear disadvantage for the supervised pretrained model because:
- all its expressive power is contained in the output of the penultimate layer
- and it was already used by the last fully-connected layer to predict 1,000 classes
When stacking the finetuning layer on top of it, this has to perform the 10-class classification using the output of the 1,000-class classfication layer.
On the contrary, if we replace the backbone last fully connected layer with the new finetuning layer, it will be able to perform the 10-class classification using all the expressive power of the features coming from the output of the penultimate layer.
In this notebook I show that if we replace the last fully connected layer with the new finetuning layer, both supervised and self-supervised approaches give comparable results.
!pip install pytorch-lightning
!pip install pytorch-lightning-bolts==0.2.5rc1
import pytorch_lightning as pl
import pl_bolts
print(f"pl version: {pl.__version__}")
print(f"pl_bolts version: {pl_bolts.__version__}")
import torch
from torchvision import models
resnet50 = models.resnet50(pretrained=True)
from torchvision.datasets import CIFAR10
from torchvision import transforms
normalize = transforms.Normalize(mean=[x / 255.0 for x in [125.3, 123.0, 113.9]],
std=[x / 255.0 for x in [63.0, 62.1, 66.7]])
cf10_transforms = transforms.Compose([
cifar_10 = CIFAR10('.', train=True, download=True, transform=cf10_transforms)
from matplotlib import pyplot as plt
image, label = next(iter(cifar_10))
print(f"LABEL: {label}")
plt_img = image.numpy().transpose(1, 2, 0)
from import DataLoader
train_loader = DataLoader(cifar_10, batch_size=32, shuffle=True)
for batch in train_loader:
x, y = batch
print(x.shape, y.shape)
import torch
from torchvision import models
resnet50 = models.resnet50(pretrained=True)
for param in resnet50.parameters():
param.requires_grad = False
num_classes = 10
resnet50.fc = torch.nn.Linear(resnet50.fc.in_features, num_classes)
# Use afterwards in optimizer: resnet50.fc.parameters()
x, y = next(iter(train_loader))
preds = resnet50(x)
from torch.nn.functional import softmax
preds = softmax(preds, dim=-1)
pred_labels = torch.argmax(preds, dim=-1)
from pl_bolts.datamodules import CIFAR10DataModule
dm = CIFAR10DataModule('.')
import pytorch_lightning as pl
from pytorch_lightning.metrics.functional import accuracy
from torch.nn.functional import cross_entropy
from torch.optim import Adam
class ImageClassifier(pl.LightningModule):
def __init__(self, num_classes=10, lr=1e-3):
# self.num_classes = num_classes
# = lr
self.model = models.resnet50(pretrained=True)
for param in self.model.parameters():
param.requires_grad = False
self.model.fc = torch.nn.Linear(self.model.fc.in_features, num_classes)
def training_step(self, batch, batch_idx):
# return the loss given a batch: this has a computational graph attached to it: optimization
x, y = batch
preds = self.model(x)
loss = cross_entropy(preds, y)
self.log('train_loss', loss) # lightning detaches your loss graph and uses its value
self.log('train_acc', accuracy(preds, y))
return loss
def configure_optimizers(self):
# return optimizer
optimizer = Adam(self.model.fc.parameters(),
return optimizer
classifier = ImageClassifier()
trainer = pl.Trainer(progress_bar_refresh_rate=20, gpus=1, max_epochs=2) # for Colab: set refresh rate to 20 instead of 10 to avoid freezing, dm) # train_loader
import pytorch_lightning as pl
from pytorch_lightning.metrics.functional import accuracy
from torch.nn.functional import cross_entropy
from torch.optim import Adam
class ImageClassifier(pl.LightningModule):
def __init__(self, num_classes=10, lr=1e-3):
# self.num_classes = num_classes
# = lr
self.model = models.resnet50(pretrained=True)
for param in self.model.parameters():
param.requires_grad = False
self.model.fc = torch.nn.Linear(self.model.fc.in_features, num_classes)
def training_step(self, batch, batch_idx):
# return the loss given a batch: this has a computational graph attached to it: optimization
x, y = batch
if self.trainer.current_epoch == 10:
for param in self.model.parameters():
param.requires_grad = True
preds = self.model(x)
loss = cross_entropy(preds, y)
self.log('train_loss', loss) # lightning detaches your loss graph and uses its value
self.log('train_acc', accuracy(preds, y))
return loss
def configure_optimizers(self):
# return optimizer
optimizer = Adam(self.model.parameters(), # self.model.fc.parameters()
return optimizer
classifier = ImageClassifier()
trainer = pl.Trainer(progress_bar_refresh_rate=5, gpus=1, limit_train_batches=20, max_epochs=20), dm) # train_loader
Self-Supervised Pretraining
import pytorch_lightning as pl
from pytorch_lightning.metrics.functional import accuracy
from torch.nn.functional import cross_entropy
from torch.optim import Adam
from pl_bolts.models.self_supervised import SwAV
weight_path = ''
swav = SwAV.load_from_checkpoint(weight_path, strict=True)
# from pl_bolts.models.self_supervised import SimCLR
# weight_path = ''
# simclr = SimCLR.load_from_checkpoint(weight_path, strict=False)
class ImageClassifier(pl.LightningModule):
def __init__(self, num_classes=10, lr=1e-3):
# self.num_classes = num_classes
# = lr
# self.model = models.resnet50(pretrained=True)
self.backbone = swav.model
# self.backbone = simclr
for param in self.backbone.parameters():
param.requires_grad = False
# self.model.fc = torch.nn.Linear(self.model.fc.in_features, num_classes)
self.finetune_layer = torch.nn.Linear(3000, num_classes)
def training_step(self, batch, batch_idx):
# return the loss given a batch: this has a computational graph attached to it: optimization
x, y = batch
if self.trainer.current_epoch == 10:
for param in self.backbone.parameters():
param.requires_grad = True
(features1, features2) = self.backbone(x)
features = features2
# features = self.backbone(x)
preds = self.finetune_layer(features)
loss = cross_entropy(preds, y)
self.log('train_loss', loss) # lightning detaches your loss graph and uses its value
self.log('train_acc', accuracy(preds, y))
return loss
def configure_optimizers(self):
# return optimizer
optimizer = Adam(self.parameters(), # self.model.fc.parameters()
return optimizer
classifier = ImageClassifier()
trainer = pl.Trainer(progress_bar_refresh_rate=5, gpus=1, limit_train_batches=20, max_epochs=20), dm) # train_loader
%reload_ext tensorboard
%tensorboard --logdir lightning_logs/