= [
models :LogisticRegression,
:FluxModel,
:FluxEnsemble
]= Flux.Descent(0.01)
opt = Dict(
generators :Greedy=>GreedyGenerator(),
:Generic=>GenericGenerator(opt = opt),
:REVISE=>REVISEGenerator(opt = opt),
:DICE=>DiCEGenerator(opt = opt),
)
4 Real-World Data
= 5000
max_obs = load_real_world(max_obs)
data_sets = [
choices :cal_housing,
:credit_default,
:gmsc,
]= filter(p -> p[1] in choices, data_sets) data_sets
using CounterfactualExplanations.DataPreprocessing: unpack
= 500
bs function data_loader(data::CounterfactualData)
= unpack(data)
X, y = Flux.DataLoader((X,y),batchsize=bs)
data return data
end
= (batch_norm=false,n_hidden=64,n_layers=3,dropout=true,p_dropout=0.1) model_params
= set_up_experiments(
experiments
data_sets,models,generators; =100, model_params=model_params,
pre_train_models=data_loader
data_loader )
4.1 Experiment
= 5
n_evals = 50
n_rounds = Int(round(n_rounds/n_evals))
evaluate_every = 5
n_folds = 10000
n_samples = 100
T = (epochs=250, latent_dim=8)
generative_model_params = run_experiments(
results
experiments;=output_path,evaluate_every=evaluate_every,n_rounds=n_rounds, n_folds=n_folds, T=T, n_samples=n_samples,
save_path=generative_model_params
generative_model_params
)Serialization.serialize(joinpath(output_path,"results.jls"),results)
4.1.1 Plots
= Serialization.deserialize(joinpath(output_path,"results.jls")) results
using Images
= Dict()
line_charts = Dict()
errorbar_charts for (data_name, res) in results
= plot(res)
plt save(joinpath(www_path, "line_chart_$(data_name).png"), plt)
Images.= plt
line_charts[data_name] = plot(res,maximum(res.output.n))
plt save(joinpath(www_path, "errorbar_chart_$(data_name).png"), plt)
Images.= plt
errorbar_charts[data_name] end
4.1.2 Line Charts
Figure 4.1 shows the evolution of the evaluation metrics over the course of the experiment.
= readdir(www_artifact_path)[contains.(readdir(www_artifact_path),"line_chart")]
img_files = joinpath.(www_artifact_path,img_files)
img_files for img in img_files
display(load(img))
end
4.1.3 Error Bar Charts
Figure 4.2 shows the evaluation metrics at the end of the experiments.
= readdir(www_artifact_path)[contains.(readdir(www_artifact_path),"errorbar_chart")]
img_files = joinpath.(www_artifact_path,img_files)
img_files for img in img_files
display(load(img))
end
4.2 Bootstrap
= 100
n_bootstrap = run_bootstrap(results, n_bootstrap; filename=joinpath(output_path,"bootstrap.csv")) df
┌──────────┬─────────┬────────────────┬────────────────────┬───────────┬──────────────┐
│ name │ scope │ data │ model │ generator │ p_value_mean │
│ String31 │ String7 │ String15 │ String31 │ String7 │ Float64 │
├──────────┼─────────┼────────────────┼────────────────────┼───────────┼──────────────┤
│ mmd │ domain │ credit_default │ LogisticRegression │ Generic │ 1.0 │
│ mmd │ domain │ credit_default │ LogisticRegression │ REVISE │ 1.0 │
│ mmd │ domain │ credit_default │ LogisticRegression │ Greedy │ 1.0 │
│ mmd │ domain │ credit_default │ LogisticRegression │ DICE │ 1.0 │
│ mmd │ domain │ credit_default │ FluxModel │ Generic │ 1.0 │
│ mmd │ domain │ credit_default │ FluxModel │ REVISE │ 0.0 │
│ mmd │ domain │ credit_default │ FluxModel │ Greedy │ 1.0 │
│ mmd │ domain │ credit_default │ FluxModel │ DICE │ 1.0 │
│ mmd │ domain │ credit_default │ FluxEnsemble │ Generic │ 1.0 │
│ mmd │ domain │ credit_default │ FluxEnsemble │ REVISE │ 0.0 │
│ mmd │ domain │ credit_default │ FluxEnsemble │ Greedy │ 1.0 │
│ mmd │ domain │ credit_default │ FluxEnsemble │ DICE │ 1.0 │
│ mmd │ domain │ cal_housing │ LogisticRegression │ Generic │ 0.0 │
│ mmd │ domain │ cal_housing │ LogisticRegression │ REVISE │ 0.0 │
│ mmd │ domain │ cal_housing │ LogisticRegression │ Greedy │ 0.0 │
│ mmd │ domain │ cal_housing │ LogisticRegression │ DICE │ 0.0 │
│ mmd │ domain │ cal_housing │ FluxModel │ Generic │ 0.0 │
│ mmd │ domain │ cal_housing │ FluxModel │ REVISE │ 0.0 │
│ mmd │ domain │ cal_housing │ FluxModel │ Greedy │ 0.0 │
│ mmd │ domain │ cal_housing │ FluxModel │ DICE │ 0.0 │
│ mmd │ domain │ cal_housing │ FluxEnsemble │ Generic │ 0.0 │
│ mmd │ domain │ cal_housing │ FluxEnsemble │ REVISE │ 0.0 │
│ mmd │ domain │ cal_housing │ FluxEnsemble │ Greedy │ 0.0 │
│ mmd │ domain │ cal_housing │ FluxEnsemble │ DICE │ 0.0 │
│ mmd │ domain │ gmsc │ LogisticRegression │ Generic │ 0.278 │
│ mmd │ domain │ gmsc │ LogisticRegression │ REVISE │ 0.0 │
│ mmd │ domain │ gmsc │ LogisticRegression │ Greedy │ 0.006 │
│ mmd │ domain │ gmsc │ LogisticRegression │ DICE │ 0.51 │
│ mmd │ domain │ gmsc │ FluxModel │ Generic │ 0.128 │
│ mmd │ domain │ gmsc │ FluxModel │ REVISE │ 0.0 │
│ mmd │ domain │ gmsc │ FluxModel │ Greedy │ 0.0 │
│ mmd │ domain │ gmsc │ FluxModel │ DICE │ 0.338 │
│ mmd │ domain │ gmsc │ FluxEnsemble │ Generic │ 0.306 │
│ mmd │ domain │ gmsc │ FluxEnsemble │ REVISE │ 0.0 │
│ mmd │ domain │ gmsc │ FluxEnsemble │ Greedy │ 0.032 │
│ mmd │ domain │ gmsc │ FluxEnsemble │ DICE │ 0.082 │
│ mmd │ model │ credit_default │ LogisticRegression │ Generic │ 0.0 │
│ mmd │ model │ credit_default │ LogisticRegression │ REVISE │ 0.436 │
│ mmd │ model │ credit_default │ LogisticRegression │ Greedy │ 0.044 │
│ mmd │ model │ credit_default │ LogisticRegression │ DICE │ 0.0 │
│ mmd │ model │ credit_default │ FluxModel │ Generic │ 0.0 │
│ mmd │ model │ credit_default │ FluxModel │ REVISE │ 0.0 │
│ mmd │ model │ credit_default │ FluxModel │ Greedy │ 0.0 │
│ mmd │ model │ credit_default │ FluxModel │ DICE │ 0.0 │
│ mmd │ model │ credit_default │ FluxEnsemble │ Generic │ 0.0 │
│ mmd │ model │ credit_default │ FluxEnsemble │ REVISE │ 0.0 │
│ mmd │ model │ credit_default │ FluxEnsemble │ Greedy │ 0.0 │
│ mmd │ model │ credit_default │ FluxEnsemble │ DICE │ 0.0 │
│ mmd │ model │ cal_housing │ LogisticRegression │ Generic │ 0.0 │
│ mmd │ model │ cal_housing │ LogisticRegression │ REVISE │ 0.0 │
│ mmd │ model │ cal_housing │ LogisticRegression │ Greedy │ 0.0 │
│ mmd │ model │ cal_housing │ LogisticRegression │ DICE │ 0.0 │
│ mmd │ model │ cal_housing │ FluxModel │ Generic │ 0.0 │
│ mmd │ model │ cal_housing │ FluxModel │ REVISE │ 0.0 │
│ mmd │ model │ cal_housing │ FluxModel │ Greedy │ 0.0 │
│ mmd │ model │ cal_housing │ FluxModel │ DICE │ 0.0 │
│ mmd │ model │ cal_housing │ FluxEnsemble │ Generic │ 0.0 │
│ mmd │ model │ cal_housing │ FluxEnsemble │ REVISE │ 0.0 │
│ mmd │ model │ cal_housing │ FluxEnsemble │ Greedy │ 0.0 │
│ mmd │ model │ cal_housing │ FluxEnsemble │ DICE │ 0.0 │
│ mmd │ model │ gmsc │ LogisticRegression │ Generic │ 0.0 │
│ mmd │ model │ gmsc │ LogisticRegression │ REVISE │ 0.0 │
│ mmd │ model │ gmsc │ LogisticRegression │ Greedy │ 0.0 │
│ mmd │ model │ gmsc │ LogisticRegression │ DICE │ 0.0 │
│ mmd │ model │ gmsc │ FluxModel │ Generic │ 0.0 │
│ mmd │ model │ gmsc │ FluxModel │ REVISE │ 0.0 │
│ mmd │ model │ gmsc │ FluxModel │ Greedy │ 0.0 │
│ mmd │ model │ gmsc │ FluxModel │ DICE │ 0.0 │
│ mmd │ model │ gmsc │ FluxEnsemble │ Generic │ 0.018 │
│ mmd │ model │ gmsc │ FluxEnsemble │ REVISE │ 0.008 │
│ mmd │ model │ gmsc │ FluxEnsemble │ Greedy │ 0.02 │
│ mmd │ model │ gmsc │ FluxEnsemble │ DICE │ 0.032 │
│ mmd_grid │ model │ credit_default │ LogisticRegression │ Generic │ 0.0 │
│ mmd_grid │ model │ credit_default │ LogisticRegression │ REVISE │ 0.044 │
│ mmd_grid │ model │ credit_default │ LogisticRegression │ Greedy │ 0.0 │
│ mmd_grid │ model │ credit_default │ LogisticRegression │ DICE │ 0.0 │
│ mmd_grid │ model │ credit_default │ FluxModel │ Generic │ 0.0 │
│ mmd_grid │ model │ credit_default │ FluxModel │ REVISE │ 0.0 │
│ mmd_grid │ model │ credit_default │ FluxModel │ Greedy │ 0.0 │
│ mmd_grid │ model │ credit_default │ FluxModel │ DICE │ 0.0 │
│ mmd_grid │ model │ credit_default │ FluxEnsemble │ Generic │ 0.0 │
│ mmd_grid │ model │ credit_default │ FluxEnsemble │ REVISE │ 0.0 │
│ mmd_grid │ model │ credit_default │ FluxEnsemble │ Greedy │ 0.164 │
│ mmd_grid │ model │ credit_default │ FluxEnsemble │ DICE │ 0.0 │
│ mmd_grid │ model │ cal_housing │ LogisticRegression │ Generic │ 0.0 │
│ mmd_grid │ model │ cal_housing │ LogisticRegression │ REVISE │ 0.01 │
│ mmd_grid │ model │ cal_housing │ LogisticRegression │ Greedy │ 0.0 │
│ mmd_grid │ model │ cal_housing │ LogisticRegression │ DICE │ 0.0 │
│ mmd_grid │ model │ cal_housing │ FluxModel │ Generic │ 0.004 │
│ mmd_grid │ model │ cal_housing │ FluxModel │ REVISE │ 0.026 │
│ mmd_grid │ model │ cal_housing │ FluxModel │ Greedy │ 0.0 │
│ mmd_grid │ model │ cal_housing │ FluxModel │ DICE │ 0.0 │
│ mmd_grid │ model │ cal_housing │ FluxEnsemble │ Generic │ 0.0 │
│ mmd_grid │ model │ cal_housing │ FluxEnsemble │ REVISE │ 0.006 │
│ mmd_grid │ model │ cal_housing │ FluxEnsemble │ Greedy │ 0.0 │
│ mmd_grid │ model │ cal_housing │ FluxEnsemble │ DICE │ 0.0 │
│ mmd_grid │ model │ gmsc │ LogisticRegression │ Generic │ 0.0 │
│ mmd_grid │ model │ gmsc │ LogisticRegression │ REVISE │ 0.0 │
│ mmd_grid │ model │ gmsc │ LogisticRegression │ Greedy │ 0.0 │
│ mmd_grid │ model │ gmsc │ LogisticRegression │ DICE │ 0.0 │
│ mmd_grid │ model │ gmsc │ FluxModel │ Generic │ 0.0 │
│ mmd_grid │ model │ gmsc │ FluxModel │ REVISE │ 0.03 │
│ mmd_grid │ model │ gmsc │ FluxModel │ Greedy │ 0.0 │
│ mmd_grid │ model │ gmsc │ FluxModel │ DICE │ 0.004 │
│ mmd_grid │ model │ gmsc │ FluxEnsemble │ Generic │ 0.002 │
│ mmd_grid │ model │ gmsc │ FluxEnsemble │ REVISE │ 0.0 │
│ mmd_grid │ model │ gmsc │ FluxEnsemble │ Greedy │ 0.0 │
│ mmd_grid │ model │ gmsc │ FluxEnsemble │ DICE │ 0.0 │
└──────────┴─────────┴────────────────┴────────────────────┴───────────┴──────────────┘
4.2.1 Chart in paper
Figure 4.3 shows the chart that went into the paper.
load(joinpath(www_artifact_path,"paper_real_world_results.png")) Images.