Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions DESCRIPTION
Original file line number Diff line number Diff line change
Expand Up @@ -25,3 +25,5 @@ Imports:
tidyselect,
modelr,
parallel
Suggests:
unittest (>= 1.4)
25 changes: 12 additions & 13 deletions R/g3_fit.R
Original file line number Diff line number Diff line change
Expand Up @@ -14,22 +14,21 @@ g3_fit <- function(model,
printatstart = 1){

## Checks
stopifnot(is.list(params))
stopifnot(is.data.frame(params)) # NB: Only accept table params, not list. The list form will likely disappear eventually
if (!(printatstart %in% c(0,1))){
stop("The printatstart argument must be '0' or '1' (class numeric)")
}

if (inherits(model, "g3_r")) {
if (is.data.frame(params)) params <- params$value
if ("report_detail" %in% names(params) && printatstart == 1) {
params$report_detail <- 1L
tmp <- attributes(model(params))
data_env <- environment(model)
} else {
tmp <- NULL
}
if ("report_detail" %in% rownames(params) && printatstart == 1) {
params["report_detail", "value"] <- 1L
tmp <- attributes(model(params))
data_env <- environment(model)
} else {
tmp <- NULL
}
} else if (inherits(model, "g3_cpp")) {
if (is.data.frame(params) && "report_detail" %in% params$switch && printatstart == 1) {
if ("report_detail" %in% params$switch && printatstart == 1) {
params['report_detail', 'value'] <- 1L
obj_fun <- gadget3::g3_tmb_adfun(model, params, type = 'Fun')
tmp <- obj_fun$report(gadget3::g3_tmb_par(params))
Expand Down Expand Up @@ -69,9 +68,9 @@ g3_fit <- function(model,
model <- gadget3::g3_to_r(re_actions)

## Run model
if (is.data.frame(params)) params <- params$value
params$report_detail <- 1L
tmp <- attributes(model(params))
r_params <- if (is.data.frame(params)) params$value else params
r_params$report_detail <- 1L
tmp <- attributes(model(r_params))
data_env <- environment(model)
}

Expand Down
113 changes: 113 additions & 0 deletions tests/test-g3_fit-modelformats.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,113 @@
library(unittest)
library(gadget3)
library(gadgetutils)

library(unittest)

options('gadget3.tmb.compile_flags', c("-O0", "-g"))

# Replace function with new one, optionally returning to normal after expr
mock_functions <- function(ns, new_funcs, expr) {
assign_list <- function (ns, replacements) {
for (k in names(replacements)) {
assignInNamespace(k, replacements[[k]], ns)
}
}

# Replace temporarily, put the old ones back again
old_funcs <- structure(
lapply(names(new_funcs), function(n) getFromNamespace(n, ns)),
names = names(new_funcs))
tryCatch({
assign_list(ns, new_funcs)
expr
}, finally = {
assign_list(ns, old_funcs)
})
}

# We don't actually care what g3_fit_inner does (and it can't process such a minimal model),
# we just want to see what the tmp parameter looks like
mock_g3_fit <- function (...) {
tmp <- mock_functions('gadgetutils', list(
g3_fit_inner = function (...) list(...) ), g3_fit(...))[[1]]
return(tmp)
}

stocks <- list(
imm = g3_stock(c(species = "fish", "imm"), 1:10 * 10) |> g3s_age(1, 5) )

actions <- list(
g3a_time(2000, 2001),
g3a_age(stocks$imm),
g3a_initialconditions_normalcv(stocks$imm),
NULL )

# Generate model with some parameters to start with
model_cpp <- g3_to_tmb(actions)
attr(model_cpp, "parameter_template") |>
g3_init_val("*.init.scalar", 10, optimise = FALSE) |>
g3_init_val("*.init.#", 10, lower = 0.001, upper = 1e6) |>
g3_init_val("*.M.#", 0.5, lower = 0.1, upper = 0.8) |>
g3_init_val("init.F", 0.5, lower = 0.1, upper = 10) |>
g3_init_val("*_imm.Linf", 144.645) |>
g3_init_val("*.K", 0.3, lower = 0.04, upper = 1.2) |>
g3_init_val("*.t0", -0.8, optimise = FALSE) |>
g3_init_val("*.walpha", 0.01, optimise = FALSE) |>
g3_init_val("*.wbeta", 3, optimise = FALSE) |>
identity() -> params.def

ok_group("R model, with reporting", local({
model <- g3_to_r(c(actions, list(
g3a_report_detail(actions) )))
params.in <- params.def
params.in["report_detail", "switch"] <- "report_detail"
params.in["report_detail", "value"] <- 1
attr(params.in, "summary") <- data.frame(a = runif(1)) # Generated by g3_optim() usually

ok(ut_cmp_error(mock_g3_fit(model, params.in$value), "params"), "Parameter list isn't allowed")

tmp <- mock_g3_fit(model, params.in)
ok(is.data.frame(tmp$model_params), "model_params returned as data.frame")
ok("report_detail" %in% rownames(tmp$model_params), "model_params has report_detail")
ok(ut_cmp_equal(tmp$model_params, params.in), "model_params match params.in")
ok(ut_cmp_equal(attr(params.in, "summary"), attr(tmp$model_params, "summary")), "summary attribute preserved")
}))

ok_group("R model, without reporting", local({
model <- g3_to_r(c(actions, list( )))
params.in <- params.def
attr(params.in, "summary") <- data.frame(a = runif(1)) # Generated by g3_optim() usually

ok(ut_cmp_error(mock_g3_fit(model, params.in$value), "params"), "Parameter list isn't allowed")

tmp <- mock_g3_fit(model, params.in)
ok(is.data.frame(tmp$model_params), "model_params returned as data.frame")
ok(ut_cmp_equal(tmp$model_params, params.in), "model_params match params.in (i.e. report_detail hasn't leaked in)")
ok(ut_cmp_equal(attr(params.in, "summary"), attr(tmp$model_params, "summary")), "summary attribute preserved")
}))

ok_group("tmb model, with reporting", local({
model <- g3_to_tmb(c(actions, list(
g3a_report_detail(actions) )))
params.in <- attr(model, "parameter_template")
for (n in params.def$switch) params.in[n,] <- params.def[n,]
attr(params.in, "summary") <- data.frame(a = runif(1)) # Generated by g3_optim() usually

tmp <- mock_g3_fit(model, params.in)
ok(is.data.frame(tmp$model_params), "model_params returned as data.frame")
ok("report_detail" %in% rownames(tmp$model_params), "model_params has report_detail")
ok(ut_cmp_equal(tmp$model_params, params.in), "model_params match params.in")
ok(ut_cmp_equal(attr(params.in, "summary"), attr(tmp$model_params, "summary")), "summary attribute preserved")
}))

ok_group("tmb model, without reporting", local({
model <- g3_to_tmb(c(actions, list( )))
params.in <- params.def
attr(params.in, "summary") <- data.frame(a = runif(1)) # Generated by g3_optim() usually

tmp <- mock_g3_fit(model, params.in)
ok(is.data.frame(tmp$model_params), "model_params returned as data.frame")
ok(ut_cmp_equal(tmp$model_params, params.in), "model_params match params.in (i.e. report_detail hasn't leaked in)")
ok(ut_cmp_equal(attr(params.in, "summary"), attr(tmp$model_params, "summary")), "summary attribute preserved")
}))
Loading