Tutorial

Writing AlexNet from Scratch in PyTorch

Updated on April 18, 2025
author

Nouman

Writing AlexNet from Scratch in PyTorch

Introduction

This article is part of our ongoing series on implementing popular convolutional neural networks from scratch using PyTorch. In the previous installment — Writing LeNet5 from Scratch in PyTorch — we explored how to build one of the earliest CNN architectures. In this post, we’re moving forward to a more advanced and historically significant model: AlexNet, a significant architecture in computer vision.

We’ll begin by breaking down AlexNet’s architecture to understand its key components and innovations. After that, we’ll load the CIFAR-10 dataset and apply some essential preprocessing steps. Then comes the hands-on part—building AlexNet from scratch in PyTorch and training it on our processed dataset. Finally, we’ll evaluate the performance of our trained model using unseen test data.

Prerequisites

Knowledge of neural networks will help you understand this article. This would encompass being familiar with the different layers of neural networks (input layer, hidden layers, output layer), activation functions, optimization algorithms (variants of gradient descent), loss functions, etc. Additionally, familiarity with Python syntax and the PyTorch library is essential for understanding the code snippets presented in this article.

An understanding of CNNs is essential. This includes knowledge of convolutional layers, pooling layers, and their role in extracting features from input data. Understanding concepts like stride, padding, and the impact of kernel/filter size is also beneficial.

AlexNet

AlexNet is a deep convolutional neural network introduced by Alex Krizhevsky and his team in 2012. It was developed for the ImageNet LSVRC-2010 competition, where it delivered groundbreaking performance and set a new benchmark in image classification. You can check out the original research paper here for an in-depth look at the model.

Let’s review the key takeaways from the AlexNet paper. First, AlexNet operated with 3-channel images (224x224x3) in size. Max pooling and ReLU activations were used when subsampling. The kernels used for convolutions were either 11x11, 5x5, or 3x3, while kernels used for max pooling were 3x3 in size. AlexNet classified images into 1000 classes and utilized multiple GPUs.

Dataset

Let’s start by loading and then pre-processing the data. For our purposes, we will be using the CIFAR-10 dataset. The dataset consists of 60000 32x32 colour images in 10 classes, with 6000 images per class. There are 50000 training images and 10000 test images.

Here are the classes in the dataset, as well as 10 random sample images from each:

image grid

Source: source

The classes are completely mutually exclusive, with no overlap between automobiles and trucks. “Automobile” encompasses sedans, SUVs, and similar vehicles, while “Truck” refers solely to large trucks, excluding pickup trucks.

Importing the Libraries

Let’s start by importing the required libraries and defining a variable device so that the notebook knows how to use a GPU to train the model if it is available.

    import numpy as np
    import torch
    import torch.nn as nn
    from torchvision import datasets
    from torchvision import transforms
    from torch.utils.data.sampler import SubsetRandomSampler

    # Device configuration
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

Loading the Dataset

