diff --git a/DESCRIPTION b/DESCRIPTION index 98484c2..1f6db52 100644 --- a/DESCRIPTION +++ b/DESCRIPTION @@ -25,3 +25,5 @@ Imports: tidyselect, modelr, parallel +Suggests: + unittest (>= 1.4) diff --git a/R/g3_fit.R b/R/g3_fit.R index fd1f203..dcadfaf 100644 --- a/R/g3_fit.R +++ b/R/g3_fit.R @@ -14,22 +14,21 @@ g3_fit <- function(model, printatstart = 1){ ## Checks - stopifnot(is.list(params)) + stopifnot(is.data.frame(params)) # NB: Only accept table params, not list. The list form will likely disappear eventually if (!(printatstart %in% c(0,1))){ stop("The printatstart argument must be '0' or '1' (class numeric)") } if (inherits(model, "g3_r")) { - if (is.data.frame(params)) params <- params$value - if ("report_detail" %in% names(params) && printatstart == 1) { - params$report_detail <- 1L - tmp <- attributes(model(params)) - data_env <- environment(model) - } else { - tmp <- NULL - } + if ("report_detail" %in% rownames(params) && printatstart == 1) { + params["report_detail", "value"] <- 1L + tmp <- attributes(model(params)) + data_env <- environment(model) + } else { + tmp <- NULL + } } else if (inherits(model, "g3_cpp")) { - if (is.data.frame(params) && "report_detail" %in% params$switch && printatstart == 1) { + if ("report_detail" %in% params$switch && printatstart == 1) { params['report_detail', 'value'] <- 1L obj_fun <- gadget3::g3_tmb_adfun(model, params, type = 'Fun') tmp <- obj_fun$report(gadget3::g3_tmb_par(params)) @@ -69,9 +68,9 @@ g3_fit <- function(model, model <- gadget3::g3_to_r(re_actions) ## Run model - if (is.data.frame(params)) params <- params$value - params$report_detail <- 1L - tmp <- attributes(model(params)) + r_params <- if (is.data.frame(params)) params$value else params + r_params$report_detail <- 1L + tmp <- attributes(model(r_params)) data_env <- environment(model) } diff --git a/tests/test-g3_fit-modelformats.R b/tests/test-g3_fit-modelformats.R new file mode 100644 index 0000000..c43f706 --- /dev/null +++ b/tests/test-g3_fit-modelformats.R @@ -0,0 +1,113 @@ +library(unittest) +library(gadget3) +library(gadgetutils) + +library(unittest) + +options('gadget3.tmb.compile_flags', c("-O0", "-g")) + +# Replace function with new one, optionally returning to normal after expr +mock_functions <- function(ns, new_funcs, expr) { + assign_list <- function (ns, replacements) { + for (k in names(replacements)) { + assignInNamespace(k, replacements[[k]], ns) + } + } + + # Replace temporarily, put the old ones back again + old_funcs <- structure( + lapply(names(new_funcs), function(n) getFromNamespace(n, ns)), + names = names(new_funcs)) + tryCatch({ + assign_list(ns, new_funcs) + expr + }, finally = { + assign_list(ns, old_funcs) + }) +} + +# We don't actually care what g3_fit_inner does (and it can't process such a minimal model), +# we just want to see what the tmp parameter looks like +mock_g3_fit <- function (...) { + tmp <- mock_functions('gadgetutils', list( + g3_fit_inner = function (...) list(...) ), g3_fit(...))[[1]] + return(tmp) +} + +stocks <- list( + imm = g3_stock(c(species = "fish", "imm"), 1:10 * 10) |> g3s_age(1, 5) ) + +actions <- list( + g3a_time(2000, 2001), + g3a_age(stocks$imm), + g3a_initialconditions_normalcv(stocks$imm), + NULL ) + +# Generate model with some parameters to start with +model_cpp <- g3_to_tmb(actions) +attr(model_cpp, "parameter_template") |> + g3_init_val("*.init.scalar", 10, optimise = FALSE) |> + g3_init_val("*.init.#", 10, lower = 0.001, upper = 1e6) |> + g3_init_val("*.M.#", 0.5, lower = 0.1, upper = 0.8) |> + g3_init_val("init.F", 0.5, lower = 0.1, upper = 10) |> + g3_init_val("*_imm.Linf", 144.645) |> + g3_init_val("*.K", 0.3, lower = 0.04, upper = 1.2) |> + g3_init_val("*.t0", -0.8, optimise = FALSE) |> + g3_init_val("*.walpha", 0.01, optimise = FALSE) |> + g3_init_val("*.wbeta", 3, optimise = FALSE) |> + identity() -> params.def + +ok_group("R model, with reporting", local({ + model <- g3_to_r(c(actions, list( + g3a_report_detail(actions) ))) + params.in <- params.def + params.in["report_detail", "switch"] <- "report_detail" + params.in["report_detail", "value"] <- 1 + attr(params.in, "summary") <- data.frame(a = runif(1)) # Generated by g3_optim() usually + + ok(ut_cmp_error(mock_g3_fit(model, params.in$value), "params"), "Parameter list isn't allowed") + + tmp <- mock_g3_fit(model, params.in) + ok(is.data.frame(tmp$model_params), "model_params returned as data.frame") + ok("report_detail" %in% rownames(tmp$model_params), "model_params has report_detail") + ok(ut_cmp_equal(tmp$model_params, params.in), "model_params match params.in") + ok(ut_cmp_equal(attr(params.in, "summary"), attr(tmp$model_params, "summary")), "summary attribute preserved") +})) + +ok_group("R model, without reporting", local({ + model <- g3_to_r(c(actions, list( ))) + params.in <- params.def + attr(params.in, "summary") <- data.frame(a = runif(1)) # Generated by g3_optim() usually + + ok(ut_cmp_error(mock_g3_fit(model, params.in$value), "params"), "Parameter list isn't allowed") + + tmp <- mock_g3_fit(model, params.in) + ok(is.data.frame(tmp$model_params), "model_params returned as data.frame") + ok(ut_cmp_equal(tmp$model_params, params.in), "model_params match params.in (i.e. report_detail hasn't leaked in)") + ok(ut_cmp_equal(attr(params.in, "summary"), attr(tmp$model_params, "summary")), "summary attribute preserved") +})) + +ok_group("tmb model, with reporting", local({ + model <- g3_to_tmb(c(actions, list( + g3a_report_detail(actions) ))) + params.in <- attr(model, "parameter_template") + for (n in params.def$switch) params.in[n,] <- params.def[n,] + attr(params.in, "summary") <- data.frame(a = runif(1)) # Generated by g3_optim() usually + + tmp <- mock_g3_fit(model, params.in) + ok(is.data.frame(tmp$model_params), "model_params returned as data.frame") + ok("report_detail" %in% rownames(tmp$model_params), "model_params has report_detail") + ok(ut_cmp_equal(tmp$model_params, params.in), "model_params match params.in") + ok(ut_cmp_equal(attr(params.in, "summary"), attr(tmp$model_params, "summary")), "summary attribute preserved") +})) + +ok_group("tmb model, without reporting", local({ + model <- g3_to_tmb(c(actions, list( ))) + params.in <- params.def + attr(params.in, "summary") <- data.frame(a = runif(1)) # Generated by g3_optim() usually + + tmp <- mock_g3_fit(model, params.in) + ok(is.data.frame(tmp$model_params), "model_params returned as data.frame") + ok(ut_cmp_equal(tmp$model_params, params.in), "model_params match params.in (i.e. report_detail hasn't leaked in)") + ok(ut_cmp_equal(attr(params.in, "summary"), attr(tmp$model_params, "summary")), "summary attribute preserved") +}))