3  Synthetic Data

This notebook was used to run the experiments for the synthetic datasets and can be used to reproduce the results in the paper. In the following we first run the experiments and then generate visualizations and tables.

3.1 Experiment

models = [
    :LogisticRegression, 
    :FluxModel, 
    :FluxEnsemble,
]
opt = Flux.Descent(0.01) 
generators = Dict(
    :Greedy=>GreedyGenerator(), 
    :Generic=>GenericGenerator(opt = opt),
    :REVISE=>REVISEGenerator(opt = opt),
    :DICE=>DiCEGenerator(opt = opt),
)
max_obs = 1000
catalogue = load_synthetic(max_obs)
choices = [
    :linearly_separable, 
    :overlapping, 
    :circles, 
    :moons,
]
data_sets = filter(p -> p[1] in choices, catalogue)
experiments = set_up_experiments(data_sets,models,generators)
plts = []
for (exp_name, exp_) in experiments
    for (M_name, M) in exp_.models
        score = round(model_evaluation(M, exp_.test_data),digits=2)
        plt = plot(M, exp_.test_data, title="$exp_name;\n $M_name ($score)")
        # Errors:
        ids = findall(vec(round.(probs(M, exp_.test_data.X)) .!= exp_.test_data.y))
        x_wrongly_labelled = exp_.test_data.X[:,ids]
        scatter!(plt, x_wrongly_labelled[1,:], x_wrongly_labelled[2,:], ms=7.5, color=:red, label="")
        plts = vcat(plts..., plt)
    end
end
plt = plot(plts..., layout=(length(choices),length(models)),size=(length(choices)*300,length(models)*300))
savefig(plt, joinpath(www_path,"models_test_before.png"))
using AlgorithmicRecourseDynamics.Models: model_evaluation
plts = []
for (exp_name, exp_) in experiments
    for (M_name, M) in exp_.models
        score = round(model_evaluation(M, exp_.train_data),digits=2)
        plt = plot(M, exp_.train_data, title="$exp_name;\n $M_name ($score)")
        # Errors:
        ids = findall(vec(round.(probs(M, exp_.train_data.X)) .!= exp_.train_data.y))
        x_wrongly_labelled = exp_.train_data.X[:,ids]
        scatter!(plt, x_wrongly_labelled[1,:], x_wrongly_labelled[2,:], ms=7.5, color=:red, label="")
        plts = vcat(plts..., plt)
    end
end
plt = plot(plts..., layout=(length(choices),length(models)),size=(length(choices)*300,length(models)*300))
savefig(plt, joinpath(www_path,"models_train_before.png"))
n_evals = 5
n_rounds = 50
evaluate_every = Int(round(n_rounds/n_evals))
n_folds = 5
T = 100
results = run_experiments(
    experiments;
    save_path=output_path,evaluate_every=evaluate_every,n_rounds=n_rounds, n_folds=n_folds, T=T
)
Serialization.serialize(joinpath(output_path,"results.jls"),results)
plot_dict = Dict(key => Dict() for (key,val) in results)
fold = 1
for (name, res) in results
    exp_ = res.experiment
    plot_dict[name] = Dict(key => [] for (key,val) in exp_.generators)
    rec_sys = exp_.recourse_systems[fold]
    sys_ids = collect(exp_.system_identifiers)
    M = length(rec_sys)
    for m in 1:M
        model_name, generator_name = sys_ids[m]
        M = rec_sys[m].model
        score = round(model_evaluation(M, exp_.test_data),digits=2)
        plt = plot(M, exp_.test_data, title="$name;\n $model_name ($score)")
        # Errors:
        ids = findall(vec(round.(probs(M, exp_.test_data.X)) .!= exp_.test_data.y))
        x_wrongly_labelled = exp_.test_data.X[:,ids]
        scatter!(plt, x_wrongly_labelled[1,:], x_wrongly_labelled[2,:], ms=7.5, color=:red, label="")
        plot_dict[name][generator_name] = vcat(plot_dict[name][generator_name], plt)
    end
end
plot_dict = Dict(key => reduce(vcat, [plots[key] for plots in values(plot_dict)]) for (key, value) in generators)
for (name, plts) in plot_dict
    plt = plot(plts..., layout=(length(choices),length(models)),size=(length(choices)*300,length(models)*300))
    savefig(plt, joinpath(www_path,"models_test_after_$(name).png"))
