Explaining Black-Box Models through Counterfactuals

JuliaCon 2022

Patrick Altmeyer

Overview

  • The Problem with Black Boxes ⬛
    • What are black-box models? Why do we need explainability?
  • Enter: Counterfactual Explanations 🔮
    • What are they? What are they not?
  • Counterfactual Explanations in Julia (and beyond!) 📦
  • Goals and Ambitions 🎯
    • Future developments - where can it go?
    • Contributor’s guide

The Problem with Black Boxes ⬛

Short Lists, Pandas and Gibbons

From human to data-driven decision-making …

  • Black-box models like deep neural networks are being deployed virtually everywhere.
  • Includes safety-critical and public domains: health care, autonomous driving, finance, …
  • More likely than not that your loan or employment application is handled by an algorithm.

… where black boxes are recipe for disaster.

  • We have no idea what exactly we’re cooking up …
    • Have you received an automated rejection email? Why didn’t you “mEet tHe sHoRtLisTiNg cRiTeRia”? 🙃
  • … but we do know that some of it is junk.
Figure 1: Adversarial attacks on deep neural networks. Source: Goodfellow, Shlens, and Szegedy (2014)

“Weapons of Math Destruction”

“You cannot appeal to (algorithms). They do not listen. Nor do they bend.”

— Cathy O’Neil in Weapons of Math Destruction, 2016

Figure 2: Cathy O’Neil. Source: Cathy O’Neil a.k.a. mathbabe.
  • If left unchallenged, these properties of black-box models can create undesirable dynamics in automated decision-making systems:
    • Human operators in charge of the system have to rely on it blindly.
    • Individuals subject to the decisions generally have no way to challenge their outcome.

Towards Trustworthy AI

Ground Truthing

Probabilistic Models

Counterfactual Reasoning

Towards Trustworthy AI

Ground Truthing

Probabilistic Models

Counterfactual Reasoning

Current Standard in ML

We typically want to maximize the likelihood of observing \(\mathcal{D}_n\) under given parameters (Murphy 2022):

\[ \theta^* = \arg \max_{\theta} p(\mathcal{D}_n|\theta) \qquad(1)\]

Compute an MLE (or MAP) point estimate \(\hat\theta = \mathbb{E} \theta^*\) and use plugin approximation for prediction:

\[ p(y|x,\mathcal{D}_n) \approx p(y|x,\hat\theta) \qquad(2)\]

  • In an ideal world we can just use parsimonious and interpretable models like GLM (Rudin 2019), for which in many cases we can rely on asymptotic properties of \(\theta\) to quantify uncertainty.
  • In practice these models often have performance limitations.
  • Black-box models like deep neural networks are popular, but they are also the very opposite of parsimonious.

Objective

Towards Trustworthy AI

Ground Truthing

Probabilistic Models

Counterfactual Reasoning

Objective

. . .

[…] deep neural networks are typically very underspecified by the available data, and […] parameters [therefore] correspond to a diverse variety of compelling explanations for the data. (Wilson 2020)

In this setting it is often crucial to treat models probabilistically!

\[ p(y|x,\mathcal{D}_n) = \int p(y|x,\theta)p(\theta|\mathcal{D}_n)d\theta \qquad(3)\]

. . .

Probabilistic models covered briefly today. More in my other talk on Laplace Redux …

Towards Trustworthy AI

Ground Truthing

Probabilistic Models

Counterfactual Reasoning

We can now make predictions – great! But do we know how the predictions are actually being made?

Objective

With the model trained for its task, we are interested in understanding how its predictions change in response to input changes.

\[ \nabla_x p(y|x,\mathcal{D}_n;\hat\theta) \qquad(4)\]

  • Counterfactual reasoning (in this context) boils down to simple questions: what if \(x\) (factual) \(\Rightarrow\) \(x\prime\) (counterfactual)?
  • By strategically perturbing features and checking the model output, we can (begin to) understand how the model makes its decisions.
  • Counterfactual Explanations always have full fidelity by construction (as opposed to surrogate explanations, for example).

. . .

Important to realize that we are keeping \(\hat\theta\) constant!

Enter: Counterfactual Explanations 🔮

A Framework for Counterfactual Explanations

Even though […] interpretability is of great importance and should be pursued, explanations can, in principle, be offered without opening the “black box”. (Wachter, Mittelstadt, and Russell 2017)

Framework

. . .

