Image Classification with PyTorch

Image Classification with PyTorch

Image classification is a common task in computer vision that involves assigning labels or categories to images. It has many real-world applications such as facial recognition, photo organization, medical imaging analysis, self-driving cars, and more. In this post, we will walk through how to build and train an image classifier using the PyTorch deep learning framework.

Setting Up the Environment

We will first import PyTorch and the related libraries we need:

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import torchvision
from torchvision import datasets, models, transforms

We will use the MNIST dataset which contains 70,000 grayscale images of handwritten digits (0-9).

MNIST Images

The images are 28x28 pixels. We will use PyTorch's built-in MNIST dataset and transform the images to tensors:

train_dataset = datasets.MNIST(root='data', 
                               train=True, 
                               transform=transforms.ToTensor(),
                               download=True)

test_dataset = datasets.MNIST(root='data', 
                              train=False, 
                              transform=transforms.ToTensor())

We define our train and test data loaders to generate batches of data during training:

batch_size = 32

train_loader = DataLoader(dataset=train_dataset, 
                          batch_size=batch_size,
                          shuffle=True)
                          
test_loader = DataLoader(dataset=test_dataset,
                         batch_size=batch_size,
                         shuffle=False)

Building the Model

For the model architecture, we will use a simple convolutional neural network (CNN) with two convolutional layers and two fully connected layers. PyTorch provides pre-defined CNN layers that make this step straightforward:

# CNN Model
class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        
        self.conv1 = nn.Conv2d(1, 32, 5) 
        self.conv2 = nn.Conv2d(32, 64, 5)
        self.fc1 = nn.Linear(1024, 512)
        self.fc2 = nn.Linear(512, 10)
        
    def forward(self, x):
        x = F.relu(F.max_pool2d(self.conv1(x), 2))
        x = F.relu(F.max_pool2d(self.conv2(x), 2)) 
        x = x.view(-1, 1024)
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        return F.log_softmax(x, dim=1)

model = Net()

This CNN applies convolution and pooling operations to extract features from the input images. The output is flattened and passed through fully connected layers to output class probabilities for the 10 digits.

We define a cross entropy loss function and SGD optimizer:

criterion = nn.NLLLoss()
optimizer = optim.SGD(model.parameters(), lr=0.01)

Training the Model

We train the model by iterating through the training dataset in batches:

num_epochs = 5
for epoch in range(num_epochs):
    for images, labels in train_loader:
        # Forward pass and loss
        x = images.view(images.shape[0], -1)
        logits = model(x) 
        loss = criterion(logits, labels)
        
        #Backward pass and update
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
    print('Epoch [{}/{}], Loss: {:.4f}'
          .format(epoch+1, num_epochs, loss.item()))

For each batch, we pass the input to our model to get predictions. We calculate the loss between predictions and ground truth labels, backpropagate this, and update model parameters via the optimizer.

Evaluating the Model

After training, we check the model's performance on the test set:

with torch.no_grad():
    correct = 0
    total = 0
    for images, labels in test_loader:
        outputs = model(images)
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

    print('Accuracy of the model on the test images: {} %'.format(100 * correct / total))

This iterates through the test set, compares predictions to labels, and calculates the model's accuracy.

And that's it! We have built and trained an image classifier in PyTorch. This can serve as a template to build more complex CNNs for image tasks.