Debug School

rakesh kumar
rakesh kumar

Posted on

How to Make Machine Learning Models Faster and Lighter

Pruning

= delete (or zero-out) parts of a model that contribute little → fewer FLOPs, fewer parameters, faster inference.

Quantization

= store & compute with fewer bits (e.g., 8-bit instead of 32-bit floats) → smaller memory, higher cache hits, faster CPU ops.

Both aim to fit tight edge budgets (CPU-only, small RAM) while keeping accuracy good enough for real-time control.

Where they act in a network

Two flavors of pruning

Quantization pipeline (typical PTQ)

PTQ (Post-Training Quantization): No training required. Fastest path to int8.

QAT (Quantization-Aware Training): Train with fake-quant modules → better accuracy at int8, especially for CNNs.

Why robots care

Latency (ms-level) & consistency (low jitter) matter for control loops.

Size matters (RAM/flash budgets).

Robustness to noise: simpler models + calibrated quantization + structured pruning → fewer surprises.

Metrics to watch

Accuracy (task metric)

Latency (avg & p95/p99)

Model size (MB)

Sparsity (% zeros) & MACs/FLOPs

Energy (optional but relevant on battery)

ML EXAMPLES (scikit-learn)

We’ll show:

Decision Tree pruning (cost-complexity)

L1 “pruning” of linear/logistic models (drives coefficients to zero)

Note: Classic scikit-learn doesn’t do int8 quantization of models end-to-end; for edge you often export to ONNX + use runtimes that quantize, or you choose small models + pruning/L1.

1) Decision Tree – cost-complexity pruning

# ml_pruning_tree.py
import time
import numpy as np
from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split
from sklearn.tree import DecisionTreeClassifier
from sklearn.tree import plot_tree  # optional (for visualization)
from sklearn.metrics import accuracy_score

X, y = load_iris(return_X_y=True)
Xtr, Xte, ytr, yte = train_test_split(X, y, test_size=0.2, random_state=42)

# Baseline (unpruned)
base = DecisionTreeClassifier(random_state=42)
base.fit(Xtr, ytr)

# Cost complexity pruning path gives candidate ccp_alphas
path = base.cost_complexity_pruning_path(Xtr, ytr)
ccp_alphas = path.ccp_alphas

best = None
best_stats = None

for ccp in ccp_alphas:
    clf = DecisionTreeClassifier(random_state=42, ccp_alpha=ccp)
    clf.fit(Xtr, ytr)
    # measure latency per sample (simple timing)
    t0 = time.time()
    ypred = clf.predict(Xte)
    t1 = time.time()
    acc = accuracy_score(yte, ypred)
    latency_ms = (t1 - t0) / len(Xte) * 1000.0
    n_nodes = clf.tree_.node_count
    stats = (acc, latency_ms, n_nodes, ccp)
    if best is None or (acc > best_stats[0]) or (acc == best_stats[0] and latency_ms < best_stats[1]):
        best, best_stats = clf, stats

print(f"Best pruned tree:")
print(f"  accuracy = {best_stats[0]:.3f}")
print(f"  latency  = {best_stats[1]:.3f} ms/sample")
print(f"  nodes    = {best_stats[2]} (ccp_alpha={best_stats[3]:.6f})")


What this gives: a smaller tree → fewer branches → faster, more stable inference on CPU.
Enter fullscreen mode Exit fullscreen mode

2) L1 “pruning” (sparse coefficients)

# ml_pruning_l1.py
import time
import numpy as np
from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import accuracy_score

X, y = load_iris(return_X_y=True)
Xtr, Xte, ytr, yte = train_test_split(X, y, test_size=0.2, random_state=42)

# L1 penalty drives many weights to zero (like pruning)
clf = LogisticRegression(penalty='l1', solver='saga', C=0.5, max_iter=500, multi_class='auto')
clf.fit(Xtr, ytr)

t0 = time.time()
yp = clf.predict(Xte)
t1 = time.time()

acc = accuracy_score(yte, yp)
latency_ms = (t1 - t0) / len(Xte) * 1000.0
sparsity = np.mean(clf.coef_ == 0.0)