Using torchvision (a helper library for computer vision tasks), we will load our dataset. This method has some helper functions that make pre-processing pretty easy and straightforward. Let’s define the functions get_train_valid_loader and get_test_loader, and then call them to load in and process our CIFAR-10 data:

    def get_train_valid_loader(data_dir,
                               batch_size,
                               augment,
                               random_seed,
                               valid_size=0.1,
                               shuffle=True):
        normalize = transforms.Normalize(
            mean=[0.4914, 0.4822, 0.4465],
            std=[0.2023, 0.1994, 0.2010],
        )

        # define transforms
        valid_transform = transforms.Compose([
                transforms.Resize((227,227)),
                transforms.ToTensor(),
                normalize,
        ])
        if augment:
            train_transform = transforms.Compose([
                transforms.RandomCrop(32, padding=4),
                transforms.RandomHorizontalFlip(),
                transforms.ToTensor(),
                normalize,
            ])
        else:
            train_transform = transforms.Compose([
                transforms.Resize((227,227)),
                transforms.ToTensor(),
                normalize,
            ])

        # load the dataset
        train_dataset = datasets.CIFAR10(
            root=data_dir, train=True,
            download=True, transform=train_transform,
        )

        valid_dataset = datasets.CIFAR10(
            root=data_dir, train=True,
            download=True, transform=valid_transform,
        )

        num_train = len(train_dataset)
        indices = list(range(num_train))
        split = int(np.floor(valid_size * num_train))

        if shuffle:
            np.random.seed(random_seed)
            np.random.shuffle(indices)

        train_idx, valid_idx = indices[split:], indices[:split]
        train_sampler = SubsetRandomSampler(train_idx)
        valid_sampler = SubsetRandomSampler(valid_idx)

        train_loader = torch.utils.data.DataLoader(
            train_dataset, batch_size=batch_size, sampler=train_sampler)

        valid_loader = torch.utils.data.DataLoader(
            valid_dataset, batch_size=batch_size, sampler=valid_sampler)

        return (train_loader, valid_loader)


    def get_test_loader(data_dir,
                        batch_size,
                        shuffle=True):
        normalize = transforms.Normalize(
            mean=[0.485, 0.456, 0.406],
            std=[0.229, 0.224, 0.225],
        )

        # define transform
        transform = transforms.Compose([
            transforms.Resize((227,227)),
            transforms.ToTensor(),
            normalize,
        ])

        dataset = datasets.CIFAR10(
            root=data_dir, train=False,
            download=True, transform=transform,
        )

        data_loader = torch.utils.data.DataLoader(
            dataset, batch_size=batch_size, shuffle=shuffle
        )

        return data_loader

    # CIFAR10 dataset 
    train_loader, valid_loader = get_train_valid_loader(data_dir = './data',batch_size = 64,augment = False,random_seed = 1)

    test_loader = get_test_loader(data_dir = './data',
                                  batch_size = 64)

Let’s break down the code:

  • We define two functions, get_train_valid_loader and get_test_loader, to load the train/validation and test sets, respectively
  • We start by defining the variable normalized with each channel’s mean and standard deviation (red, green, and blue) in the dataset. These can be calculated manually, but are also available online since CIFAR-10 is quite popular.
  • For our training dataset, we add the option to augment the dataset for more robust training and increase the number of images. Note: Augmentation is only applied to the training subset and not the validation and testing subsets, as they are only used for evaluation purposes
  • We split the training dataset into train and validation sets (90:10 ratio) and randomly subset it from the whole training set.
  • We specify the batch size and shuffle the dataset during loading to ensure that each batch contains a variety of labels. This approach enhances the efficacy of our resulting model.
  • Finally, we utilize data loaders. While their impact may be minimal for small datasets like CIFAR-10, they can significantly hinder performance with larger datasets and are generally regarded as a best practice. Data loaders enable us to iterate through the data in batches, loading it incrementally as we iterate, rather than all at once at the beginning and filling up the RAM.

AlexNet from Scratch

Let’s start with the code:

    class AlexNet(nn.Module):
        def __init__(self, num_classes=10):
            super(AlexNet, self).__init__()
            self.layer1 = nn.Sequential(
                nn.Conv2d(3, 96, kernel_size=11, stride=4, padding=0),
                nn.BatchNorm2d(96),
                nn.ReLU(),
                nn.MaxPool2d(kernel_size = 3, stride = 2))
            self.layer2 = nn.Sequential(
                nn.Conv2d(96, 256, kernel_size=5, stride=1, padding=2),
                nn.BatchNorm2d(256),
                nn.ReLU(),
                nn.MaxPool2d(kernel_size = 3, stride = 2))
            self.layer3 = nn.Sequential(
                nn.Conv2d(256, 384, kernel_size=3, stride=1, padding=1),
                nn.BatchNorm2d(384),
                nn.ReLU())
            self.layer4 = nn.Sequential(
                nn.Conv2d(384, 384, kernel_size=3, stride=1, padding=1),
                nn.BatchNorm2d(384),
                nn.ReLU())
            self.layer5 = nn.Sequential(
                nn.Conv2d(384, 256, kernel_size=3, stride=1, padding=1),
                nn.BatchNorm2d(256),
                nn.ReLU(),
                nn.MaxPool2d(kernel_size = 3, stride = 2))
            self.fc = nn.Sequential(
                nn.Dropout(0.5),
                nn.Linear(9216, 4096),
                nn.ReLU())
            self.fc1 = nn.Sequential(
                nn.Dropout(0.5),
                nn.Linear(4096, 4096),
                nn.ReLU())
            self.fc2= nn.Sequential(
                nn.Linear(4096, num_classes))

        def forward(self, x):
            out = self.layer1(x)
            out = self.layer2(out)
            out = self.layer3(out)
            out = self.layer4(out)
            out = self.layer5(out)
            out = out.reshape(out.size(0), -1)
            out = self.fc(out)
            out = self.fc1(out)
            out = self.fc2(out)
            return out

