Debug School

rakesh kumar
rakesh kumar

Posted on

How to Make Machine Learning Models Faster and Lighter

Pruning

= delete (or zero-out) parts of a model that contribute little → fewer FLOPs, fewer parameters, faster inference.

Quantization

= store & compute with fewer bits (e.g., 8-bit instead of 32-bit floats) → smaller memory, higher cache hits, faster CPU ops.

Both aim to fit tight edge budgets (CPU-only, small RAM) while keeping accuracy good enough for real-time control.

Where they act in a network

Two flavors of pruning

Quantization pipeline (typical PTQ)

PTQ (Post-Training Quantization): No training required. Fastest path to int8.

QAT (Quantization-Aware Training): Train with fake-quant modules → better accuracy at int8, especially for CNNs.

Why robots care

Latency (ms-level) & consistency (low jitter) matter for control loops.

Size matters (RAM/flash budgets).

Robustness to noise: simpler models + calibrated quantization + structured pruning → fewer surprises.

Metrics to watch

Accuracy (task metric)

Latency (avg & p95/p99)

Model size (MB)

Sparsity (% zeros) & MACs/FLOPs

Energy (optional but relevant on battery)

ML EXAMPLES (scikit-learn)

We’ll show:

Decision Tree pruning (cost-complexity)

L1 “pruning” of linear/logistic models (drives coefficients to zero)

Note: Classic scikit-learn doesn’t do int8 quantization of models end-to-end; for edge you often export to ONNX + use runtimes that quantize, or you choose small models + pruning/L1.

1) Decision Tree – cost-complexity pruning

# ml_pruning_tree.py
import time
import numpy as np
from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split
from sklearn.tree import DecisionTreeClassifier
from sklearn.tree import plot_tree  # optional (for visualization)
from sklearn.metrics import accuracy_score

X, y = load_iris(return_X_y=True)
Xtr, Xte, ytr, yte = train_test_split(X, y, test_size=0.2, random_state=42)

# Baseline (unpruned)
base = DecisionTreeClassifier(random_state=42)
base.fit(Xtr, ytr)

# Cost complexity pruning path gives candidate ccp_alphas
path = base.cost_complexity_pruning_path(Xtr, ytr)
ccp_alphas = path.ccp_alphas

best = None
best_stats = None

for ccp in ccp_alphas:
    clf = DecisionTreeClassifier(random_state=42, ccp_alpha=ccp)
    clf.fit(Xtr, ytr)
    # measure latency per sample (simple timing)
    t0 = time.time()
    ypred = clf.predict(Xte)
    t1 = time.time()
    acc = accuracy_score(yte, ypred)
    latency_ms = (t1 - t0) / len(Xte) * 1000.0
    n_nodes = clf.tree_.node_count
    stats = (acc, latency_ms, n_nodes, ccp)
    if best is None or (acc > best_stats[0]) or (acc == best_stats[0] and latency_ms < best_stats[1]):
        best, best_stats = clf, stats

print(f"Best pruned tree:")
print(f"  accuracy = {best_stats[0]:.3f}")
print(f"  latency  = {best_stats[1]:.3f} ms/sample")
print(f"  nodes    = {best_stats[2]} (ccp_alpha={best_stats[3]:.6f})")


What this gives: a smaller tree → fewer branches → faster, more stable inference on CPU.
Enter fullscreen mode Exit fullscreen mode

2) L1 “pruning” (sparse coefficients)

# ml_pruning_l1.py
import time
import numpy as np
from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import accuracy_score

X, y = load_iris(return_X_y=True)
Xtr, Xte, ytr, yte = train_test_split(X, y, test_size=0.2, random_state=42)

# L1 penalty drives many weights to zero (like pruning)
clf = LogisticRegression(penalty='l1', solver='saga', C=0.5, max_iter=500, multi_class='auto')
clf.fit(Xtr, ytr)

