sim.R

set.seed(1234)

# === param_stddevs =============================
# list of standard deviations for group-level means of parameters
param_stddevs <- list(
  LR_group = 0.2,
  LRs_group = 0.2,
  inv_temp_group = 0.3,
  initQ_group = 2,
  mu_R_group = 2,
  sigma_R_group = 2,
  margin_group = 2
)

# === param_bounds =============================
# list of theoretical bounds for parameters; same as in Stan
param_bounds <- list(
  LR_group = c(0, 1),
  LRs_group = c(0, 1),
  LR = c(0, 1),
  inv_temp_group = c(0, 5),
  inv_temp = c(0, 5),
  initQ_group = c(1, 10),
  `initQ_group$F` = c(1, 10),
  `initQ_group$U` = c(1, 10),
  initQF = c(1, 10),
  initQU = c(1, 10),
  mu_R_group = c(1, 10),
  sigma_R_group = c(0, 10),
  margin_group = c(0, 10)
)

# === run_std() =============================
# runs the standard simulation (no confirmation bias)
# arguments: 
# - params: vector of parameter settings
# 
# returns: 
# - data frame of simulated data
run_std <- function(params) {
  library(truncnorm) # for drawing from truncated distribution

  # ------ initialize ------
  n_part <- params$n_part
  n_trials <- params$n_trials
  Q_F <- matrix(ncol = n_trials, nrow = n_part)
  Q_U <- matrix(ncol = n_trials, nrow = n_part)
  choice <- matrix(ncol = n_trials, nrow = n_part)
  R <- matrix(ncol = n_trials, nrow = n_part)
  pred_err <- matrix(ncol = n_trials, nrow = n_part)
  LR <- c()
  inv_temp <- c()

  for (j in 1:n_part) {
    P_F <- c()
    pred_err <- c()

    # ----- initialize parameters -----
    # TODO: possibly replace this with a loop over params (more elegant)
    LR[j] <- draw_from_group_mean(params, "LR_group")
    inv_temp[j] <- draw_from_group_mean(params, "inv_temp_group")
    initQ <- draw_from_group_mean(params, "initQ_group")
    Q_F[j, 1] <- initQ[1]
    Q_U[j, 1] <- initQ[2]
    mu_R <- draw_from_group_mean(params, "mu_R_group")
    sigma_R <- draw_from_group_mean(params, "sigma_R_group")

    # --------- run trials ------------
    for (t in 1:n_trials) {

      # choose
      P_F[t] <- 1 / (1 + exp(-inv_temp[j] * (Q_F[j, t] - Q_U[j, t])))
      choice[j, t] <- sample(c(1, 2), 
                             size = 1,
                             prob = c(P_F[t], 1 - P_F[t]))

      # rate
      R[j, t] <- round(rtruncnorm(n = 1, a = 1, b = 10,
                                  mean = mu_R[[choice[j, t]]], 
                                  sd = sigma_R),
                       0)

      # learn
      if (t < n_trials) {   # no updating Qs in the very last trial
        if (choice[j, t] == 1) {
          pred_err[t] <- R[j, t] - Q_F[j, t]
          Q_F[j, t+1] <- Q_F[j, t] + LR[j] * pred_err[t]
          Q_U[j, t+1] <- Q_U[j, t]
        } else {
          pred_err[t] <- R[j, t] - Q_U[j, t]
          Q_U[j, t+1] <- Q_U[j, t] + LR[j] * pred_err[t]
          Q_F[j, t+1] <- Q_F[j, t]
        }
      }
    }
  }

  dat <- data.frame(
    participant =   rep(seq_len(n_part), each = n_trials),
    trial =         rep(seq_len(n_trials), n_part),
    Q_F =           array(t(Q_F)),
    Q_U =           array(t(Q_U)),
    choice =        array(t(choice)),
    R =             array(t(R)),
    LR =            rep(LR, each = n_trials),
    inv_temp =      rep(inv_temp, each = n_trials)
  )
  return(dat)
}