print(f"Accuracy      : {acc:.3f}")
print(f"Latency       : {latency_ms:.3f} ms/sample")
print(f"Weight zeros  : {sparsity*100:.1f}%")


Enter fullscreen mode Exit fullscreen mode

Takeaway: Smaller effective feature set → cache-friendly, consistent latency.

DL EXAMPLES (PyTorch)

We’ll show:

Unstructured & structured pruning via torch.nn.utils.prune

Dynamic PTQ (int8) for Linear/LSTM

Static PTQ (FX graph mode) for a tiny CNN (with calibration)

QAT sketch for best accuracy at int8

These run on CPU and illustrate what to change. Replace synthetic data with your sensor features.

Helper: latency + size + sparsity utilities
Enter fullscreen mode Exit fullscreen mode
# utils_perf.py
import time
import torch
import os

def measure_latency_ms(model, inp, n_warm=10, n_runs=50):
    model.eval()
    with torch.no_grad():
        for _ in range(n_warm):
            _ = model(inp)
        t0 = time.time()
        for _ in range(n_runs):
            _ = model(inp)
        t1 = time.time()
    return (t1 - t0) / n_runs * 1000.0

def count_nonzero_params(model):
    nz = 0
    tot = 0
    for p in model.parameters():
        tot += p.numel()
        nz += (p != 0).sum().item()
    return nz, tot, 1.0 - nz / tot

def save_size_mb(model, path="temp.pth"):
    torch.save(model.state_dict(), path)
    sz = os.path.getsize(path) / (1024*1024)
    os.remove(path)
    return sz
Enter fullscreen mode Exit fullscreen mode

A. PRUNING (PyTorch)
A1) Unstructured pruning (magnitude)

# dl_prune_unstructured.py
import torch
import torch.nn as nn
import torch.nn.utils.prune as prune
from utils_perf import measure_latency_ms, count_nonzero_params, save_size_mb

# Tiny MLP for demonstration
class MLP(nn.Module):
    def __init__(self, d_in=64, d_hidden=128, d_out=10):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(d_in, d_hidden),
            nn.ReLU(),
            nn.Linear(d_hidden, d_out)
        )
    def forward(self, x):
        return self.net(x)

model = MLP()
inp = torch.randn(1, 64)

lat0 = measure_latency_ms(model, inp)
nz0, tot0, sparsity0 = count_nonzero_params(model)
size0 = save_size_mb(model)

# Apply 80% unstructured pruning on Linear weights
for m in model.modules():
    if isinstance(m, nn.Linear):
        prune.l1_unstructured(m, name="weight", amount=0.8)

# Remove reparametrization (make pruning permanent)
for m in model.modules():
    if isinstance(m, nn.Linear) and hasattr(m, "weight_orig"):
        prune.remove(m, "weight")

lat1 = measure_latency_ms(model, inp)
nz1, tot1, sparsity1 = count_nonzero_params(model)
size1 = save_size_mb(model)

print(f"Latency (ms)   before/after: {lat0:.3f} / {lat1:.3f}")
print(f"Sparsity       before/after: {sparsity0*100:.1f}% / {sparsity1*100:.1f}%")
print(f"Model size (MB)before/after: {size0:.3f} / {size1:.3f}")

Enter fullscreen mode Exit fullscreen mode

Note: Unstructured zeros may not speed up on vanilla kernels; you get memory benefits and speedups only if using sparse backends. For deterministic speed on CPU, prefer structured pruning.

A2) Structured channel pruning (conv channels)

# dl_prune_structured.py
import torch
import torch.nn as nn
import torch.nn.utils.prune as prune
from utils_perf import measure_latency_ms, count_nonzero_params, save_size_mb

