Preface: this notebook is part 1 in a series of tutorials discussing gradients (manipulation, stopping, etc.) in PyTorch. The series of tutorials cover the following network architectures:

1) Single-headed simple architecture

2) Single-headed complex architecture

3) Multi-headed architecture

The notebook for this tutorial can be found Google Colab gradient_flow_1

Note For the purpose of this discussion, we define a module to be a single layer or a collection of layers in a neural network.

1) Motivation

The motivation behind this post was 3-fold:

1.1) Familiarizing myself with PyTorch.

PyTorch is easy to prototype in, but I don’t fully understand the PyTorch computation graph.

1.2) Playing with gradient stopping and propagating.

This is useful if we have a frozen layer that we want to avoid training. This problem is simple if we have a simple module that looks like

Simple module)

but what happens if the shared module is an intermediate component of the model?

Complicated Module

1.3) Working on some funky graph shapes.

What happens if we have a network that looks like

DDPG

NOTE: Image sourced from IntelLabs: DDPG

where we have two primary modules: the actor and the critic. We see that the critic (the bottom module) accepts the actor’s output, but unless we stop the gradient flow, the computation graph will backpropagate critic updates through the actor, which we do not want.

2) Contents:

We focus on 5 methods that we categorize into High-Level where we use built-in methods and Low-Level where we manually access the gradients.

2.1) High-Level

All the methods listed below are only pertinent to stopping gradients.

returns a copied tensor of the same values and properties but detached from the graph. The original is persisted.

which is a context manager that disables gradient calculation. This method sets all variables created inside its scope to have requires_grad to be False.

which stops gradients entirely downstream, as well as upstream. This is a relatively new method (Sept 14, 20221), so it would be worth discussing.

2.2) Low-Level

Since we have direct access to the gradients, we can not only stop gradients but also manipulate them based on our needs.

2.3) eval misconception

When I first started using PyTorch, I incorrectly assumed that eval would:

but this is not the case regarding turning off the computation graph.

2.3) Making the right choice

At the end of the day, each of the methods above comes with various tradeoffs. We will discuss those tradeoffs below, but ultimately you will have to decide what is best for your application.

3) Problem setup

We have the following graph:

Simple Graph

where we want to only update the network’s output head (L2). What are the various ways we can accomplish this?

I highly recommend having the colab notebook open as you work through this. I made it a point to plot the resulting computation graph for each setting, making it easier to understand what is happening.

4) High-Level

4.1) detach

detach detaches upstream values from the graph, so we only calculate the gradient backward up to the first detach. Our current graph setup is too simple to illustrate this phenomenon, but the computation graph in the follow-up post will work well.

4.1.1) Observations

Notice 2 things from the cells:

4.1.1) Usecase

torch tensors keep track of data such as the computation graph. We drop the computation graph of all upstream operations up to the current variable by detaching these tensors.

Trying to convert directly to numpy errors out (rightfully so) because numpy does not keep track of the computation graph. It is safer to have a clear distinction between numpy arrays and torch tensors.

import torch as T
a = T.tensor(1.0, requires_grad=True)
b = a + a
b.numpy()

4.2) no_grad

4.2.1) no_grad in action

It can be used as such:

#!pip install -q torchviz
import torch as T
from torchviz import make_dot

# Requires grad = True to construct graph
x = T.ones(10, requires_grad=True)  
with T.no_grad():
	pass
y = x ** 2
z = x ** 3
r = (y + z).sum()

make_dot(
    r, 
    params={"y": y, "z": z, "r": r, "x": x},
    show_attrs=True
)

Uncomment the first line if you do not already have torchviz. Then, play around with moving y or z into the T.no_grad() context.

4.2.2) Observations

4.2.3) Usecase

no_grad tells PyTorch to not track all operations within the context, which means that the computation graph is not created.

Furthermore, no_grad is faster than detach as detach returns a copy of the input tensor (just without the computation graph). By comparison, no_grad does not persist the computation graph of variables within its scope.

Keeping both the torch tensor and numpy array around might not be your intention, and you might accidentally operate on the wrong variable.

4.3) inference

4.3.1) Observations