# === draw_from_group_mean() =======================
# arguments: 
# - param_settings: the list of parameter settings (i.e., group means)
# - p: the parameter for which we want to draw from the group mean
# 
# returns:
# - draw from the group mean of parameter p, keeping in mind the theoretical bounds of the parameter, the group mean and the standard deviation
draw_from_group_mean <- function(param_settings, p) {
  library(truncnorm) # for drawing from truncated distribution
  draw <- rtruncnorm(n = 1, 
                     a = param_bounds[[p]][1],
                     b = param_bounds[[p]][2],
                     mean = param_settings[[p]],
                     sd = param_stddevs[[p]])
  return(draw)
}

# === save_sim_dat() =======================
# saves parameter settings to file "sim_param_settings.json" and saves simulated data to file "sim_dat.json"
# arguments: 
# - params: vector of parameter settings
# - sim_dat: data frame of simulated data
# - dat_file_name: file to save data to
# 
# returns: nothing
save_sim_dat <- function(params, sim_dat, dat_file_name) {
  library(cmdstanr) # contains function write_stan_json()

  # parameter settings
  params$LR <- round(sim_dat$LR[which(sim_dat$trial == 1)], 4)
  params$inv_temp <- round(sim_dat$inv_temp[which(sim_dat$trial == 1)], 4)
  params$initQF <- round(sim_dat$Q_F[which(sim_dat$trial == 1)], 4)
  params$initQU <- round(sim_dat$Q_U[which(sim_dat$trial == 1)], 4)
  param_file_name <- stringr::str_replace(dat_file_name, "sim_dat", "sim_param_settings")
  write_stan_json(params, file = param_file_name)

  # data
  n_part <- params$n_part
  n_trials <- params$n_trials
  choice <- matrix(sim_dat$choice,
                   nrow = n_part,
                   ncol = n_trials,
                   byrow = TRUE)
  R <- matrix(sim_dat$R,
              nrow = n_part,
              ncol = n_trials,
              byrow = TRUE)
  dat_names <- c("n_part", "n_trials", "choice", "R")
  list_dat <- setNames(mget(dat_names), dat_names)
  write_stan_json(list_dat, file = dat_file_name)
}

# === did_sim_dat_change() =================
# arguments: 
# - data_file: file path for the JSON file
# - sim_dat: data frame to compare to file
# 
# returns: 
#  - TRUE if sim_dat changed compared to saved JSON file, else FALSE
did_sim_dat_change <- function(data_file, sim_dat) {
  json_data <- rjson::fromJSON(file = data_file)
  json_choice <- unlist(json_data$choice)
  json_R <- unlist(json_data$R)
  sim_choice <- sim_dat$choice
  sim_R <- sim_dat$R
  if (identical(json_choice, sim_choice) & 
      identical(json_R, sim_R)) {
    return(FALSE)
  } else {
    return(TRUE)
  }
}