end
using AlgorithmicRecourseDynamics.Models: model_evaluation
plot_dict = Dict(key => Dict() for (key,val) in results)
fold = 1
for (name, res) in results
    exp_ = res.experiment
    plot_dict[name] = Dict(key => [] for (key,val) in exp_.generators)
    rec_sys = exp_.recourse_systems[fold]
    sys_ids = collect(exp_.system_identifiers)
    M = length(rec_sys)
    for m in 1:M
        model_name, generator_name = sys_ids[m]
        M = rec_sys[m].model
        data = rec_sys[m].data
        score = round(model_evaluation(M, data),digits=2)
        plt = plot(M, data, title="$name;\n $model_name ($score)")
        # Errors:
        ids = findall(vec(round.(probs(M, data.X)) .!= data.y))
        x_wrongly_labelled = data.X[:,ids]
        scatter!(plt, x_wrongly_labelled[1,:], x_wrongly_labelled[2,:], ms=7.5, color=:red, label="")
        plot_dict[name][generator_name] = vcat(plot_dict[name][generator_name], plt)
    end
end
plot_dict = Dict(key => reduce(vcat, [plots[key] for plots in values(plot_dict)]) for (key, value) in generators)
for (name, plts) in plot_dict
    plt = plot(plts..., layout=(length(choices),length(models)),size=(length(choices)*300,length(models)*300))
    savefig(plt, joinpath(www_path,"models_train_after_$(name).png"))
end

3.2 Plots

results = Serialization.deserialize(joinpath(output_path,"results.jls"));
using Images
line_charts = Dict()
errorbar_charts = Dict()
for (data_name, res) in results
    plt = plot(res)
    Images.save(joinpath(www_path, "line_chart_$(data_name).png"), plt)
    line_charts[data_name] = plt
    plt = plot(res,maximum(res.output.n))
    Images.save(joinpath(www_path, "errorbar_chart_$(data_name).png"), plt)
    errorbar_charts[data_name] = plt
end

3.2.1 Line Charts

Figure 3.1 shows the evolution of the evaluation metrics over the course of the experiment.

(a) Circles

(b) Linearly Separable

(c) Moons

(d) Overlapping

Figure 3.1: Line Charts

3.2.2 Error Bar Charts

Figure 3.2 shows the evaluation metrics at the end of the experiments.

(a) Circles

(b) Linearly Separable

(c) Moons

(d) Overlapping

Figure 3.2: Error Bar Charts

3.3 Bootstrap

