..
Neural network regularization
Core Techniques
-
L1/L2 Regularization
Loss = Original_Loss + λ * Regularization_term
- L1: Sum of absolute weights $$ L1 = \lambda \sum_{i} |w_i| $$
- L2: Sum of squared weights $$ L2 = \lambda \sum_{i} w_i^2 $$
-
Dropout
- Randomly “turns off” neurons during training
- Typical dropout rates: 0.2-0.5
- Acts as ensemble learning
-
Early Stopping
- Monitors validation performance
- Stops when validation error starts increasing
-
Data Augmentation
- Transforms training data
- Common in image tasks: rotation, scaling, flipping
-
Batch Normalization $$ \hat{x}^{(k)} = \frac{x^{(k)} - \mu_B}{\sqrt{\sigma_B^2 + \epsilon}} $$
- Normalizes layer inputs
- Reduces internal covariate shift
Benefits
-
Prevents Overfitting
- Limits model complexity
- Reduces overfit to training data
-
Improves Generalization
- Better performance on new data
- Reduces test error
-
Enhances Robustness
- Reduces noise sensitivity
- Improves model stability
2. Connection with Lipschitz Continuity
Mathematical Foundation
-
Gradient Clipping and Lipschitz Constraint
- For L-Lipschitz functions: $$ |\nabla f(x)| \leq L $$
# Gradient clipping implementation max_grad_norm = L torch.nn.utils.clip_grad_norm_(model.parameters(), max_grad_norm)
-
Wasserstein GAN Application $$W(p_r, p_g) = \sup_{f \in \mathcal{F}_L} \mathbb{E}_r[f(x)] - \mathbb{E}_g[f(x)]$$
- Requires 1-Lipschitz continuous discriminator
-
Robustness Guarantee $$ |f(x + \delta) - f(x)| \leq L|\delta| $$
- L bounds output change for input perturbation
Practical Implications
-
Optimization Convergence
- Learning rate condition: $$ \eta < \frac{2}{L} $$
-
Implementation Example
class LipschitzLayer(nn.Module): def __init__(self, L): super().__init__() self.L = L def forward(self, x): W = self.get_normalized_weights() return F.linear(x, W) def get_normalized_weights(self): with torch.no_grad(): W = self.weight W_norm = torch.norm(W, p=2) if W_norm > self.L: W = W * (self.L / W_norm) return W
-
Theoretical Bounds
- Generalization error bound: $$ \text{Generalization Error} \leq O\left(\frac{L}{\sqrt{n}}\right) $$ where n is the training sample size