# === run_LRN_discr() =================
# arguments: 
# - params: vector of parameter settings; 
# - LR_function: which LR function to use (LR_approx or LR_geq)
# - belief_type: whether belief is "stat" or "dyn"
# 
# returns: 
# - data frame of simulated data
run_LRN_discr <- function(params, LR_function, belief_type) {
  library(truncnorm) # for drawing from truncated distribution
  dat <- data.frame()

  n_part <- params$n_part
  n_trials <- params$n_trials

  for (j in 1:n_part) {

    # ------ init data frames & vectors -----
    Q <- data.frame(
      F = rep(NA, n_trials),
      U = rep(NA, n_trials)
    )
    P_F <- c()
    choice <- c()
    R <- c()
    pred_err <- c()

    # ----- initialize parameters -----
    # TODO: possibly replace this with a loop over params (more elegant)
    LRs <- draw_from_group_mean(params, "LRs_group")
    names(LRs) <- c("conf", "disconf")
    inv_temp <- draw_from_group_mean(params, "inv_temp_group")
    initQ <- draw_from_group_mean(params, "initQ_group")
    Q$F[1] <- initQ[1]
    Q$U[1] <- initQ[2]
    mu_R <- draw_from_group_mean(params, "mu_R_group")
    sigma_R <- draw_from_group_mean(params, "sigma_R_group")
    margin <- draw_from_group_mean(params, "margin_group")

    # --------- run trials ------------
    for (t in 1:n_trials) {

      # choose
      P_F[t] <- 1 / (1 + exp(-inv_temp * (Q$F[t] - Q$U[t])))
      choice[t] <- sample(c(1, 2), 
                          size = 1,
                          prob = c(P_F[t], 1 - P_F[t]))

      # rate
      R[t] <- round(rtruncnorm(n = 1, a = 1, b = 10,
                               mean = mu_R[[choice[t]]], 
                               sd = sigma_R),
                    0)

      # learn
      pred_err[t] <- R[t] - Q[t, choice[t]]

      if (t < n_trials) {   # no updating Qs in the very last trial
        if (choice[t] == 1) {
          belief <- if (belief_type == "stat") Q[1, 1] else  Q[max(t-1, 1), 1]
          LR <- LRs[[ LR_function(R[t], belief, margin) ]]
          Q[t+1, 1] <- Q[t, 1] + LR * pred_err[t]
          Q[t+1, 2] <- Q[t, 2]
        } else {
          belief <- if (belief_type == "stat") Q[1, 2] else Q[max(t-1, 1), 2]
          LR <- LRs[[ LR_function(R[t], belief, margin) ]]
          Q[t+1, 2] <- Q[t, 2] + LR * pred_err[t]
          Q[t+1, 1] <- Q[t, 1]
        }
      }
    }

    dat_p <- data.frame(
      participant = rep(j, n_trials),
      trial =       1:n_trials,
      Q_F =         Q$F,
      Q_U =         Q$U,
      P_F =         P_F,
      LR =          LR,
      choice =      choice,
      R =           R,
      pred_err =    pred_err
    )

    dat <- rbind(dat, dat_p)
  }
  return(dat)
}

# === LR_approx() =================
# arguments: the rating of this trial; the belief to compare it to; the margin
# returns: "conf" or "disconf"
LR_approx <- function(R, belief, margin) {
  if (abs(R - belief) <= margin) {
    return("conf")
  } else {
    return("disconf")
  }
}

# === LR_geq() =================
# arguments: the rating of this trial; the belief to compare it to; the margin
# returns: "conf" or "disconf"
LR_geq <- function(R, belief, margin) {
  if (R + margin >= belief) {
    return("conf")
  } else {
    return("disconf")
  }
}

# === run_LRN_cont() =================
# arguments: vector of parameter settings; whether belief is stat or dyn
# returns: data frame of simulated data
run_LRN_cont <- function(params, belief_type) {
  library(truncnorm) # for drawing from truncated distribution
  dat <- data.frame()

  n_part <- params$n_part
  n_trials <- params$n_trials

  for (j in 1:n_part) {

    # ------ init data frames & vectors -----
    Q <- data.frame(
      F = rep(NA, n_trials),
      U = rep(NA, n_trials)
    )
    P_F <- c()
    choice <- c()
    R <- c()
    LR <- c()
    pred_err <- c()

    # ----- initialize parameters -----
    w_LR <- params$w_LR
    inv_temp <- params$inv_temp
    Q$F[1] <- params$initQ$F
    Q$U[1] <- params$initQ$U
    mu_R <- params$mu_R
    sigma_R <- params$sigma_R

    # --------- run trials ------------
    for (t in 1:n_trials) {

      # choose
      P_F[t] <- 1 / (1 + exp(-inv_temp * (Q$F[t] - Q$U[t])))
      choice[t] <- sample(c(1, 2), 
                          size = 1,
                          prob = c(P_F[t], 1 - P_F[t]))

      # rate
      R[t] <- round(rtruncnorm(n = 1, a = 1, b = 10,
                               mean = mu_R[[choice[t]]], 
                               sd = sigma_R),
                    0)

      # learn
      pred_err[t] <- R[t] - Q[t, choice[t]]

      if (t < n_trials) {   # no updating Qs in the very last trial
        if (choice[t] == 1) {                                 # since t = 0 doesn't exist
          belief <- if (belief_type == "stat") Q[1, 1] else Q[max(t-1, 1), 1]
          LR[t] <- LR_cont(R[t], belief, w_LR)
          Q[t+1, 1] <- Q[t, 1] + LR[t] * pred_err[t]
          Q[t+1, 2] <- Q[t, 2]
        } else {
          belief <- if (belief_type == "stat") Q[1, 2] else Q[max(t-1, 1), 2]
          LR[t] <- LR_cont(R[t], belief, w_LR)
          Q[t+1, 2] <- Q[t, 2] + LR[t] * pred_err[t]
          Q[t+1, 1] <- Q[t, 1]
        }
      }
    }

    dat_p <- data.frame(
      participant = rep(j, n_trials),
      trial =       1:n_trials,
      Q_F =         Q$F,
      Q_U =         Q$U,
      P_F =         P_F,
      LR =          c(LR, NA),
      choice =      choice,
      R =           R,
      pred_err =    pred_err
    )

    dat <- rbind(dat, dat_p)
  }
  return(dat)
}

