Helper files

In this R markdown I use code from the following files: sim.R, plot_utils.R, utils.R.

1 The data

1.1 Running the simulation

For the full simulation code, see run_std().

rm(list = ls())
# setwd("~/research/climate-RL-mod/stan_models")
main_dir <- "~/research/climate-RL-mod/"
sim_dir <- paste0(main_dir, "R_simulation/")

sim <- new.env()
source(paste0(sim_dir, "sim.R"), local = sim) # access functions using sim$fun()

params <- list(
  n_part = 50,
  n_trials = 30,
  LR_group = 0.4,
  inv_temp_group = 0.5,
  initQ_group = list(F = 8, U = 2),
  mu_R_group = list(F = 5, U = 5),
  sigma_R_group = 2
)

sim_dat <- sim$run_std(params)

cat(paste0("PARAMETER SETTINGS:"), capture.output(dplyr::glimpse(params)), sep = "\n")
PARAMETER SETTINGS:
List of 7
 $ n_part        : num 50
 $ n_trials      : num 30
 $ LR_group      : num 0.4
 $ inv_temp_group: num 0.5
 $ initQ_group   :List of 2
  ..$ F: num 8
  ..$ U: num 2
 $ mu_R_group    :List of 2
  ..$ F: num 5
  ..$ U: num 5
 $ sigma_R_group : num 2
cat(paste0("SIMULATED DATA:"), capture.output(dplyr::glimpse(sim_dat)), sep = "\n")
SIMULATED DATA:
Rows: 1,500
Columns: 8
$ participant <int> 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1…
$ trial       <int> 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17,…
$ Q_F         <dbl> 3.308605, 3.576838, 3.961119, 3.967285, 3.967285, 3.967285…
$ Q_U         <dbl> 2.858249, 2.858249, 2.858249, 2.858249, 3.039316, 2.874494…
$ choice      <dbl> 1, 1, 1, 2, 2, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1…
$ R           <dbl> 5, 6, 4, 4, 2, 6, 7, 8, 6, 6, 6, 5, 6, 6, 7, 6, 7, 5, 7, 5…
$ LR          <dbl> 0.1585869, 0.1585869, 0.1585869, 0.1585869, 0.1585869, 0.1…
$ inv_temp    <dbl> 0.5832288, 0.5832288, 0.5832288, 0.5832288, 0.5832288, 0.5…

1.2 Inspecting the data

Just so I know what to expect of our model, I include plots of the simulated data: Q values and choices over time.

plot <- new.env()
source(paste0(main_dir, "plot_utils.R"), local = plot) # access functions using plot$fun()

plot$sim_plots(sim_dat, params)

2 Explaining the model

I show the Stan model code bit by bit to explain what is happening. The full code can be found in climate-RL.stan.

2.1 data block

The data we supply to the model consist of:

  1. the number of participants: n_part;
  2. the number of trials per participant: n_trials;
  3. an array of choices that a participant made: choice[n_part, n_trials];
  4. an array of ratings that a participant gave: R[n_part, n_trials]. These are analogous to the array outcomes[n_trials] in the 2arm_bandit_example (see Stan code for this example).
mod_code <- readLines("climate-RL.stan")
start <- grep("data \\{", mod_code)[1]
end <- grep("transformed data \\{", mod_code) - 2
cat(mod_code[start:end], sep = "\n")
data {
  int<lower=1> n_part;
  int<lower=1> n_trials;
  array[n_part, n_trials] int<lower=1, upper=2> choice;
  array[n_part, n_trials] int<lower=1, upper=10> R;
}

2.2 transformed data block

Currently unused.

start <- grep("transformed data \\{", mod_code)
end <- grep("parameters \\{", mod_code)[1] - 2
cat(mod_code[start:end], sep = "\n")

2.3 parameters block

In the current model, we have four free parameters:

  1. learning rate;
  2. inverse temperature;
  3. initial Q for friendly;
  4. initial Q for unfriendly.

[explanation here]