Defining the AlexNet Model

Let’s dive into how the above code works:

  • The first step to defining any neural network (whether a CNN or not) in PyTorch is to define a class that inherits nn.Module, as it contains many of the methods that we will need to utilize at a later point.
  • After that, there are two main steps to follow. The first step is to initialize the layers that will be used in our Convolutional Neural Network (CNN) within the __init__ method. The second step is to define the sequence in which these layers will process the image, which is done in the forward function.
  • We start by defining the convolutional layers using the nn.Conv2d function, specifying the appropriate kernel size and the number of input and output channels. We also incorporate max pooling by utilizing the nn.MaxPool2d function. One of the advantages of PyTorch is the ability to combine the convolutional layer, the activation function, and max pooling into a single layer using the nn.Sequential function.
  • We define the fully connected layers using linear (nn.Linear), dropout (nn.Dropout), and the ReLU activation function (nn.ReLU), combining these with the nn.Sequential function.
  • Our final layer produces 10 neurons, each representing predictions for 10 object classes.

Setting Hyperparameters

Before training, we need to set some hyperparameters, such as the loss function and the optimizer to be used, along with batch size, learning rate, and number of epochs.

    num_classes = 10
    num_epochs = 20
    batch_size = 64
    learning_rate = 0.005

    model = AlexNet(num_classes).to(device)

    # Loss and optimizer
    criterion = nn.CrossEntropyLoss()
    optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate, weight_decay = 0.005, momentum = 0.9)  
    
    # Train the model
    total_step = len(train_loader)

We start by defining simple hyperparameters (epochs, batch size, and learning rate) and initializing our model using the number of classes as an argument, which in this case is 10. We then transfer the model to the appropriate device (CPU or GPU).

We begin by defining the cost function as the cross-entropy loss and selecting the Adam optimizer. While various options are available for both, these choices generally yield good results with the model and the provided data. Lastly, we introduce a variable called total_step to effectively monitor the steps during training.

Training

We are ready to train our model at this point:

    total_step = len(train_loader)

    for epoch in range(num_epochs):
        for i, (images, labels) in enumerate(train_loader):  
            # Move tensors to the configured device
            images = images.to(device)
            labels = labels.to(device)

            # Forward pass
            outputs = model(images)
            loss = criterion(outputs, labels)

            # Backward and optimize
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

        print ('Epoch [{}/{}], Step [{}/{}], Loss: {:.4f}' 
                       .format(epoch+1, num_epochs, i+1, total_step, loss.item()))

        # Validation
        with torch.no_grad():
            correct = 0
            total = 0
            for images, labels in valid_loader:
                images = images.to(device)
                labels = labels.to(device)
                outputs = model(images)
                _, predicted = torch.max(outputs.data, 1)
                total += labels.size(0)
                correct += (predicted == labels).sum().item()
                del images, labels, outputs

            print('Accuracy of the network on the {} validation images: {} %'.format(5000, 100 * correct / total))

Let’s see what the code does:

  • We start by iterating through the number of epochs, and then the batches in our training data.
  • We convert the images and the labels according to the device we are using, i.e., GPU or CPU.
  • In the forward pass, we make predictions using our model and calculate the loss based on those predictions and our actual labels.
  • Next, we do the backward pass, where we update our weights to improve our model.
  • We then set the gradients to zero before every update using optimizer.zero_grad() function.
  • Then, we calculate the new gradients using the loss.backward() function.
  • Finally, we update the weights with the optimizer.step() function.
  • Also, at the end of every epoch, we use our validation set to calculate the accuracy of the model as well. In this case, we don’t need gradients, so we use torch.no_grad() for faster evaluation.