n_bootstrap = 100
df = run_bootstrap(results, n_bootstrap; filename=joinpath(output_path,"bootstrap.csv"))
┌──────────┬─────────┬────────────────────┬────────────────────┬───────────┬──────────────┐
│     name │   scope │               data │              model │ generator │ p_value_mean │
│ String31 │ String7 │           String31 │           String31 │   String7 │      Float64 │
├──────────┼─────────┼────────────────────┼────────────────────┼───────────┼──────────────┤
│      mmd │  domain │        overlapping │ LogisticRegression │    REVISE │          0.0 │
│      mmd │  domain │        overlapping │ LogisticRegression │    Greedy │          0.0 │
│      mmd │  domain │        overlapping │ LogisticRegression │      DICE │          0.0 │
│      mmd │  domain │        overlapping │ LogisticRegression │   Generic │          0.0 │
│      mmd │  domain │        overlapping │          FluxModel │    REVISE │          0.0 │
│      mmd │  domain │        overlapping │          FluxModel │    Greedy │          0.0 │
│      mmd │  domain │        overlapping │          FluxModel │      DICE │          0.0 │
│      mmd │  domain │        overlapping │          FluxModel │   Generic │          0.0 │
│      mmd │  domain │        overlapping │       FluxEnsemble │    REVISE │          0.0 │
│      mmd │  domain │        overlapping │       FluxEnsemble │    Greedy │          0.0 │
│      mmd │  domain │        overlapping │       FluxEnsemble │      DICE │          0.0 │
│      mmd │  domain │        overlapping │       FluxEnsemble │   Generic │          0.0 │
│      mmd │  domain │ linearly_separable │ LogisticRegression │    REVISE │        0.768 │
│      mmd │  domain │ linearly_separable │ LogisticRegression │    Greedy │          0.0 │
│      mmd │  domain │ linearly_separable │ LogisticRegression │      DICE │          0.0 │
│      mmd │  domain │ linearly_separable │ LogisticRegression │   Generic │          0.0 │
│      mmd │  domain │ linearly_separable │          FluxModel │    REVISE │         0.69 │
│      mmd │  domain │ linearly_separable │          FluxModel │    Greedy │          0.0 │
│      mmd │  domain │ linearly_separable │          FluxModel │      DICE │          0.0 │
│      mmd │  domain │ linearly_separable │          FluxModel │   Generic │          0.0 │
│      mmd │  domain │ linearly_separable │       FluxEnsemble │    REVISE │        0.748 │
│      mmd │  domain │ linearly_separable │       FluxEnsemble │    Greedy │          0.0 │
│      mmd │  domain │ linearly_separable │       FluxEnsemble │      DICE │          0.0 │
│      mmd │  domain │ linearly_separable │       FluxEnsemble │   Generic │          0.0 │
│      mmd │  domain │            circles │ LogisticRegression │    REVISE │       0.9925 │
│      mmd │  domain │            circles │ LogisticRegression │    Greedy │          1.0 │
│      mmd │  domain │            circles │ LogisticRegression │      DICE │          1.0 │
│      mmd │  domain │            circles │ LogisticRegression │   Generic │        0.996 │
│      mmd │  domain │            circles │          FluxModel │    REVISE │          1.0 │
│      mmd │  domain │            circles │          FluxModel │    Greedy │        0.994 │
│      mmd │  domain │            circles │          FluxModel │      DICE │         0.99 │
│      mmd │  domain │            circles │          FluxModel │   Generic │         0.99 │
│      mmd │  domain │            circles │       FluxEnsemble │    REVISE │       0.9975 │
│      mmd │  domain │            circles │       FluxEnsemble │    Greedy │        0.992 │
│      mmd │  domain │            circles │       FluxEnsemble │      DICE │        0.988 │
│      mmd │  domain │            circles │       FluxEnsemble │   Generic │        0.996 │
│      mmd │  domain │              moons │ LogisticRegression │    REVISE │          0.0 │
│      mmd │  domain │              moons │ LogisticRegression │    Greedy │          0.0 │
│      mmd │  domain │              moons │ LogisticRegression │      DICE │          0.0 │
│      mmd │  domain │              moons │ LogisticRegression │   Generic │          0.0 │
│      mmd │  domain │              moons │          FluxModel │    REVISE │          0.0 │
│      mmd │  domain │              moons │          FluxModel │    Greedy │          0.0 │
│      mmd │  domain │              moons │          FluxModel │      DICE │          0.0 │
│      mmd │  domain │              moons │          FluxModel │   Generic │          0.0 │
│      mmd │  domain │              moons │       FluxEnsemble │    REVISE │          0.0 │
│      mmd │  domain │              moons │       FluxEnsemble │    Greedy │          0.0 │
│      mmd │  domain │              moons │       FluxEnsemble │      DICE │          0.0 │
│      mmd │  domain │              moons │       FluxEnsemble │   Generic │          0.0 │
│      mmd │   model │        overlapping │ LogisticRegression │    REVISE │        0.012 │
│      mmd │   model │        overlapping │ LogisticRegression │    Greedy │          0.0 │
│      mmd │   model │        overlapping │ LogisticRegression │      DICE │          0.0 │
│      mmd │   model │        overlapping │ LogisticRegression │   Generic │          0.0 │
│      mmd │   model │        overlapping │          FluxModel │    REVISE │        0.034 │
│      mmd │   model │        overlapping │          FluxModel │    Greedy │        0.004 │
│      mmd │   model │        overlapping │          FluxModel │      DICE │        0.002 │
│      mmd │   model │        overlapping │          FluxModel │   Generic │        0.002 │
│      mmd │   model │        overlapping │       FluxEnsemble │    REVISE │        0.034 │
│      mmd │   model │        overlapping │       FluxEnsemble │    Greedy │        0.002 │
│      mmd │   model │        overlapping │       FluxEnsemble │      DICE │          0.0 │
│      mmd │   model │        overlapping │       FluxEnsemble │   Generic │        0.004 │
│      mmd │   model │ linearly_separable │ LogisticRegression │    REVISE │         0.46 │
│      mmd │   model │ linearly_separable │ LogisticRegression │    Greedy │          0.0 │
│      mmd │   model │ linearly_separable │ LogisticRegression │      DICE │          0.0 │
│      mmd │   model │ linearly_separable │ LogisticRegression │   Generic │          0.0 │
│      mmd │   model │ linearly_separable │          FluxModel │    REVISE │        0.852 │
│      mmd │   model │ linearly_separable │          FluxModel │    Greedy │        0.684 │
│      mmd │   model │ linearly_separable │          FluxModel │      DICE │        0.964 │
│      mmd │   model │ linearly_separable │          FluxModel │   Generic │        0.944 │
│      mmd │   model │ linearly_separable │       FluxEnsemble │    REVISE │        0.856 │
│      mmd │   model │ linearly_separable │       FluxEnsemble │    Greedy │        0.716 │
│      mmd │   model │ linearly_separable │       FluxEnsemble │      DICE │       0.9525 │
│      mmd │   model │ linearly_separable │       FluxEnsemble │   Generic │        0.958 │
│      mmd │   model │            circles │ LogisticRegression │    REVISE │          0.0 │
│      mmd │   model │            circles │ LogisticRegression │    Greedy │          0.0 │
│      mmd │   model │            circles │ LogisticRegression │      DICE │        0.796 │
│      mmd │   model │            circles │ LogisticRegression │   Generic │        0.996 │
│      mmd │   model │            circles │          FluxModel │    REVISE │        0.994 │
│      mmd │   model │            circles │          FluxModel │    Greedy │        0.996 │
│      mmd │   model │            circles │          FluxModel │      DICE │       0.9975 │
│      mmd │   model │            circles │          FluxModel │   Generic │        0.992 │
│      mmd │   model │            circles │       FluxEnsemble │    REVISE │       0.9975 │
│      mmd │   model │            circles │       FluxEnsemble │    Greedy │          1.0 │
│      mmd │   model │            circles │       FluxEnsemble │      DICE │        0.996 │
│      mmd │   model │            circles │       FluxEnsemble │   Generic │          1.0 │
│      mmd │   model │              moons │ LogisticRegression │    REVISE │        0.004 │
│      mmd │   model │              moons │ LogisticRegression │    Greedy │          0.0 │
│      mmd │   model │              moons │ LogisticRegression │      DICE │          0.0 │
│      mmd │   model │              moons │ LogisticRegression │   Generic │          0.0 │
│      mmd │   model │              moons │          FluxModel │    REVISE │         0.91 │
│      mmd │   model │              moons │          FluxModel │    Greedy │        0.346 │
│      mmd │   model │              moons │          FluxModel │      DICE │         0.87 │
│      mmd │   model │              moons │          FluxModel │   Generic │         0.84 │
│      mmd │   model │              moons │       FluxEnsemble │    REVISE │        0.902 │
│      mmd │   model │              moons │       FluxEnsemble │    Greedy │        0.388 │
│      mmd │   model │              moons │       FluxEnsemble │      DICE │        0.865 │
│      mmd │   model │              moons │       FluxEnsemble │   Generic │        0.678 │
│ mmd_grid │   model │        overlapping │ LogisticRegression │    REVISE │         0.02 │
│ mmd_grid │   model │        overlapping │ LogisticRegression │    Greedy │          0.0 │
│ mmd_grid │   model │        overlapping │ LogisticRegression │      DICE │          0.0 │
│ mmd_grid │   model │        overlapping │ LogisticRegression │   Generic │          0.0 │
│ mmd_grid │   model │        overlapping │          FluxModel │    REVISE │        0.342 │
│ mmd_grid │   model │        overlapping │          FluxModel │    Greedy │        0.002 │
│ mmd_grid │   model │        overlapping │          FluxModel │      DICE │          0.0 │
│ mmd_grid │   model │        overlapping │          FluxModel │   Generic │          0.0 │
│ mmd_grid │   model │        overlapping │       FluxEnsemble │    REVISE │        0.208 │
│ mmd_grid │   model │        overlapping │       FluxEnsemble │    Greedy │          0.0 │
│ mmd_grid │   model │        overlapping │       FluxEnsemble │      DICE │        0.002 │
│ mmd_grid │   model │        overlapping │       FluxEnsemble │   Generic │          0.0 │
│ mmd_grid │   model │ linearly_separable │ LogisticRegression │    REVISE │          0.0 │
│ mmd_grid │   model │ linearly_separable │ LogisticRegression │    Greedy │          0.0 │
│ mmd_grid │   model │ linearly_separable │ LogisticRegression │      DICE │          0.0 │
│ mmd_grid │   model │ linearly_separable │ LogisticRegression │   Generic │          0.0 │
│ mmd_grid │   model │ linearly_separable │          FluxModel │    REVISE │          0.0 │
│ mmd_grid │   model │ linearly_separable │          FluxModel │    Greedy │          0.0 │
│ mmd_grid │   model │ linearly_separable │          FluxModel │      DICE │          0.0 │
│ mmd_grid │   model │ linearly_separable │          FluxModel │   Generic │          0.0 │
│ mmd_grid │   model │ linearly_separable │       FluxEnsemble │    REVISE │          0.0 │
│ mmd_grid │   model │ linearly_separable │       FluxEnsemble │    Greedy │          0.0 │
│ mmd_grid │   model │ linearly_separable │       FluxEnsemble │      DICE │          0.0 │
│ mmd_grid │   model │ linearly_separable │       FluxEnsemble │   Generic │          0.0 │
│ mmd_grid │   model │            circles │ LogisticRegression │    REVISE │          0.0 │
│ mmd_grid │   model │            circles │ LogisticRegression │    Greedy │          0.0 │
│ mmd_grid │   model │            circles │ LogisticRegression │      DICE │        0.814 │
│ mmd_grid │   model │            circles │ LogisticRegression │   Generic │        0.994 │
│ mmd_grid │   model │            circles │          FluxModel │    REVISE │        0.996 │
│ mmd_grid │   model │            circles │          FluxModel │    Greedy │        0.776 │
│ mmd_grid │   model │            circles │          FluxModel │      DICE │       0.7375 │
│ mmd_grid │   model │            circles │          FluxModel │   Generic │        0.688 │
│ mmd_grid │   model │            circles │       FluxEnsemble │    REVISE │          1.0 │
│ mmd_grid │   model │            circles │       FluxEnsemble │    Greedy │        0.568 │
│ mmd_grid │   model │            circles │       FluxEnsemble │      DICE │        0.762 │
│ mmd_grid │   model │            circles │       FluxEnsemble │   Generic │         0.89 │
│ mmd_grid │   model │              moons │ LogisticRegression │    REVISE │        0.004 │
│ mmd_grid │   model │              moons │ LogisticRegression │    Greedy │          0.0 │
│ mmd_grid │   model │              moons │ LogisticRegression │      DICE │          0.0 │
│ mmd_grid │   model │              moons │ LogisticRegression │   Generic │          0.0 │
│ mmd_grid │   model │              moons │          FluxModel │    REVISE │        0.174 │
│ mmd_grid │   model │              moons │          FluxModel │    Greedy │          0.0 │
│ mmd_grid │   model │              moons │          FluxModel │      DICE │         0.01 │
│ mmd_grid │   model │              moons │          FluxModel │   Generic │         0.02 │
│ mmd_grid │   model │              moons │       FluxEnsemble │    REVISE │        0.114 │
│ mmd_grid │   model │              moons │       FluxEnsemble │    Greedy │        0.006 │
│ mmd_grid │   model │              moons │       FluxEnsemble │      DICE │       0.1225 │
│ mmd_grid │   model │              moons │       FluxEnsemble │   Generic │        0.016 │
└──────────┴─────────┴────────────────────┴────────────────────┴───────────┴──────────────┘

3.4 Chart in paper

Figure 3.3 shows the chart that went into the paper.

Images.load(joinpath(www_artifact_path,"paper_synthetic_results.png"))

Figure 3.3: Chart in paper

# echo: false

generate_artifacts(output_path)
generate_artifacts(www_path)