Inspecting torch.compile()

In this section we’ll define a very simple function and analyze what’s happening under the hood.

import torch
import math
import os
import matplotlib.pyplot as plt
from torch import optim
import torch._dynamo
from torchvision import models
from torch.profiler import profile, record_function, ProfilerActivity

pi = math.pi
device = torch.device("cuda") if torch.cuda.is_available() else "cpu"

Define a simple function sin^2(x) + cos^2(x), as you know this equals 1 for all real values of x

def fn(x):
    return torch.sin(x)**2 + torch.cos(x)**2

Now let’s define a 1million x 1 tensor to pass to our function

torch.manual_seed(0)
x = torch.rand(1000000, requires_grad=True).to(device)

out = fn(x)
torch.linalg.norm(out-1) <= 1e-4

Sure enough you should see the following output:

tensor(True, device='cuda:0')

Now let’s write a simple function to intercept what the compiler sees. To do that we create a fake compiler and pass it to torch.compile(). We do this to take a look at the next intermediate code representation from our high level code

torch.manual_seed(0)
x = torch.rand(1000000, requires_grad=True).to(device)

def inspect_backend(gm, sample_inputs):
    gm.print_readable()
    return gm.forward

torch._dynamo.reset()
compiled_model = torch.compile(fn, backend=inspect_backend)

out = compiled_model(x)

Otuput:

import torch._dynamo
from torch.fx.passes.graph_drawer import FxGraphDrawer
from torch._functorch.aot_autograd import aot_module_simplified

def inspect_backend(gm, sample_inputs): 
    def fw(gm, sample_inputs):
        gm.print_readable()
        g = FxGraphDrawer(gm, 'fn')
        with open("forward.svg", "wb") as f:
            f.write(g.get_dot_graph().create_svg())
        return gm.forward
    
    def bw(gm, sample_inputs):
        gm.print_readable()
        g = FxGraphDrawer(gm, 'fn')
        with open("backward.svg", "wb") as f:
            f.write(g.get_dot_graph().create_svg())
        return gm.forward

    # Invoke AOTAutograd
    return aot_module_simplified(
        gm,
        sample_inputs,
        fw_compiler=fw,
        bw_compiler=bw
    )

torch._dynamo.reset()
compiled_model = torch.compile(fn, backend=inspect_backend)

out = compiled_model(x).sum().backward()

Output:

torch._dynamo.reset()
compiled_model = torch.compile(fn, backend='inductor',
                              options={'trace.enabled':True,
                                      'trace.graph_diagram':True})

out = compiled_model(x).sum().backward()