# === LR_cont() =================
# arguments: the rating of this trial; the belief to compare it to; the learning rate weight
# returns: learning rate for this trial
LR_cont <- function(R, belief, w_LR) {
  diff <- R - belief
  LR_prime <- min(1, 
                  1/9 * diff + 1)
  return(w_LR * LR_prime)
}

plot_utils.R

# ---------------------
#        GENERAL
# ---------------------
main_dir <- "~/research/climate-RL-mod/"
sim_dir <- paste0(main_dir, "R_simulation/")
library(tidyverse)
my_teal <- "#008080"
my_pink <- "#ff00dd"
my_blue <- "#00aadd"
my_dark_blue <- "#001199"
my_colors <- c(my_teal, my_pink, my_teal, my_pink, my_teal, my_pink, my_blue, my_dark_blue, my_blue, my_dark_blue)
my_param_colors <- setNames(my_colors, c("F", "U", "initQF", "initQU", "initQ_group$F", "initQ_group$U", "LR", "inv_temp", "LR_group", "inv_temp_group"))

my_theme <- theme_bw() +
  theme(plot.title = element_text(size = 22, face = "bold")) +
  theme(axis.text = element_text(size = 16),
        axis.title = element_text(size = 18)) +
  theme(legend.title = element_blank(),
        legend.text = element_text(size = 16)) +
  theme(strip.text = element_text(size = 18, face = "bold"))

my_theme_classic <- theme_classic() +
  theme(plot.title = element_text(size = 22, face = "bold")) +
  theme(axis.text = element_text(size = 16),
        axis.title = element_text(size = 18)) +
  theme(legend.title = element_blank(),
        legend.text = element_text(size = 16)) +
  theme(strip.text = element_text(size = 18, face = "bold"))

# --------------------------------------
#        PLOTS FOR SIMULATED DATA
# --------------------------------------

# === Q() ========================
# arguments: 
# - sim_dat: data frame of simulated data
# 
# returns: 
# - smooth plot of Q values over time (ggplot)
Q <- function(sim_dat) {
  # data to long format
  sim_dat <- sim_dat %>%
    pivot_longer(c(Q_F, Q_U), names_prefix = "Q_", names_to = "option", values_to = "Q") %>%
    mutate(option = factor(option),
           choice = factor(choice))

  p <- ggplot(sim_dat, aes(x = trial,
                           y = Q,
                           color = option)) +
    geom_smooth(aes(fill = option)) +
    ylim(c(1, 10)) +
    labs(x = "Trial") +
    scale_color_manual(values = my_param_colors,
                       labels = c("Friendly", "Unfriendly")) +  
    scale_fill_manual(values = my_param_colors,
                      labels = c("Friendly", "Unfriendly")) +
    my_theme +
    theme(legend.position = "inside",
          legend.position.inside = c(0.83, 0.91))

  return(p)
}

# === choice() ========================
# arguments: 
# - sim_dat: data frame of simulated data
# 
# returns: 
# - smooth plot of choices over time (ggplot)
choice <- function(sim_dat) {
  # data to long format
  sim_dat <- sim_dat %>%
    mutate(choice_is_F = as.numeric(choice == 1),
           choice_is_U = 1 - choice_is_F)
    
  p <- ggplot(sim_dat, aes(x = trial)) +
    geom_smooth(aes(y = choice_is_F),
                color = my_param_colors[["F"]],
                fill = my_param_colors[["F"]]) +
    geom_smooth(aes(y = choice_is_U),
                color = my_param_colors[["U"]],
                fill = my_param_colors[["U"]]) +
    ylim(c(0, 1)) +
    labs(x = "Trial",
        y = "Proportion chosen") +
    my_theme
  return(p)
}