Objective originally proposed by Wachter, Mittelstadt, and Russell (2017) is as follows

\[ \min_{x\prime \in \mathcal{X}} h(x\prime) \ \ \ \mbox{s. t.} \ \ \ M(x\prime) = t \qquad(5)\]

where \(h\) relates to the complexity of the counterfactual and \(M\) denotes the classifier.

. . .

Typically this is approximated through regularization:

\[ x\prime = \arg \min_{x\prime} \ell(M(x\prime),t) + \lambda h(x\prime) \qquad(6)\]

Intuition

. . .

Figure 3: A cat performing gradient descent in the feature space à la Wachter, Mittelstadt, and Russell (2017).

Counterfactuals … as in Adversarial Examples?

Yes and no!

While both are methodologically very similar, adversarial examples are meant to go undetected while CEs ought to be meaningful.

Effective counterfactuals should meet certain criteria ✅

  • closeness: the average distance between factual and counterfactual features should be small (Wachter, Mittelstadt, and Russell (2017))
  • actionability: the proposed feature perturbation should actually be actionable (Ustun, Spangher, and Liu (2019), Poyiadzi et al. (2020))
  • plausibility: the counterfactual explanation should be plausible to a human (Joshi et al. (2019))
  • unambiguity: a human should have no trouble assigning a label to the counterfactual (Schut et al. (2021))
  • sparsity: the counterfactual explanation should involve as few individual feature changes as possible (Schut et al. (2021))
  • robustness: the counterfactual explanation should be robust to domain and model shifts (Upadhyay, Joshi, and Lakkaraju (2021))
  • diversity: ideally multiple diverse counterfactual explanations should be provided (Mothilal, Sharma, and Tan (2020))
  • causality: counterfactual explanations reflect the structural causal model underlying the data generating process (Karimi et al. (2020), Karimi, Schölkopf, and Valera (2021))

Counterfactuals … as in Causal Inference?

NO!

Causal inference: counterfactuals are thought of as unobserved states of the world that we would like to observe in order to establish causality.

  • The only way to do this is by actually interfering with the state of the world: \(p(y|do(x),\theta)\).
  • In practice we can only move some individuals to the counterfactual state of the world and compare their outcomes to a control group.
  • Provided we have controlled for confounders, properly randomized, … we can estimate an average treatment effect: \(\hat\theta\).

Counterfactual Explanations: involves perturbing features after some model has been trained.

  • We end up comparing modeled outcomes \(p(y|x,\hat\phi)\) and \(p(y|x\prime,\hat\phi)\) for individuals.
  • We have not magically solved causality.

Probabilistic Methods for Counterfactual Explanations

When people say that counterfactuals should look realistic or plausible, they really mean that counterfactuals should be generated by the same Data Generating Process (DGP) as the factuals:

\[ x\prime \sim p(x) \]

But how do we estimate \(p(x)\)? Two probabilistic approaches …

Schut et al. (2021) note that by maximizing predictive probabilities \(\sigma(M(x\prime))\) for probabilistic models \(M\in\mathcal{\widetilde{M}}\) one implicitly minimizes epistemic and aleotoric uncertainty.

\[ x\prime = \arg \min_{x\prime} \ell(M(x\prime),t) \ \ \ , \ \ \ M\in\mathcal{\widetilde{M}} \qquad(7)\]

Figure 4: A cat performing gradient descent in the feature space à la Schut et al. (2021)

Instead of perturbing samples directly, some have proposed to instead traverse a lower-dimensional latent embedding learned through a generative model (Joshi et al. 2019).

\[ z\prime = \arg \min_{z\prime} \ell(M(dec(z\prime)),t) + \lambda h(x\prime) \qquad(8)\]

and

\[x\prime = dec(z\prime)\]

where \(dec(\cdot)\) is the decoder function.

Figure 5: Counterfactual (yellow) generated through latent space search (right panel) following Joshi et al. (2019). The corresponding counterfactual path in the feature space is shown in the left panel.

Counterfactual Explanations in Julia (and beyond!)

Limited Software Availability

Work currently scattered across different GitHub repositories …

  • Only one unifying Python library: CARLA (Pawelczyk et al. 2021).
    • Comprehensive and (somewhat) extensible.
    • But not language-agnostic and some desirable functionality not supported.
    • Also not composable: each generator is treated as different class/entity.
  • Both R and Julia lacking any kind of implementation.