class TinyCNN(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(3, 16, 3, padding=1)
        self.relu = nn.ReLU()
        self.conv2 = nn.Conv2d(16, 32, 3, padding=1)
        self.pool = nn.AdaptiveAvgPool2d(1)
        self.fc = nn.Linear(32, 10)
    def forward(self, x):
        x = self.relu(self.conv1(x))
        x = self.relu(self.conv2(x))
        x = self.pool(x).flatten(1)
        return self.fc(x)

model = TinyCNN()
inp = torch.randn(1, 3, 64, 64)

lat0 = measure_latency_ms(model, inp)
nz0, tot0, s0 = count_nonzero_params(model)
size0 = save_size_mb(model)

# Prune entire output channels from conv2 (e.g., 50%)
# amount is fraction of channels to remove; dim=0 means output channels
prune.ln_structured(model.conv2, name="weight", amount=0.5, n=2, dim=0)
prune.remove(model.conv2, "weight")

lat1 = measure_latency_ms(model, inp)
nz1, tot1, s1 = count_nonzero_params(model)
size1 = save_size_mb(model)

print(f"Latency (ms)   before/after: {lat0:.3f} / {lat1:.3f}")
print(f"Sparsity       before/after: {s0*100:.1f}% / {s1*100:.1f}%")
print(f"Model size (MB)before/after: {size0:.3f} / {size1:.3f}")

Enter fullscreen mode Exit fullscreen mode

Why structured helps: removed channels shrink subsequent ops → actual FLOP reduction and stable CPU speedups.

B. QUANTIZATION (PyTorch)
B1) Dynamic quantization (fastest path; great for Linear/LSTM on CPU)

# dl_quant_dynamic.py
import torch
import torch.nn as nn
from utils_perf import measure_latency_ms, save_size_mb

class TinyRNN(nn.Module):
    def __init__(self, d_in=32, d_hidden=64, d_out=6):
        super().__init__()
        self.rnn = nn.LSTM(d_in, d_hidden, num_layers=1, batch_first=True)
        self.fc = nn.Linear(d_hidden, d_out)
    def forward(self, x):
        # x: [B, T, d_in]
        y, _ = self.rnn(x)
        return self.fc(y[:, -1, :])

model_fp32 = TinyRNN().eval()
inp = torch.randn(1, 20, 32)

lat0 = measure_latency_ms(model_fp32, inp)
size0 = save_size_mb(model_fp32)

# Dynamic quantize only supported layer types (Linear, LSTM)
model_int8 = torch.ao.quantization.quantize_dynamic(
    model_fp32, {nn.Linear, nn.LSTM}, dtype=torch.qint8
).eval()

lat1 = measure_latency_ms(model_int8, inp)
size1 = save_size_mb(model_int8)

print(f"Latency (ms)   FP32 / INT8(dynamic): {lat0:.3f} / {lat1:.3f}")
print(f"Model size (MB)FP32 / INT8(dynamic): {size0:.3f} / {size1:.3f}")


Enter fullscreen mode Exit fullscreen mode

Use when: CPU edge device, MLP/RNN control heads, quick wins without retraining.

B2) Static PTQ (FX graph mode) for a tiny CNN

# dl_quant_static_ptq.py
import torch
import torch.nn as nn
from torch.ao.quantization import get_default_qconfig_mapping
from torch.ao.quantization.quantize_fx import prepare_fx, convert_fx
from utils_perf import measure_latency_ms, save_size_mb

class TinyCNN(nn.Module):
    def __init__(self):
        super().__init__()
        self.seq = nn.Sequential(
            nn.Conv2d(3, 16, 3, stride=1, padding=1),
            nn.ReLU(),
            nn.Conv2d(16, 16, 3, stride=1, padding=1),
            nn.ReLU(),
            nn.AdaptiveAvgPool2d(1),
            nn.Flatten(),
            nn.Linear(16, 10)
        )
    def forward(self, x):
        return self.seq(x)

model = TinyCNN().eval()
example = torch.randn(1, 3, 64, 64)

lat0 = measure_latency_ms(model, example)
size0 = save_size_mb(model)

# 1) Choose backend/qconfig
backend = "qnnpack"  # good for ARM/Android; "fbgemm" for x86
torch.backends.quantized.engine = backend
qconfig_mapping = get_default_qconfig_mapping(backend)

# 2) Prepare FX graph
prepared = prepare_fx(model, {"": qconfig_mapping})

# 3) Calibrate with a small representative set
prepared(torch.randn(1,3,64,64))
prepared(torch.randn(1,3,64,64))
prepared(torch.randn(1,3,64,64))

# 4) Convert to int8
int8_model = convert_fx(prepared).eval()

