Effortless Bayesian Deep Learning through Laplace Redux

JuliaCon 2022

Patrick Altmeyer

Overview

  • The Case for Bayesian Deep Learning
  • Laplace Redux in Julia 📦
    • From Bayesian Logistic Regression …
    • … to Bayesian Neural Networks.
  • Goals and Ambitions 🎯

The Case for Bayesian Deep Learning

Bayesian Model Averaging

Don’t put all your 🥚 in one 🧺.

  • In Deep Learning we typically maximise highly non-convex functions full of local optima and saddle points.
  • There may be many \(\hat\theta_1, ..., \hat\theta_m\) that are slightly different, but yield similar performance.

[…] parameters correspond to a diverse variety of compelling explanations for the data. (Wilson 2020)

\(\theta\) is a random variable. Shouldn’t we treat it that way?

\[ p(y|x,\mathcal{D}) = \int p(y|x,\theta)p(\theta|\mathcal{D})d\theta \qquad(1)\]

Intractable!

In practice we typically rely on a plugin approximation (Murphy 2022).

\[ p(y|x,\mathcal{D}) = \int p(y|x,\theta)p(\theta|\mathcal{D})d\theta \approx p(y|x,\hat\theta) \qquad(2)\]

Yes, “plugin” is literal … can we do better?

Enter: Bayesian Deep Learning 🔮

Yes, we can!

MCMC (see Turing)

Variational Inference (Blundell et al. 2015)

Monte Carlo Dropout (Gal and Ghahramani 2016)

Laplace Redux (Immer, Korzepa, and Bauer (2020),Daxberger et al. (2021))

. . .

Figure 1: Pierre-Simon Laplace as chancellor of the Senate under the First French Empire. Source: Wikipedia
Figure 2: Simulation of changing posteriour predictive distribution. Image by author.

Laplace Approximation

We first need to estimate the weight posterior \(p(\theta|\mathcal{D})\)

Idea 💡: Taylor approximation at the mode.

  • Going through the maths we find that this yields a Gaussian posteriour centered around the MAP estimate \(\hat\theta\) (see pp. 148/149 in Murphy (2022)).
  • Covariance corresponds to inverse Hessian at the mode (in practice we may have to rely on approximations).

Unnormalized log-posterior and corresponding Laplace Approximation. Source: Murphy (2022).

Now we can rely on MC or Probit Approximation to compute posterior predictive (classification).

Laplace Redux in Julia

LaplaceRedux.jl - a small package 📦

Dev Build Status codecov codecov 95% 95%

What started out as my first coding project Julia …

  • Big fan of learning by coding so after reading the first chapters of Murphy (2022) I decided to code up Bayesian Logisitic Regression from scratch.
  • I also wanted to learn Julia at the time, so tried to hit two birds with one stone.
  • Outcome: 1. This blog post. 2. I have since been hooked on Julia.

… has turned into a small package 📦 with great potential.

  • When coming across the NeurIPS 2021 paper on Laplace Redux for deep learning (Daxberger et al. 2021), I figured I could step it up a notch.
  • Outcome: LaplaceRedux.jl and another blog post.

So let’s add the package …

using Pkg
Pkg.add("https://github.com/juliatrustworthyai/LaplaceRedux.jl")

… and use it.

using LaplaceRedux

From Bayesian Logistic Regression …

From maths …

. . .

We assume a Gaussian prior for our weights … \[ p(\theta) \sim \mathcal{N} \left( \theta | \mathbf{0}, \lambda^{-1} \mathbf{I} \right)=\mathcal{N} \left( \theta | \mathbf{0}, \mathbf{H}_0^{-1} \right) \qquad(3)\]

. . .

… which corresponds to logit binary crossentropy loss with weight decay:

\[ \ell(\theta)= - \sum_{n}^N [y_n \log \mu_n + (1-y_n)\log (1-\mu_n)] + \\ \frac{1}{2} (\theta-\theta_0)^T\mathbf{H}_0(\theta-\theta_0) \qquad(4)\]

. . .

For Logistic Regression we have the Hessian in closed form (p. 338 in Murphy (2022)):

\[ \nabla_{\theta}\nabla_{\theta}^\mathsf{T}\ell(\theta) = \frac{1}{N} \sum_{n}^N(\mu_n(1-\mu_n)\mathbf{x}_n)\mathbf{x}_n^\mathsf{T} + \mathbf{H}_0 \qquad(5)\]

… to code

. . .