Enter: CounterfactualExplanations.jl 📦

Stable Dev Build Status codecov codecov 89% 89%

… until now!

  • A unifying framework for generating Counterfactual Explanations.
  • Built in Julia, but essentially language agnostic:
    • Currently supporting explanations for differentiable models built in Julia (e.g. Flux) and torch (R and Python).
  • Designed to be easily extensible through dispatch.
  • Designed to be composable allowing users and developers to combine different counterfactual generators.

Photo by Denise Jans on Unsplash.

Julia has an edge with respect to Trustworthy AI: it’s open-source, uniquely transparent and interoperable 🔴🟢🟣

Package Architecture

Modular, composable, scalable!

Overview

Figure 6: Overview of package architecture. Modules are shown in red, structs in green and functions in blue.

Generators

using CounterfactualExplanations, Plots, GraphRecipes
plt = plot(AbstractGenerator, method=:tree, fontsize=10, nodeshape=:rect, size=(1000,700))
savefig(plt, joinpath(www_path,"generators.png"))

Figure 7: Type tree for AbstractGenerator.

Models

plt = plot(AbstractFittedModel, method=:tree, fontsize=10, nodeshape=:rect, size=(1000,700))
savefig(plt, joinpath(www_path,"models.png"))

Figure 8: Type tree for AbstractFittedModel.

Basic Usage

A simple example

  1. Load and prepare some toy data.
  2. Select a random sample.
  3. Generate counterfactuals using different approaches.