start <- grep("parameters \\{", mod_code)[1]
end <- grep("transformed parameters \\{", mod_code) - 2
cat(mod_code[start:end], sep = "\n")
parameters {
  // group-level parameters
  vector[4] means_probit;
  vector<lower=0>[4] sigmas;

  // participant-level parameters
  vector[n_part] LR_probit;
  vector[n_part] inv_temp_probit;
  vector[n_part] initQF_probit;
  vector[n_part] initQU_probit;
}

2.4 transformed parameters block

[explanation here]

start <- grep("transformed parameters \\{", mod_code)
end <- grep("model \\{", mod_code) - 2
cat(mod_code[start:end], sep = "\n")

2.5 model block

2.5.1 The priors

[explanation here]

start <- grep("model \\{", mod_code)
end <- grep("participant loop", mod_code)[1] - 2
cat(mod_code[start:end], sep = "\n")
model {
  // priors
  means_probit ~ normal(0, 1);
  sigmas ~ normal(0, 0.2);
  LR_probit ~ normal(0, 1);
  inv_temp_probit ~ normal(0, 1);
  initQF_probit ~ normal(0, 1); 
  initQU_probit ~ normal(0, 1); 

2.5.2 Participant loop

For every participant I initialize an array for holding the Q values. I also create a temporary vector Q_t that holds the two Q values (F, U) for the current trial, because otherwise the categorical_logit wouldn’t work, somehow. This works for now, and maybe in the future I will find a more elegant solution.

start <- grep("participant loop", mod_code) + 1
end <- grep("trial loop", mod_code) - 2
cat(mod_code[start:end], sep = "\n")
  for (j in 1:n_part) {
    
    // initialization
    array[n_trials, 2] real Q;
    Q[1, 1] = initQF[j];
    Q[1, 2] = initQU[j];
    vector[2] Q_t;
    real pred_err;

2.5.3 Trial loop

start <- grep("trial loop", mod_code) + 1
end <- grep("generated quantities", mod_code) - 2
cat(mod_code[start:end], sep = "\n")
    for (t in 1:n_trials) {
      Q_t = to_vector(Q[t]);

      // sample choice (1 is F, 2 is U) via softmax
      choice[j, t] ~ categorical_logit(inv_temp[j] * Q_t);

      // prediction error
      pred_err = R[j, t] - Q[t, choice[j, t]];

      // update value (learn)
      if (t < n_trials) {    // no updating in the very last trial
        if (choice[j, t] == 1) {
          Q[t+1, 1] = Q[t, 1] + LR[j] * pred_err;
          Q[t+1, 2] = Q[t, 2];
        } else {
          Q[t+1, 1] = Q[t, 1];
          Q[t+1, 2] = Q[t, 2] + LR[j] * pred_err;
        }
      }
    }
  }
}

2.6 generated quantities block

[explanation here]

start <- grep("generated quantities", mod_code)
end <- length(mod_code)
cat(mod_code[start:end], sep = "\n")
generated quantities {
  vector[4] means;
  means[1] = Phi_approx(means_probit[1]);
  means[2] = Phi_approx(means_probit[2]) * 5;
  means[3] = Phi_approx(means_probit[3]) * 9 + 1;
  means[4] = Phi_approx(means_probit[4]) * 9 + 1;
}

3 Fitting the model

3.1 Single run (demo)

3.1.1 Do the run

To save running time, I only refit the model if I changed the model or if the simulated data have changed (i.e., if I changed parameter settings). To this end, I compare the previously saved sim_data.json to the current sim_dat. After this comparison, I can save the new sim_data and refit the model if necessary.

dat_dir <- paste0(main_dir, "stan_models/single_run_dat/")
dat_file <- paste0(dat_dir, "sim_dat.json")
dat_changed <- sim$did_sim_dat_change(dat_file, sim_dat)
sim$save_sim_dat(params, sim_dat, dat_file)
model_changed <- FALSE

library(cmdstanr)
options(mc.cores = parallel::detectCores())
if (dat_changed | model_changed) {
  m <- cmdstan_model("climate-RL.stan")
  it <- 1000
  fit <- m$sample(
    data = dat_file,
    iter_sampling = it,
    chains = 1,
    thin = 1,
    iter_warmup = it / 2,
    refresh = it / 5,
    seed = 1234
  )
  fit$save_object(file = paste0(dat_dir, "climate-RL_single_fit.rds"))
} else {
  fit <- readRDS(file = paste0(dat_dir, "climate-RL_single_fit.rds"))
}
draws <- posterior::as_draws_df(fit$draws()) # df makes it easier to handle
draws <- draws %>%
  dplyr::rename(
    LR_group = `means[1]`,
    inv_temp_group = `means[2]`,
    `initQ_group$F` = `means[3]`,
    `initQ_group$U` = `means[4]`
  )

