Training AlexNet with tips and checks on how to train CNNs: Practical CNNs in PyTorch(1)

Welcome to the first post of the ‘Practical CNNs in PyTorch’ series. I am planning to cover a variety of topics in this series from CNNs to visualizations, object detection, Neural Turing machine and various other applications of CNNs over the course of the next 2 months.

Overview of this post

This post would act as a base for my future posts. I would cover

  1. Data Loading. From creating data loaders to checks to perform to see everything is working alright.
  2. Imagenet. How to download, preprocess it and make different folders for our input pipeline.
  3. Model construction checks. After making your model, I provide some checks like overfitting on small batches, how to use loss values to check if your implementation is correct.
  4. General guidelines that are useful when constructing models.
  5. Using pre-trained models, brief discussion.

Link to Jupyter notebook.


To train CNNs we want data. The options available to you are MNIST, CIFAR, Imagenet with these being the most common. You can use any dataset. I use Imagenet as it requires some preprocessing to work.

SideNote:- I use the validation data provided by Imagenet i.e. 50000 images as my train data and take 10 images from each class from the train dataset as my val dataset(script to do so in my jupyter notebook). The choice of the dataset is up to you. Below is the processing that you have to do.

  1. Download Imagenet. You can refer to the Imagenet site to download the data. If you have limited internet, then this option is good, as you can download fewer images. Or use ImageNet Object Localization Challenge to directly download all the files (warning 155GB).
  2. Unzip the tar.gz file using tar xzvf file_name -C destination_path .
  3. In the Data/CLS-LOC folder you have the train, val and test images folders. The train images are already in their class folders i.e. the images of dogs are in a folder called dog and images of cats are in cat folder. But the val images are not classified in their class folders.
  4. Use this command from your terminal in the val folder wget -qO- | bash . It would move all the images to their respective class folders.
  5. As a general preprocessing step, we rescale all images to 256x??? on thir shorter side. As this operation repeats everytime I store the rescaled version of the images on disk. Using find . -name “*.JPEG” | xargs -I {} convert {} -resize “256^>” {} .

After doing the above steps you would have your folder with all the images in their class folders and the dimension of all images would be 256x???.

Fig1: File structure for my data folder. Each folder contains subfolder like ‘n01440764’ and the images for that class are placed in that folder.

Data Loading

Steps involved:-

  1. Create a dataset class or use a predefined class
  2. Choose what transforms you want to perform on the data.
  3. Create data loaders
