Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,9 @@
# serodynamics (development version)

* Added support for overlaying stratified marginal curves in `plot_predicted_curve()`:
When `ids = "newperson"` and `overlay_strata = TRUE`, the function now overlays
median marginal predictions for each stratum in the same panel, with different
strata distinguished by color (#73).
* Added dev container configuration for persistent, cached development environment
that includes R, JAGS, and all dependencies preinstalled, making Copilot
Workspace sessions much faster.
Expand Down
199 changes: 136 additions & 63 deletions R/plot_predicted_curve.R
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,13 @@
#' Optionally overlays observed data,
#' applies logarithmic spacing on the y- and x-axes,
#' and shows all individual
#' sampled curves.
#' sampled curves. When `ids = "newperson"` and `overlay_strata = TRUE`,
#' overlays marginal predictions for each stratum in the same panel.
#'
#' @param model An `sr_model` object (returned by [run_mod()]) containing
#' samples from the posterior distribution of the model parameters.
#' @param ids The participant IDs to plot; for example, `"sees_npl_128"`.
#' Use `"newperson"` to plot marginal ("new person") predictions.
#' @param antigen_iso The antigen isotype to plot; for example, "HlyE_IgA" or
#' "HlyE_IgG".
#' @param dataset (Optional) A [dplyr::tbl_df] with observed antibody response
Expand All @@ -19,6 +21,9 @@
#' - `value`
#' - `id`
#' - `antigen_iso`
#' @param overlay_strata [logical]; if [TRUE], overlays marginal predictions
#' for each stratum in the same panel when `ids = "newperson"`.
#' Different strata are distinguished by color. Defaults to [FALSE].
#' @param legend_obs Label for observed data in the legend.
#' @param legend_median Label for the median prediction line.
#' @param show_quantiles [logical]; if [TRUE] (default), plots the 2.5%, 50%,
Expand Down Expand Up @@ -49,6 +54,7 @@ plot_predicted_curve <- function(model,
ids,
antigen_iso,
dataset = NULL,
overlay_strata = FALSE,
legend_obs = "Observed data",
legend_median = "Median prediction",
show_quantiles = TRUE,
Expand All @@ -68,16 +74,24 @@ plot_predicted_curve <- function(model,
.data$Iso_type == antigen_iso # e.g. "HlyE_IgA"
)

# Check if Stratification column exists in the model
has_stratification <- "Stratification" %in% names(sr_model_sub)

# Determine if we're in stratification overlay mode
is_stratified_overlay <- overlay_strata &&
ids[1] == "newperson" &&
has_stratification

# Select columns for pivoting (including Stratification if present)
select_cols <- c("Chain", "Iteration", "Iso_type", "Parameter",
"value", "Subject")
if (has_stratification) {
select_cols <- c(select_cols, "Stratification")
}

# Pivot to wide format: one row per iteration/chain
param_medians_wide <- sr_model_sub |>
dplyr::select(
all_of(c("Chain",
"Iteration",
"Iso_type",
"Parameter",
"value",
"Subject"))
) |>
dplyr::select(all_of(select_cols)) |>
tidyr::pivot_wider(
names_from = c("Parameter"),
values_from = c("value")
Expand All @@ -90,6 +104,12 @@ plot_predicted_curve <- function(model,
) |>
dplyr::select(-c("Iso_type", "Subject"))

# Add stratum factor if in stratified overlay mode
if (is_stratified_overlay) {
param_medians_wide <- param_medians_wide |>
dplyr::mutate(stratum = as.factor(.data$Stratification))
}

# Add sample_id if not present (to identify individual samples)
if (!"sample_id" %in% names(param_medians_wide)) {
param_medians_wide <- param_medians_wide |>
Expand Down Expand Up @@ -135,37 +155,73 @@ plot_predicted_curve <- function(model,

# If show_all_curves is TRUE, overlay all individual sampled curves.
if (show_all_curves) {
p <- p +
ggplot2::geom_line(data = serocourse_all1,
ggplot2::aes(x = .data$t,
y = .data$res,
group = .data$sample_id,
color = "samples"),
alpha = alpha_samples)
if (is_stratified_overlay) {
# For stratified overlay, color by stratum
p <- p +
ggplot2::geom_line(data = serocourse_all1,
ggplot2::aes(x = .data$t,
y = .data$res,
group = .data$sample_id,
color = .data$stratum),
alpha = alpha_samples)
} else {
# For non-stratified, use single "samples" color
p <- p +
ggplot2::geom_line(data = serocourse_all1,
ggplot2::aes(x = .data$t,
y = .data$res,
group = .data$sample_id,
color = "samples"),
alpha = alpha_samples)
}
}

# --- Summarize & Plot Model 1 (Median + 95% Ribbon) ---
if (show_quantiles) {
sum1 <- serocourse_all1 |>
dplyr::summarise(
.by = all_of(c("id", "t")),
res.med = stats::quantile(.data$res, probs = 0.50, na.rm = TRUE),
res.low = stats::quantile(.data$res, probs = 0.025, na.rm = TRUE),
res.high = stats::quantile(.data$res, probs = 0.975, na.rm = TRUE)
)

p <- p +
ggplot2::geom_ribbon(data = sum1,
# Determine grouping variables based on stratification mode
if (is_stratified_overlay) {
sum1 <- serocourse_all1 |>
dplyr::summarise(
.by = all_of(c("stratum", "t")),
res.med = stats::quantile(.data$res, probs = 0.50, na.rm = TRUE),
res.low = stats::quantile(.data$res, probs = 0.025, na.rm = TRUE),
res.high = stats::quantile(.data$res, probs = 0.975, na.rm = TRUE)
)

p <- p +
ggplot2::geom_ribbon(data = sum1,
ggplot2::aes(x = .data$t,
ymin = .data$res.low,
ymax = .data$res.high,
fill = .data$stratum),
alpha = 0.2, inherit.aes = FALSE) +
ggplot2::geom_line(data = sum1,
ggplot2::aes(x = .data$t,
ymin = .data$res.low,
ymax = .data$res.high,
fill = "ci"),
alpha = 0.2, inherit.aes = FALSE) +
ggplot2::geom_line(data = sum1,
ggplot2::aes(x = .data$t,
y = .data$res.med,
color = "median"),
linewidth = 1, inherit.aes = FALSE)
y = .data$res.med,
color = .data$stratum),
linewidth = 1, inherit.aes = FALSE)
} else {
sum1 <- serocourse_all1 |>
dplyr::summarise(
.by = all_of(c("id", "t")),
res.med = stats::quantile(.data$res, probs = 0.50, na.rm = TRUE),
res.low = stats::quantile(.data$res, probs = 0.025, na.rm = TRUE),
res.high = stats::quantile(.data$res, probs = 0.975, na.rm = TRUE)
)

p <- p +
ggplot2::geom_ribbon(data = sum1,
ggplot2::aes(x = .data$t,
ymin = .data$res.low,
ymax = .data$res.high,
fill = "ci"),
alpha = 0.2, inherit.aes = FALSE) +
ggplot2::geom_line(data = sum1,
ggplot2::aes(x = .data$t,
y = .data$res.med,
color = "median"),
linewidth = 1, inherit.aes = FALSE)
}
}

# --- Overlay Observed Data (if provided) ---
Expand Down Expand Up @@ -199,37 +255,54 @@ plot_predicted_curve <- function(model,
}

# --- Construct Unified Legend ---
color_vals <- c("median" = "red")
color_labels <- c("median" = legend_median)
fill_vals <- c("ci" = "red")
fill_labels <- c("ci" = "95% credible interval")

if (show_all_curves) {
color_vals["samples"] <- "gray"
color_labels["samples"] <- "Posterior samples"
}

if (!is.null(dataset)) {
color_vals["observed"] <- "blue"
color_labels["observed"] <- legend_obs
if (is_stratified_overlay) {
# For stratified overlay, use stratum colors directly
strata_levels <- unique(param_medians_wide$stratum)
n_strata <- length(strata_levels)

# Use ggplot2 default colors or a color scale
strata_colors <- scales::hue_pal()(n_strata)
names(strata_colors) <- strata_levels

p <- p +
ggplot2::scale_color_discrete(name = "Stratum") +
ggplot2::scale_fill_discrete(name = "Stratum")

} else {
# Original legend construction for non-stratified plots
color_vals <- c("median" = "red")
color_labels <- c("median" = legend_median)
fill_vals <- c("ci" = "red")
fill_labels <- c("ci" = "95% credible interval")

if (show_all_curves) {
color_vals["samples"] <- "gray"
color_labels["samples"] <- "Posterior samples"
}

if (!is.null(dataset)) {
color_vals["observed"] <- "blue"
color_labels["observed"] <- legend_obs
}

p <- p +
ggplot2::scale_color_manual(
name = "",
values = color_vals,
labels = color_labels,
guide = ggplot2::guide_legend(override.aes = list(shape = NA))
) +
ggplot2::scale_fill_manual(
name = "",
values = fill_vals,
labels = fill_labels,
guide = ggplot2::guide_legend(override.aes = list(color = NA))
)
}

p <- p +
ggplot2::scale_color_manual(
name = "",
values = color_vals,
labels = color_labels,
guide = ggplot2::guide_legend(override.aes = list(shape = NA))
) +
ggplot2::scale_fill_manual(
name = "",
values = fill_vals,
labels = fill_labels,
guide = ggplot2::guide_legend(override.aes = list(color = NA))
)

# --- Optionally facet by ID ---
if (facet_by_id) {
# Don't facet when in stratified overlay mode
if (facet_by_id && !is_stratified_overlay) {
if (is.null(ncol)) {
n_ids <- length(unique(param_medians_wide$id))
ncol <- if (n_ids == 1) {
Expand Down
22 changes: 22 additions & 0 deletions inst/examples/examples-plot_predicted_curve.R
Original file line number Diff line number Diff line change
Expand Up @@ -51,3 +51,25 @@ p4 <- plot_predicted_curve(
facet_by_id = TRUE
)
print(p4)

# Stratified marginal overlay plot (newperson with multiple strata):
p5 <- plot_predicted_curve(
model = sees_model,
ids = "newperson",
antigen_iso = "HlyE_IgA",
overlay_strata = TRUE,
show_quantiles = TRUE,
log_y = FALSE
)
print(p5)

# Stratified marginal overlay with log y-axis:
p6 <- plot_predicted_curve(
model = sees_model,
ids = "newperson",
antigen_iso = "HlyE_IgA",
overlay_strata = TRUE,
show_quantiles = TRUE,
log_y = TRUE
)
print(p6)
33 changes: 31 additions & 2 deletions man/plot_predicted_curve.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Loading