In this section we’ll define a very simple function and analyze what’s happening under the hood.
import torch
import math
import os
import matplotlib.pyplot as plt
from torch import optim
import torch._dynamo
from torchvision import models
from torch.profiler import profile, record_function, ProfilerActivity
pi = math.pi
device = torch.device("cuda") if torch.cuda.is_available() else "cpu"
Define a simple function sin^2(x) + cos^2(x), as you know this equals 1 for all real values of x
def fn(x):
return torch.sin(x)**2 + torch.cos(x)**2
Now let’s define a 1million x 1 tensor to pass to our function
torch.manual_seed(0)
x = torch.rand(1000000, requires_grad=True).to(device)
out = fn(x)
torch.linalg.norm(out-1) <= 1e-4
Sure enough you should see the following output:
tensor(True, device='cuda:0')
Now let’s write a simple function to intercept what the compiler sees. To do that we create a fake compiler and pass it to torch.compile()
. We do this to take a look at the next intermediate code representation from our high level code
torch.manual_seed(0)
x = torch.rand(1000000, requires_grad=True).to(device)
def inspect_backend(gm, sample_inputs):
gm.print_readable()
return gm.forward
torch._dynamo.reset()
compiled_model = torch.compile(fn, backend=inspect_backend)
out = compiled_model(x)
Otuput:
import torch._dynamo
from torch.fx.passes.graph_drawer import FxGraphDrawer
from torch._functorch.aot_autograd import aot_module_simplified
def inspect_backend(gm, sample_inputs):
def fw(gm, sample_inputs):
gm.print_readable()
g = FxGraphDrawer(gm, 'fn')
with open("forward.svg", "wb") as f:
f.write(g.get_dot_graph().create_svg())
return gm.forward
def bw(gm, sample_inputs):
gm.print_readable()
g = FxGraphDrawer(gm, 'fn')
with open("backward.svg", "wb") as f:
f.write(g.get_dot_graph().create_svg())
return gm.forward
# Invoke AOTAutograd
return aot_module_simplified(
gm,
sample_inputs,
fw_compiler=fw,
bw_compiler=bw
)
torch._dynamo.reset()
compiled_model = torch.compile(fn, backend=inspect_backend)
out = compiled_model(x).sum().backward()
Output:
![]() | ![]() |
---|
torch._dynamo.reset()
compiled_model = torch.compile(fn, backend='inductor',
options={'trace.enabled':True,
'trace.graph_diagram':True})
out = compiled_model(x).sum().backward()