lat1 = measure_latency_ms(int8_model, example)
size1 = save_size_mb(int8_model)

print(f"Latency (ms)   FP32 / INT8(static): {lat0:.3f} / {lat1:.3f}")
print(f"Model size (MB)FP32 / INT8(static): {size0:.3f} / {size1:.3f}")
Enter fullscreen mode Exit fullscreen mode

Calibrate carefully: use a few hundred real sensor samples for best accuracy.

B3) QAT sketch (best accuracy @ int8 for CNNs)

# dl_qat_skeleton.py
import torch
import torch.nn as nn
import torch.optim as optim
from torch.ao.quantization import get_default_qat_qconfig_mapping
from torch.ao.quantization.quantize_fx import prepare_qat_fx, convert_fx

class TinyCNN(nn.Module):
    def __init__(self):
        super().__init__()
        self.seq = nn.Sequential(
            nn.Conv2d(3, 16, 3, padding=1), nn.ReLU(),
            nn.Conv2d(16,16,3,padding=1),   nn.ReLU(),
            nn.AdaptiveAvgPool2d(1), nn.Flatten(), nn.Linear(16, 10)
        )
    def forward(self, x): return self.seq(x)

model = TinyCNN().train()
backend = "qnnpack"
torch.backends.quantized.engine = backend
qconfig_mapping = get_default_qat_qconfig_mapping(backend)

# Insert fake-quant observers
model_qat = prepare_qat_fx(model, {"": qconfig_mapping}).train()

opt = optim.Adam(model_qat.parameters(), lr=1e-3)

# Train as usual (fake-quant active)
for step in range(200):  # demo loop
    x = torch.randn(16,3,64,64)
    y = torch.randint(0,10,(16,))
    logits = model_qat(x)
    loss = nn.CrossEntropyLoss()(logits, y)
    opt.zero_grad(); loss.backward(); opt.step()

# Convert to real int8
model_int8 = convert_fx(model_qat.eval()).eval()

Enter fullscreen mode Exit fullscreen mode

Putting it together: which stages benefit most?

Perception & control networks (Conv/MLP/RNN heads) → Quantization (int8) + Structured pruning (channels) for real CPU gains.

Feature encoders (heavy backbones) → QAT (keep accuracy) + careful structured pruning of later blocks.

Classical ML pieces (trees, linear heads) → prune (cost-complexity / L1) and/or replace with smaller models; quantization usually handled by deployment runtime if needed.

Practical workflow (ASCII diagram)

Train FP32  --->  Profile (lat/size/acc)  --->  Choose targets
                                         \
                                          \--> Structured prune (channels/heads) -> Fine-tune
                                              -> PTQ (dynamic/static) or QAT
                                              -> Re-profile (p50/p95 latency, acc, size)
                                              -> Iterate until SLA met
Enter fullscreen mode Exit fullscreen mode

Which Stages Benefit Most from Quantization or Pruning?

Best Stage to Apply Quantization/Pruning

The Model / Inference Stage
This is where matrix multiplications and decision logic happen.

❌ Not Useful to Quantize/Prune

So:

The parts that benefit the most are:

Neural network layers (Linear, Conv, GRU/LSTM)

Decision trees / Random Forests (by pruning depth or removing weak nodes)

Large linear models (by pruning small-magnitude weights)

Part 1: Classical ML Example

We will:

Train Logistic Regression

Prune small weights

Quantize model weights to float16 to reduce memory + improve speed

# ===== classical_prune_quant.py =====
import numpy as np
from sklearn.datasets import load_iris
from sklearn.preprocessing import StandardScaler
from sklearn.linear_model import LogisticRegression
from sklearn.pipeline import Pipeline
from sklearn.metrics import accuracy_score

# Load dataset
data = load_iris()
X, y = data.data, data.target

# Build pipeline
pipe = Pipeline([
    ("scaler", StandardScaler()),
    ("clf", LogisticRegression(max_iter=500, multi_class='multinomial'))
])

pipe.fit(X, y)
clf = pipe.named_steps["clf"]

print("Original weights shape:", clf.coef_.shape)