# === param_annotation() ======================
# arguments: 
# - named list of parameter settings
# 
# returns: 
# - textGrob of parameter settings list
param_annotation <- function(params) {
  library(grid)
  full_text <- c()
  for (p in names(params)) {
    if (length(params[[p]]) == 1) {
      add_text <- paste0(p, " = ", params[[p]])
      full_text <- paste(full_text, add_text, sep = "\n")
    } else {
      for (i in 1:length(params[[p]])) {
        param_name <- paste0(p, "$", names(params[[p]][i]))
        add_text <- paste0(param_name, " = ", params[[p]][i])
        full_text <- paste(full_text, add_text, sep = "\n")
      }
    }
  }
  g <- textGrob(label = full_text, x = unit(1, "npc"), y = unit(0.98, "npc"), just = c("right", "top"))
  return(g)
}

# === sim_plots() ======================
# arguments: 
# - sim_dat: simulated data; 
# - params: list of parameter settings;
# - plot_title: string providing the title of the plot, or NA for no title
# 
# returns: nothing
sim_plots <- function(sim_dat, params, plot_title = NA) {
  annotation <- param_annotation(params)
  if (is.na(plot_title)) {
    title <- NA
  } else {
    title <- textGrob(plot_title, gp = gpar(fontsize = 20, font = 2))
  }
  
  gridExtra::grid.arrange(Q(sim_dat), 
                          choice(sim_dat), 
                          annotation,
                          ncol = 3,
                          widths = unit.c(unit(1, "null"), # fill space evenly
                                          unit(1, "null"),
                                          grobWidth(annotation) + unit(2, "mm")),
                          top = title
                         )
}

# --------------------------------
#        PLOTS FOR MODELING
# --------------------------------

# === posterior_density() =============================
# arguments:
# - draws: data frame of posterior draws from model
# - to_plot: string array of parameters to plot
# - param_settings: named list of parameter settings; if NULL, don't show simulated value
# 
# returns: 
# - density plot(s) of posterior distribution(s) with simulated value as dashed line. plots are organized using facet_grid and are color-coded
posterior_density <- function(draws, to_plot, param_settings = NULL) {
  plot_data <- data.frame()

  for (p in to_plot) {
    if (!is.null(param_settings)) {
      if (grepl("\\$", p)) { # parameter is part of a list
        split_name <- strsplit(p, "\\$")[[1]]
        sim_value <- purrr::pluck(param_settings, split_name[1], split_name[2])
      } else {
        sim_value <- param_settings[[p]]
      }
    }
    dat <- data.frame(
      parameter = as.factor(p),
      estimate = draws[[p]]
    )
    if (!is.null(param_settings)) {
      dat$sim_value <- sim_value
    }
    plot_data <- rbind(plot_data, dat)
  }

  plot <- ggplot(plot_data, aes(x = estimate, color = parameter, fill = parameter)) +
    geom_density(alpha = 0.6) +
    labs(title = "Posterior distributions", x = "Estimate", y = "Density") +
    facet_wrap(. ~ factor(parameter, to_plot), scales = "free") +
    scale_color_manual(values = my_param_colors) +
    scale_fill_manual(values = my_param_colors) +
    guides(linetype = "legend", color = "none", fill = "none") +
    my_theme

  if (!is.null(param_settings)) {
    plot <- plot + 
    geom_vline(aes(xintercept = sim_value, color = parameter, linetype = "sim_value")) +
    scale_linetype_manual(values = c("sim_value" = 2), name = NULL)
  }

  return(plot)
}

