Debug School

rakesh kumar
rakesh kumar

Posted on • Edited on

Different ways to clearing autograd in neural network

Clearing / zeroing gradients (what “clear grad” usually means)
Stopping gradients / turning off autograd (what your screenshot lists)
requires_grad_(False)
tensor.detach():
with torch.no_grad()
optimizer.zero_grad()

10 practical areas / use-cases where torch.no_grad() is required (or strongly recommended):

Clearing / zeroing gradients (what “clear grad” usually means)

In PyTorch, gradients accumulate into param.grad by default: each .backward() adds to the buffer.

If you don’t clear them before the next backward, you will effectively sum gradients from multiple batches → unstable updates.

Ways to clear:

optimizer.zero_grad(set_to_none=True) (recommended; sets param.grad=None, saving memory & a tiny compute)

model.zero_grad() (older common form; zeros tensors in-place)

Manually: for a single tensor p.grad = None or p.grad.zero_()

When needed: every iteration of your training loop before calling loss.backward().

Stopping gradients / turning off autograd (what your screenshot lists)

These do not “clear grads”; they prevent gradient tracking/propagation.

requires_grad_(False): permanently (until flipped back) marks a tensor/parameter as frozen; no grads will be computed for it.

Use when freezing layers (transfer learning) or for fixed embeddings, etc.

tensor.detach(): returns a view of the tensor without history; backprop stops at that boundary.

Use to truncate the graph (e.g., stop gradient into an input, copy targets from model outputs for bootstrapping, avoid accidental second-order graphs).

with torch.no_grad():: temporarily disables autograd inside the block; no graph is built, ops are faster & cheaper.

Use for inference, EMA updates, or metric computation.

Bonus: with torch.inference_mode(): even leaner than no_grad for inference (more aggressive), but cannot be nested with grad-enabled ops.

requires_grad_(False)

This means: "Don’t compute derivative for this variable anymore."

detach()

This means: “Make a new tensor that looks the same, but breaks the gradient chain.”

torch.no_grad()

This means: “Turn off autograd temporarily.”

Use case:

During inference (testing), where we don’t need training.

It saves memory and computation time.

Correct training loop: clear grads each step

import torch
import torch.nn as nn
import torch.optim as optim

model = nn.Linear(10, 1)
opt = optim.Adam(model.parameters(), lr=1e-3)
loss_fn = nn.BCEWithLogitsLoss()

for xb, yb in loader:                      # your DataLoader
    opt.zero_grad(set_to_none=True)        # ← CLEAR accumulated grads (A)
    logits = model(xb)                     # forward (graph is recorded)
    loss = loss_fn(logits, yb.float())
    loss.backward()                        # compute grads
    opt.step()                             # update params

Alternatives to clear grads
# option 1: model.zero_grad()    # zeros existing tensors (older style)
# option 2: for p in model.parameters(): p.grad = None
# option 3: for p in model.parameters():
#              if p.grad is not None: p.grad.zero_()
Enter fullscreen mode Exit fullscreen mode

2) Freeze some layers (no grads created)

# Freeze feature extractor; train only classifier head
for p in backbone.parameters():
    p.requires_grad_(False)                # ← stop creating grads for these params (B)

# Only the head’s params will appear in optimizer
opt = optim.SGD(head.parameters(), lr=1e-2)
Enter fullscreen mode Exit fullscreen mode

3) Detach to stop gradient flow through a tensor

out = model(x)                 # requires grad
target = out.detach()          # ← gradient will not flow into 'out' via 'target'

loss = ((out - target)**2).mean()  # trivial example (bootstrapping style)
loss.backward()               # grads flow to 'out' path only, not to 'target'


Another common use—truncate BPTT in RNNs:

h = torch.zeros(batch, hidden)
for t in range(T):
    h = rnn_cell(x[:, t], h)
    if (t + 1) % trunc_k == 0:
        h = h.detach()        # ← cut graph to limit backprop length
Enter fullscreen mode Exit fullscreen mode

4) no_grad / inference_mode for evaluation & EMA

# Evaluation: faster & memory-saving
model.eval()
with torch.no_grad():                     # ← no graph built inside
    for xb, yb in val_loader:
        preds = torch.sigmoid(model(xb))
        # compute metrics…

