diff --git a/R/iterate_dynamic_function.R b/R/iterate_dynamic_function.R index bca1f5c..02a570a 100644 --- a/R/iterate_dynamic_function.R +++ b/R/iterate_dynamic_function.R @@ -78,7 +78,7 @@ iterate_dynamic_function <- function( ..., parameter_is_time_varying = c(), state_limits = c(-Inf, Inf) -) { + ) { # generalise checking of inputs from iterate_matrix into functions niter <- as.integer(niter) @@ -96,6 +96,7 @@ iterate_dynamic_function <- function( "{.var initial_state} must be either a column vector, or a 3D array \\ with final dimension 1" ) + } # if this is multisite @@ -115,10 +116,21 @@ iterate_dynamic_function <- function( ) } + if (length(dots) > 1 && is.null(names(dots))) { + stop("all arguments passed to the transition function ", + "must be explicitly named", + call. = FALSE) + } + # handle time-varying parameters, sending only a slice to the function when # converting to TF for (name in parameter_is_time_varying) { - dots[[name]] <- slice_first_dim(dots[[name]], 1) + res <- slice_first_dim(dots[[name]], 1) + # if the array is 2d, transpose it so it's a column vector not a row vector + if (length(dim(res)) == 2) { + res <- t(res) + } + dots[[name]] <- res } # get index to time-varying parameters in a list @@ -205,6 +217,7 @@ as_tf_transition_function <- function (transition_function, state, iter, dots) { # tf_dots will have been added to this environment by # tf_iterate_dynamic_function + # tf_iterate_dynamic_matrix args <- list(state = state, iter = iter) do.call(tf_fun, c(args, tf_dots)) @@ -269,6 +282,13 @@ tf_iterate_dynamic_function <- function (state, batch_size, envir = greta::.internals$greta_stash) + # slice up the relevant parameters dots as needed + tf_dots <- environment(tf_transition_function)$tf_dots + for(index in parameter_is_time_varying_index) { + tf_dots[[index]] <- tf_slice_first_dim(tf_dots[[index]], iter) + } + assign("tf_dots", tf_dots, + environment(tf_transition_function)) # evaluate function to get the new state (dots have been inserted into its # environment, since TF while loops are treacherous things) new_state <- tf_transition_function(old_state, iter) @@ -400,7 +420,6 @@ tf_extract_stable_population <- function (results) { # given a greta array, tensor, or array, extract the 'element'th element on # the first dimmension, preserving all other dimensions slice_first_dim <- function(x, element) { - # if it's a vector, just return like this if (is.vector(x)) { return(x[element]) @@ -449,3 +468,6 @@ tf_slice_first_dim <- function(x, element) { x_out } + +# drop is ignored when element is a tensor. Use an alternate slicing interface +# for the tensorflow version? diff --git a/tests/testthat/helpers.R b/tests/testthat/helpers.R index 0031ac2..1055b67 100644 --- a/tests/testthat/helpers.R +++ b/tests/testthat/helpers.R @@ -117,6 +117,35 @@ r_iterate_dynamic_function <- function(transition_function, converged = as.integer(diff < tol), max_iter = i) } +# +# r_iterate_dynamic_function <- function(transition_function, +# initial_state, +# niter = 100, +# tol = 1e-6, +# ...) { +# +# states <- list(initial_state) +# +# i <- 0L +# diff <- Inf +# +# while(i < niter & diff > tol) { +# i <- i + 1L +# states[[i + 1]] <- transition_function(states[[i]], i, ...) +# growth <- states[[i + 1]] / states[[i]] +# diffs <- growth - 1 +# diff <- max(abs(diffs)) +# } +# +# all_states <- matrix(0, length(states[[1]]), niter) +# states_keep <- states[-1] +# all_states[, seq_along(states_keep)] <- t(do.call(rbind, states_keep)) +# +# list(stable_state = states[[i]], +# all_states = all_states, +# converged = as.integer(diff < tol), +# max_iter = i) +# } # a midpoint solver for use in deSolve, from the vignette p8 rk_midpoint <- deSolve::rkMethod( diff --git a/tests/testthat/test_iterate_dynamic_function.R b/tests/testthat/test_iterate_dynamic_function.R index f7acecd..cb0a0be 100644 --- a/tests/testthat/test_iterate_dynamic_function.R +++ b/tests/testthat/test_iterate_dynamic_function.R @@ -1,7 +1,6 @@ test_that("single iteration works", { skip_if_not(check_tf_version()) set.seed(2017 - 05 - 01) - n <- 4 init <- rep(1, n) niter <- 100 @@ -60,8 +59,6 @@ test_that("single iteration works", { }) - - test_that("iteration works with time-varying parameters", { skip_if_not(check_tf_version()) set.seed(2017 - 05 - 01) @@ -78,7 +75,6 @@ test_that("iteration works with time-varying parameters", { fun <- function(state, iter, x) { # make fecundity a Ricker-like function of the total population, with random - # fluctuations on each state Nt <- sum(state) K <- 100 ratio <- exp(1 - Nt / K) @@ -97,6 +93,7 @@ test_that("iteration works with time-varying parameters", { parameter_is_time_varying = "x" ) + # target_stable <- r_iterates$stable_state target_states <- r_iterates$all_states # greta version @@ -109,10 +106,24 @@ test_that("iteration works with time-varying parameters", { parameter_is_time_varying = "x" ) + # states <- iterates$all_states + stable <- iterates$stable_population states <- iterates$all_states + converged <- iterates$converged + iterations <- iterates$iterations + + greta_stable <- calculate(stable)[[1]] + difference <- abs(greta_stable - target_stable) + expect_true(all(difference < test_tol)) greta_states <- calculate(states)[[1]] difference <- abs(greta_states - target_states) expect_true(all(difference < test_tol)) + greta_converged <- calculate(converged)[[1]] + expect_true(greta_converged == 1) + + greta_iterations <- calculate(iterations)[[1]] + expect_lt(greta_iterations, niter) + })