# ----- PRUNING: remove small weights -----
threshold = np.percentile(np.abs(clf.coef_), 20)   # prune 20% smallest weights
mask = np.abs(clf.coef_) > threshold
clf.coef_ = clf.coef_ * mask

print("Pruned weights mean magnitude:", np.mean(np.abs(clf.coef_)))

# ----- QUANTIZATION: convert to float16 -----
clf.coef_  = clf.coef_.astype(np.float16)
clf.intercept_ = clf.intercept_.astype(np.float16)

print("After quantization dtype:", clf.coef_.dtype)

# Check accuracy
y_pred = pipe.predict(X)
print("Accuracy after prune + quant:", accuracy_score(y, y_pred))
Enter fullscreen mode Exit fullscreen mode

Result:

Model is lighter, smaller, and usually faster on edge CPUs.

Accuracy remains similar (because small weights typically don’t matter).

Part 2: Deep Learning Example (PyTorch)

We will:

Train a tiny MLP

Prune weights with torch.nn.utils.prune

Quantize using PyTorch dynamic quantization

Compare latency

# ===== dl_prune_quant.py =====
import torch
import torch.nn as nn
import torch.nn.utils.prune as prune
import time
import numpy as np

# Create dummy dataset
X = torch.randn(300, 16)
y = torch.randint(0, 3, (300,))

# Tiny MLP model
class MLP(nn.Module):
    def __init__(self):
        super().__init__()
        self.fc1 = nn.Linear(16, 32)
        self.fc2 = nn.Linear(32, 3)
    def forward(self, x):
        return self.fc2(torch.relu(self.fc1(x)))

model = MLP()
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-2)

# Train briefly
for _ in range(200):
    optimizer.zero_grad()
    loss = criterion(model(X), y)
    loss.backward()
    optimizer.step()

# -------- PRUNING --------
prune.l1_unstructured(model.fc1, name="weight", amount=0.3)  # prune 30% of weights
prune.remove(model.fc1, 'weight')  # finalize mask

# -------- QUANTIZATION --------
quantized_model = torch.quantization.quantize_dynamic(
    model, {nn.Linear}, dtype=torch.qint8
)

# -------- LATENCY TEST --------
def bench(model):
    model.eval()
    times=[]
    for _ in range(500):
        inp = torch.randn(1,16)
        t0=time.time()
        _=model(inp)
        times.append((time.time()-t0)*1000)
    return np.mean(times), np.percentile(times,95)

fp32_mean, fp32_p95 = bench(model)
int8_mean, int8_p95 = bench(quantized_model)

print("FP32 latency avg:", fp32_mean, "ms  p95:", fp32_p95)
print("INT8 latency avg:", int8_mean, "ms  p95:", int8_p95)
Enter fullscreen mode Exit fullscreen mode

Full Code — ML (scikit-learn + NumPy int8)

This example shows:

Train a Logistic Regression on the digits dataset.

Prune small coefficients by magnitude.

Quantize weights and inputs to int8 with scale/zero-point.

Evaluate accuracy and average latency.

# ===== ML PRUNING + INT8 QUANTIZATION (SCIKIT + NUMPY) =====
import numpy as np
import time
from sklearn.datasets import load_digits
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import accuracy_score

# --- Data ---
digits = load_digits()
X = digits.data.astype(np.float32)  # shape (n, 64)
y = digits.target
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42, stratify=y)

# Standardize inputs (stable normalization for noisy sensors)
scaler = StandardScaler()
X_train = scaler.fit_transform(X_train)
X_test  = scaler.transform(X_test)

# --- Baseline model ---
clf = LogisticRegression(max_iter=200, solver="saga", penalty="l2", n_jobs=-1)
clf.fit(X_train, y_train)
base_pred = clf.predict(X_test)
base_acc = accuracy_score(y_test, base_pred)

# --- PRUNING: zero out small coefficients by magnitude ---
def prune_coef(clf, sparsity=0.3):
    # sparsity = fraction of weights to zero by magnitude (per class)
    W = clf.coef_.copy()  # [n_classes, n_features]
    B = clf.intercept_.copy()

    pruned_W = W.copy()
    for c in range(W.shape[0]):
        w = W[c]
        k = int(np.floor(sparsity * w.size))
        if k <= 0: 
            continue
        thresh = np.partition(np.abs(w), k)[k]
        mask = np.abs(w) >= thresh
        pruned_W[c] = w * mask  # zero-out small weights

    return pruned_W, B