We can see the output as follows:

image

Training Loss and Validation Accuracy

The loss is decreasing with each epoch, indicating that our model is learning effectively. It’s important to note that this loss is calculated on the training set. If the loss value is excessively low, it may suggest that the model is overfitting. To address this concern, we are also utilizing a validation set. The accuracy on the validation set is increasing, which suggests that there is a low risk of overfitting. Next, let’s test our model to evaluate its performance.

Testing

Now, we see how our model performs on unseen data:

    with torch.no_grad():
        correct = 0
        total = 0
        for images, labels in test_loader:
            images = images.to(device)
            labels = labels.to(device)
            outputs = model(images)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
            del images, labels, outputs

        print('Accuracy of the network on the {} test images: {} %'.format(10000, 100 * correct / total))

Note that the code is exactly the same as for our validation purposes.

Using the model, and training for only 6 epochs, we seem to get around 78.8% accuracy on the validation set.

Testing Accuracy

FAQ

What is AlexNet, and why is it significant? AlexNet is a deep learning model that revolutionized image classification by winning the 2012 ImageNet challenge.

Which dataset is used in this tutorial? The CIFAR-10 dataset is used here, which has 60,000 small images across 10 classes.

Why are CIFAR-10 images resized to 227x227? AlexNet expects 227x227 input size, so CIFAR-10 images are resized to match it.

What is the purpose of Batch Normalization in this model? Batch Norm helps the model train faster and more reliably by stabilizing the activations.

How is the dataset preprocessed? The images are resized, normalized, and converted into tensors for training.

What are the main components of the AlexNet architecture in PyTorch? It includes convolutional layers, ReLU, max pooling, batch norm, and fully connected layers.

What optimizer and loss function are used in training? The model uses CrossEntropyLoss and the Adam optimizer for training.

What kind of accuracy can I expect from this AlexNet on CIFAR-10? You can expect around 80-85% accuracy after proper training and tuning.

Can this AlexNet implementation be used on other datasets? Yes, just resize your images and adjust output classes to fit the new dataset.

Conclusion

To wrap things up, we started by breaking down the AlexNet architecture and understanding the role of each layer. We then loaded and preprocessed the CIFAR-10 dataset using torchvision, making it ready for our model. After that, we built the AlexNet model from scratch using PyTorch and trained it for just a few epochs. Even with minimal training, the model showed promising performance on the test set, proving that AlexNet still holds up as a solid starting point for image classification tasks.

Overall, this tutorial gave us hands-on experience with building and training a deep neural network and showed how classic models like AlexNet can still be useful when adapted to modern datasets. With further tuning, data augmentation, or even transfer learning, you could push the model’s performance even further.

Thanks for learning with the DigitalOcean Community. Check out our offerings for compute, storage, networking, and managed databases.

Learn more about our products

About the author(s)

Still looking for an answer?

Ask a questionSearch for more help

Was this helpful?
 
Leave a comment


This textbox defaults to using Markdown to format your answer.

You can type !ref in this text area to quickly search our full set of tutorials, documentation & marketplace offerings and insert the link!

Join the Tech Talk
Success! Thank you! Please check your email for further details.

Please complete your information!

Become a contributor for community

Get paid to write technical tutorials and select a tech-focused charity to receive a matching donation.

DigitalOcean Documentation

Full documentation for every DigitalOcean product.

Resources for startups and SMBs

The Wave has everything you need to know about building a business, from raising funding to marketing your product.

Get our newsletter

Stay up to date by signing up for DigitalOcean’s Infrastructure as a Newsletter.

New accounts only. By submitting your email you agree to our Privacy Policy

The developer cloud

Scale up as you grow — whether you're running one virtual machine or ten thousand.

Get started for free

Sign up and get $200 in credit for your first 60 days with DigitalOcean.*

*This promotional offer applies to new accounts only.