4  Real-World Data

models = [
    :LogisticRegression, 
    :FluxModel, 
    :FluxEnsemble
]
opt = Flux.Descent(0.01) 
generators = Dict(
    :Greedy=>GreedyGenerator(), 
    :Generic=>GenericGenerator(opt = opt),
    :REVISE=>REVISEGenerator(opt = opt),
    :DICE=>DiCEGenerator(opt = opt),
)
max_obs = 5000
data_sets = load_real_world(max_obs)
choices = [
    :cal_housing, 
    :credit_default, 
    :gmsc, 
]
data_sets = filter(p -> p[1] in choices, data_sets)
using CounterfactualExplanations.DataPreprocessing: unpack
bs = 500
function data_loader(data::CounterfactualData)
    X, y = unpack(data)
    data = Flux.DataLoader((X,y),batchsize=bs)
    return data
end
model_params = (batch_norm=false,n_hidden=64,n_layers=3,dropout=true,p_dropout=0.1)
experiments = set_up_experiments(
    data_sets,models,generators; 
    pre_train_models=100, model_params=model_params, 
    data_loader=data_loader
)

4.1 Experiment

n_evals = 5
n_rounds = 50
evaluate_every = Int(round(n_rounds/n_evals))
n_folds = 5
n_samples = 10000
T = 100
generative_model_params = (epochs=250, latent_dim=8)
results = run_experiments(
    experiments;
    save_path=output_path,evaluate_every=evaluate_every,n_rounds=n_rounds, n_folds=n_folds, T=T, n_samples=n_samples,
    generative_model_params=generative_model_params
)
Serialization.serialize(joinpath(output_path,"results.jls"),results)

4.1.1 Plots

results = Serialization.deserialize(joinpath(output_path,"results.jls"))
using Images
line_charts = Dict()
errorbar_charts = Dict()
for (data_name, res) in results
    plt = plot(res)
    Images.save(joinpath(www_path, "line_chart_$(data_name).png"), plt)
    line_charts[data_name] = plt
    plt = plot(res,maximum(res.output.n))
    Images.save(joinpath(www_path, "errorbar_chart_$(data_name).png"), plt)
    errorbar_charts[data_name] = plt
end

4.1.2 Line Charts

Figure 4.1 shows the evolution of the evaluation metrics over the course of the experiment.

img_files = readdir(www_artifact_path)[contains.(readdir(www_artifact_path),"line_chart")]
img_files = joinpath.(www_artifact_path,img_files)
for img in img_files
    display(load(img))
end

(a) California Housing

(b) Credit Default

(c) GMSC

Figure 4.1: Line Charts

4.1.3 Error Bar Charts

Figure 4.2 shows the evaluation metrics at the end of the experiments.

img_files = readdir(www_artifact_path)[contains.(readdir(www_artifact_path),"errorbar_chart")]
img_files = joinpath.(www_artifact_path,img_files)
for img in img_files
    display(load(img))
end

(a) California Housing

(b) Credit Default

(c) GMSC

Figure 4.2: Error Bar Charts

4.2 Bootstrap