3.1.2 Group-level parameter recovery

to_inspect <- c("LR_group", "inv_temp_group", "initQ_group$F", "initQ_group$U")
plot$posterior_density(draws, to_inspect, params)

util <- new.env()
source(paste0(main_dir, "utils.R"), local = util) # access functions using util$fun()
util$print_posterior_table(draws, params, to_inspect)
Posteriors for free parameters
Parameter Simulated value Median [95% credibility interval]
LR_group 0.4 0.39 [0.28, 0.50]
inv_temp_group 0.5 0.51 [0.37, 0.70]
initQ_group$F 8.0 6.45 [4.79, 8.27]
initQ_group$U 2.0 2.64 [1.43, 3.97]

3.1.3 Participant-level parameter recovery

participant_params <- rjson::fromJSON(file = paste0(dat_dir, "sim_param_settings.json"))
plot$pp_level_param_fit(draws, c("LR", "inv_temp", "initQF", "initQU"), participant_params)

3.2 Many runs

To check parameter recovery across a range of parameter values, I run the model a bunch of times with different parameter settings. More specifically, each parameter value is chosen randomly from its whole theoretical range.

n_runs <- 10
dat_dir <- paste0(main_dir, "stan_models/many_runs_dat/")
m <- cmdstan_model("climate-RL.stan")
it <- 1000
free_params <- c("LR_group", "inv_temp_group", "initQ_group$F", "initQ_group$U")

# takes current list of params, randomizes the free ones and returns the list
randomize_free_params <- function(params) {
  for (p in free_params) {
    bounds <- sim$param_bounds[[p]]
    params[[p]] <- runif(1, min = bounds[1], max = bounds[2])
  }

  return(params)
}

if (FALSE) {
  for (k in 1:n_runs) {
    # simulate
    params <- randomize_free_params(params)
    sim_dat <- sim$run_std(params)
    dat_file <- paste0(dat_dir, "sim_dat_", sprintf("%02d", k), ".json")
    sim$save_sim_dat(params, sim_dat, dat_file)

    # fit
    fit <- m$sample(
      data = dat_file,
      iter_sampling = it,
      chains = 1,
      thin = 1,
      iter_warmup = it / 2,
      refresh = it / 5,
      seed = 1234
    )
    fit_file <- paste0(dat_dir, "climate-RL_fit_", sprintf("%02d", k), ".rds")
    fit$save_object(file = fit_file)
  }
}
# run time: around 40 seconds per sim, 6-7 min total
sim_params <- data.frame(k = 1:n_runs)
fit_params <- data.frame(k = 1:n_runs)

for (k in 1:n_runs) {
  sim_file <- paste0(dat_dir, "sim_param_settings_", sprintf("%02d", k), ".json")
  sim <- rjson::fromJSON(file = sim_file)

  fit_file <- paste0(dat_dir, "climate-RL_fit_", sprintf("%02d", k), ".rds")
  fit <- readRDS(fit_file)
  draws <- posterior::as_draws_df(fit$draws()) %>%
    dplyr::rename(
      LR_group = `means[1]`,
      inv_temp_group = `means[2]`,
      `initQ_group$F` = `means[3]`,
      `initQ_group$U` = `means[4]`
    )

  for (p in free_params) {
    if (grepl("\\$", p)) { # parameter is part of a list
      split_name <- strsplit(p, "\\$")[[1]]
      sim_value <- sim[[split_name[1]]][(split_name[2] == "U") + 1]
      sim_params[[p]][k] <- sim_value
    } else {
      sim_params[[p]][k] <- sim[[p]]
    }
    fit_params[[p]][k] <- median(draws[[p]])
  }
}

source(paste0(main_dir, "plot_utils.R"), local = plot)
plot$many_runs_param_fit(sim_params, fit_params, free_params)