# Data:
using Random
Random.seed!(123)
N = 100
using CounterfactualExplanations
xs, ys = toy_data_linear(N)
X = hcat(xs...)
counterfactual_data = CounterfactualData(X,ys')

# Randomly selected factual:
x = select_factual(counterfactual_data,rand(1:size(X)[2]))
Figure 9: Synthetic data.

Generic Generator

Code

. . .

We begin by instantiating the fitted model …

# Model
w = [1.0 1.0] # estimated coefficients
b = 0 # estimated bias
M = LogisticModel(w, [b])

. . .

… then based on its prediction for \(x\) we choose the opposite label as our target …

# Select target class:
y = round(probs(M, x)[1])
target = ifelse(y==1.0,0.0,1.0) # opposite label as target

. . .

… and finally generate the counterfactual.

# Counterfactual search:
generator = GenericGenerator()
counterfactual = generate_counterfactual(
  x, target, counterfactual_data, M, generator
)

Output

. . .

… et voilà!

Figure 10: Counterfactual path (left) and predicted probability (right) for GenericGenerator. The contour (left) shows the predicted probabilities of the classifier (Logistic Regression).

Greedy Generator

Code

. . .

This time we use a Bayesian classifier …

using LinearAlgebra
Σ = Symmetric(reshape(randn(9),3,3).*0.01 + UniformScaling(1)) # MAP covariance matrix
μ = hcat(b, w)
M = BayesianLogisticModel(μ, Σ)

. . .

… and once again choose our target label as before …

# Select target class:
y = round(probs(M, x)[1])
target = ifelse(y==1.0,0.0,1.0) # opposite label as target

. . .

… to then finally use greedy search to find a counterfactual.

# Counterfactual search:
params = GreedyGeneratorParams(
  δ = 0.5,
  n = 10
)
generator = GreedyGenerator(;params=params)
counterfactual = generate_counterfactual(
  x, target, counterfactual_data, M, generator
)

Output

. . .

In this case the Bayesian approach yields a similar outcome.

Figure 11: Counterfactual path (left) and predicted probability (right) for GreedyGenerator. The contour (left) shows the predicted probabilities of the classifier (Bayesian Logistic Regression).

REVISE Generator

Code

Using the same classifier as before we can either use the specific REVISEGenerator

# Counterfactual search:
generator = REVISEGenerator()
counterfactual = generate_counterfactual(
  x, target, counterfactual_data, M, generator
)

. . .

… or realize that that REVISE (Joshi et al. 2019) just boils down to generic search in a latent space:

# Counterfactual search:
generator = GenericGenerator()
counterfactual = generate_counterfactual(
  x, target, counterfactual_data, M, generator,
  latent_space=true
)

Output

. . .

We have essentially combined latent search with a probabilistic classifier (as in Antorán et al. (2020)).

Figure 12: Counterfactual path (left) and predicted probability (right) for REVISEGenerator.

Customization

Custom Models - Deep Ensemble

Loading the pre-trained deep ensemble …

ensemble = mnist_ensemble() # deep ensemble

Step 1: add composite type as subtype of AbstractFittedModel.

struct FittedEnsemble <: Models.AbstractFittedModel
    ensemble::AbstractArray
end

Step 2: dispatch logits and probs methods for new model type.

using Statistics
import CounterfactualExplanations.Models: logits, probs
logits(M::FittedEnsemble, X::AbstractArray) = mean(Flux.stack([nn(X) for nn in M.ensemble],3), dims=3)
probs(M::FittedEnsemble, X::AbstractArray) = mean(Flux.stack([softmax(nn(X)) for nn in M.ensemble],3),dims=3)
M = FittedEnsemble(ensemble)

Results for a simple deep ensemble also look convincing!

Figure 16: Turning a nine (9) into a four (4) using generic (Wachter) and greedy search for MLP and deep ensemble.

Custom Models - Interoperability

Adding support for torch models was easy! Here’s how I implemented it for torch classifiers trained in R.

Source code

. . .

Step 1: add composite type as subtype of AbstractFittedModel

Implemented here.

Step 2: dispatch logits and probs methods for new model type.

Implemented here.

. . .

Step 3: add gradient access.

Implemented here.

Unchanged API

. . .

M = RTorchModel(model)
# Select target class:
y = round(probs(M, x)[1])
target = ifelse(y==1.0,0.0,1.0) # opposite label as target
# Define generator:
generator = GenericGenerator()
# Generate recourse:
counterfactual = generate_counterfactual(
  x, target, counterfactual_data, M, generator
)
Figure 17: Counterfactual path (left) and predicted probability (right) for GenericGenerator and RTorchModel.

Custom Generators

Idea 💡: let’s implement a generic generator with dropout!

Dispatch

. . .

Step 1: create a subtype of AbstractGradientBasedGenerator (adhering to some basic rules).

# Constructor:
abstract type AbstractDropoutGenerator <: AbstractGradientBasedGenerator end
struct DropoutGenerator <: AbstractDropoutGenerator
    loss::Symbol # loss function
    complexity::Function # complexity function
    mutability::Union{Nothing,Vector{Symbol}} # mutibility constraints 
    λ::AbstractFloat # strength of penalty
    ϵ::AbstractFloat # step size
    τ::AbstractFloat # tolerance for convergence
    p_dropout::AbstractFloat # dropout rate
end

. . .

Step 2: implement logic for generating perturbations.

import CounterfactualExplanations.Generators: generate_perturbations, ∇
using StatsBase
function generate_perturbations(generator::AbstractDropoutGenerator, counterfactual_state::State)
    𝐠ₜ = (generator, counterfactual_state.M, counterfactual_state) # gradient
    # Dropout:
    set_to_zero = sample(1:length(𝐠ₜ),Int(round(generator.p_dropout*length(𝐠ₜ))),replace=false)
    𝐠ₜ[set_to_zero] .= 0
    Δx′ = - (generator.ϵ .* 𝐠ₜ) # gradient step
    return Δx′
end

Unchanged API

. . .

# Instantiate:
using LinearAlgebra
generator = DropoutGenerator(
    :logitbinarycrossentropy,
    norm,
    nothing,
    0.1,
    0.1,
    1e-5,
    0.5
)
counterfactual = generate_counterfactual(
  x, target, counterfactual_data, M, generator
)
Figure 18: Counterfactual path (left) and predicted probability (right) for custom DropoutGenerator and RTorchModel.

Goals and Ambitions 🎯

JuliaCon 2022 and beyond

To JuliaCon …

Develop package, register and submit to JuliaCon 2022.

Native support for deep learning models (Flux, torch).

Add latent space search.

… and beyond

. . .

. . .

  • Add support for more models:
    • MLJ, GLM, …
    • Non-differentiable

. . .

  • Enhance preprocessing functionality.

. . .

  • Extend functionality to regression problems.

. . .

  • Use Flux optimizers.

. . .

Photo by Ivan Diaz on Unsplash

More Resources 📚

Read on …

… or get involved! 🤗

References

Antorán, Javier, Umang Bhatt, Tameem Adel, Adrian Weller, and José Miguel Hernández-Lobato. 2020. “Getting a Clue: A Method for Explaining Uncertainty Estimates.” https://arxiv.org/abs/2006.06848.
Goodfellow, Ian J, Jonathon Shlens, and Christian Szegedy. 2014. “Explaining and Harnessing Adversarial Examples.” https://arxiv.org/abs/1412.6572.
Joshi, Shalmali, Oluwasanmi Koyejo, Warut Vijitbenjaronk, Been Kim, and Joydeep Ghosh. 2019. “Towards Realistic Individual Recourse and Actionable Explanations in Black-Box Decision Making Systems.” https://arxiv.org/abs/1907.09615.
Karimi, Amir-Hossein, Bernhard Schölkopf, and Isabel Valera. 2021. “Algorithmic Recourse: From Counterfactual Explanations to Interventions.” In Proceedings of the 2021 ACM Conference on Fairness, Accountability, and Transparency, 353–62.
Karimi, Amir-Hossein, Julius Von Kügelgen, Bernhard Schölkopf, and Isabel Valera. 2020. “Algorithmic Recourse Under Imperfect Causal Knowledge: A Probabilistic Approach.” https://arxiv.org/abs/2006.06831.
Kaur, Harmanpreet, Harsha Nori, Samuel Jenkins, Rich Caruana, Hanna Wallach, and Jennifer Wortman Vaughan. 2020. “Interpreting Interpretability: Understanding Data Scientists’ Use of Interpretability Tools for Machine Learning.” In Proceedings of the 2020 CHI Conference on Human Factors in Computing Systems, 1–14. https://doi.org/10.1145/3313831.3376219.
Mothilal, Ramaravind K, Amit Sharma, and Chenhao Tan. 2020. “Explaining Machine Learning Classifiers Through Diverse Counterfactual Explanations.” In Proceedings of the 2020 Conference on Fairness, Accountability, and Transparency, 607–17. https://doi.org/10.1145/3351095.3372850.
Murphy, Kevin P. 2022. Probabilistic Machine Learning: An Introduction. MIT Press.
Pawelczyk, Martin, Sascha Bielawski, Johannes van den Heuvel, Tobias Richter, and Gjergji Kasneci. 2021. “CARLA: A Python Library to Benchmark Algorithmic Recourse and Counterfactual Explanation Algorithms.” https://arxiv.org/abs/2108.00783.
Poyiadzi, Rafael, Kacper Sokol, Raul Santos-Rodriguez, Tijl De Bie, and Peter Flach. 2020. FACE: Feasible and Actionable Counterfactual Explanations.” In Proceedings of the AAAI/ACM Conference on AI, Ethics, and Society, 344–50.
Rudin, Cynthia. 2019. “Stop Explaining Black Box Machine Learning Models for High Stakes Decisions and Use Interpretable Models Instead.” Nature Machine Intelligence 1 (5): 206–15. https://doi.org/10.1038/s42256-019-0048-x.
Schut, Lisa, Oscar Key, Rory Mc Grath, Luca Costabello, Bogdan Sacaleanu, Yarin Gal, et al. 2021. “Generating Interpretable Counterfactual Explanations By Implicit Minimisation of Epistemic and Aleatoric Uncertainties.” In International Conference on Artificial Intelligence and Statistics, 1756–64. PMLR.
Slack, Dylan, Sophie Hilgard, Emily Jia, Sameer Singh, and Himabindu Lakkaraju. 2020. “Fooling Lime and Shap: Adversarial Attacks on Post Hoc Explanation Methods.” In Proceedings of the AAAI/ACM Conference on AI, Ethics, and Society, 180–86.
Upadhyay, Sohini, Shalmali Joshi, and Himabindu Lakkaraju. 2021. “Towards Robust and Reliable Algorithmic Recourse.” Advances in Neural Information Processing Systems 34: 16926–37.
Ustun, Berk, Alexander Spangher, and Yang Liu. 2019. “Actionable Recourse in Linear Classification.” In Proceedings of the Conference on Fairness, Accountability, and Transparency, 10–19. https://doi.org/10.1145/3287560.3287566.
Wachter, Sandra, Brent Mittelstadt, and Chris Russell. 2017. “Counterfactual Explanations Without Opening the Black Box: Automated Decisions and the GDPR.” Harv. JL & Tech. 31: 841. https://doi.org/10.2139/ssrn.3063289.
Wilson, Andrew Gordon. 2020. “The Case for Bayesian Deep Learning.” https://arxiv.org/abs/2001.10995.