t0 = time.time()
yp = clf.predict(Xte)
t1 = time.time()

acc = accuracy_score(yte, yp)
latency_ms = (t1 - t0) / len(Xte) * 1000.0
sparsity = np.mean(clf.coef_ == 0.0)

print(f"Accuracy      : {acc:.3f}")
print(f"Latency       : {latency_ms:.3f} ms/sample")
print(f"Weight zeros  : {sparsity*100:.1f}%")


Enter fullscreen mode Exit fullscreen mode

Takeaway: Smaller effective feature set → cache-friendly, consistent latency.

DL EXAMPLES (PyTorch)

We’ll show:

Unstructured & structured pruning via torch.nn.utils.prune

Dynamic PTQ (int8) for Linear/LSTM

Static PTQ (FX graph mode) for a tiny CNN (with calibration)

QAT sketch for best accuracy at int8

These run on CPU and illustrate what to change. Replace synthetic data with your sensor features.

Helper: latency + size + sparsity utilities
Enter fullscreen mode Exit fullscreen mode
# utils_perf.py
import time
import torch
import os

def measure_latency_ms(model, inp, n_warm=10, n_runs=50):
    model.eval()
    with torch.no_grad():
        for _ in range(n_warm):
            _ = model(inp)
        t0 = time.time()
        for _ in range(n_runs):
            _ = model(inp)
        t1 = time.time()
    return (t1 - t0) / n_runs * 1000.0

def count_nonzero_params(model):
    nz = 0
    tot = 0
    for p in model.parameters():
        tot += p.numel()
        nz += (p != 0).sum().item()
    return nz, tot, 1.0 - nz / tot

def save_size_mb(model, path="temp.pth"):
    torch.save(model.state_dict(), path)
    sz = os.path.getsize(path) / (1024*1024)
    os.remove(path)
    return sz
Enter fullscreen mode Exit fullscreen mode

A. PRUNING (PyTorch)
A1) Unstructured pruning (magnitude)

# dl_prune_unstructured.py
import torch
import torch.nn as nn
import torch.nn.utils.prune as prune
from utils_perf import measure_latency_ms, count_nonzero_params, save_size_mb

# Tiny MLP for demonstration
class MLP(nn.Module):
    def __init__(self, d_in=64, d_hidden=128, d_out=10):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(d_in, d_hidden),
            nn.ReLU(),
            nn.Linear(d_hidden, d_out)
        )
    def forward(self, x):
        return self.net(x)

model = MLP()
inp = torch.randn(1, 64)

lat0 = measure_latency_ms(model, inp)
nz0, tot0, sparsity0 = count_nonzero_params(model)
size0 = save_size_mb(model)

# Apply 80% unstructured pruning on Linear weights
for m in model.modules():
    if isinstance(m, nn.Linear):
        prune.l1_unstructured(m, name="weight", amount=0.8)

# Remove reparametrization (make pruning permanent)
for m in model.modules():
    if isinstance(m, nn.Linear) and hasattr(m, "weight_orig"):
        prune.remove(m, "weight")

lat1 = measure_latency_ms(model, inp)
nz1, tot1, sparsity1 = count_nonzero_params(model)
size1 = save_size_mb(model)

print(f"Latency (ms)   before/after: {lat0:.3f} / {lat1:.3f}")
print(f"Sparsity       before/after: {sparsity0*100:.1f}% / {sparsity1*100:.1f}%")
print(f"Model size (MB)before/after: {size0:.3f} / {size1:.3f}")

Enter fullscreen mode Exit fullscreen mode

Note: Unstructured zeros may not speed up on vanilla kernels; you get memory benefits and speedups only if using sparse backends. For deterministic speed on CPU, prefer structured pruning.

A2) Structured channel pruning (conv channels)

# dl_prune_structured.py
import torch
import torch.nn as nn
import torch.nn.utils.prune as prune
from utils_perf import measure_latency_ms, count_nonzero_params, save_size_mb

