Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ export(g3_data)
export(g3_fit)
export(g3_init_guess)
export(g3_iterative)
export(g3_iterative_default_grouping)
export(g3_iterative_setup)
export(g3_jitter)
export(g3_leaveout)
Expand Down
50 changes: 49 additions & 1 deletion R/g3_iterative.R
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@
#' @export
g3_iterative <- function(gd, wgts = 'WGTS',
model, params.in,
grouping = list(),
grouping = g3_iterative_default_grouping(params.in),
use_parscale = TRUE,
method = 'BFGS',
control = list(),
Expand Down Expand Up @@ -374,6 +374,54 @@ g3_iterative <- function(gd, wgts = 'WGTS',
return(params_final)
}

# Generate default grouping, combine all fleet likelihoods into one group
# NB: For this to work, nll_names need to be in the form (nll_source)_(nll_dist), where (nll_dist) matches one of the (nll_dist_names)
#' @param params.in Initial parameters to use with the model
#' @param nll_dist_names Character vector of postfixes to consider when looking for groupings
#' @return
#' \subsection{g3_iterative_default_grouping}{
#' A list of component groups to component names, as required by the \var{grouping} parameter
#' }
#' @details
#' \subsection{g3_iterative_default_grouping}{
#' This assumes that your likelihood component names are of the form ``(nll_group)_(nll_dist)``,
#' where ``(nll_dist)`` matches one of the regexes in \var{nll_dist_names}.
#' For example, ``afleet_ldist``, ``afleet_aldist``, ``bfleet_ldist``. ``afleet`` & ``bfleet`` will be the groups used.
#' }
#' @rdname g3_iterative
#' @export
g3_iterative_default_grouping <- function (params.in, nll_dist_names = c("ldist", "aldist", "matp", "sexdist", "SI", "len\\d+SI")) {
# Extract all likelihood component weight names from params.in
weight_re <- paste0(
"^",
"(?<dist>.dist|.sparse)_",
"(?<function>surveyindices_log|[a-z]+)_",
"(?<nll_source>.+)_",
"(?<nll_dist>", paste0(nll_dist_names, collapse = "|"), ")_",
"weight$"
)

# Break up names into a data.frame of param_name -> regex groups
weight_names <- grep(weight_re, rownames(params.in), value = TRUE, perl = TRUE)
weight_parts <- as.data.frame(do.call(rbind, regmatches(weight_names, regexec(weight_re, weight_names, perl = TRUE))))
names(weight_parts)[[1]] <- "param_name"
weight_parts$value <- params.in[weight_parts$param_name, "value"]

# Remove any zero-weighted parameters
zero_value <- weight_parts[weight_parts$value == 0, "param_name"]
if (length(zero_value) > 0) {
warning("Parameters ", paste(zero_value, collapse = ", ") , " have a value of 0, removing from grouping")
weight_parts <- weight_parts[weight_parts$value > 0,]
}

# Group rows together into a list of nll_source -> vector of (nll_source)_(nll_dist)
sapply(
unique(weight_parts$nll_source),
function (nll_source) paste0(nll_source, "_", weight_parts[weight_parts$nll_source == nll_source, "nll_dist"]),
simplify = FALSE
)
}

#' @title Initial parameters for iterative re-weighting
#' @param lik_out A likelihood summary dataframe. The output of g3_lik_out(model, param)
#' @param grouping A list describing how to group likelihood components for iterative re-weighting
Expand Down
20 changes: 19 additions & 1 deletion man/g3_iterative.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

49 changes: 49 additions & 0 deletions tests/test-g3_iterative-default_grouping.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
library(unittest)
library(gadgetutils)

library(unittest)

# Convert a string into a data.frame
table_string <- function (text, ...) {
out <- read.table(
text = text,
blank.lines.skip = TRUE,
header = TRUE,
stringsAsFactors = FALSE,
...)
rownames(out) <- out$switch
return(out)
}

ok(ut_cmp_identical(g3_iterative_default_grouping(table_string('
switch value
cdist_sumofsquares_comm_ldist_weight 1
cdist_sumofsquares_comm_aldist_weight 1
cdist_sumofsquares_comm_argle_weight 1
cdist_sumofsquares_comm_matp_weight 1
cdist_sumofsquares_fgn_ldist_weight 1
cdist_sumofsquares_fgn_aldist_weight 1
cdist_surveyindices_log_surv_si_weight 1
'), nll_dist_names = c("ldist", "aldist", "matp", "si")), list(
# NB: argle is missing
comm = c("comm_ldist", "comm_aldist", "comm_matp"),
fgn = c("fgn_ldist", "fgn_aldist"),
# NB: parsed the awkward surveyindices_log
surv = c("surv_si")
)))

ok(ut_cmp_identical(suppressWarnings(g3_iterative_default_grouping(table_string('
switch value
cdist_sumofsquares_comm_ldist_weight 1
cdist_sumofsquares_comm_aldist_weight 1
cdist_sumofsquares_comm_argle_weight 1
cdist_sumofsquares_comm_matp_weight 1
cdist_sumofsquares_fgn_ldist_weight 1
cdist_sumofsquares_fgn_aldist_weight 0
cdist_surveyindices_log_surv_si_weight 1
'), nll_dist_names = c("ldist", "aldist", "matp", "si"))), list(
comm = c("comm_ldist", "comm_aldist", "comm_matp"),
# NB: zero-weighted doesn't count
fgn = c("fgn_ldist"),
surv = c("surv_si")
)))
Loading