diff --git a/NEWS.md b/NEWS.md
index 89199522..fdc98378 100644
--- a/NEWS.md
+++ b/NEWS.md
@@ -1,5 +1,9 @@
# serodynamics (development version)
+* Added support for overlaying stratified marginal curves in `plot_predicted_curve()`:
+ When `ids = "newperson"` and `overlay_strata = TRUE`, the function now overlays
+ median marginal predictions for each stratum in the same panel, with different
+ strata distinguished by color (#73).
* Added dev container configuration for persistent, cached development environment
that includes R, JAGS, and all dependencies preinstalled, making Copilot
Workspace sessions much faster.
diff --git a/R/plot_predicted_curve.R b/R/plot_predicted_curve.R
index 3a0d4654..79c017ab 100644
--- a/R/plot_predicted_curve.R
+++ b/R/plot_predicted_curve.R
@@ -5,11 +5,13 @@
#' Optionally overlays observed data,
#' applies logarithmic spacing on the y- and x-axes,
#' and shows all individual
-#' sampled curves.
+#' sampled curves. When `ids = "newperson"` and `overlay_strata = TRUE`,
+#' overlays marginal predictions for each stratum in the same panel.
#'
#' @param model An `sr_model` object (returned by [run_mod()]) containing
#' samples from the posterior distribution of the model parameters.
#' @param ids The participant IDs to plot; for example, `"sees_npl_128"`.
+#' Use `"newperson"` to plot marginal ("new person") predictions.
#' @param antigen_iso The antigen isotype to plot; for example, "HlyE_IgA" or
#' "HlyE_IgG".
#' @param dataset (Optional) A [dplyr::tbl_df] with observed antibody response
@@ -19,6 +21,9 @@
#' - `value`
#' - `id`
#' - `antigen_iso`
+#' @param overlay_strata [logical]; if [TRUE], overlays marginal predictions
+#' for each stratum in the same panel when `ids = "newperson"`.
+#' Different strata are distinguished by color. Defaults to [FALSE].
#' @param legend_obs Label for observed data in the legend.
#' @param legend_median Label for the median prediction line.
#' @param show_quantiles [logical]; if [TRUE] (default), plots the 2.5%, 50%,
@@ -49,6 +54,7 @@ plot_predicted_curve <- function(model,
ids,
antigen_iso,
dataset = NULL,
+ overlay_strata = FALSE,
legend_obs = "Observed data",
legend_median = "Median prediction",
show_quantiles = TRUE,
@@ -68,16 +74,24 @@ plot_predicted_curve <- function(model,
.data$Iso_type == antigen_iso # e.g. "HlyE_IgA"
)
+ # Check if Stratification column exists in the model
+ has_stratification <- "Stratification" %in% names(sr_model_sub)
+
+ # Determine if we're in stratification overlay mode
+ is_stratified_overlay <- overlay_strata &&
+ ids[1] == "newperson" &&
+ has_stratification
+
+ # Select columns for pivoting (including Stratification if present)
+ select_cols <- c("Chain", "Iteration", "Iso_type", "Parameter",
+ "value", "Subject")
+ if (has_stratification) {
+ select_cols <- c(select_cols, "Stratification")
+ }
+
# Pivot to wide format: one row per iteration/chain
param_medians_wide <- sr_model_sub |>
- dplyr::select(
- all_of(c("Chain",
- "Iteration",
- "Iso_type",
- "Parameter",
- "value",
- "Subject"))
- ) |>
+ dplyr::select(all_of(select_cols)) |>
tidyr::pivot_wider(
names_from = c("Parameter"),
values_from = c("value")
@@ -90,6 +104,12 @@ plot_predicted_curve <- function(model,
) |>
dplyr::select(-c("Iso_type", "Subject"))
+ # Add stratum factor if in stratified overlay mode
+ if (is_stratified_overlay) {
+ param_medians_wide <- param_medians_wide |>
+ dplyr::mutate(stratum = as.factor(.data$Stratification))
+ }
+
# Add sample_id if not present (to identify individual samples)
if (!"sample_id" %in% names(param_medians_wide)) {
param_medians_wide <- param_medians_wide |>
@@ -135,37 +155,73 @@ plot_predicted_curve <- function(model,
# If show_all_curves is TRUE, overlay all individual sampled curves.
if (show_all_curves) {
- p <- p +
- ggplot2::geom_line(data = serocourse_all1,
- ggplot2::aes(x = .data$t,
- y = .data$res,
- group = .data$sample_id,
- color = "samples"),
- alpha = alpha_samples)
+ if (is_stratified_overlay) {
+ # For stratified overlay, color by stratum
+ p <- p +
+ ggplot2::geom_line(data = serocourse_all1,
+ ggplot2::aes(x = .data$t,
+ y = .data$res,
+ group = .data$sample_id,
+ color = .data$stratum),
+ alpha = alpha_samples)
+ } else {
+ # For non-stratified, use single "samples" color
+ p <- p +
+ ggplot2::geom_line(data = serocourse_all1,
+ ggplot2::aes(x = .data$t,
+ y = .data$res,
+ group = .data$sample_id,
+ color = "samples"),
+ alpha = alpha_samples)
+ }
}
# --- Summarize & Plot Model 1 (Median + 95% Ribbon) ---
if (show_quantiles) {
- sum1 <- serocourse_all1 |>
- dplyr::summarise(
- .by = all_of(c("id", "t")),
- res.med = stats::quantile(.data$res, probs = 0.50, na.rm = TRUE),
- res.low = stats::quantile(.data$res, probs = 0.025, na.rm = TRUE),
- res.high = stats::quantile(.data$res, probs = 0.975, na.rm = TRUE)
- )
-
- p <- p +
- ggplot2::geom_ribbon(data = sum1,
+ # Determine grouping variables based on stratification mode
+ if (is_stratified_overlay) {
+ sum1 <- serocourse_all1 |>
+ dplyr::summarise(
+ .by = all_of(c("stratum", "t")),
+ res.med = stats::quantile(.data$res, probs = 0.50, na.rm = TRUE),
+ res.low = stats::quantile(.data$res, probs = 0.025, na.rm = TRUE),
+ res.high = stats::quantile(.data$res, probs = 0.975, na.rm = TRUE)
+ )
+
+ p <- p +
+ ggplot2::geom_ribbon(data = sum1,
+ ggplot2::aes(x = .data$t,
+ ymin = .data$res.low,
+ ymax = .data$res.high,
+ fill = .data$stratum),
+ alpha = 0.2, inherit.aes = FALSE) +
+ ggplot2::geom_line(data = sum1,
ggplot2::aes(x = .data$t,
- ymin = .data$res.low,
- ymax = .data$res.high,
- fill = "ci"),
- alpha = 0.2, inherit.aes = FALSE) +
- ggplot2::geom_line(data = sum1,
- ggplot2::aes(x = .data$t,
- y = .data$res.med,
- color = "median"),
- linewidth = 1, inherit.aes = FALSE)
+ y = .data$res.med,
+ color = .data$stratum),
+ linewidth = 1, inherit.aes = FALSE)
+ } else {
+ sum1 <- serocourse_all1 |>
+ dplyr::summarise(
+ .by = all_of(c("id", "t")),
+ res.med = stats::quantile(.data$res, probs = 0.50, na.rm = TRUE),
+ res.low = stats::quantile(.data$res, probs = 0.025, na.rm = TRUE),
+ res.high = stats::quantile(.data$res, probs = 0.975, na.rm = TRUE)
+ )
+
+ p <- p +
+ ggplot2::geom_ribbon(data = sum1,
+ ggplot2::aes(x = .data$t,
+ ymin = .data$res.low,
+ ymax = .data$res.high,
+ fill = "ci"),
+ alpha = 0.2, inherit.aes = FALSE) +
+ ggplot2::geom_line(data = sum1,
+ ggplot2::aes(x = .data$t,
+ y = .data$res.med,
+ color = "median"),
+ linewidth = 1, inherit.aes = FALSE)
+ }
}
# --- Overlay Observed Data (if provided) ---
@@ -199,37 +255,54 @@ plot_predicted_curve <- function(model,
}
# --- Construct Unified Legend ---
- color_vals <- c("median" = "red")
- color_labels <- c("median" = legend_median)
- fill_vals <- c("ci" = "red")
- fill_labels <- c("ci" = "95% credible interval")
-
- if (show_all_curves) {
- color_vals["samples"] <- "gray"
- color_labels["samples"] <- "Posterior samples"
- }
-
- if (!is.null(dataset)) {
- color_vals["observed"] <- "blue"
- color_labels["observed"] <- legend_obs
+ if (is_stratified_overlay) {
+ # For stratified overlay, use stratum colors directly
+ strata_levels <- unique(param_medians_wide$stratum)
+ n_strata <- length(strata_levels)
+
+ # Use ggplot2 default colors or a color scale
+ strata_colors <- scales::hue_pal()(n_strata)
+ names(strata_colors) <- strata_levels
+
+ p <- p +
+ ggplot2::scale_color_discrete(name = "Stratum") +
+ ggplot2::scale_fill_discrete(name = "Stratum")
+
+ } else {
+ # Original legend construction for non-stratified plots
+ color_vals <- c("median" = "red")
+ color_labels <- c("median" = legend_median)
+ fill_vals <- c("ci" = "red")
+ fill_labels <- c("ci" = "95% credible interval")
+
+ if (show_all_curves) {
+ color_vals["samples"] <- "gray"
+ color_labels["samples"] <- "Posterior samples"
+ }
+
+ if (!is.null(dataset)) {
+ color_vals["observed"] <- "blue"
+ color_labels["observed"] <- legend_obs
+ }
+
+ p <- p +
+ ggplot2::scale_color_manual(
+ name = "",
+ values = color_vals,
+ labels = color_labels,
+ guide = ggplot2::guide_legend(override.aes = list(shape = NA))
+ ) +
+ ggplot2::scale_fill_manual(
+ name = "",
+ values = fill_vals,
+ labels = fill_labels,
+ guide = ggplot2::guide_legend(override.aes = list(color = NA))
+ )
}
- p <- p +
- ggplot2::scale_color_manual(
- name = "",
- values = color_vals,
- labels = color_labels,
- guide = ggplot2::guide_legend(override.aes = list(shape = NA))
- ) +
- ggplot2::scale_fill_manual(
- name = "",
- values = fill_vals,
- labels = fill_labels,
- guide = ggplot2::guide_legend(override.aes = list(color = NA))
- )
-
# --- Optionally facet by ID ---
- if (facet_by_id) {
+ # Don't facet when in stratified overlay mode
+ if (facet_by_id && !is_stratified_overlay) {
if (is.null(ncol)) {
n_ids <- length(unique(param_medians_wide$id))
ncol <- if (n_ids == 1) {
diff --git a/inst/examples/examples-plot_predicted_curve.R b/inst/examples/examples-plot_predicted_curve.R
index 582b49c9..b0f229e7 100644
--- a/inst/examples/examples-plot_predicted_curve.R
+++ b/inst/examples/examples-plot_predicted_curve.R
@@ -51,3 +51,25 @@ p4 <- plot_predicted_curve(
facet_by_id = TRUE
)
print(p4)
+
+# Stratified marginal overlay plot (newperson with multiple strata):
+p5 <- plot_predicted_curve(
+ model = sees_model,
+ ids = "newperson",
+ antigen_iso = "HlyE_IgA",
+ overlay_strata = TRUE,
+ show_quantiles = TRUE,
+ log_y = FALSE
+)
+print(p5)
+
+# Stratified marginal overlay with log y-axis:
+p6 <- plot_predicted_curve(
+ model = sees_model,
+ ids = "newperson",
+ antigen_iso = "HlyE_IgA",
+ overlay_strata = TRUE,
+ show_quantiles = TRUE,
+ log_y = TRUE
+)
+print(p6)
diff --git a/man/plot_predicted_curve.Rd b/man/plot_predicted_curve.Rd
index 7e2fabde..169f86a2 100644
--- a/man/plot_predicted_curve.Rd
+++ b/man/plot_predicted_curve.Rd
@@ -9,6 +9,7 @@ plot_predicted_curve(
ids,
antigen_iso,
dataset = NULL,
+ overlay_strata = FALSE,
legend_obs = "Observed data",
legend_median = "Median prediction",
show_quantiles = TRUE,
@@ -26,7 +27,8 @@ plot_predicted_curve(
\item{model}{An \code{sr_model} object (returned by \code{\link[=run_mod]{run_mod()}}) containing
samples from the posterior distribution of the model parameters.}
-\item{ids}{The participant IDs to plot; for example, \code{"sees_npl_128"}.}
+\item{ids}{The participant IDs to plot; for example, \code{"sees_npl_128"}.
+Use \code{"newperson"} to plot marginal ("new person") predictions.}
\item{antigen_iso}{The antigen isotype to plot; for example, "HlyE_IgA" or
"HlyE_IgG".}
@@ -41,6 +43,10 @@ Must contain:
\item \code{antigen_iso}
}}
+\item{overlay_strata}{\link{logical}; if \link{TRUE}, overlays marginal predictions
+for each stratum in the same panel when \code{ids = "newperson"}.
+Different strata are distinguished by color. Defaults to \link{FALSE}.}
+
\item{legend_obs}{Label for observed data in the legend.}
\item{legend_median}{Label for the median prediction line.}
@@ -82,7 +88,8 @@ ribbon, using MCMC samples from the posterior distribution.
Optionally overlays observed data,
applies logarithmic spacing on the y- and x-axes,
and shows all individual
-sampled curves.
+sampled curves. When \code{ids = "newperson"} and \code{overlay_strata = TRUE},
+overlays marginal predictions for each stratum in the same panel.
}
\examples{
sees_model <- serodynamics::nepal_sees_jags_output
@@ -138,4 +145,26 @@ p4 <- plot_predicted_curve(
facet_by_id = TRUE
)
print(p4)
+
+# Stratified marginal overlay plot (newperson with multiple strata):
+p5 <- plot_predicted_curve(
+ model = sees_model,
+ ids = "newperson",
+ antigen_iso = "HlyE_IgA",
+ overlay_strata = TRUE,
+ show_quantiles = TRUE,
+ log_y = FALSE
+)
+print(p5)
+
+# Stratified marginal overlay with log y-axis:
+p6 <- plot_predicted_curve(
+ model = sees_model,
+ ids = "newperson",
+ antigen_iso = "HlyE_IgA",
+ overlay_strata = TRUE,
+ show_quantiles = TRUE,
+ log_y = TRUE
+)
+print(p6)
}
diff --git a/tests/testthat/_snaps/plot_predicted_curve/predicted-curve-stratified-all-curves.svg b/tests/testthat/_snaps/plot_predicted_curve/predicted-curve-stratified-all-curves.svg
new file mode 100644
index 00000000..6bfb431f
--- /dev/null
+++ b/tests/testthat/_snaps/plot_predicted_curve/predicted-curve-stratified-all-curves.svg
@@ -0,0 +1,2080 @@
+
+
diff --git a/tests/testthat/_snaps/plot_predicted_curve/predicted-curve-stratified-log.svg b/tests/testthat/_snaps/plot_predicted_curve/predicted-curve-stratified-log.svg
new file mode 100644
index 00000000..7c5d707b
--- /dev/null
+++ b/tests/testthat/_snaps/plot_predicted_curve/predicted-curve-stratified-log.svg
@@ -0,0 +1,76 @@
+
+
diff --git a/tests/testthat/_snaps/plot_predicted_curve/predicted-curve-stratified-overlay.svg b/tests/testthat/_snaps/plot_predicted_curve/predicted-curve-stratified-overlay.svg
new file mode 100644
index 00000000..e987ef6e
--- /dev/null
+++ b/tests/testthat/_snaps/plot_predicted_curve/predicted-curve-stratified-overlay.svg
@@ -0,0 +1,76 @@
+
+
diff --git a/tests/testthat/test-plot_predicted_curve.R b/tests/testthat/test-plot_predicted_curve.R
index d0631294..5f031477 100644
--- a/tests/testthat/test-plot_predicted_curve.R
+++ b/tests/testthat/test-plot_predicted_curve.R
@@ -110,3 +110,55 @@ testthat::test_that(
vdiffr::expect_doppelganger("predicted-curve-multi-id-4", plot_multi)
}
)
+
+testthat::test_that(
+ "plot_predicted_curve() works with stratified overlay for newperson",
+ {
+ plot_strat <- plot_predicted_curve(
+ model = serodynamics::nepal_sees_jags_output,
+ ids = "newperson",
+ antigen_iso = "HlyE_IgA",
+ overlay_strata = TRUE,
+ show_quantiles = TRUE,
+ log_y = FALSE
+ )
+ vdiffr::expect_doppelganger("predicted-curve-stratified-overlay", plot_strat)
+ }
+)
+
+testthat::test_that(
+ "plot_predicted_curve() stratified overlay works with show_all_curves",
+ {
+ plot_strat_curves <- plot_predicted_curve(
+ model = serodynamics::nepal_sees_jags_output,
+ ids = "newperson",
+ antigen_iso = "HlyE_IgA",
+ overlay_strata = TRUE,
+ show_all_curves = TRUE,
+ alpha_samples = 0.1,
+ log_y = FALSE
+ )
+ vdiffr::expect_doppelganger(
+ "predicted-curve-stratified-all-curves",
+ plot_strat_curves
+ )
+ }
+)
+
+testthat::test_that(
+ "plot_predicted_curve() stratified overlay works with log_y",
+ {
+ plot_strat_log <- plot_predicted_curve(
+ model = serodynamics::nepal_sees_jags_output,
+ ids = "newperson",
+ antigen_iso = "HlyE_IgA",
+ overlay_strata = TRUE,
+ show_quantiles = TRUE,
+ log_y = TRUE
+ )
+ vdiffr::expect_doppelganger(
+ "predicted-curve-stratified-log",
+ plot_strat_log
+ )
+ }
+)