class TinyCNN(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(3, 16, 3, padding=1)
        self.relu = nn.ReLU()
        self.conv2 = nn.Conv2d(16, 32, 3, padding=1)
        self.pool = nn.AdaptiveAvgPool2d(1)
        self.fc = nn.Linear(32, 10)
    def forward(self, x):
        x = self.relu(self.conv1(x))
        x = self.relu(self.conv2(x))
        x = self.pool(x).flatten(1)
        return self.fc(x)

model = TinyCNN()
inp = torch.randn(1, 3, 64, 64)

lat0 = measure_latency_ms(model, inp)
nz0, tot0, s0 = count_nonzero_params(model)
size0 = save_size_mb(model)

# Prune entire output channels from conv2 (e.g., 50%)
# amount is fraction of channels to remove; dim=0 means output channels
prune.ln_structured(model.conv2, name="weight", amount=0.5, n=2, dim=0)
prune.remove(model.conv2, "weight")

lat1 = measure_latency_ms(model, inp)
nz1, tot1, s1 = count_nonzero_params(model)
size1 = save_size_mb(model)

print(f"Latency (ms)   before/after: {lat0:.3f} / {lat1:.3f}")
print(f"Sparsity       before/after: {s0*100:.1f}% / {s1*100:.1f}%")
print(f"Model size (MB)before/after: {size0:.3f} / {size1:.3f}")

Enter fullscreen mode Exit fullscreen mode

Why structured helps: removed channels shrink subsequent ops → actual FLOP reduction and stable CPU speedups.

B. QUANTIZATION (PyTorch)
B1) Dynamic quantization (fastest path; great for Linear/LSTM on CPU)

# dl_quant_dynamic.py
import torch
import torch.nn as nn
from utils_perf import measure_latency_ms, save_size_mb

class TinyRNN(nn.Module):
    def __init__(self, d_in=32, d_hidden=64, d_out=6):
        super().__init__()
        self.rnn = nn.LSTM(d_in, d_hidden, num_layers=1, batch_first=True)
        self.fc = nn.Linear(d_hidden, d_out)
    def forward(self, x):
        # x: [B, T, d_in]
        y, _ = self.rnn(x)
        return self.fc(y[:, -1, :])

model_fp32 = TinyRNN().eval()
inp = torch.randn(1, 20, 32)

lat0 = measure_latency_ms(model_fp32, inp)
size0 = save_size_mb(model_fp32)

# Dynamic quantize only supported layer types (Linear, LSTM)
model_int8 = torch.ao.quantization.quantize_dynamic(
    model_fp32, {nn.Linear, nn.LSTM}, dtype=torch.qint8
).eval()

lat1 = measure_latency_ms(model_int8, inp)
size1 = save_size_mb(model_int8)

print(f"Latency (ms)   FP32 / INT8(dynamic): {lat0:.3f} / {lat1:.3f}")
print(f"Model size (MB)FP32 / INT8(dynamic): {size0:.3f} / {size1:.3f}")


Enter fullscreen mode Exit fullscreen mode

Use when: CPU edge device, MLP/RNN control heads, quick wins without retraining.

B2) Static PTQ (FX graph mode) for a tiny CNN

# dl_quant_static_ptq.py
import torch
import torch.nn as nn
from torch.ao.quantization import get_default_qconfig_mapping
from torch.ao.quantization.quantize_fx import prepare_fx, convert_fx
from utils_perf import measure_latency_ms, save_size_mb

class TinyCNN(nn.Module):
    def __init__(self):
        super().__init__()
        self.seq = nn.Sequential(
            nn.Conv2d(3, 16, 3, stride=1, padding=1),
            nn.ReLU(),
            nn.Conv2d(16, 16, 3, stride=1, padding=1),
            nn.ReLU(),
            nn.AdaptiveAvgPool2d(1),
            nn.Flatten(),
            nn.Linear(16, 10)
        )
    def forward(self, x):
        return self.seq(x)