train_dir = '../../../Data/ILSVRC2012/train'
val_dir = '../../../Data/ILSVRC2012/val'
size = 224
batch_size = 32
num_workers = 8
data_transforms = {
'train': transforms.Compose([
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
'val': transforms.Compose([
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
image_datasets = {
'train': ImageFolder(train_dir, transform=data_transforms['train']),
'val': ImageFolder(val_dir, transform=data_transforms['val']),
data_loader = {
num_workers=num_workers) for x in ['train', 'val']

The normalization values are precalculated for the Imagenet dataset so we use those values for normalization step.

Check dataloaders

After creating the input data pipeline, you should do a sanity check to see everything is working as expected. Plot some images.

# As our images are normalized we have to denormalize them and 
# convert them to numpy arrays.
def imshow(img, title=None):
img = img.numpy().transpose((1, 2, 0))
mean = np.array([0.485, 0.456, 0.406])
std = np.array([0.229, 0.224, 0.225])
img = std*img + mean
img = np.clip(img, 0, 1)
if title is not None:
plt.pause(0.001) #Pause is necessary to display images correctly

images, labels = next(iter(data_loader['train']))
grid_img = make_grid(images[:4], nrow=4)
imshow(grid_img, title = [labels_list[x] for x in labels[:4]])

One problem that you will face with Imagenet data is with getting the class names. The class names are contained in the file LOC_synset_mapping.txt.

f = open("../../Data/LOC_synset_mapping.txt", "r")
labels_dict = {} # Get class label by folder name
labels_list = [] # Get class label by indexing
for line in f:
split = line.split(' ', maxsplit=1)
split[1] = split[1][:-1]
label_id, label = split[0], split[1]
labels_dict[label_id] = label

Model Construction

Create your model. Pre-trained models covered at the end of the post.

SideNote:- Changes from the original AlexNet. We use BatchNorm instead of ‘brightness normalization’.

class AlexNet(nn.Module):
def __init__(self, num_classes=1000):
super(AlexNet, self).__init__()
self.conv_base = nn.Sequential(
nn.Conv2d(3, 96, kernel_size=11, stride=4, padding=2, bias=False),
nn.MaxPool2d(kernel_size=3, stride=2),

nn.Conv2d(96, 256, kernel_size=5, stride=1, padding=2, bias=False),
nn.MaxPool2d(kernel_size=3, stride=2),

nn.Conv2d(256, 384, kernel_size=3, stride=1, padding=1),

nn.Conv2d(384, 384, kernel_size=3, stride=1, padding=1),

nn.Conv2d(384, 256, kernel_size=3, stride=1, padding=1),
nn.MaxPool2d(kernel_size=3, stride=2),
self.fc_base = nn.Sequential(
nn.Linear(256*6*6, 4096),

nn.Linear(4096, 4096),

nn.Linear(4096, num_classes),

def forward(self, x):
x = self.conv_base(x)
x = x.view(x.size(0), 256*6*6)
x = self.fc_base(x)
return x

See the division of the conv_base and fc_base in the model. This is a general scheme that you would see in most implementations i.e. dividing the model into smaller models. We use 0-indexing to access the layers for now, but in future posts, I would use names for layers (as it would help for weight initialization).

Best practices for CNN

  1. Activation function:- ReLU is the default choice. But LeakyReLU is also good. Use LeakyReLU in GANs always.
  2. Weight Initialization:- Use He initialization as default with ReLU. PyTorch provides kaiming_uniform_ and kaiming_normal_ for this purpose.
  3. Preprocess data:- There are two choices normalizing between [-1,1] or using (x-mean)/std. We prefer the former when we know different features do not relate to each other.
  4. Batch Normalization:- Apply before non-linearity i.e. ReLU. For the values of the mean and variance use the running average of the values while training as test time. PyTorch automatically maintains this for you. Note: In a recent review paper for ICLR 2019, FixUp initialization was introduced. Using it, you don’t need batchnorm layers in your model.
  5. Pooling layers:- Apply after non-linearity i.e. ReLU. Different tasks would require different pooling methods for classification max-pool is good.
  6. Optimizer:- Adam is a good choice, SDG+momentum+nesterov is also good. recently announced a new optimizer AdamW. Choice of optimizer comes to experimentation and the task at hand. Look at benchmarks using different optimizers as a reference.

Weight Initialization

Do not use this method as a default. After, naming the layers you can do this very easily.

conv_list = [0, 4, 8, 10, 12]
fc_list = [1, 4, 6]
for i in conv_list:
for i in fc_list:

Create optimizers, schedulers and loss functions

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# Cross entropy loss takes the logits directly, so we don't need to apply softmax in our CNN
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001, weight_decay=0.0005)
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'max', verbose=True)

PyTorch specific discussion

  • You have to specify the padding yourself. Check this thread for discussion on this topic.
  • Create the optimizer after moving the model to GPU.
  • The decision to add softmax layer in your model depends on your loss function. In case of CrossEntropyLoss, we do not need to add softmax layer in our model as that is handled by loss function itself.
  • Do not forget to zero the grads.

How to check my model is correct?

Check 1:- The first technique is to overfit a mini-batch. If the model is not able to overfit small mini-batch then your model lacks the power to generalize over the dataset. Below I overfit 32-batch input

Things to remember

Turn off regularization like Dropout, BatchNorm although results don’t vary much in other case. Don’t use L2 regularization i.e. make weight_decay=0 in optimizer.

Remember to reinitialize your weights again.

Check 2:- Double check loss value. If you are doing a binary classification and are getting a loss of 2.3 on the first iter then it is ok, but if you are getting a loss of 100 then there are some problems.

In the above figure, you can see we got a loss value of 10.85 which is ok considering the fact we have 1000 classes. In case you get weird loss values try checking for negative signs.

Using Pre-Trained Models

As we are using AlexNet, we download AlexNet from torchvision.models and try to fit it on CIFAR-10 dataset.

Warning: Just doing for fun. Rescaling images from 32x32 to 224x224 is not recommended.

Refer to this script on how I processed CIFAR data after downloading from the official site. You can also download CIFAR from torchvision.datasets.

PyTorch has a very good tutorial on fine-tuning torchvision models. I give a short implementation with the rest of the code being in the jupyter notebook.


We discussed how to create dataloaders, plot images to check data loaders are correct. Then we implemented AlexNet in PyTorch and then discussed some important choices while working with CNNs like activations functions, pooling functions, weight initialization (code for He. initialization was also shared). Some checks like overfitting small dataset and manually checking the loss function were then discussed. We concluded by using a pre-trained AlenNet to classify CIFAR-10 images.


[1] Helped in preprocessing of the Imagenet dataset

[2] Many code references were taken from this tutorial.

[3] AlexNet paper

Deep Learning Researcher with interest in Computer Vision and Natural Language Processing