Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 25 additions & 3 deletions R/iterate_dynamic_function.R
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ iterate_dynamic_function <- function(
...,
parameter_is_time_varying = c(),
state_limits = c(-Inf, Inf)
) {
) {

# generalise checking of inputs from iterate_matrix into functions
niter <- as.integer(niter)
Expand All @@ -96,6 +96,7 @@ iterate_dynamic_function <- function(
"{.var initial_state} must be either a column vector, or a 3D array \\
with final dimension 1"
)

}

# if this is multisite
Expand All @@ -115,10 +116,21 @@ iterate_dynamic_function <- function(
)
}

if (length(dots) > 1 && is.null(names(dots))) {
stop("all arguments passed to the transition function ",
"must be explicitly named",
call. = FALSE)
}

# handle time-varying parameters, sending only a slice to the function when
# converting to TF
for (name in parameter_is_time_varying) {
dots[[name]] <- slice_first_dim(dots[[name]], 1)
res <- slice_first_dim(dots[[name]], 1)
# if the array is 2d, transpose it so it's a column vector not a row vector
if (length(dim(res)) == 2) {
res <- t(res)
}
dots[[name]] <- res
}

# get index to time-varying parameters in a list
Expand Down Expand Up @@ -205,6 +217,7 @@ as_tf_transition_function <- function (transition_function, state, iter, dots) {

# tf_dots will have been added to this environment by
# tf_iterate_dynamic_function
# tf_iterate_dynamic_matrix
args <- list(state = state, iter = iter)
do.call(tf_fun, c(args, tf_dots))

Expand Down Expand Up @@ -269,6 +282,13 @@ tf_iterate_dynamic_function <- function (state,
batch_size,
envir = greta::.internals$greta_stash)

# slice up the relevant parameters dots as needed
tf_dots <- environment(tf_transition_function)$tf_dots
for(index in parameter_is_time_varying_index) {
tf_dots[[index]] <- tf_slice_first_dim(tf_dots[[index]], iter)
}
assign("tf_dots", tf_dots,
environment(tf_transition_function))
# evaluate function to get the new state (dots have been inserted into its
# environment, since TF while loops are treacherous things)
new_state <- tf_transition_function(old_state, iter)
Expand Down Expand Up @@ -400,7 +420,6 @@ tf_extract_stable_population <- function (results) {
# given a greta array, tensor, or array, extract the 'element'th element on
# the first dimmension, preserving all other dimensions
slice_first_dim <- function(x, element) {

# if it's a vector, just return like this
if (is.vector(x)) {
return(x[element])
Expand Down Expand Up @@ -449,3 +468,6 @@ tf_slice_first_dim <- function(x, element) {
x_out

}

# drop is ignored when element is a tensor. Use an alternate slicing interface
# for the tensorflow version?
29 changes: 29 additions & 0 deletions tests/testthat/helpers.R
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,35 @@ r_iterate_dynamic_function <- function(transition_function,
converged = as.integer(diff < tol),
max_iter = i)
}
#
# r_iterate_dynamic_function <- function(transition_function,
# initial_state,
# niter = 100,
# tol = 1e-6,
# ...) {
#
# states <- list(initial_state)
#
# i <- 0L
# diff <- Inf
#
# while(i < niter & diff > tol) {
# i <- i + 1L
# states[[i + 1]] <- transition_function(states[[i]], i, ...)
# growth <- states[[i + 1]] / states[[i]]
# diffs <- growth - 1
# diff <- max(abs(diffs))
# }
#
# all_states <- matrix(0, length(states[[1]]), niter)
# states_keep <- states[-1]
# all_states[, seq_along(states_keep)] <- t(do.call(rbind, states_keep))
#
# list(stable_state = states[[i]],
# all_states = all_states,
# converged = as.integer(diff < tol),
# max_iter = i)
# }

# a midpoint solver for use in deSolve, from the vignette p8
rk_midpoint <- deSolve::rkMethod(
Expand Down
19 changes: 15 additions & 4 deletions tests/testthat/test_iterate_dynamic_function.R
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
test_that("single iteration works", {
skip_if_not(check_tf_version())
set.seed(2017 - 05 - 01)

n <- 4
init <- rep(1, n)
niter <- 100
Expand Down Expand Up @@ -60,8 +59,6 @@ test_that("single iteration works", {

})



test_that("iteration works with time-varying parameters", {
skip_if_not(check_tf_version())
set.seed(2017 - 05 - 01)
Expand All @@ -78,7 +75,6 @@ test_that("iteration works with time-varying parameters", {
fun <- function(state, iter, x) {

# make fecundity a Ricker-like function of the total population, with random
# fluctuations on each state
Nt <- sum(state)
K <- 100
ratio <- exp(1 - Nt / K)
Expand All @@ -97,6 +93,7 @@ test_that("iteration works with time-varying parameters", {
parameter_is_time_varying = "x"
)

# target_stable <- r_iterates$stable_state
target_states <- r_iterates$all_states

# greta version
Expand All @@ -109,10 +106,24 @@ test_that("iteration works with time-varying parameters", {
parameter_is_time_varying = "x"
)

# states <- iterates$all_states
stable <- iterates$stable_population
states <- iterates$all_states
converged <- iterates$converged
iterations <- iterates$iterations

greta_stable <- calculate(stable)[[1]]
difference <- abs(greta_stable - target_stable)
expect_true(all(difference < test_tol))

greta_states <- calculate(states)[[1]]
difference <- abs(greta_states - target_states)
expect_true(all(difference < test_tol))

greta_converged <- calculate(converged)[[1]]
expect_true(greta_converged == 1)

greta_iterations <- calculate(iterations)[[1]]
expect_lt(greta_iterations, niter)

})
Loading