= [
models :LogisticRegression,
:FluxModel,
:FluxEnsemble,
]= Flux.Descent(0.01)
opt = Dict(
generators :Greedy=>GreedyGenerator(),
:Generic=>GenericGenerator(opt = opt),
:REVISE=>REVISEGenerator(opt = opt),
:DICE=>DiCEGenerator(opt = opt),
)
3 Synthetic Data
This notebook was used to run the experiments for the synthetic datasets and can be used to reproduce the results in the paper. In the following we first run the experiments and then generate visualizations and tables.
3.1 Experiment
= 1000
max_obs = load_synthetic(max_obs)
catalogue = [
choices :linearly_separable,
:overlapping,
:circles,
:moons,
]= filter(p -> p[1] in choices, catalogue) data_sets
= set_up_experiments(data_sets,models,generators) experiments
= []
plts for (exp_name, exp_) in experiments
for (M_name, M) in exp_.models
= round(model_evaluation(M, exp_.test_data),digits=2)
score = plot(M, exp_.test_data, title="$exp_name;\n $M_name ($score)")
plt # Errors:
= findall(vec(round.(probs(M, exp_.test_data.X)) .!= exp_.test_data.y))
ids = exp_.test_data.X[:,ids]
x_wrongly_labelled scatter!(plt, x_wrongly_labelled[1,:], x_wrongly_labelled[2,:], ms=7.5, color=:red, label="")
= vcat(plts..., plt)
plts end
end
= plot(plts..., layout=(length(choices),length(models)),size=(length(choices)*300,length(models)*300))
plt savefig(plt, joinpath(www_path,"models_test_before.png"))
using AlgorithmicRecourseDynamics.Models: model_evaluation
= []
plts for (exp_name, exp_) in experiments
for (M_name, M) in exp_.models
= round(model_evaluation(M, exp_.train_data),digits=2)
score = plot(M, exp_.train_data, title="$exp_name;\n $M_name ($score)")
plt # Errors:
= findall(vec(round.(probs(M, exp_.train_data.X)) .!= exp_.train_data.y))
ids = exp_.train_data.X[:,ids]
x_wrongly_labelled scatter!(plt, x_wrongly_labelled[1,:], x_wrongly_labelled[2,:], ms=7.5, color=:red, label="")
= vcat(plts..., plt)
plts end
end
= plot(plts..., layout=(length(choices),length(models)),size=(length(choices)*300,length(models)*300))
plt savefig(plt, joinpath(www_path,"models_train_before.png"))
= 5
n_evals = 50
n_rounds = Int(round(n_rounds/n_evals))
evaluate_every = 5
n_folds = 100
T = run_experiments(
results
experiments;=output_path,evaluate_every=evaluate_every,n_rounds=n_rounds, n_folds=n_folds, T=T
save_path
)Serialization.serialize(joinpath(output_path,"results.jls"),results)
= Dict(key => Dict() for (key,val) in results)
plot_dict = 1
fold for (name, res) in results
= res.experiment
exp_ = Dict(key => [] for (key,val) in exp_.generators)
plot_dict[name] = exp_.recourse_systems[fold]
rec_sys = collect(exp_.system_identifiers)
sys_ids = length(rec_sys)
M for m in 1:M
= sys_ids[m]
model_name, generator_name = rec_sys[m].model
M = round(model_evaluation(M, exp_.test_data),digits=2)
score = plot(M, exp_.test_data, title="$name;\n $model_name ($score)")
plt # Errors:
= findall(vec(round.(probs(M, exp_.test_data.X)) .!= exp_.test_data.y))
ids = exp_.test_data.X[:,ids]
x_wrongly_labelled scatter!(plt, x_wrongly_labelled[1,:], x_wrongly_labelled[2,:], ms=7.5, color=:red, label="")
= vcat(plot_dict[name][generator_name], plt)
plot_dict[name][generator_name] end
end
= Dict(key => reduce(vcat, [plots[key] for plots in values(plot_dict)]) for (key, value) in generators)
plot_dict for (name, plts) in plot_dict
= plot(plts..., layout=(length(choices),length(models)),size=(length(choices)*300,length(models)*300))
plt savefig(plt, joinpath(www_path,"models_test_after_$(name).png"))
end
using AlgorithmicRecourseDynamics.Models: model_evaluation
= Dict(key => Dict() for (key,val) in results)
plot_dict = 1
fold for (name, res) in results
= res.experiment
exp_ = Dict(key => [] for (key,val) in exp_.generators)
plot_dict[name] = exp_.recourse_systems[fold]
rec_sys = collect(exp_.system_identifiers)
sys_ids = length(rec_sys)
M for m in 1:M
= sys_ids[m]
model_name, generator_name = rec_sys[m].model
M = rec_sys[m].data
data = round(model_evaluation(M, data),digits=2)
score = plot(M, data, title="$name;\n $model_name ($score)")
plt # Errors:
= findall(vec(round.(probs(M, data.X)) .!= data.y))
ids = data.X[:,ids]
x_wrongly_labelled scatter!(plt, x_wrongly_labelled[1,:], x_wrongly_labelled[2,:], ms=7.5, color=:red, label="")
= vcat(plot_dict[name][generator_name], plt)
plot_dict[name][generator_name] end
end
= Dict(key => reduce(vcat, [plots[key] for plots in values(plot_dict)]) for (key, value) in generators)
plot_dict for (name, plts) in plot_dict
= plot(plts..., layout=(length(choices),length(models)),size=(length(choices)*300,length(models)*300))
plt savefig(plt, joinpath(www_path,"models_train_after_$(name).png"))
end
3.2 Plots
= Serialization.deserialize(joinpath(output_path,"results.jls")); results
using Images
= Dict()
line_charts = Dict()
errorbar_charts for (data_name, res) in results
= plot(res)
plt save(joinpath(www_path, "line_chart_$(data_name).png"), plt)
Images.= plt
line_charts[data_name] = plot(res,maximum(res.output.n))
plt save(joinpath(www_path, "errorbar_chart_$(data_name).png"), plt)
Images.= plt
errorbar_charts[data_name] end
3.2.1 Line Charts
Figure 3.1 shows the evolution of the evaluation metrics over the course of the experiment.
3.2.2 Error Bar Charts
Figure 3.2 shows the evaluation metrics at the end of the experiments.
3.3 Bootstrap
= 100
n_bootstrap = run_bootstrap(results, n_bootstrap; filename=joinpath(output_path,"bootstrap.csv")) df
┌──────────┬─────────┬────────────────────┬────────────────────┬───────────┬──────────────┐
│ name │ scope │ data │ model │ generator │ p_value_mean │
│ String31 │ String7 │ String31 │ String31 │ String7 │ Float64 │
├──────────┼─────────┼────────────────────┼────────────────────┼───────────┼──────────────┤
│ mmd │ domain │ overlapping │ LogisticRegression │ REVISE │ 0.0 │
│ mmd │ domain │ overlapping │ LogisticRegression │ Greedy │ 0.0 │
│ mmd │ domain │ overlapping │ LogisticRegression │ DICE │ 0.0 │
│ mmd │ domain │ overlapping │ LogisticRegression │ Generic │ 0.0 │
│ mmd │ domain │ overlapping │ FluxModel │ REVISE │ 0.0 │
│ mmd │ domain │ overlapping │ FluxModel │ Greedy │ 0.0 │
│ mmd │ domain │ overlapping │ FluxModel │ DICE │ 0.0 │
│ mmd │ domain │ overlapping │ FluxModel │ Generic │ 0.0 │
│ mmd │ domain │ overlapping │ FluxEnsemble │ REVISE │ 0.0 │
│ mmd │ domain │ overlapping │ FluxEnsemble │ Greedy │ 0.0 │
│ mmd │ domain │ overlapping │ FluxEnsemble │ DICE │ 0.0 │
│ mmd │ domain │ overlapping │ FluxEnsemble │ Generic │ 0.0 │
│ mmd │ domain │ linearly_separable │ LogisticRegression │ REVISE │ 0.768 │
│ mmd │ domain │ linearly_separable │ LogisticRegression │ Greedy │ 0.0 │
│ mmd │ domain │ linearly_separable │ LogisticRegression │ DICE │ 0.0 │
│ mmd │ domain │ linearly_separable │ LogisticRegression │ Generic │ 0.0 │
│ mmd │ domain │ linearly_separable │ FluxModel │ REVISE │ 0.69 │
│ mmd │ domain │ linearly_separable │ FluxModel │ Greedy │ 0.0 │
│ mmd │ domain │ linearly_separable │ FluxModel │ DICE │ 0.0 │
│ mmd │ domain │ linearly_separable │ FluxModel │ Generic │ 0.0 │
│ mmd │ domain │ linearly_separable │ FluxEnsemble │ REVISE │ 0.748 │
│ mmd │ domain │ linearly_separable │ FluxEnsemble │ Greedy │ 0.0 │
│ mmd │ domain │ linearly_separable │ FluxEnsemble │ DICE │ 0.0 │
│ mmd │ domain │ linearly_separable │ FluxEnsemble │ Generic │ 0.0 │
│ mmd │ domain │ circles │ LogisticRegression │ REVISE │ 0.9925 │
│ mmd │ domain │ circles │ LogisticRegression │ Greedy │ 1.0 │
│ mmd │ domain │ circles │ LogisticRegression │ DICE │ 1.0 │
│ mmd │ domain │ circles │ LogisticRegression │ Generic │ 0.996 │
│ mmd │ domain │ circles │ FluxModel │ REVISE │ 1.0 │
│ mmd │ domain │ circles │ FluxModel │ Greedy │ 0.994 │
│ mmd │ domain │ circles │ FluxModel │ DICE │ 0.99 │
│ mmd │ domain │ circles │ FluxModel │ Generic │ 0.99 │
│ mmd │ domain │ circles │ FluxEnsemble │ REVISE │ 0.9975 │
│ mmd │ domain │ circles │ FluxEnsemble │ Greedy │ 0.992 │
│ mmd │ domain │ circles │ FluxEnsemble │ DICE │ 0.988 │
│ mmd │ domain │ circles │ FluxEnsemble │ Generic │ 0.996 │
│ mmd │ domain │ moons │ LogisticRegression │ REVISE │ 0.0 │
│ mmd │ domain │ moons │ LogisticRegression │ Greedy │ 0.0 │
│ mmd │ domain │ moons │ LogisticRegression │ DICE │ 0.0 │
│ mmd │ domain │ moons │ LogisticRegression │ Generic │ 0.0 │
│ mmd │ domain │ moons │ FluxModel │ REVISE │ 0.0 │
│ mmd │ domain │ moons │ FluxModel │ Greedy │ 0.0 │
│ mmd │ domain │ moons │ FluxModel │ DICE │ 0.0 │
│ mmd │ domain │ moons │ FluxModel │ Generic │ 0.0 │
│ mmd │ domain │ moons │ FluxEnsemble │ REVISE │ 0.0 │
│ mmd │ domain │ moons │ FluxEnsemble │ Greedy │ 0.0 │
│ mmd │ domain │ moons │ FluxEnsemble │ DICE │ 0.0 │
│ mmd │ domain │ moons │ FluxEnsemble │ Generic │ 0.0 │
│ mmd │ model │ overlapping │ LogisticRegression │ REVISE │ 0.012 │
│ mmd │ model │ overlapping │ LogisticRegression │ Greedy │ 0.0 │
│ mmd │ model │ overlapping │ LogisticRegression │ DICE │ 0.0 │
│ mmd │ model │ overlapping │ LogisticRegression │ Generic │ 0.0 │
│ mmd │ model │ overlapping │ FluxModel │ REVISE │ 0.034 │
│ mmd │ model │ overlapping │ FluxModel │ Greedy │ 0.004 │
│ mmd │ model │ overlapping │ FluxModel │ DICE │ 0.002 │
│ mmd │ model │ overlapping │ FluxModel │ Generic │ 0.002 │
│ mmd │ model │ overlapping │ FluxEnsemble │ REVISE │ 0.034 │
│ mmd │ model │ overlapping │ FluxEnsemble │ Greedy │ 0.002 │
│ mmd │ model │ overlapping │ FluxEnsemble │ DICE │ 0.0 │
│ mmd │ model │ overlapping │ FluxEnsemble │ Generic │ 0.004 │
│ mmd │ model │ linearly_separable │ LogisticRegression │ REVISE │ 0.46 │
│ mmd │ model │ linearly_separable │ LogisticRegression │ Greedy │ 0.0 │
│ mmd │ model │ linearly_separable │ LogisticRegression │ DICE │ 0.0 │
│ mmd │ model │ linearly_separable │ LogisticRegression │ Generic │ 0.0 │
│ mmd │ model │ linearly_separable │ FluxModel │ REVISE │ 0.852 │
│ mmd │ model │ linearly_separable │ FluxModel │ Greedy │ 0.684 │
│ mmd │ model │ linearly_separable │ FluxModel │ DICE │ 0.964 │
│ mmd │ model │ linearly_separable │ FluxModel │ Generic │ 0.944 │
│ mmd │ model │ linearly_separable │ FluxEnsemble │ REVISE │ 0.856 │
│ mmd │ model │ linearly_separable │ FluxEnsemble │ Greedy │ 0.716 │
│ mmd │ model │ linearly_separable │ FluxEnsemble │ DICE │ 0.9525 │
│ mmd │ model │ linearly_separable │ FluxEnsemble │ Generic │ 0.958 │
│ mmd │ model │ circles │ LogisticRegression │ REVISE │ 0.0 │
│ mmd │ model │ circles │ LogisticRegression │ Greedy │ 0.0 │
│ mmd │ model │ circles │ LogisticRegression │ DICE │ 0.796 │
│ mmd │ model │ circles │ LogisticRegression │ Generic │ 0.996 │
│ mmd │ model │ circles │ FluxModel │ REVISE │ 0.994 │
│ mmd │ model │ circles │ FluxModel │ Greedy │ 0.996 │
│ mmd │ model │ circles │ FluxModel │ DICE │ 0.9975 │
│ mmd │ model │ circles │ FluxModel │ Generic │ 0.992 │
│ mmd │ model │ circles │ FluxEnsemble │ REVISE │ 0.9975 │
│ mmd │ model │ circles │ FluxEnsemble │ Greedy │ 1.0 │
│ mmd │ model │ circles │ FluxEnsemble │ DICE │ 0.996 │
│ mmd │ model │ circles │ FluxEnsemble │ Generic │ 1.0 │
│ mmd │ model │ moons │ LogisticRegression │ REVISE │ 0.004 │
│ mmd │ model │ moons │ LogisticRegression │ Greedy │ 0.0 │
│ mmd │ model │ moons │ LogisticRegression │ DICE │ 0.0 │
│ mmd │ model │ moons │ LogisticRegression │ Generic │ 0.0 │
│ mmd │ model │ moons │ FluxModel │ REVISE │ 0.91 │
│ mmd │ model │ moons │ FluxModel │ Greedy │ 0.346 │
│ mmd │ model │ moons │ FluxModel │ DICE │ 0.87 │
│ mmd │ model │ moons │ FluxModel │ Generic │ 0.84 │
│ mmd │ model │ moons │ FluxEnsemble │ REVISE │ 0.902 │
│ mmd │ model │ moons │ FluxEnsemble │ Greedy │ 0.388 │
│ mmd │ model │ moons │ FluxEnsemble │ DICE │ 0.865 │
│ mmd │ model │ moons │ FluxEnsemble │ Generic │ 0.678 │
│ mmd_grid │ model │ overlapping │ LogisticRegression │ REVISE │ 0.02 │
│ mmd_grid │ model │ overlapping │ LogisticRegression │ Greedy │ 0.0 │
│ mmd_grid │ model │ overlapping │ LogisticRegression │ DICE │ 0.0 │
│ mmd_grid │ model │ overlapping │ LogisticRegression │ Generic │ 0.0 │
│ mmd_grid │ model │ overlapping │ FluxModel │ REVISE │ 0.342 │
│ mmd_grid │ model │ overlapping │ FluxModel │ Greedy │ 0.002 │
│ mmd_grid │ model │ overlapping │ FluxModel │ DICE │ 0.0 │
│ mmd_grid │ model │ overlapping │ FluxModel │ Generic │ 0.0 │
│ mmd_grid │ model │ overlapping │ FluxEnsemble │ REVISE │ 0.208 │
│ mmd_grid │ model │ overlapping │ FluxEnsemble │ Greedy │ 0.0 │
│ mmd_grid │ model │ overlapping │ FluxEnsemble │ DICE │ 0.002 │
│ mmd_grid │ model │ overlapping │ FluxEnsemble │ Generic │ 0.0 │
│ mmd_grid │ model │ linearly_separable │ LogisticRegression │ REVISE │ 0.0 │
│ mmd_grid │ model │ linearly_separable │ LogisticRegression │ Greedy │ 0.0 │
│ mmd_grid │ model │ linearly_separable │ LogisticRegression │ DICE │ 0.0 │
│ mmd_grid │ model │ linearly_separable │ LogisticRegression │ Generic │ 0.0 │
│ mmd_grid │ model │ linearly_separable │ FluxModel │ REVISE │ 0.0 │
│ mmd_grid │ model │ linearly_separable │ FluxModel │ Greedy │ 0.0 │
│ mmd_grid │ model │ linearly_separable │ FluxModel │ DICE │ 0.0 │
│ mmd_grid │ model │ linearly_separable │ FluxModel │ Generic │ 0.0 │
│ mmd_grid │ model │ linearly_separable │ FluxEnsemble │ REVISE │ 0.0 │
│ mmd_grid │ model │ linearly_separable │ FluxEnsemble │ Greedy │ 0.0 │
│ mmd_grid │ model │ linearly_separable │ FluxEnsemble │ DICE │ 0.0 │
│ mmd_grid │ model │ linearly_separable │ FluxEnsemble │ Generic │ 0.0 │
│ mmd_grid │ model │ circles │ LogisticRegression │ REVISE │ 0.0 │
│ mmd_grid │ model │ circles │ LogisticRegression │ Greedy │ 0.0 │
│ mmd_grid │ model │ circles │ LogisticRegression │ DICE │ 0.814 │
│ mmd_grid │ model │ circles │ LogisticRegression │ Generic │ 0.994 │
│ mmd_grid │ model │ circles │ FluxModel │ REVISE │ 0.996 │
│ mmd_grid │ model │ circles │ FluxModel │ Greedy │ 0.776 │
│ mmd_grid │ model │ circles │ FluxModel │ DICE │ 0.7375 │
│ mmd_grid │ model │ circles │ FluxModel │ Generic │ 0.688 │
│ mmd_grid │ model │ circles │ FluxEnsemble │ REVISE │ 1.0 │
│ mmd_grid │ model │ circles │ FluxEnsemble │ Greedy │ 0.568 │
│ mmd_grid │ model │ circles │ FluxEnsemble │ DICE │ 0.762 │
│ mmd_grid │ model │ circles │ FluxEnsemble │ Generic │ 0.89 │
│ mmd_grid │ model │ moons │ LogisticRegression │ REVISE │ 0.004 │
│ mmd_grid │ model │ moons │ LogisticRegression │ Greedy │ 0.0 │
│ mmd_grid │ model │ moons │ LogisticRegression │ DICE │ 0.0 │
│ mmd_grid │ model │ moons │ LogisticRegression │ Generic │ 0.0 │
│ mmd_grid │ model │ moons │ FluxModel │ REVISE │ 0.174 │
│ mmd_grid │ model │ moons │ FluxModel │ Greedy │ 0.0 │
│ mmd_grid │ model │ moons │ FluxModel │ DICE │ 0.01 │
│ mmd_grid │ model │ moons │ FluxModel │ Generic │ 0.02 │
│ mmd_grid │ model │ moons │ FluxEnsemble │ REVISE │ 0.114 │
│ mmd_grid │ model │ moons │ FluxEnsemble │ Greedy │ 0.006 │
│ mmd_grid │ model │ moons │ FluxEnsemble │ DICE │ 0.1225 │
│ mmd_grid │ model │ moons │ FluxEnsemble │ Generic │ 0.016 │
└──────────┴─────────┴────────────────────┴────────────────────┴───────────┴──────────────┘
3.4 Chart in paper
Figure 3.3 shows the chart that went into the paper.
load(joinpath(www_artifact_path,"paper_synthetic_results.png")) Images.
# echo: false
generate_artifacts(output_path)
generate_artifacts(www_path)