model = TinyCNN().eval()
example = torch.randn(1, 3, 64, 64)

lat0 = measure_latency_ms(model, example)
size0 = save_size_mb(model)

# 1) Choose backend/qconfig
backend = "qnnpack"  # good for ARM/Android; "fbgemm" for x86
torch.backends.quantized.engine = backend
qconfig_mapping = get_default_qconfig_mapping(backend)

# 2) Prepare FX graph
prepared = prepare_fx(model, {"": qconfig_mapping})

# 3) Calibrate with a small representative set
prepared(torch.randn(1,3,64,64))
prepared(torch.randn(1,3,64,64))
prepared(torch.randn(1,3,64,64))

# 4) Convert to int8
int8_model = convert_fx(prepared).eval()

lat1 = measure_latency_ms(int8_model, example)
size1 = save_size_mb(int8_model)

print(f"Latency (ms)   FP32 / INT8(static): {lat0:.3f} / {lat1:.3f}")
print(f"Model size (MB)FP32 / INT8(static): {size0:.3f} / {size1:.3f}")
Enter fullscreen mode Exit fullscreen mode

Calibrate carefully: use a few hundred real sensor samples for best accuracy.

B3) QAT sketch (best accuracy @ int8 for CNNs)

# dl_qat_skeleton.py
import torch
import torch.nn as nn
import torch.optim as optim
from torch.ao.quantization import get_default_qat_qconfig_mapping
from torch.ao.quantization.quantize_fx import prepare_qat_fx, convert_fx

class TinyCNN(nn.Module):
    def __init__(self):
        super().__init__()
        self.seq = nn.Sequential(
            nn.Conv2d(3, 16, 3, padding=1), nn.ReLU(),
            nn.Conv2d(16,16,3,padding=1),   nn.ReLU(),
            nn.AdaptiveAvgPool2d(1), nn.Flatten(), nn.Linear(16, 10)
        )
    def forward(self, x): return self.seq(x)

model = TinyCNN().train()
backend = "qnnpack"
torch.backends.quantized.engine = backend
qconfig_mapping = get_default_qat_qconfig_mapping(backend)

# Insert fake-quant observers
model_qat = prepare_qat_fx(model, {"": qconfig_mapping}).train()

opt = optim.Adam(model_qat.parameters(), lr=1e-3)

# Train as usual (fake-quant active)
for step in range(200):  # demo loop
    x = torch.randn(16,3,64,64)
    y = torch.randint(0,10,(16,))
    logits = model_qat(x)
    loss = nn.CrossEntropyLoss()(logits, y)
    opt.zero_grad(); loss.backward(); opt.step()

# Convert to real int8
model_int8 = convert_fx(model_qat.eval()).eval()

Enter fullscreen mode Exit fullscreen mode

Putting it together: which stages benefit most?

Perception & control networks (Conv/MLP/RNN heads) → Quantization (int8) + Structured pruning (channels) for real CPU gains.

Feature encoders (heavy backbones) → QAT (keep accuracy) + careful structured pruning of later blocks.

Classical ML pieces (trees, linear heads) → prune (cost-complexity / L1) and/or replace with smaller models; quantization usually handled by deployment runtime if needed.

Practical workflow (ASCII diagram)

Train FP32  --->  Profile (lat/size/acc)  --->  Choose targets
                                         \
                                          \--> Structured prune (channels/heads) -> Fine-tune
                                              -> PTQ (dynamic/static) or QAT
                                              -> Re-profile (p50/p95 latency, acc, size)
                                              -> Iterate until SLA met
Enter fullscreen mode Exit fullscreen mode

Which Stages Benefit Most from Quantization or Pruning?

Best Stage to Apply Quantization/Pruning

The Model / Inference Stage
This is where matrix multiplications and decision logic happen.