# === pp_level_param_fit() =============================
# arguments:
# - draws: data frame of posterior draws from model
# - to_plot: string array of parameters to plot
# - param_settings: named list of parameter settings
# 
# returns: 
# - nothing
pp_level_param_fit <- function(draws, to_plot, param_settings) {
  n_part <- param_settings$n_part
  plots <- list()
  for (p in to_plot) {
    median_draws <- c()
    sim_values <- c()
    for (j in 1:n_part) {
      sim_values <- c(sim_values, param_settings[[p]][j])
      param_name <- paste0(p, "[", j, "]")
      median_draws <- c(median_draws, median(draws[[param_name]]))
    }
    dat <- data.frame(sim_value = sim_values, 
                      fit_value = median_draws)
    
    bounds <- range(sim_values)

    plot <- ggplot(dat, aes(x = sim_value, y = fit_value)) +
      geom_point(color = my_param_colors[[p]]) +
      lims(y = bounds) +
      geom_abline(intercept = 0, slope = 1, linetype = 2) +
      labs(title = p, x = NULL, y = NULL) +
      my_theme + theme(plot.title = element_text(size = 18, face = "bold", hjust = 0.5))
    plots[[p]] <- plot
  }

  library(grid)
  gridExtra::grid.arrange(grobs = plots,
                          ncol = 2,
                          top = textGrob("Participant-level parameter estimations\n", x = unit(0, "npc"), just = "left", gp = gpar(fontsize = 22, font = 2)),
                          bottom = textGrob("Simulated value", gp = gpar(fontsize = 18)),
                          left = textGrob("Fitted value", rot = 90, gp = gpar(fontsize = 18))
                         )
}

many_runs_param_fit <- function(sim_params, fit_params, to_plot) {
  plots <- list()
  for (p in to_plot) {
    sim <- new.env()
    source(paste0(sim_dir, "sim.R"), local = sim)
    bounds <- sim$param_bounds[[p]]

    plot <- ggplot(data.frame(x = sim_params[[p]], y = fit_params[[p]]), aes(x = x, y = y)) +
      geom_point(color = my_param_colors[[p]]) +
      geom_abline(intercept = 0, slope = 1, linetype = 2) +
      lims(x = bounds, y = bounds) +
      labs(title = p, x = NULL, y = NULL) +
      my_theme + theme(plot.title = element_text(size = 18, face = "bold", hjust = 0.5))
    plots[[p]] <- plot
  }

  title <- paste0("Parameter recovery for ", max(sim_params$k), " simulations")

  library(grid)
  gridExtra::grid.arrange(grobs = plots,
                          ncol = 2,
                          top = textGrob(title, x = unit(0, "npc"), just = "left", gp = gpar(fontsize = 22, font = 2)),
                          bottom = textGrob("Simulated value", gp = gpar(fontsize = 18)),
                          left = textGrob("Fitted value", rot = 90, gp = gpar(fontsize = 18))
                         )
}

utils.R

# === cred_int() ========================
# arguments: 
# - posterior_dist: array of posterior distribution
# 
# returns: 
# - bounds of 95% credibility interval in array
cred_int <- function(posterior_dist) {
  return(as.numeric(quantile(posterior_dist, c(0.025, 0.975))))
}

# === print_posterior_table() ========================
# arguments: 
# - draws: data frame of posterior draws from model
# - param_settings: named list of parameter settings
# - to_show: string array of parameters to show in table
# 
# returns: 
# - nothing
print_posterior_table <- function(draws, param_settings, to_show) {
  table_data <- data.frame()

  for (p in to_show) {
    if (grepl("\\$", p)) { # parameter is part of a list
      split_name <- strsplit(p, "\\$")[[1]]
      sim_value <- purrr::pluck(param_settings, split_name[1], split_name[2])
      param_name <- paste0(split_name[1], "\\$", split_name[2]) # escape dollar sign for kable
    } else {
      sim_value <- param_settings[[p]]
      param_name <- p
    }
    dat <- data.frame(
      parameter = param_name,
      sim_value = sim_value,
      median_CI = sprintf("%.2f [%.2f, %.2f]",
                          median(draws[[p]]),
                          cred_int(draws[[p]])[1],
                          cred_int(draws[[p]])[2])
    )
    table_data <- rbind(table_data, dat)
  }
  colnames <- c("Parameter", "Simulated value",
                "Median [95% credibility interval]")
  knitr::kable(table_data,
               col.names = colnames,
               align = "lll",
               caption = "Posteriors for free parameters") %>%
  kableExtra::kable_styling(full_width = FALSE, position = "left")
}