We discuss two observations for this code section:

Cell 1: without_grad

Viewing the computation graph, we see that no values are tracked (hence an empty singular block)

Solution If we want to allow downstream calculations that themselves are not in inference mode, we must make a clone of the tensor. We display the relevant sections of this in section 4.3.2) Relevant code

Cell 2: with_grad

We see this method produced the same computation graph as in the detach and no_grad settings. Like no_grad, inference() is a context manager. In no_grad and detach, upstream values were not tracked in the computation graph; in inference, even downstream values are not tracked.

*Pytorch CPP Inference mode docs

4.3.2) Relevant Code

We generated the two graphs by following the setup from this official Twitter post in mind about

def _inference_forward(self, X):
  # First var is a inferenced-var
  with T.inference_mode():
    tmp = self.l1(X)
  try:
    # Try to do a non-inference forward pass
    return self.l2(tmp)
  except Exception:
    print(f"Trying to use intermediate inference_mode tensor outside inference_mode context manager")
    
    # Getting pure-inference
    with T.inference_mode():
      grad_disabled = self.l2(tmp)
    # Convert inferenced-var and allow us to
    # do a normal forward pass
    new_tmp = T.clone(tmp)
    grad_enabled = self.l2(new_tmp)
    return grad_disabled, grad_enabled

4.3.3) Usecase

Gradient Propagation It is possible to use this method to stop gradients, but there are easier ways to accomplish this.

Inference Speed While no_grad stops operation tracking, inference disables two other autograd features: version counting and metadata tracking.

5) Low-Level

In the following methods, we work directly with the computed gradients instead of detaching variables or telling PyTorch to ignore blocks. This low-level manipulation is helpful if we want to make complex modifications to our gradients (it won’t be relevant here, but it is worth mentioning ahead of time).

Furthermore, whereas the methods in the High-Level section stopped all gradients from flowing upstream, both of the Low-Level methods allow us to skip modules.

Note: The gradients are stored in the model parameters when we call loss.backward. The only thing our optimizer.step call does is apply the gradients. This means that using the optimizer method is more or less equivalent to the manual manipulation method.

5.1) Common observations

Unlike the resulting computation graphs in the High-Level section, we see that all variables here are tracked:

These methods can consume far more memory as the entire computation graph has to be computed.

5.2) optim.Optimizer

We modify our optimizer such that instead of doing something like optim.SomeOptimizer(model.parameters()),

we instead do optim.SomeOptimizer(model.l2.parameters()) which tells our optimizer to only apply gradients for the L2 parameters.

5.2.1) Usecase

As in the above methods, we can “freeze” a layer by using this method.

We can specify per-module hyperparameters

However, we do not have fine-grained control.

5.3) Manual manipulation

While the above section had the optimizer apply our gradients, we manually apply the gradient here.

5.3.1) Usecase

The only use-case I see for this method over every other method listed above is custom gradient applications. For example, if you wanted to zero out gradients every other step or scale the gradients if certain conditions are met.

6) Conclusion

6.1) Gradient Stopping

The “simple” methods are a lot easier to pull off and should be preferred if all you need to do is stop gradients from flowing upstream.

My recommendations are to use no_grad wherever possible as it is faster than detach. As with most style preferences, this is subjective, but I feel that no_grad is also better because it is clear that you are excluding a block of computations that will be used further down. When you detach a variable, you now have the torch tensor version, as well as the numpy array.

I recommend avoiding inference for gradient manipulation unless you’re absolutely sure that you have a good reason. I do not see a scenario where you might prefer doing inference and then copying the variable when you can use no_grad directly.

6.2) Gradient Manipulation

If possible, use the optimizer approach as there’s less room for error. However, the Manual manipulation method is ideal if you need to apply custom operations.

One such use-case for manual manipulation is to scale only particular layers if specific conditions are met or if you want to zero out gradients every other step.

7) Thank you!

Thank you for taking the time to read this! If you ever want to contact me feel free to email me at firstname@website URL. You can also reach me on Linkedin: ianq, but if we don’t know each other, either attach a note to your invitation or send me an email along with the invitation. I tend to ignore requests otherwise.