Debug School

rakesh kumar
rakesh kumar

Posted on

Different ways to clearing autograd in neural network

Clearing (zeroing) accumulated gradients on parameters

Stopping gradients from being created/propagated through some tensors

Clearing / zeroing gradients (what “clear grad” usually means)

In PyTorch, gradients accumulate into param.grad by default: each .backward() adds to the buffer.

If you don’t clear them before the next backward, you will effectively sum gradients from multiple batches → unstable updates.

Ways to clear:

optimizer.zero_grad(set_to_none=True) (recommended; sets param.grad=None, saving memory & a tiny compute)

model.zero_grad() (older common form; zeros tensors in-place)

Manually: for a single tensor p.grad = None or p.grad.zero_()

When needed: every iteration of your training loop before calling loss.backward().

Stopping gradients / turning off autograd (what your screenshot lists)

These do not “clear grads”; they prevent gradient tracking/propagation.

requires_grad_(False): permanently (until flipped back) marks a tensor/parameter as frozen; no grads will be computed for it.

Use when freezing layers (transfer learning) or for fixed embeddings, etc.

tensor.detach(): returns a view of the tensor without history; backprop stops at that boundary.

Use to truncate the graph (e.g., stop gradient into an input, copy targets from model outputs for bootstrapping, avoid accidental second-order graphs).

with torch.no_grad():: temporarily disables autograd inside the block; no graph is built, ops are faster & cheaper.

Use for inference, EMA updates, or metric computation.

Bonus: with torch.inference_mode(): even leaner than no_grad for inference (more aggressive), but cannot be nested with grad-enabled ops.

requires_grad_(False)

This means: "Don’t compute derivative for this variable anymore."

detach()

This means: “Make a new tensor that looks the same, but breaks the gradient chain.”

torch.no_grad()

This means: “Turn off autograd temporarily.”

Use case:

During inference (testing), where we don’t need training.

It saves memory and computation time.

Correct training loop: clear grads each step

import torch
import torch.nn as nn
import torch.optim as optim

model = nn.Linear(10, 1)
opt = optim.Adam(model.parameters(), lr=1e-3)
loss_fn = nn.BCEWithLogitsLoss()

for xb, yb in loader:                      # your DataLoader
    opt.zero_grad(set_to_none=True)        # ← CLEAR accumulated grads (A)
    logits = model(xb)                     # forward (graph is recorded)
    loss = loss_fn(logits, yb.float())
    loss.backward()                        # compute grads
    opt.step()                             # update params

Alternatives to clear grads
# option 1: model.zero_grad()    # zeros existing tensors (older style)
# option 2: for p in model.parameters(): p.grad = None
# option 3: for p in model.parameters():
#              if p.grad is not None: p.grad.zero_()
Enter fullscreen mode Exit fullscreen mode

2) Freeze some layers (no grads created)

# Freeze feature extractor; train only classifier head
for p in backbone.parameters():
    p.requires_grad_(False)                # ← stop creating grads for these params (B)

# Only the head’s params will appear in optimizer
opt = optim.SGD(head.parameters(), lr=1e-2)
Enter fullscreen mode Exit fullscreen mode

3) Detach to stop gradient flow through a tensor

out = model(x)                 # requires grad
target = out.detach()          # ← gradient will not flow into 'out' via 'target'

loss = ((out - target)**2).mean()  # trivial example (bootstrapping style)
loss.backward()               # grads flow to 'out' path only, not to 'target'


Another common use—truncate BPTT in RNNs:

h = torch.zeros(batch, hidden)
for t in range(T):
    h = rnn_cell(x[:, t], h)
    if (t + 1) % trunc_k == 0:
        h = h.detach()        # ← cut graph to limit backprop length
Enter fullscreen mode Exit fullscreen mode

4) no_grad / inference_mode for evaluation & EMA

# Evaluation: faster & memory-saving
model.eval()
with torch.no_grad():                     # ← no graph built inside
    for xb, yb in val_loader:
        preds = torch.sigmoid(model(xb))
        # compute metrics…

# EMA update: parameter averaging without tracking grads
with torch.no_grad():
    for p, p_ema in zip(model.parameters(), ema_model.parameters()):
        p_ema.mul_(0.999).add_(p, alpha=0.001)
Enter fullscreen mode Exit fullscreen mode

colab.research

Top comments (0)