In this R markdown I use code from the following files: sim.R, plot_utils.R, utils.R.
For the full simulation code, see run_std().
rm(list = ls())
# setwd("~/research/climate-RL-mod/stan_models")
main_dir <- "~/research/climate-RL-mod/"
sim_dir <- paste0(main_dir, "R_simulation/")
sim <- new.env()
source(paste0(sim_dir, "sim.R"), local = sim) # access functions using sim$fun()
params <- list(
n_part = 50,
n_trials = 30,
LR_group = 0.4,
inv_temp_group = 0.5,
initQ_group = list(F = 8, U = 2),
mu_R_group = list(F = 5, U = 5),
sigma_R_group = 2
)
sim_dat <- sim$run_std(params)
cat(paste0("PARAMETER SETTINGS:"), capture.output(dplyr::glimpse(params)), sep = "\n")PARAMETER SETTINGS:
List of 7
$ n_part : num 50
$ n_trials : num 30
$ LR_group : num 0.4
$ inv_temp_group: num 0.5
$ initQ_group :List of 2
..$ F: num 8
..$ U: num 2
$ mu_R_group :List of 2
..$ F: num 5
..$ U: num 5
$ sigma_R_group : num 2
SIMULATED DATA:
Rows: 1,500
Columns: 8
$ participant <int> 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1…
$ trial <int> 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17,…
$ Q_F <dbl> 3.308605, 3.576838, 3.961119, 3.967285, 3.967285, 3.967285…
$ Q_U <dbl> 2.858249, 2.858249, 2.858249, 2.858249, 3.039316, 2.874494…
$ choice <dbl> 1, 1, 1, 2, 2, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1…
$ R <dbl> 5, 6, 4, 4, 2, 6, 7, 8, 6, 6, 6, 5, 6, 6, 7, 6, 7, 5, 7, 5…
$ LR <dbl> 0.1585869, 0.1585869, 0.1585869, 0.1585869, 0.1585869, 0.1…
$ inv_temp <dbl> 0.5832288, 0.5832288, 0.5832288, 0.5832288, 0.5832288, 0.5…
I show the Stan model code bit by bit to explain what is happening. The full code can be found in climate-RL.stan.
data
blockThe data we supply to the model consist of:
n_part;n_trials;choice[n_part, n_trials];R[n_part, n_trials]. These are analogous to the array
outcomes[n_trials] in the 2arm_bandit_example
(see Stan
code for this example).mod_code <- readLines("climate-RL.stan")
start <- grep("data \\{", mod_code)[1]
end <- grep("transformed data \\{", mod_code) - 2
cat(mod_code[start:end], sep = "\n")data {
int<lower=1> n_part;
int<lower=1> n_trials;
array[n_part, n_trials] int<lower=1, upper=2> choice;
array[n_part, n_trials] int<lower=1, upper=10> R;
}
transformed data blockCurrently unused.
parameters blockIn the current model, we have four free parameters:
[explanation here]
start <- grep("parameters \\{", mod_code)[1]
end <- grep("transformed parameters \\{", mod_code) - 2
cat(mod_code[start:end], sep = "\n")parameters {
// group-level parameters
vector[4] means_probit;
vector<lower=0>[4] sigmas;
// participant-level parameters
vector[n_part] LR_probit;
vector[n_part] inv_temp_probit;
vector[n_part] initQF_probit;
vector[n_part] initQU_probit;
}
transformed parameters block[explanation here]
model
block[explanation here]
start <- grep("model \\{", mod_code)
end <- grep("participant loop", mod_code)[1] - 2
cat(mod_code[start:end], sep = "\n")model {
// priors
means_probit ~ normal(0, 1);
sigmas ~ normal(0, 0.2);
LR_probit ~ normal(0, 1);
inv_temp_probit ~ normal(0, 1);
initQF_probit ~ normal(0, 1);
initQU_probit ~ normal(0, 1);
For every participant I initialize an array for holding the Q values.
I also create a temporary vector Q_t that holds the two Q
values (F, U) for the current trial, because otherwise the
categorical_logit wouldn’t work, somehow. This works for
now, and maybe in the future I will find a more elegant solution.
start <- grep("participant loop", mod_code) + 1
end <- grep("trial loop", mod_code) - 2
cat(mod_code[start:end], sep = "\n") for (j in 1:n_part) {
// initialization
array[n_trials, 2] real Q;
Q[1, 1] = initQF[j];
Q[1, 2] = initQU[j];
vector[2] Q_t;
real pred_err;
start <- grep("trial loop", mod_code) + 1
end <- grep("generated quantities", mod_code) - 2
cat(mod_code[start:end], sep = "\n") for (t in 1:n_trials) {
Q_t = to_vector(Q[t]);
// sample choice (1 is F, 2 is U) via softmax
choice[j, t] ~ categorical_logit(inv_temp[j] * Q_t);
// prediction error
pred_err = R[j, t] - Q[t, choice[j, t]];
// update value (learn)
if (t < n_trials) { // no updating in the very last trial
if (choice[j, t] == 1) {
Q[t+1, 1] = Q[t, 1] + LR[j] * pred_err;
Q[t+1, 2] = Q[t, 2];
} else {
Q[t+1, 1] = Q[t, 1];
Q[t+1, 2] = Q[t, 2] + LR[j] * pred_err;
}
}
}
}
}
generated quantities block[explanation here]
start <- grep("generated quantities", mod_code)
end <- length(mod_code)
cat(mod_code[start:end], sep = "\n")generated quantities {
vector[4] means;
means[1] = Phi_approx(means_probit[1]);
means[2] = Phi_approx(means_probit[2]) * 5;
means[3] = Phi_approx(means_probit[3]) * 9 + 1;
means[4] = Phi_approx(means_probit[4]) * 9 + 1;
}
To save running time, I only refit the model if I changed the model
or if the simulated data have changed (i.e., if I changed parameter
settings). To this end, I compare the previously saved
sim_data.json to the current sim_dat. After
this comparison, I can save the new sim_data and refit the
model if necessary.
dat_dir <- paste0(main_dir, "stan_models/single_run_dat/")
dat_file <- paste0(dat_dir, "sim_dat.json")
dat_changed <- sim$did_sim_dat_change(dat_file, sim_dat)
sim$save_sim_dat(params, sim_dat, dat_file)
model_changed <- FALSE
library(cmdstanr)
options(mc.cores = parallel::detectCores())
if (dat_changed | model_changed) {
m <- cmdstan_model("climate-RL.stan")
it <- 1000
fit <- m$sample(
data = dat_file,
iter_sampling = it,
chains = 1,
thin = 1,
iter_warmup = it / 2,
refresh = it / 5,
seed = 1234
)
fit$save_object(file = paste0(dat_dir, "climate-RL_single_fit.rds"))
} else {
fit <- readRDS(file = paste0(dat_dir, "climate-RL_single_fit.rds"))
}
draws <- posterior::as_draws_df(fit$draws()) # df makes it easier to handle
draws <- draws %>%
dplyr::rename(
LR_group = `means[1]`,
inv_temp_group = `means[2]`,
`initQ_group$F` = `means[3]`,
`initQ_group$U` = `means[4]`
)to_inspect <- c("LR_group", "inv_temp_group", "initQ_group$F", "initQ_group$U")
plot$posterior_density(draws, to_inspect, params)util <- new.env()
source(paste0(main_dir, "utils.R"), local = util) # access functions using util$fun()
util$print_posterior_table(draws, params, to_inspect)| Parameter | Simulated value | Median [95% credibility interval] |
|---|---|---|
| LR_group | 0.4 | 0.39 [0.28, 0.50] |
| inv_temp_group | 0.5 | 0.51 [0.37, 0.70] |
| initQ_group$F | 8.0 | 6.45 [4.79, 8.27] |
| initQ_group$U | 2.0 | 2.64 [1.43, 3.97] |
To check parameter recovery across a range of parameter values, I run the model a bunch of times with different parameter settings. More specifically, each parameter value is chosen randomly from its whole theoretical range.
n_runs <- 10
dat_dir <- paste0(main_dir, "stan_models/many_runs_dat/")
m <- cmdstan_model("climate-RL.stan")
it <- 1000
free_params <- c("LR_group", "inv_temp_group", "initQ_group$F", "initQ_group$U")
# takes current list of params, randomizes the free ones and returns the list
randomize_free_params <- function(params) {
for (p in free_params) {
bounds <- sim$param_bounds[[p]]
params[[p]] <- runif(1, min = bounds[1], max = bounds[2])
}
return(params)
}
if (FALSE) {
for (k in 1:n_runs) {
# simulate
params <- randomize_free_params(params)
sim_dat <- sim$run_std(params)
dat_file <- paste0(dat_dir, "sim_dat_", sprintf("%02d", k), ".json")
sim$save_sim_dat(params, sim_dat, dat_file)
# fit
fit <- m$sample(
data = dat_file,
iter_sampling = it,
chains = 1,
thin = 1,
iter_warmup = it / 2,
refresh = it / 5,
seed = 1234
)
fit_file <- paste0(dat_dir, "climate-RL_fit_", sprintf("%02d", k), ".rds")
fit$save_object(file = fit_file)
}
}
# run time: around 40 seconds per sim, 6-7 min totalsim_params <- data.frame(k = 1:n_runs)
fit_params <- data.frame(k = 1:n_runs)
for (k in 1:n_runs) {
sim_file <- paste0(dat_dir, "sim_param_settings_", sprintf("%02d", k), ".json")
sim <- rjson::fromJSON(file = sim_file)
fit_file <- paste0(dat_dir, "climate-RL_fit_", sprintf("%02d", k), ".rds")
fit <- readRDS(fit_file)
draws <- posterior::as_draws_df(fit$draws()) %>%
dplyr::rename(
LR_group = `means[1]`,
inv_temp_group = `means[2]`,
`initQ_group$F` = `means[3]`,
`initQ_group$U` = `means[4]`
)
for (p in free_params) {
if (grepl("\\$", p)) { # parameter is part of a list
split_name <- strsplit(p, "\\$")[[1]]
sim_value <- sim[[split_name[1]]][(split_name[2] == "U") + 1]
sim_params[[p]][k] <- sim_value
} else {
sim_params[[p]][k] <- sim[[p]]
}
fit_params[[p]][k] <- median(draws[[p]])
}
}
source(paste0(main_dir, "plot_utils.R"), local = plot)
plot$many_runs_param_fit(sim_params, fit_params, free_params)