n_bootstrap = 100
df = run_bootstrap(results, n_bootstrap; filename=joinpath(output_path,"bootstrap.csv"))
┌──────────┬─────────┬────────────────┬────────────────────┬───────────┬──────────────┐
│     name │   scope │           data │              model │ generator │ p_value_mean │
│ String31 │ String7 │       String15 │           String31 │   String7 │      Float64 │
├──────────┼─────────┼────────────────┼────────────────────┼───────────┼──────────────┤
│      mmd │  domain │ credit_default │ LogisticRegression │   Generic │          1.0 │
│      mmd │  domain │ credit_default │ LogisticRegression │    REVISE │          1.0 │
│      mmd │  domain │ credit_default │ LogisticRegression │    Greedy │          1.0 │
│      mmd │  domain │ credit_default │ LogisticRegression │      DICE │          1.0 │
│      mmd │  domain │ credit_default │          FluxModel │   Generic │          1.0 │
│      mmd │  domain │ credit_default │          FluxModel │    REVISE │          0.0 │
│      mmd │  domain │ credit_default │          FluxModel │    Greedy │          1.0 │
│      mmd │  domain │ credit_default │          FluxModel │      DICE │          1.0 │
│      mmd │  domain │ credit_default │       FluxEnsemble │   Generic │          1.0 │
│      mmd │  domain │ credit_default │       FluxEnsemble │    REVISE │          0.0 │
│      mmd │  domain │ credit_default │       FluxEnsemble │    Greedy │          1.0 │
│      mmd │  domain │ credit_default │       FluxEnsemble │      DICE │          1.0 │
│      mmd │  domain │    cal_housing │ LogisticRegression │   Generic │          0.0 │
│      mmd │  domain │    cal_housing │ LogisticRegression │    REVISE │          0.0 │
│      mmd │  domain │    cal_housing │ LogisticRegression │    Greedy │          0.0 │
│      mmd │  domain │    cal_housing │ LogisticRegression │      DICE │          0.0 │
│      mmd │  domain │    cal_housing │          FluxModel │   Generic │          0.0 │
│      mmd │  domain │    cal_housing │          FluxModel │    REVISE │          0.0 │
│      mmd │  domain │    cal_housing │          FluxModel │    Greedy │          0.0 │
│      mmd │  domain │    cal_housing │          FluxModel │      DICE │          0.0 │
│      mmd │  domain │    cal_housing │       FluxEnsemble │   Generic │          0.0 │
│      mmd │  domain │    cal_housing │       FluxEnsemble │    REVISE │          0.0 │
│      mmd │  domain │    cal_housing │       FluxEnsemble │    Greedy │          0.0 │
│      mmd │  domain │    cal_housing │       FluxEnsemble │      DICE │          0.0 │
│      mmd │  domain │           gmsc │ LogisticRegression │   Generic │        0.278 │
│      mmd │  domain │           gmsc │ LogisticRegression │    REVISE │          0.0 │
│      mmd │  domain │           gmsc │ LogisticRegression │    Greedy │        0.006 │
│      mmd │  domain │           gmsc │ LogisticRegression │      DICE │         0.51 │
│      mmd │  domain │           gmsc │          FluxModel │   Generic │        0.128 │
│      mmd │  domain │           gmsc │          FluxModel │    REVISE │          0.0 │
│      mmd │  domain │           gmsc │          FluxModel │    Greedy │          0.0 │
│      mmd │  domain │           gmsc │          FluxModel │      DICE │        0.338 │
│      mmd │  domain │           gmsc │       FluxEnsemble │   Generic │        0.306 │
│      mmd │  domain │           gmsc │       FluxEnsemble │    REVISE │          0.0 │
│      mmd │  domain │           gmsc │       FluxEnsemble │    Greedy │        0.032 │
│      mmd │  domain │           gmsc │       FluxEnsemble │      DICE │        0.082 │
│      mmd │   model │ credit_default │ LogisticRegression │   Generic │          0.0 │
│      mmd │   model │ credit_default │ LogisticRegression │    REVISE │        0.436 │
│      mmd │   model │ credit_default │ LogisticRegression │    Greedy │        0.044 │
│      mmd │   model │ credit_default │ LogisticRegression │      DICE │          0.0 │
│      mmd │   model │ credit_default │          FluxModel │   Generic │          0.0 │
│      mmd │   model │ credit_default │          FluxModel │    REVISE │          0.0 │
│      mmd │   model │ credit_default │          FluxModel │    Greedy │          0.0 │
│      mmd │   model │ credit_default │          FluxModel │      DICE │          0.0 │
│      mmd │   model │ credit_default │       FluxEnsemble │   Generic │          0.0 │
│      mmd │   model │ credit_default │       FluxEnsemble │    REVISE │          0.0 │
│      mmd │   model │ credit_default │       FluxEnsemble │    Greedy │          0.0 │
│      mmd │   model │ credit_default │       FluxEnsemble │      DICE │          0.0 │
│      mmd │   model │    cal_housing │ LogisticRegression │   Generic │          0.0 │
│      mmd │   model │    cal_housing │ LogisticRegression │    REVISE │          0.0 │
│      mmd │   model │    cal_housing │ LogisticRegression │    Greedy │          0.0 │
│      mmd │   model │    cal_housing │ LogisticRegression │      DICE │          0.0 │
│      mmd │   model │    cal_housing │          FluxModel │   Generic │          0.0 │
│      mmd │   model │    cal_housing │          FluxModel │    REVISE │          0.0 │
│      mmd │   model │    cal_housing │          FluxModel │    Greedy │          0.0 │
│      mmd │   model │    cal_housing │          FluxModel │      DICE │          0.0 │
│      mmd │   model │    cal_housing │       FluxEnsemble │   Generic │          0.0 │
│      mmd │   model │    cal_housing │       FluxEnsemble │    REVISE │          0.0 │
│      mmd │   model │    cal_housing │       FluxEnsemble │    Greedy │          0.0 │
│      mmd │   model │    cal_housing │       FluxEnsemble │      DICE │          0.0 │
│      mmd │   model │           gmsc │ LogisticRegression │   Generic │          0.0 │
│      mmd │   model │           gmsc │ LogisticRegression │    REVISE │          0.0 │
│      mmd │   model │           gmsc │ LogisticRegression │    Greedy │          0.0 │
│      mmd │   model │           gmsc │ LogisticRegression │      DICE │          0.0 │
│      mmd │   model │           gmsc │          FluxModel │   Generic │          0.0 │
│      mmd │   model │           gmsc │          FluxModel │    REVISE │          0.0 │
│      mmd │   model │           gmsc │          FluxModel │    Greedy │          0.0 │
│      mmd │   model │           gmsc │          FluxModel │      DICE │          0.0 │
│      mmd │   model │           gmsc │       FluxEnsemble │   Generic │        0.018 │
│      mmd │   model │           gmsc │       FluxEnsemble │    REVISE │        0.008 │
│      mmd │   model │           gmsc │       FluxEnsemble │    Greedy │         0.02 │
│      mmd │   model │           gmsc │       FluxEnsemble │      DICE │        0.032 │
│ mmd_grid │   model │ credit_default │ LogisticRegression │   Generic │          0.0 │
│ mmd_grid │   model │ credit_default │ LogisticRegression │    REVISE │        0.044 │
│ mmd_grid │   model │ credit_default │ LogisticRegression │    Greedy │          0.0 │
│ mmd_grid │   model │ credit_default │ LogisticRegression │      DICE │          0.0 │
│ mmd_grid │   model │ credit_default │          FluxModel │   Generic │          0.0 │
│ mmd_grid │   model │ credit_default │          FluxModel │    REVISE │          0.0 │
│ mmd_grid │   model │ credit_default │          FluxModel │    Greedy │          0.0 │
│ mmd_grid │   model │ credit_default │          FluxModel │      DICE │          0.0 │
│ mmd_grid │   model │ credit_default │       FluxEnsemble │   Generic │          0.0 │
│ mmd_grid │   model │ credit_default │       FluxEnsemble │    REVISE │          0.0 │
│ mmd_grid │   model │ credit_default │       FluxEnsemble │    Greedy │        0.164 │
│ mmd_grid │   model │ credit_default │       FluxEnsemble │      DICE │          0.0 │
│ mmd_grid │   model │    cal_housing │ LogisticRegression │   Generic │          0.0 │
│ mmd_grid │   model │    cal_housing │ LogisticRegression │    REVISE │         0.01 │
│ mmd_grid │   model │    cal_housing │ LogisticRegression │    Greedy │          0.0 │
│ mmd_grid │   model │    cal_housing │ LogisticRegression │      DICE │          0.0 │
│ mmd_grid │   model │    cal_housing │          FluxModel │   Generic │        0.004 │
│ mmd_grid │   model │    cal_housing │          FluxModel │    REVISE │        0.026 │
│ mmd_grid │   model │    cal_housing │          FluxModel │    Greedy │          0.0 │
│ mmd_grid │   model │    cal_housing │          FluxModel │      DICE │          0.0 │
│ mmd_grid │   model │    cal_housing │       FluxEnsemble │   Generic │          0.0 │
│ mmd_grid │   model │    cal_housing │       FluxEnsemble │    REVISE │        0.006 │
│ mmd_grid │   model │    cal_housing │       FluxEnsemble │    Greedy │          0.0 │
│ mmd_grid │   model │    cal_housing │       FluxEnsemble │      DICE │          0.0 │
│ mmd_grid │   model │           gmsc │ LogisticRegression │   Generic │          0.0 │
│ mmd_grid │   model │           gmsc │ LogisticRegression │    REVISE │          0.0 │
│ mmd_grid │   model │           gmsc │ LogisticRegression │    Greedy │          0.0 │
│ mmd_grid │   model │           gmsc │ LogisticRegression │      DICE │          0.0 │
│ mmd_grid │   model │           gmsc │          FluxModel │   Generic │          0.0 │
│ mmd_grid │   model │           gmsc │          FluxModel │    REVISE │         0.03 │
│ mmd_grid │   model │           gmsc │          FluxModel │    Greedy │          0.0 │
│ mmd_grid │   model │           gmsc │          FluxModel │      DICE │        0.004 │
│ mmd_grid │   model │           gmsc │       FluxEnsemble │   Generic │        0.002 │
│ mmd_grid │   model │           gmsc │       FluxEnsemble │    REVISE │          0.0 │
│ mmd_grid │   model │           gmsc │       FluxEnsemble │    Greedy │          0.0 │
│ mmd_grid │   model │           gmsc │       FluxEnsemble │      DICE │          0.0 │
└──────────┴─────────┴────────────────┴────────────────────┴───────────┴──────────────┘

4.2.1 Chart in paper

Figure 4.3 shows the chart that went into the paper.

Images.load(joinpath(www_artifact_path,"paper_real_world_results.png"))

Figure 4.3: Chart in paper