A simple example

How to use torch.compile() : A simple example

Load necessary modules, we’ll discuss the strange torch._dynamo bit a little later.

import torch
import torch.nn as nn
from torchvision.models import resnet
import torch._dynamo

device = torch.device("cuda") if torch.cuda.is_available() else "cpu"

Define a very simple NN and compile it with torch.compile() and compare the results with your model to convince yourself that it won’t break your code.

class MLP(nn.Module):
    def __init__(self):
        super().__init__()
        self.fc1 = nn.Linear(32, 64)

    def forward(self, x):
        x = self.fc1(x)
        x = torch.nn.functional.relu(x)
        return x

model = MLP()
input = torch.randn(8, 32)

torch._dynamo.reset() # Only needed if you call this cell repeatedly
compiled_model = torch.compile(model)

# Alternatively you can also pass the backend
compiled_model = torch.compile(model, backend='inductor')

output = model(input)
# triggers compilation of forward graph on the first run
output_compiled = compiled_model(input)

torch.all(output == output_compiled)

Here is the output, congratulations! you ran you first torch.compile() compilation task. Now that you’re convinced that the compilation doesn’t change the output, let’s proceed with benchmarking!

    tensor(True)