ars / R / ars.R
ars.R
Raw
#' Adaptive Rejection Sampling
#' \code{ars}
#'
#' Adaptive rejection sampling (ARS) is a sampling algorithm for log-concave univariate functions. One could use Rejection Sampling, but ARS is less computationally expensive through assumption of log-concavity and the unique function of updating the envelope and squeeze functions.
#'
#'
#' @author Kristoffer Hernandez, Leon Weingartner, Nikita Mehandru
#' @usage ars(f, N = 1000, bounds = c(-Inf, Inf), ...)
#'
#' @param f function to sample from (must fit log-concave properties)
#' @param N number of observations outputted
#' @param bounds vector of length 2 containing the lower and upper bounds for the distribution function f
#' @param ... additional arguments to pass to f (e.g. mean, sd, shape, etc.)
#'
#' @return Returns vector of samples, length N, from the distribution f(x)
#'
#' @references
#' Gilks, W. R., & Wild, P. (1992). Adaptive Rejection Sampling for Gibbs Sampling.
#' \emph{Journal of the Royal Statistical Society. Series C (Applied Statistics)}, 41(2), 337–348.
#'
#' Markou, S. (Dec 2022). \emph{Adaptive rejection sampling.} Random walks.
#' https://random-walks.org/content/misc/ars/ars.html
#'
#' Wickham, H. (2015). \emph{R packages.} O'Reilly Media.
#' https://r-pkgs.org/testing-basics.html
#'
#'
#' @examples
#' library(ars)
#'
#' # Example 1: Sample 1500 values from a gamma distribution (Gammma(2,5))
#' N <- 1500
#' samples <- ars(dgamma, N, bounds=c(0,Inf), shape = 2, rate = 5)
#'
#' # plot a histogram and curve of actual to check for similarity
#' hist(samples, prob= TRUE, breaks = 50)
#' curve(dgamma(x, shape = 2, rate = 5),
#'     col="darkblue", lwd=2, add=TRUE, yaxt="n")
#'
#'
#' # Example 2: Sample 1000 values from a beta distribution (Beta(2,6))
#' f <- dbeta
#' samples <- ars(f, bounds=c(0,1), shape1 = 2, shape2 = 6)
#'
#' # plot a histogram and curve of actual to check for similarity
#' hist(samples, prob= TRUE)
#' curve(dbeta(x, shape1 = 2, shape2 = 6),
#'     col="darkblue", lwd=2, add=TRUE, yaxt="n")
#'
#' @importFrom stats runif
#' @importFrom stats optimize
#' @importFrom numDeriv grad
#' @importFrom assertthat assert_that
#'

#' @export
ars <- function(f, N = 1000, bounds = c(-Inf, Inf), ...){

  ##########PARAMETER CHECKS############
  assertthat::assert_that(is.numeric(bounds), is.vector(bounds))

  assertthat::assert_that(is.numeric(N))

  assertthat::assert_that(is.function(f),
              msg = paste("Error: parameter f is not a valid function!"))

  if (bounds[1] == bounds[2] || bounds[1] > bounds[2] || length(bounds) != 2) {
    stop('ERROR: invalid bounds')
  }
  ######################################


  h <- function(x) {
    #return (log(f(x, ...)))
    return(log(f(x, ...)))
  }

  h_prime <- function(x) {
    return (numDeriv::grad(h, x))
  }


  samples = c()

  #write function to chose abscissae within bounds
  vars <- set_abcissa(f, h, h_prime, bounds, ...)

  xs <- vars$xs
  hs <- vars$hs
  h_primes <- vars$h_primes



  while(length(samples) < N){

    #SAMPLE STEP

    x = sample_step(xs, hs, h_primes, bounds)

    ux = u_k(x, xs, hs, h_primes, bounds)
    lx = l_k(x, xs, hs)



    u = runif(1)

    #1st rejection step (squeeze test)
    if(u <= exp(lx)/exp(ux)){
      #accept sample
      samples <- append(samples, x)
    }
    #2st rejection step
    else if(u <= exp(h(x))/exp(ux)){
      #accept sample
      samples <- append(samples, x)

    }
    else{
      #reject sample

    }

    #update xs, hs and h_primes
    xs <- append(xs, x)
    xs <- sort(xs)
    i = which(xs == x)
    hs <- append(hs, h(x), after = i-1)
    h_primes <- append(h_primes, h_prime(x), after = i-1)
  }


  return(samples)
}