# Hessian:
function ∇∇𝓁(θ,θ_0,H_0,X,y)
    N = length(y)
    μ = sigmoid(θ,X)
    H = (μ[n] * (1-μ[n]) * X[n,:] * X[n,:]' for n=1:N)
    return H + H_0
end

Gotta love Julia ❤️💜💚

. . .

Logistic Regression can be done in Flux

using Flux
# Initializing weights as zeros only for illustrative purposes:
nn = Chain(Dense(zeros(1,2),zeros(1))) 

. . .

… but now we autograd! Leveraged in LaplaceRedux.

la = Laplace(nn, λ=λ)
fit!(la, data)
Figure 3: Posterior predictive distribution of Logistic regression in the 2D feature space using plugin estimator (left) and Laplace approximation (right). Image by author.

… to Bayesian Neural Networks

Code

. . .

An actual MLP …

# Build MLP:
n_hidden = 32
D = size(X)[1]
nn = Chain(
    Dense(
      randn(n_hidden,D)./10,
      zeros(n_hidden), σ
    ),
    Dense(
      randn(1,n_hidden)./10,
      zeros(1)
    )
)  

. . .

… same API call:

la = Laplace(
  nn, λ=λ, 
  subset_of_weights=:last_layer
)
fit!(la, data)

. . .

Results

. . .

Figure 4: Posterior predictive distribution of MLP in the 2D feature space using plugin estimator (left) and Laplace approximation (right). Image by author.

A quick note on the prior

Low prior uncertainty \(\rightarrow\) posterior dominated by prior. High prior uncertainty \(\rightarrow\) posterior approaches MLE.

Logistic Regression

Figure 5: Prior uncertainty increases from left to right (Logsitic Regression). Image by author.

MLP

Figure 6: Prior uncertainty increases from left to right (MLP). Image by author.

A crucial detail I skipped

We’re really been using linearized neural networks …

MC fails

  • Could do Monte Carlo for true BNN predictive, but this performs poorly when using approximations for the Hessian.
  • Instead we rely on linear expansion of predictive around mode (Immer, Korzepa, and Bauer 2020).
  • Intuition: Hessian approximation involves linearization, then so should the predictive.

. . .

Applying the GNN approximation […] turns the underlying probabilistic model locally from a BNN into a GLM […] Because we have effectively done inference in the GGN-linearized model, we should instead predict using these modified features. — Immer, Korzepa, and Bauer (2020)

Figure 7: MC samples from the Laplace posterior (Lawrence 2001).

Goals and Ambitions 🎯

JuliaCon 2022 and beyond

To JuliaCon …

Learn about Laplace Redux by implementing it in Julia.

Turn code into a small package.

Submit to JuliaCon 2022 and share the idea.

… and beyond

. . .

Package is bare-bones at this point and needs a lot of work.

  • Goal: reach same level of maturity as Python counterpart. (Beautiful work btw!)
  • Problem: limited capacity and fairly new to Julia.
  • Solution: find contributors 🤗.

Photo by Ivan Diaz on Unsplash

Specific Goals

Easy

  • Still missing support for multi-class and regression.
  • Due diligence: peer review and unit testing.

Harder

  • Hessian approximations still quadratically large: use factorizations.
  • Hyperparameter tuning: what about that prior?
  • Scaling things up: subnetwork inference.
  • Early stopping: do we really end up at the mode?

Source: Giphy

More Resources 📚

Read on …

… or even better: get involved! 🤗

References

Blundell, Charles, Julien Cornebise, Koray Kavukcuoglu, and Daan Wierstra. 2015. “Weight Uncertainty in Neural Network.” In International Conference on Machine Learning, 1613–22. PMLR.
Daxberger, Erik, Agustinus Kristiadi, Alexander Immer, Runa Eschenhagen, Matthias Bauer, and Philipp Hennig. 2021. “Laplace Redux-Effortless Bayesian Deep Learning.” Advances in Neural Information Processing Systems 34.
Gal, Yarin, and Zoubin Ghahramani. 2016. “Dropout as a Bayesian Approximation: Representing Model Uncertainty in Deep Learning.” In International Conference on Machine Learning, 1050–59. PMLR.
Immer, Alexander, Maciej Korzepa, and Matthias Bauer. 2020. “Improving Predictions of Bayesian Neural Networks via Local Linearization.” https://arxiv.org/abs/2008.08400.
Lakshminarayanan, Balaji, Alexander Pritzel, and Charles Blundell. 2017. “Simple and Scalable Predictive Uncertainty Estimation Using Deep Ensembles.” Advances in Neural Information Processing Systems 30.
Lawrence, Neil David. 2001. “Variational Inference in Probabilistic Models.” PhD thesis, University of Cambridge.
Murphy, Kevin P. 2022. Probabilistic Machine Learning: An Introduction. MIT Press.
Wilson, Andrew Gordon. 2020. “The Case for Bayesian Deep Learning.” https://arxiv.org/abs/2001.10995.