From 7164f46bc771fed700ce686e4701b0f9c933fad6 Mon Sep 17 00:00:00 2001 From: Nick Golding Date: Sun, 27 Nov 2022 21:36:32 +0000 Subject: [PATCH 1/5] fix warning on unnamed arguments to iterate_dynamic_function --- R/iterate_dynamic_function.R | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/R/iterate_dynamic_function.R b/R/iterate_dynamic_function.R index bca1f5c..216de7c 100644 --- a/R/iterate_dynamic_function.R +++ b/R/iterate_dynamic_function.R @@ -115,6 +115,12 @@ iterate_dynamic_function <- function( ) } + if (length(dots) > 1 && is.null(names(dots))) { + stop("all arguments passed to the transition function ", + "must be explicitly named", + call. = FALSE) + } + # handle time-varying parameters, sending only a slice to the function when # converting to TF for (name in parameter_is_time_varying) { From 94f54e08730c09bb986f54eb0140f9cda52ef544 Mon Sep 17 00:00:00 2001 From: Nick Golding Date: Fri, 25 Nov 2022 11:38:32 +0000 Subject: [PATCH 2/5] fix tests --- R/iterate_dynamic_function.R | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/R/iterate_dynamic_function.R b/R/iterate_dynamic_function.R index 216de7c..62299b7 100644 --- a/R/iterate_dynamic_function.R +++ b/R/iterate_dynamic_function.R @@ -124,7 +124,12 @@ iterate_dynamic_function <- function( # handle time-varying parameters, sending only a slice to the function when # converting to TF for (name in parameter_is_time_varying) { - dots[[name]] <- slice_first_dim(dots[[name]], 1) + res <- slice_first_dim(dots[[name]], 1) + # if the array is 2d, transpose it so it's a column vector not a row vector + if (length(dim(res)) == 2) { + res <- t(res) + } + dots[[name]] <- res } # get index to time-varying parameters in a list From 59eb35caa6392e6e8530fd2582cd5bf6eb08d481 Mon Sep 17 00:00:00 2001 From: Nick Golding Date: Fri, 25 Nov 2022 11:16:48 +0000 Subject: [PATCH 3/5] some changes from a cherry pick commit --- R/iterate_dynamic_function.R | 12 +++++++++++- tests/testthat/test_iterate_dynamic_function.R | 16 +++++++++++++++- 2 files changed, 26 insertions(+), 2 deletions(-) diff --git a/R/iterate_dynamic_function.R b/R/iterate_dynamic_function.R index 62299b7..adac90c 100644 --- a/R/iterate_dynamic_function.R +++ b/R/iterate_dynamic_function.R @@ -280,6 +280,14 @@ tf_iterate_dynamic_function <- function (state, batch_size, envir = greta::.internals$greta_stash) + # slice up the relevant parameters dots as needed + tf_dots <- environment(tf_transition_function)$tf_dots + for(index in parameter_is_time_varying_index) { + tf_dots[[index]] <- tf_slice_first_dim(tf_dots[[index]], iter) + } + assign("tf_dots", tf_dots, + environment(tf_transition_function)) + # evaluate function to get the new state (dots have been inserted into its # environment, since TF while loops are treacherous things) new_state <- tf_transition_function(old_state, iter) @@ -411,7 +419,6 @@ tf_extract_stable_population <- function (results) { # given a greta array, tensor, or array, extract the 'element'th element on # the first dimmension, preserving all other dimensions slice_first_dim <- function(x, element) { - # if it's a vector, just return like this if (is.vector(x)) { return(x[element]) @@ -460,3 +467,6 @@ tf_slice_first_dim <- function(x, element) { x_out } + +# drop is ignored when element is a tensor. Use an alternate slicing interface +# for the tensorflow version? diff --git a/tests/testthat/test_iterate_dynamic_function.R b/tests/testthat/test_iterate_dynamic_function.R index f7acecd..18bb340 100644 --- a/tests/testthat/test_iterate_dynamic_function.R +++ b/tests/testthat/test_iterate_dynamic_function.R @@ -78,7 +78,6 @@ test_that("iteration works with time-varying parameters", { fun <- function(state, iter, x) { # make fecundity a Ricker-like function of the total population, with random - # fluctuations on each state Nt <- sum(state) K <- 100 ratio <- exp(1 - Nt / K) @@ -97,6 +96,7 @@ test_that("iteration works with time-varying parameters", { parameter_is_time_varying = "x" ) + # target_stable <- r_iterates$stable_state target_states <- r_iterates$all_states # greta version @@ -109,10 +109,24 @@ test_that("iteration works with time-varying parameters", { parameter_is_time_varying = "x" ) + # states <- iterates$all_states + stable <- iterates$stable_population states <- iterates$all_states + converged <- iterates$converged + iterations <- iterates$iterations + + greta_stable <- calculate(stable)[[1]] + difference <- abs(greta_stable - target_stable) + expect_true(all(difference < test_tol)) greta_states <- calculate(states)[[1]] difference <- abs(greta_states - target_states) expect_true(all(difference < test_tol)) + greta_converged <- calculate(converged)[[1]] + expect_true(greta_converged == 1) + + greta_iterations <- calculate(iterations)[[1]] + expect_lt(greta_iterations, niter) + }) From 17ee6d30927efb3aeb7c3a46e178e901188b8245 Mon Sep 17 00:00:00 2001 From: Nick Golding Date: Thu, 24 Nov 2022 16:31:45 +0000 Subject: [PATCH 4/5] various from cherry picked commits --- R/iterate_dynamic_function.R | 5 ++-- tests/testthat/helpers.R | 25 +++++++++++++++++++ .../testthat/test_iterate_dynamic_function.R | 3 --- 3 files changed, 28 insertions(+), 5 deletions(-) diff --git a/R/iterate_dynamic_function.R b/R/iterate_dynamic_function.R index adac90c..02a570a 100644 --- a/R/iterate_dynamic_function.R +++ b/R/iterate_dynamic_function.R @@ -78,7 +78,7 @@ iterate_dynamic_function <- function( ..., parameter_is_time_varying = c(), state_limits = c(-Inf, Inf) -) { + ) { # generalise checking of inputs from iterate_matrix into functions niter <- as.integer(niter) @@ -96,6 +96,7 @@ iterate_dynamic_function <- function( "{.var initial_state} must be either a column vector, or a 3D array \\ with final dimension 1" ) + } # if this is multisite @@ -216,6 +217,7 @@ as_tf_transition_function <- function (transition_function, state, iter, dots) { # tf_dots will have been added to this environment by # tf_iterate_dynamic_function + # tf_iterate_dynamic_matrix args <- list(state = state, iter = iter) do.call(tf_fun, c(args, tf_dots)) @@ -287,7 +289,6 @@ tf_iterate_dynamic_function <- function (state, } assign("tf_dots", tf_dots, environment(tf_transition_function)) - # evaluate function to get the new state (dots have been inserted into its # environment, since TF while loops are treacherous things) new_state <- tf_transition_function(old_state, iter) diff --git a/tests/testthat/helpers.R b/tests/testthat/helpers.R index 0031ac2..ef7dea1 100644 --- a/tests/testthat/helpers.R +++ b/tests/testthat/helpers.R @@ -118,6 +118,31 @@ r_iterate_dynamic_function <- function(transition_function, max_iter = i) } +r_iterate_dynamic_function <- function(transition_function, initial_state, niter = 100, tol = 1e-6, ...) { + + states <- list(initial_state) + + i <- 0L + diff <- Inf + + while(i < niter & diff > tol) { + i <- i + 1L + states[[i + 1]] <- transition_function(states[[i]], i, ...) + growth <- states[[i + 1]] / states[[i]] + diffs <- growth - 1 + diff <- max(abs(diffs)) + } + + all_states <- matrix(0, length(states[[1]]), niter) + states_keep <- states[-1] + all_states[, seq_along(states_keep)] <- t(do.call(rbind, states_keep)) + + list(stable_state = states[[i]], + all_states = all_states, + converged = as.integer(diff < tol), + max_iter = i) +} + # a midpoint solver for use in deSolve, from the vignette p8 rk_midpoint <- deSolve::rkMethod( ID = "midpoint", diff --git a/tests/testthat/test_iterate_dynamic_function.R b/tests/testthat/test_iterate_dynamic_function.R index 18bb340..cb0a0be 100644 --- a/tests/testthat/test_iterate_dynamic_function.R +++ b/tests/testthat/test_iterate_dynamic_function.R @@ -1,7 +1,6 @@ test_that("single iteration works", { skip_if_not(check_tf_version()) set.seed(2017 - 05 - 01) - n <- 4 init <- rep(1, n) niter <- 100 @@ -60,8 +59,6 @@ test_that("single iteration works", { }) - - test_that("iteration works with time-varying parameters", { skip_if_not(check_tf_version()) set.seed(2017 - 05 - 01) From 29b09738554b1659bad65e050af8257bb9018ef0 Mon Sep 17 00:00:00 2001 From: njtierney Date: Mon, 9 Dec 2024 16:34:19 +1100 Subject: [PATCH 5/5] remove duplicated/overloaded r_iterate_dynamic_function --- tests/testthat/helpers.R | 54 +++++++++++++++++++++------------------- 1 file changed, 29 insertions(+), 25 deletions(-) diff --git a/tests/testthat/helpers.R b/tests/testthat/helpers.R index ef7dea1..1055b67 100644 --- a/tests/testthat/helpers.R +++ b/tests/testthat/helpers.R @@ -117,31 +117,35 @@ r_iterate_dynamic_function <- function(transition_function, converged = as.integer(diff < tol), max_iter = i) } - -r_iterate_dynamic_function <- function(transition_function, initial_state, niter = 100, tol = 1e-6, ...) { - - states <- list(initial_state) - - i <- 0L - diff <- Inf - - while(i < niter & diff > tol) { - i <- i + 1L - states[[i + 1]] <- transition_function(states[[i]], i, ...) - growth <- states[[i + 1]] / states[[i]] - diffs <- growth - 1 - diff <- max(abs(diffs)) - } - - all_states <- matrix(0, length(states[[1]]), niter) - states_keep <- states[-1] - all_states[, seq_along(states_keep)] <- t(do.call(rbind, states_keep)) - - list(stable_state = states[[i]], - all_states = all_states, - converged = as.integer(diff < tol), - max_iter = i) -} +# +# r_iterate_dynamic_function <- function(transition_function, +# initial_state, +# niter = 100, +# tol = 1e-6, +# ...) { +# +# states <- list(initial_state) +# +# i <- 0L +# diff <- Inf +# +# while(i < niter & diff > tol) { +# i <- i + 1L +# states[[i + 1]] <- transition_function(states[[i]], i, ...) +# growth <- states[[i + 1]] / states[[i]] +# diffs <- growth - 1 +# diff <- max(abs(diffs)) +# } +# +# all_states <- matrix(0, length(states[[1]]), niter) +# states_keep <- states[-1] +# all_states[, seq_along(states_keep)] <- t(do.call(rbind, states_keep)) +# +# list(stable_state = states[[i]], +# all_states = all_states, +# converged = as.integer(diff < tol), +# max_iter = i) +# } # a midpoint solver for use in deSolve, from the vignette p8 rk_midpoint <- deSolve::rkMethod(