pruned_W, pruned_B = prune_coef(clf, sparsity=0.3)

# --- QUANTIZATION: per-tensor int8 for weights and inputs ---
# Helper: affine quantization (symmetric zero-point for simplicity)
def quantize_per_tensor(x, num_bits=8):
    qmin, qmax = -128, 127
    max_abs = np.max(np.abs(x)) + 1e-8
    scale = max_abs / qmax
    zp = 0
    x_q = np.clip(np.round(x / scale), qmin, qmax).astype(np.int8)
    return x_q, scale, zp

def dequantize_per_tensor(x_q, scale, zp=0):
    return (x_q.astype(np.float32) - zp) * scale

# Quantize weights (one scale per class to keep it simple)
W_q_list, W_scales = [], []
for c in range(pruned_W.shape[0]):
    W_q, s, _ = quantize_per_tensor(pruned_W[c])
    W_q_list.append(W_q)
    W_scales.append(s)
W_q = np.stack(W_q_list, axis=0).astype(np.int8)   # [C, F]
W_scales = np.array(W_scales, dtype=np.float32)    # [C]
B_q, B_scale, _ = quantize_per_tensor(pruned_B)    # bias int8; small accuracy hit but shows the idea

# Quantize inputs with a single global scale (simple & fast)
X_scale = np.max(np.abs(X_train)) / 127.0 + 1e-8

def predict_int8(Xf):
    # Quantize input
    X_q = np.clip(np.round(Xf / X_scale), -128, 127).astype(np.int8)  # [N, F]

    # Matmul in int32: logits_q = W_q @ X_q.T + B_q
    # We'll dequantize to float for soft argmax (largest logit).
    # Effective scale per class = (W_scale[c] * X_scale)
    logits = []
    X_q_T = X_q.T.astype(np.int32)  # [F, N]
    for c in range(W_q.shape[0]):
        # int8 dot → int32
        dot_int32 = (W_q[c].astype(np.int32) @ X_q_T)  # [N]
        # dequantize weights*inputs
        fl = dot_int32.astype(np.float32) * (W_scales[c] * X_scale)
        # add (dequantized) bias
        fl += dequantize_per_tensor(B_q[c], B_scale)
        logits.append(fl)
    logits = np.stack(logits, axis=1)  # [N, C]
    return np.argmax(logits, axis=1)

# --- Evaluate accuracy and latency ---
start = time.time()
y_pred_int8 = predict_int8(X_test)
lat = (time.time() - start) / len(X_test) * 1000.0  # ms/sample

int8_acc = accuracy_score(y_test, y_pred_int8)

print(f"Baseline float32 Accuracy: {base_acc*100:.2f}%")
print(f"Pruned+INT8 Accuracy     : {int8_acc*100:.2f}%")
print(f"Avg INT8 Latency/sample  : {lat:.3f} ms")
print(f"Sparsity used            : 30% of weights zeroed")
Enter fullscreen mode Exit fullscreen mode

What this shows

You get a compact model (sparser + int8) and lower per-sample latency on CPU.

You can tune sparsity and choose per-channel scales for better accuracy.

On edge devices, use an int8-friendly runtime (e.g., ONNX Runtime / TFLite) for further speedups—this code demonstrates the mechanics in pure NumPy.

Full Code — DL (PyTorch pruning + quantization)

We’ll:

Train a tiny MLP on the digits dataset.

Unstructured pruning (L1) + optional structured pruning (remove neurons).

Dynamic quantization to int8 for Linear layers (CPU-friendly).

Compare accuracy & latency.

# ===== DL PRUNING + DYNAMIC QUANTIZATION (PYTORCH) =====
import time
import numpy as np
from sklearn.datasets import load_digits
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
from sklearn.metrics import accuracy_score

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import TensorDataset, DataLoader
from torch.nn.utils import prune

# --- Data ---
digits = load_digits()
X = digits.data.astype(np.float32)  # (n, 64)
y = digits.target.astype(np.int64)

