torch.compile()
: A simple exampleLoad necessary modules, we’ll discuss the strange torch._dynamo
bit a little later.
import torch
import torch.nn as nn
from torchvision.models import resnet
import torch._dynamo
device = torch.device("cuda") if torch.cuda.is_available() else "cpu"
Define a very simple NN and compile it with torch.compile()
and compare the results with your model to convince yourself that it won’t break your code.
class MLP(nn.Module):
def __init__(self):
super().__init__()
self.fc1 = nn.Linear(32, 64)
def forward(self, x):
x = self.fc1(x)
x = torch.nn.functional.relu(x)
return x
model = MLP()
input = torch.randn(8, 32)
torch._dynamo.reset() # Only needed if you call this cell repeatedly
compiled_model = torch.compile(model)
# Alternatively you can also pass the backend
compiled_model = torch.compile(model, backend='inductor')
output = model(input)
# triggers compilation of forward graph on the first run
output_compiled = compiled_model(input)
torch.all(output == output_compiled)
Here is the output, congratulations! you ran you first torch.compile()
compilation task.
Now that you’re convinced that the compilation doesn’t change the output, let’s proceed with benchmarking!
tensor(True)