❌ Not Useful to Quantize/Prune

So:

The parts that benefit the most are:

Neural network layers (Linear, Conv, GRU/LSTM)

Decision trees / Random Forests (by pruning depth or removing weak nodes)

Large linear models (by pruning small-magnitude weights)

Part 1: Classical ML Example

We will:

Train Logistic Regression

Prune small weights

Quantize model weights to float16 to reduce memory + improve speed

# ===== classical_prune_quant.py =====
import numpy as np
from sklearn.datasets import load_iris
from sklearn.preprocessing import StandardScaler
from sklearn.linear_model import LogisticRegression
from sklearn.pipeline import Pipeline
from sklearn.metrics import accuracy_score

# Load dataset
data = load_iris()
X, y = data.data, data.target

# Build pipeline
pipe = Pipeline([
    ("scaler", StandardScaler()),
    ("clf", LogisticRegression(max_iter=500, multi_class='multinomial'))
])

pipe.fit(X, y)
clf = pipe.named_steps["clf"]

print("Original weights shape:", clf.coef_.shape)

# ----- PRUNING: remove small weights -----
threshold = np.percentile(np.abs(clf.coef_), 20)   # prune 20% smallest weights
mask = np.abs(clf.coef_) > threshold
clf.coef_ = clf.coef_ * mask

print("Pruned weights mean magnitude:", np.mean(np.abs(clf.coef_)))

# ----- QUANTIZATION: convert to float16 -----
clf.coef_  = clf.coef_.astype(np.float16)
clf.intercept_ = clf.intercept_.astype(np.float16)

print("After quantization dtype:", clf.coef_.dtype)

# Check accuracy
y_pred = pipe.predict(X)
print("Accuracy after prune + quant:", accuracy_score(y, y_pred))
Enter fullscreen mode Exit fullscreen mode

Result:

Model is lighter, smaller, and usually faster on edge CPUs.

Accuracy remains similar (because small weights typically don’t matter).

Part 2: Deep Learning Example (PyTorch)

We will:

Train a tiny MLP

Prune weights with torch.nn.utils.prune

Quantize using PyTorch dynamic quantization

Compare latency

# ===== dl_prune_quant.py =====
import torch
import torch.nn as nn
import torch.nn.utils.prune as prune
import time
import numpy as np

# Create dummy dataset
X = torch.randn(300, 16)
y = torch.randint(0, 3, (300,))

# Tiny MLP model
class MLP(nn.Module):
    def __init__(self):
        super().__init__()
        self.fc1 = nn.Linear(16, 32)
        self.fc2 = nn.Linear(32, 3)
    def forward(self, x):
        return self.fc2(torch.relu(self.fc1(x)))

model = MLP()
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-2)

# Train briefly
for _ in range(200):
    optimizer.zero_grad()
    loss = criterion(model(X), y)
    loss.backward()
    optimizer.step()

# -------- PRUNING --------
prune.l1_unstructured(model.fc1, name="weight", amount=0.3)  # prune 30% of weights
prune.remove(model.fc1, 'weight')  # finalize mask

# -------- QUANTIZATION --------
quantized_model = torch.quantization.quantize_dynamic(
    model, {nn.Linear}, dtype=torch.qint8
)

# -------- LATENCY TEST --------
def bench(model):
    model.eval()
    times=[]
    for _ in range(500):
        inp = torch.randn(1,16)
        t0=time.time()
        _=model(inp)
        times.append((time.time()-t0)*1000)
    return np.mean(times), np.percentile(times,95)

fp32_mean, fp32_p95 = bench(model)
int8_mean, int8_p95 = bench(quantized_model)

print("FP32 latency avg:", fp32_mean, "ms  p95:", fp32_p95)
print("INT8 latency avg:", int8_mean, "ms  p95:", int8_p95)
Enter fullscreen mode Exit fullscreen mode

Top comments (0)