# EMA update: parameter averaging without tracking grads
with torch.no_grad():
    for p, p_ema in zip(model.parameters(), ema_model.parameters()):
        p_ema.mul_(0.999).add_(p, alpha=0.001)
Enter fullscreen mode Exit fullscreen mode

10 practical areas / use-cases where torch.no_grad() is required (or strongly recommended):

  1. Model Evaluation / Inference

When you’re running your trained model on validation or test data:

model.eval()
with torch.no_grad():
    preds = model(x_val)
Enter fullscreen mode Exit fullscreen mode

➡ Prevents storing intermediate gradients.

🧪 2. Validation Loop (During Training)

Inside your training epoch, when evaluating after each batch/epoch:

if step % eval_interval == 0:
    with torch.no_grad():
        val_loss = criterion(model(x_val), y_val)
Enter fullscreen mode Exit fullscreen mode

➡ Keeps evaluation memory-efficient.

📊 3. Computing Accuracy or Metrics

When calculating accuracy, F1-score, confusion matrix, etc.:

with torch.no_grad():
    outputs = model(x_test)
    acc = (outputs.argmax(1) == y_test).float().mean()
Enter fullscreen mode Exit fullscreen mode

🧍 4. Serving / Deployment (Prediction API)

In Flask/FastAPI or other inference servers:

@app.post("/predict")
def predict(data):
    with torch.no_grad():
        result = model(data)
Enter fullscreen mode Exit fullscreen mode

➡ Prevents unnecessary gradient tracking in production.

🔁 5. Feature Extraction / Embedding Generation

Using pretrained networks (e.g., ResNet, BERT) for feature extraction:

with torch.no_grad():
    features = backbone(imgs)
Enter fullscreen mode Exit fullscreen mode

🧩 6. Transfer Learning (Frozen Layers)

When you freeze base layers and train only a new head:

with torch.no_grad():
    base_output = base_model(x)
Enter fullscreen mode Exit fullscreen mode

🎯 7. Teacher-Student / Knowledge Distillation (Teacher Forward Pass)

Teacher model shouldn’t track gradients:

with torch.no_grad():
    teacher_output = teacher_model(x)
Enter fullscreen mode Exit fullscreen mode

📷 8. Image Generation or Translation

When using GAN generators for inference:

with torch.no_grad():
    fake_images = generator(noise)
Enter fullscreen mode Exit fullscreen mode

📈 9. Validation in Reinforcement Learning / Simulation

During agent evaluation (no learning updates):

with torch.no_grad():
    action = policy_net(state)
Enter fullscreen mode Exit fullscreen mode

🧮 10. Post-processing Outputs (e.g., Softmax, Argmax, Thresholding)

When computing softmax or argmax for predictions:

with torch.no_grad():
    probs = torch.softmax(logits, dim=1)
Enter fullscreen mode Exit fullscreen mode

Real Life Clear Usage Examples

1) Training Loop (Use optimizer.zero_grad())

Gradients accumulate by default in PyTorch. So we clear them before computing new gradients.

for images, labels in train_loader:

    optimizer.zero_grad()        # 1. Clear previous gradients

    outputs = model(images)      # 2. Forward pass
    loss = criterion(outputs, labels)

    loss.backward()              # 3. Backprop compute gradients
    optimizer.step()             # 4. Update model weights
Enter fullscreen mode Exit fullscreen mode

2) Transfer Learning (Use requires_grad_(False))

Freeze pretrained backbone, train only classifier.

# Freeze convolution layers
for param in model.features.parameters():
    param.requires_grad_(False)

# Only train classifier part
optimizer = torch.optim.Adam(model.classifier.parameters(), lr=1e-3)
Enter fullscreen mode Exit fullscreen mode

3) Stop Gradient Flow in Middle (Use tensor.detach())

Useful in GANs or when reusing feature outputs.

features = model.encoder(images)
features_no_grad = features.detach()   # No gradient flows into encoder

# Use features for something else
result = another_network(features_no_grad)
Enter fullscreen mode Exit fullscreen mode

4) Evaluation (Use with torch.no_grad())

We don’t train during inference, so no need to compute gradients.

model.eval()        # Set dropout/batchnorm to eval mode

with torch.no_grad():
    for images, labels in val_loader:
        outputs = model(images)
        # compute accuracy or metrics
Enter fullscreen mode Exit fullscreen mode

This reduces memory usage and speeds up inference.

colab.research

Top comments (0)