scaler = StandardScaler()
X = scaler.fit_transform(X)

X_train, X_test, y_train, y_test = train_test_split(
    X, y, test_size=0.2, random_state=42, stratify=y
)

train_ds = TensorDataset(torch.from_numpy(X_train), torch.from_numpy(y_train))
test_ds  = TensorDataset(torch.from_numpy(X_test),  torch.from_numpy(y_test))
train_loader = DataLoader(train_ds, batch_size=128, shuffle=True)
test_loader  = DataLoader(test_ds,  batch_size=256, shuffle=False)

device = torch.device("cpu")

# --- Model ---
class MLP(nn.Module):
    def __init__(self, in_dim=64, h=128, out_dim=10):
        super().__init__()
        self.fc1 = nn.Linear(in_dim, h)
        self.fc2 = nn.Linear(h, out_dim)

    def forward(self, x):
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        return x

model = MLP().to(device)

# --- Train (quick) ---
def train(model, loader, epochs=8, lr=3e-3):
    opt = torch.optim.Adam(model.parameters(), lr=lr)
    for ep in range(epochs):
        model.train()
        total = 0
        for xb, yb in loader:
            xb, yb = xb.to(device), yb.to(device)
            opt.zero_grad()
            logits = model(xb)
            loss = F.cross_entropy(logits, yb)
            loss.backward()
            opt.step()
            total += loss.item() * xb.size(0)
        # print(f"Epoch {ep+1}: loss={(total/len(loader.dataset)):.4f}")

def evaluate(model, loader):
    model.eval()
    preds, labels = [], []
    with torch.no_grad():
        for xb, yb in loader:
            logits = model(xb.to(device))
            preds.append(torch.argmax(logits, dim=1).cpu().numpy())
            labels.append(yb.numpy())
    y_pred = np.concatenate(preds)
    y_true = np.concatenate(labels)
    return accuracy_score(y_true, y_pred)

train(model, train_loader, epochs=8)
base_acc = evaluate(model, test_loader)

# --- PRUNING: Unstructured L1 on fully-connected layers (30%) ---
amount = 0.3
prune.l1_unstructured(model.fc1, name="weight", amount=amount)
prune.l1_unstructured(model.fc2, name="weight", amount=amount)

# Optional: remove pruning reparam to make zeros permanent
prune.remove(model.fc1, "weight")
prune.remove(model.fc2, "weight")

pruned_acc = evaluate(model, test_loader)

# --- DYNAMIC QUANTIZATION (int8) for Linear layers ---
# Works great on CPU for Linear/LSTM; no calibration needed.
quantized_model = torch.ao.quantization.quantize_dynamic(
    model, {nn.Linear}, dtype=torch.qint8
)

quant_acc = evaluate(quantized_model, test_loader)

# --- Latency (per-sample) ---
def avg_latency_ms(m, loader, repeat=50):
    m.eval()
    # warmup
    with torch.no_grad():
        for xb, _ in loader:
            m(xb.to(device)); break
    # measure
    N = 0
    t0 = time.time()
    with torch.no_grad():
        for _ in range(repeat):
            for xb, _ in loader:
                m(xb.to(device))
                N += xb.size(0)
    t1 = time.time()
    return (t1 - t0) / N * 1000.0

base_model = MLP().to(device)
base_model.load_state_dict(model.state_dict())  # already pruned weights are in
# (To measure "pre-pruning baseline", re-train another instance. For simplicity we re-use the trained one.)

lat_pruned = avg_latency_ms(model, test_loader, repeat=30)
lat_quant  = avg_latency_ms(quantized_model, test_loader, repeat=30)

print(f"Baseline (trained) Accuracy : {base_acc*100:.2f}%")
print(f"Pruned (30%) Accuracy       : {pruned_acc*100:.2f}%")
print(f"Quantized int8 Accuracy     : {quant_acc*100:.2f}%")
print(f"Latency pruned (ms/sample)  : {lat_pruned:.4f}")
print(f"Latency quant  (ms/sample)  : {lat_quant:.4f}")
Enter fullscreen mode Exit fullscreen mode

Top comments (0)