Introduction
In this post, I will demonstrate how to use pruning to significantly reduce a model’s size and latency while maintaining minimal accuracy loss. In the example, we achieve a 90% reduction in model size and 5.5x faster inference time, all while preserving the same level of accuracy.
We will begin with a brief explanation of what pruning is and why it is important. Then, I’ll provide a hands-on demonstration of applying pruning to a PyTorch model.
Overview of Pruning
Neural network pruning involves removing less important weights, channels, or neurons from a neural network to make it smaller and faster. The goal is to reduce computational costs (such as latency and memory usage) without significantly affecting model accuracy.
Deep neural networks often contain a lot of redundancy. This redundancy arises because models are typically overparameterized to ensure high accuracy and generalization. During training, many parameters become co-dependent or have little impact on the final output. For example, multiple neurons may learn similar features, or certain filters may remain underutilized. This redundancy makes models robust but also bloated. Pruning helps streamline these models by eliminating parts that contribute the least to the output, resulting in a more efficient network that is easier to deploy on edge devices or in latency-sensitive applications.
There are two main types of pruning:
-
Unstructured Pruning: Removes individual weights regardless of their position. While it can achieve high sparsity, it often requires specialized hardware or libraries to fully utilize the sparsity. Zeroing out individual weights typically does not improve latency because standard deep learning libraries use dense matrix multiplication regardless of how many weights are zeroed out. To benefit from sparsity, the model must be converted into a sparse format, which is often not well-supported by commodity hardware. In fact, these sparse representations can sometimes be slower than dense operations due to less optimized memory access patterns and lack of hardware acceleration. As a result, unstructured pruning offers theoretical compression but not always practical speedups unless carefully integrated into the deployment pipeline.
-
Structured Pruning: Removes entire filters, channels, or layers, leading to real improvements in inference speed on standard hardware. Unlike unstructured pruning, which retains the original dense structure and thus doesn’t alter compute patterns, structured pruning directly reduces the dimensionality of tensors and layers. This means fewer floating-point operations (FLOPs) and less memory access, as the actual matrices involved in convolutions and linear operations are physically smaller. As a result, inference is faster and more efficient on standard hardware using optimized dense kernels, with no need for specialized sparse computation support.
PyTorch Pruning API
PyTorch provides a built-in pruning utility under torch.nn.utils.prune
. This API supports both unstructured pruning (zeroing individual weights by magnitude or custom metrics) and structured pruning (removing entire channels or neurons). The PyTorch pruning tutorial offers a solid introduction using iterative magnitude pruning. However, it is important to note that the PyTorch pruning API does not result in real inference speedups out-of-the-box. This is because it primarily focuses on zeroing out weights rather than removing them. For unstructured pruning, it does not convert the model to a sparse representation, which is necessary to leverage computational gains. For structured pruning, it does not automatically modify the architecture to remove entire channels or filters, which means the computational graph remains unchanged.
That said, the PyTorch pruning API is a flexible and useful tool for experimenting with pruning strategies. It provides a simple interface to apply custom pruning criteria, evaluate sparsity effects, and implement iterative pruning and retraining loops. It is especially helpful for research and prototyping where exact hardware efficiency is less critical than functional model behavior.
Why Use Torch-Pruning?
Structured pruning isn’t trivial. Removing a channel in one layer often requires modifying downstream layers. Structured pruning often involves complex inter-layer dependencies. For example, if you prune an output channel from a convolutional layer, any layer that consumes its output, such as a batch normalization layer or subsequent convolution, must also be updated to match the new shape. Managing these changes across many layers can be error-prone and tedious when done manually. Torch-Pruning solves this by introducing a graph-based algorithm called DepGraph, which automatically analyzes the model’s computation graph, identifies dependencies, and organizes pruning into safe and consistent execution plans.
Practical Usage Example: Pruning ResNet-18 in PyTorch
Let’s walk through pruning a ResNet-18 model step-by-step using torch-pruning
. We’ll do this in Google Colab, so you can follow along easily. This example is adapted from the official README of Torch-Pruning.
Run this code in Google Colab to try it yourself.
Setup
First, install the required library:
!pip install torch-pruning
Then, define the required imports:
import os
import time
import copy
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import datasets, transforms, models
from torch.utils.data import DataLoader
import torch_pruning as tp
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"{device=}")
Get CIFAR-10 Train and Test Sets
transform = transforms.Compose([
transforms.Resize(32),
transforms.ToTensor(),
transforms.Normalize((0.5,), (0.5,))
])
train_loader = DataLoader(
datasets.CIFAR10(root="./data", train=True, download=True, transform=transform),
batch_size=128, shuffle=True
)
test_loader = DataLoader(
datasets.CIFAR10(root="./data", train=False, download=True, transform=transform),
batch_size=256
)
Adjust ResNet-18 Network for CIFAR-10 Dataset
def get_resnet18_for_cifar10():
model = models.resnet18(weights=None, num_classes=10)
model.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False)
model.maxpool = nn.Identity()
return model.to(device)
full_model = get_resnet18_for_cifar10()
Define Train and Evaluate Functions
def train(model, loader, epochs, lr=0.01, save_path="model.pth", silent=False):
if os.path.exists(save_path):
if not silent:
print(f"Model already trained. Loading from {save_path}")
model.load_state_dict(torch.load(save_path))
return
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=lr, momentum=0.9)
model.train()
for epoch in range(epochs):
for x, y in loader:
x, y = x.to(device), y.to(device)
optimizer.zero_grad()
loss = criterion(model(x), y)
loss.backward()
optimizer.step()
if not silent:
print(f"Epoch {epoch+1}: loss={loss.item():.4f}")
torch.save(model.state_dict(), save_path)
if not silent:
print(f"Training complete. Model saved to {save_path}")
def evaluate(model):
model.eval()
correct = total = 0
with torch.no_grad():
for x, y in test_loader:
x, y = x.to(device), y.to(device)
preds = model(x).argmax(1)
correct += (preds == y).sum().item()
total += y.size(0)
return correct / total
Define Helper Functions to Measure Latency
class Timer:
def __init__(self):
self.use_cuda = torch.cuda.is_available()
if self.use_cuda:
self.starter = torch.cuda.Event(enable_timing=True)
self.ender = torch.cuda.Event(enable_timing=True)
def start(self):
if self.use_cuda:
self.starter.record()
else:
self.start_time = time.time()
def stop(self):
if self.use_cuda:
self.ender.record()
torch.cuda.synchronize()
return self.starter.elapsed_time(self.ender) # ms
else:
return (time.time() - self.start_time) * 1000 # ms
def estimate_latency(model, example_inputs, repetitions=50):
timer = Timer()
timings = np.zeros((repetitions, 1))
# Warm-up
for _ in range(5):
_ = model(example_inputs)
with torch.no_grad():
for rep in range(repetitions):
timer.start()
_ = model(example_inputs)
elapsed = timer.stop()
timings[rep] = elapsed
return np.mean(timings), np.std(timings)
Train and Evaluate the Full Model
train(full_model, train_loader, epochs=10, save_path="full_model.pth")
accuracy_full = evaluate(full_model)
example_input = torch.rand(128, 3, 32, 32).to(device)
macs, parameters = tp.utils.count_ops_and_params(full_model, example_input)
latency_mu, latency_std = estimate_latency(full_model, example_input)
print(f"[full model] \t\tMACs: {macs/1e9:.2f} G, \tParameters: {parameters/1e6:.2f} M, \tLatency: {latency_mu:.2f} ± {latency_std:.2f} ms \tAccuracy: {accuracy_full*100:.2f}%")
To save you some time, here are the results for the fully trained model:
[full model] MACs: 0.56 G, Parameters: 11.17 M, Latency: 16.52 ± 0.03 ms Accuracy: 76.85%
Prune by L2 Magnitude
# Clone full model before pruning
pruned_model = copy.deepcopy(full_model)
pruned_model = pruned_model.to(device)
# Set which layers to skip pruning. Important to keep the final classifier layer
ignored_layers = []
for m in pruned_model.modules():
if isinstance(m, torch.nn.Linear) and m.out_features == 10:
ignored_layers.append(m)
# Iterative pruning
iterative_steps = 20
pruner = tp.pruner.MagnitudePruner(
model=pruned_model,
example_inputs=example_input,
importance=tp.importance.MagnitudeImportance(p=2),
pruning_ratio=1,
iterative_steps=iterative_steps,
ignored_layers=ignored_layers,
round_to=2,
)
for iter in range(iterative_steps):
# Prune
pruner.step()
# Evaluate after pruning
acc_before = evaluate(pruned_model)
# Fine-tune pruned model
train(pruned_model, train_loader, epochs=1, save_path=f"pruned_model_{iter}.pth", silent=True)
# Evaluate after fine-tuning
acc_after = evaluate(pruned_model)
# Count MACs and parameters
macs, parameters = tp.utils.count_ops_and_params(pruned_model, example_input)
latency_mu, latency_std = estimate_latency(pruned_model, example_input)
current_pruning_ratio = 1 / iterative_steps * (iter + 1)
print(f"[pruned model] \tPruning ratio: {current_pruning_ratio:.2f}, \tMACs: {macs/1e9:.2f} G, \tParameters: {parameters/1e6:.2f} M, \tLatency: {latency_mu:.2f} ± {latency_std:.2f} ms \tAccuracy pruned: {acc_before*100:.2f}%\tFinetuned: {acc_after*100:.2f}%")
The pruning results show the model’s accuracy immediately after pruning and again after fine-tuning the smaller, pruned model. While accuracy initially drops following pruning, it recovers significantly after just one epoch of fine-tuning.
[pruned model] Pruning ratio: 0.05, MACs: 0.49 G, Parameters: 10.03 M, Latency: 17.64 ± 0.04 ms Accuracy pruned: 63.60% Finetuned: 72.17%
[pruned model] Pruning ratio: 0.10, MACs: 0.44 G, Parameters: 9.00 M, Latency: 16.12 ± 0.04 ms Accuracy pruned: 44.51% Finetuned: 76.51%
[pruned model] Pruning ratio: 0.15, MACs: 0.40 G, Parameters: 8.01 M, Latency: 16.40 ± 0.04 ms Accuracy pruned: 66.98% Finetuned: 75.18%
[pruned model] Pruning ratio: 0.20, MACs: 0.35 G, Parameters: 7.09 M, Latency: 16.33 ± 0.04 ms Accuracy pruned: 51.83% Finetuned: 74.64%
[pruned model] Pruning ratio: 0.25, MACs: 0.31 G, Parameters: 6.29 M, Latency: 14.40 ± 0.05 ms Accuracy pruned: 63.51% Finetuned: 76.73%
[pruned model] Pruning ratio: 0.30, MACs: 0.27 G, Parameters: 5.44 M, Latency: 14.07 ± 0.03 ms Accuracy pruned: 49.36% Finetuned: 74.64%
[pruned model] Pruning ratio: 0.35, MACs: 0.23 G, Parameters: 4.69 M, Latency: 12.27 ± 0.03 ms Accuracy pruned: 58.74% Finetuned: 77.56%
[pruned model] Pruning ratio: 0.40, MACs: 0.20 G, Parameters: 3.98 M, Latency: 12.28 ± 0.03 ms Accuracy pruned: 63.98% Finetuned: 78.29%
[pruned model] Pruning ratio: 0.45, MACs: 0.16 G, Parameters: 3.34 M, Latency: 11.41 ± 0.02 ms Accuracy pruned: 45.66% Finetuned: 78.58%
[pruned model] Pruning ratio: 0.50, MACs: 0.14 G, Parameters: 2.80 M, Latency: 7.06 ± 0.03 ms Accuracy pruned: 49.91% Finetuned: 72.77%
[pruned model] Pruning ratio: 0.55, MACs: 0.11 G, Parameters: 2.24 M, Latency: 6.82 ± 0.05 ms Accuracy pruned: 38.72% Finetuned: 76.13%
[pruned model] Pruning ratio: 0.60, MACs: 0.09 G, Parameters: 1.77 M, Latency: 5.96 ± 0.05 ms Accuracy pruned: 42.84% Finetuned: 79.09%
[pruned model] Pruning ratio: 0.65, MACs: 0.07 G, Parameters: 1.34 M, Latency: 4.88 ± 0.09 ms Accuracy pruned: 33.88% Finetuned: 75.54%
[pruned model] Pruning ratio: 0.70, MACs: 0.05 G, Parameters: 0.99 M, Latency: 4.17 ± 0.01 ms Accuracy pruned: 22.50% Finetuned: 75.60%
[pruned model] Pruning ratio: 0.75, MACs: 0.04 G, Parameters: 0.70 M, Latency: 2.96 ± 0.08 ms Accuracy pruned: 34.23% Finetuned: 78.91%
[pruned model] Pruning ratio: 0.80, MACs: 0.02 G, Parameters: 0.44 M, Latency: 2.70 ± 0.02 ms Accuracy pruned: 15.91% Finetuned: 75.55%
[pruned model] Pruning ratio: 0.85, MACs: 0.01 G, Parameters: 0.25 M, Latency: 2.69 ± 0.04 ms Accuracy pruned: 14.16% Finetuned: 75.01%
[pruned model] Pruning ratio: 0.90, MACs: 0.01 G, Parameters: 0.11 M, Latency: 2.63 ± 0.01 ms Accuracy pruned: 10.00% Finetuned: 68.87%
[pruned model] Pruning ratio: 0.95, MACs: 0.00 G, Parameters: 0.03 M, Latency: 2.59 ± 0.02 ms Accuracy pruned: 10.00% Finetuned: 53.36%
[pruned model] Pruning ratio: 1.00, MACs: 0.00 G, Parameters: 0.03 M, Latency: 2.57 ± 0.01 ms Accuracy pruned: 53.36% Finetuned: 54.91%
Note that one of the final models achives same accuracy (even higher, 78.91%) while having 15x less parameters (0.7M vs. 11.17M), and is 5.5x faster than original (2.96 ms vs. 16.52 ms).
Summary
Pruning is a powerful technique to make deep networks lighter and faster. In this blog post, we:
- Explored what pruning is and why it matters
- Compared the native PyTorch pruning API with Torch-Pruning
- Used
torch-pruning
to prune a ResNet-18 model in PyTorch - Evaluated model size, inference latency, and top-1 prediction accuracy using CIFAR-10 data
By applying structured pruning, you can make your models more efficient with minimal impact on performance, a valuable step in any model optimization workflow.