cache.R

              
            
              #' Apply a database caching layer to an arbitrary R function.
#'
#' Requesting data from an API or performing queries can sometimes yield
#' dataframe outputs that are easily cached according to some primary key.
#' This function makes it possible to cache the output along the primary
#' key, while only using the uncached function on those records that
#' have not been computed before.
#'
#' @param uncached_function function. The function to cache.
#' @param key character. A character vector of primary keys. If \code{key} is unnamed,
#'   the user guarantees that \code{uncached_function} has these as formal arguments
#'   and that it returns a data.frame containing columns with at least those
#'   names. For example, if we are caching a function that looks like
#'   \code{function(author) { ... }}, we expect its output to be data.frames
#'   containing an \code{"author"} column with one record for each author.
#'   In this situation, \code{key = "author"}. Otherwise if \code{key} is
#'   a named length 1 vector, the name shall match the uncached_function key
#'   argument, the value shall be matched to at least one of the columns the
#'   returned data.frame contains.
#' @param salt character. The names of the formal arguments of \code{uncached_function}
#'   for which a unique value at calltime should use a different database
#'   table. In other words, if \code{uncached_function} has arguments
#'   \code{id, x, y}, but different kinds of data.frames (i.e., ones with
#'   different types and/or column names) will be returned depending
#'   on the value of \code{x} or \code{y}, then we can set
#'   \code{salt = c("x", "y")} to use a different database table for
#'   each combination of values of \code{x} and \code{y}. For example,
#'   if \code{x} and \code{y} are only allowed to be \code{TRUE} or
#'   \code{FALSE}, with potentially four different kinds of data.frame
#'   outputs, then up to four tables would be created.
#' @param con SQLConnection or character. Database connection object, \emph{or}
#'   character path to database.yml file. In the latter case, you will have to
#'   specify an \code{env} parameter that determines the environment used for
#'   the database.yml file.
#' @param prefix character. Database table prefix. A different prefix should
#'   be used for each cached function so that there are no table collisions.
#'   Optional, but highly recommended. By default, the deparsed name of the
#'   \code{uncached_function} parameter.
#' @param env character. The environment of the database connection if con
#'   is a yaml cofiguration file. By default, \code{"cache"}.
#' @param batch_size integer. Usually, the uncached operation is slow
#'   (or we would not have to cache it!). However, fetching data from the
#'   database is fast. To handle this dichotomy, the \code{batch_size}
#'   parameter gives the ability to control the chunks in which to compute
#'   and cache the uncached operation. This makes it more robust to failures,
#'   and ensures fetching of uncached data is partially stored even when
#'   errors occur midway through the process. The default is \code{100}.
#'
#'   Note that the \href{http://github.com/peterhurford/batchman}{batchman}
#'   package should be installed for batching to take effect.
#' @param safe_columns logical or function.  If safe_columns = \code{TRUE}
#'   and a caching call would add additional columns for an already existing
#'   cache with already existing columns, the function will instead crash.
#'   If safe_columns is a function, that function will be called.  The function
#'   must return /code{TRUE} for this to work.  Also the function will be called
#'   with no arguments.  This is mainly so you can write your own error message.
#'   If safe_columns is /code{FALSE}, the additional columns will be added. 
#'   Defaults \code{FALSE}.
#' @return A function with a caching layer that does not call
#'   \code{uncached_function} with already computed records, but retrieves
#'   those results from an underlying database table.
#' @export
#' @examples
#' \dontrun{
#' # These examples assume you have a database connection object
#' # (as specified in the DBI package) in a local variable `con`.
#'
#' # Imagine we have a function that returns a data.frame of information
#' # about IMDB titles through their API. It takes an integer vector of
#' # IDs and returns a data.frame with an "id" column, with one row for
#' # each title. (for example, 111161 would correspond to
#' # http://www.imdb.com/title/tt111161/ which is The Shawshank Redemption).
#' amazon_info <- function(id) {
#'   # Call external API.
#' }
#'
#' # Sending HTTP requests to Amazon and waiting for the response is
#' # computationally intensive, so if we ask for some IDs that have
#' # already been computed in the past, it would be useful to not
#' # make additional HTTP requests for those records. For example,
#' # we may want to do some processing on all Amazon titles. However,
#' # new records are created each day. Instead of parsing all
#' # the historical records on each execution, we would like to only
#' # parse new records; old records would be retrieved from a database
#' # table that had the same column names as a typical output data.frame
#' # of the `amazon_info` function.
#' cached_amazon_info <- cachemeifyoucan::cache(amazon_info, key = 'id', con = con)
#'
#' # By using the `cache` function, we are asking for the following:
#' #   (1) If we call `cached_amazon_info` with a vector of integer IDs,
#' #       take the subset of IDs that have already been returned from
#' #       a previous call to `cached_amazon_info`. Retrieve the data.frame
#' #       for these records from an underlying database table.
#' #   (2) The remaining IDs (those we have never passed to `cached_amazon_info`)
#' #       should be fed to the base `amazon_info` function as if we had
#' #       called it with this subset. This will yield another data.frame that
#' #       was computed using live HTTP requests.
#' # The `cached_amazon_info` function will return the union (rbind) of these
#' # two data sets as one single data set, as if we had called `amazon_info`
#' # by itself. It will also cache the second data set so another identical
#' # call to `cached_amazon_info` will not trigger any additional HTTP requests.
#'
#' ###
#' # Salts
#' ###
#'
#' # Imagine our `amazon_info` function is slightly more complicated:
#' # instead of always returning the same information about film titles,
#' # it has an additional parameter `type` that controls whether we
#' # want info about the filmography or about the reviews. The output
#' # of this function will still be data.frame's with an `id` column
#' # and one row for each title, but the other columns can be different
#' # now depending on the `type` parameter.
#' amazon_info2 <- function(id, type = 'filmography') {
#'   if (identical(type, 'filmography')) { return(amazon_info(id)) }
#'   else { return(review_amazon_info(id)) } # Assume we have this other function
#' }
#'
#' # If we wish to cache `amazon_info2`, we need to use different underlying
#' # database tables depending on the given `type`. One table may have
#' # columns like `num_actors` or `film_length` and the other may have
#' # column such as `num_reviews` and `avg_rating`.
#' cached_amazon_info2 <- cachemeifyoucan::cache(amazon_info2, key = 'id',
#'   salt = 'type', con = con)
#'
#' # We have told the caching layer to use the `type` parameter as the "salt".
#' # This means different values of `type` will use different underlying
#' # database tables for caching. It is up to the user to construct a
#' # function like `amazon_info2` well so that it always returns a data.frame
#' # with exactly the same column names if the `type` parameter is held fixed.
#' # The salt should usually consist of a collection of parameters (typically
#' # only one, `type` as in this example) that have a small number of possible
#' # values; otherwise, many database tables would be created for different
#' # values of the salt. Consider the following example.
#'
#' bad_amazon_filmography <- function(id, actor_id) {
#'   # Given a single actor_id and a vector of title IDs,
#'   # return information about that actor's role in the film.
#' }
#' bad_cached_amazon_filmography <-
#'   cachemeifyoucan::cache(bad_amazon_filmography, key = 'id',
#'     salt = 'actor_id', con = con)
#'
#' # We will now be creating a separate table each time we call
#' # `bad_amazon_filmography` for a different actor!
#'
#' ###
#' # Prefixes
#' ###
#'
#' # It is very important to give the function you are caching a prefix:
#' # when it is stored in the database, its table name will be the prefix
#' # combined with some string derived from the values in the salt.
#'
#' cached_review_amazon_info <- cachemeifyoucan::cache(review_amazon_info,
#'   key = 'id', con = con)
#'
#' # Remember our `review_amazon_info` function from an earlier example?
#' # If we attempted to cache it without a prefix while also caching
#' # the vanilla `amazon_info` function, the same database table would be
#' # used for both functions! Since function representation in R is complex
#' # and there is no good way in general to determine whether two functions
#' # are identical, it is up to the user to determine a good prefix for
#' # their function (usually the function's name) so that it does not clash
#' # with other database tables.
#'
#' cached_amazon_info <- cachemeifyoucan::cache(amazon_info,
#'   prefix = 'amazon_info', key = 'id', con = con)
#' cached_review_amazon_info <- cachemeifyoucan::cache(review_amazon_info,
#'   prefix = 'review_amazon_info', key = 'id', con = con)
#'
#' # We will now use different database tables for these two functions.
#'
#' ###
#' # force.
#' ###
#'
#' # `force.` is a reserved argument for the to-be-cached function. If
#' # it is specified to be `TRUE`, the caching layer will forcibly
#' # repopulate the database tables for the given ids. The default value
#' # is `FALSE`.
#'
#' cached_amazon_info <- cachemeifyoucan::cache(amazon_info,
#'   prefix = 'amazon_info', key = 'id', con = con)
#' cached_amazon_info(c(10, 20), force. = TRUE) # Will forcibly repopulate.
#'
#' ###
#' # Advanced features
#' ###
#'
#' # We can use multiple primary keys and salts.
#' grab_sql_table <- function(table_name, year, month, dbname = 'default') {
#'   # Imagine we have some function that given a table name
#'   # and a database name returns a data.frame with aggregate
#'   # information about records created in that table from a
#'   # given year and month (e.g., ensuring each table has a
#'   # created_at column). This function will return a data.frame
#'   # with one record for each year-month pair, with at least
#'   # the columns "year" and "month".
#' }
#'
#' cached_sql_table <- cachemeifyoucan::cache(grab_sql_table,
#'   key = c('year', 'month'), salt = c('table_name', 'dbname'), con = con,
#'   prefix = 'sql_table')
#'
#' # We would like to use a separate table to cache each combination of
#' # table_name and dbname. Note that the character vector passed into
#' # the `salt` parameter has to exactly match the names of the formal
#' # arguments in the initial function, and must also be the name of
#' # the columns returned by the data.frame. If these do not agree,
#' # you can wrap your function. For example, if the data.frame returned
#' # has 'mth' and 'yr' columns, you could instead cache the wrapper:
#' wrap_sql_table <- function(table_name, yr, mth, dbname = 'default') {
#'   grab_sql_table(table_name = table_name, year = yr, month = mth, dbname = dbname)
#' }
#'
#' ###
#' # Debugging option `cachemeifyoucan.debug`
#' ###
#'
#' Sometimes it might be interesting to take a look at the underlying database
#' tables for debugging purposes. However, the contents of the database are
#' somewhat obfuscated. If you set `cachemeifyoucan.debug` option to TRUE will
#' every time you execute a cached function you will see some additional metadata
#' printed out, helping you navigate the database. An example output looks like this:
#'
#' Using table name: amazon_data_c3204c0a47beb9238a787058d4f03834
#' Shard dimensions:
#'   shard1_f8e8e2b41ac5c783d0954ce588f220fc: 45 rows * 308 columns
#' 11 cached keys
#' 5 uncached keys
#'
#' }
cache <- function(uncached_function, key, salt, con, prefix = deparse(uncached_function),
                  env = "cache", batch_size = 100, safe_columns = FALSE) {
  stopifnot(is.function(uncached_function),
    is.character(prefix), length(prefix) == 1,
    is.character(key), length(key) > 0,
    is.atomic(salt) || is.list(salt), is.numeric(batch_size),
    (is.logical(safe_columns) || is.function(safe_columns)))

  cached_function <- new("function")

            

Retain the same formal arguments as the base function.

                formals(cached_function) <- formals(uncached_function)

            

Check “force.”, “dry.” name collision

                lapply(c("force.", "dry."), function(x) {
    if (x %in% names(formals(cached_function))) {
      stop(sQuote(x), " is a reserved argument in cachemeifyoucan layer, ",
      "collision with formals in the cachemeifyoucan function.", call. = FALSE)
    }
  })

            

Default force. argument to be FALSE

                formals(cached_function)$force. <- FALSE

            

Default dry. argument to be FALSE

                formals(cached_function)$dry. <- FALSE

            

Inject some values we will need in the body of the caching layer.

                environment(cached_function) <-
    list2env(list(`_prefix` = prefix, `_key` = key, `_salt` = salt
      , `_uncached_function` = uncached_function, `_con` = NULL
      , `_con_build` = c(list(con), if (!missing(env)) list(env))
      , `_env` = if (!missing(env)) env
      , `_batch_size` = batch_size
      ),
      parent = environment(uncached_function))

  build_cached_function(cached_function, safe_columns)
}

#' Fetch the uncached function
#'
#' If applied to a regular function it returns this function.
#'
#' @param fn function. The function that you want to uncache.
#' @export
uncached <- function(fn) {
  stopifnot(is.function(fn))
  if (is(fn, "cached_function")) {
    environment(fn)$`_uncached_function`
  } else {
    fn
  }
}

build_cached_function <- function(cached_function, safe_columns) {
            

All cached functions will have the same body.

                body(cached_function) <- quote({
            

If a user calls the uncached_function with, e.g., fn <- function(x, y) { … } fn(1:2), fn(x = 1:2), fn(y = 5, 1:2), fn(y = 5, x = 1:2) then call will be a list with names “x” and “y” in all situations.

              
    raw_call <- match.call()
            

Strip function name but retain arguments.

                  call     <- as.list(raw_call[-1])

            

Strip away the dry. and force. parameter, which are reserved.

                  is_force <- isTRUE(eval.parent(call$force.))
    is_dry <- isTRUE(eval.parent(call$dry.))
    call$force. <- NULL
    call$dry. <- NULL

            

Evaluate function call parameters in the calling environment

                  for (name in names(call)) {
      call[[name]] <- eval.parent(call[[name]])
    }

            

Only apply salt on provided values.

                  true_salt <- call[intersect(names(call), `_salt`)]

            

Since the values in call might be expressions, evaluate them in the calling environment to get their actual values.

                  for (name in names(true_salt)) {
      true_salt[[name]] <- eval.parent(true_salt[[name]])
    }

            

The database table to use is determined by the prefix and what values of the salted parameters were used at calltime.

                  tbl_name <- cachemeifyoucan:::table_name(`_prefix`, true_salt)

            

Check database connection and reconnect if necessary

                  if (is.null(`_con`) || !dbtest::is_db_connected(`_con`)) {
      if (!is.null(`_con_build`[[1]])) {
        `_con` <<- do.call(dbtest::build_connection, `_con_build`)
      } else {
        stop("Cannot re-establish database connection (caching layer)!")
      }
    }

            

Check whether force. was set

                  if (is_force) {
      force. <- TRUE
      message("`force.` detected. Overwriting cache...\n")
    } else force. <- FALSE

    fcn_call <- cachemeifyoucan:::cached_function_call(`_uncached_function`, call,
        parent.frame(), tbl_name, `_key`, `_con`, force., `_batch_size`)

            

Grab the all keys

                  keys <- fcn_call$call[[fcn_call$key]]

            

Log cache metadata if in debug mode

                  status <- cachemeifyoucan:::debug_info(fcn_call, keys)

    if (!is_dry) cachemeifyoucan:::execute(fcn_call, keys, safe_columns) else status
  })

  class(cached_function) <- append("cached_function", class(cached_function))
  environment(cached_function)$safe_columns <- safe_columns
  cached_function
}

            

A helper function to execute a cached function call.

              execute <- function(fcn_call, keys, safe_columns) {

            

If some keys were populated by another process, we will keep track of those so that we do not have to duplicate the caching effort.

                intercepted_keys <- list2env(list(keys = integer(0)))

  compute_and_cache_data <- function(keys, overwrite = FALSE) {

    if (isTRUE(overwrite)) {
      remove_old_key(fcn_call$con, fcn_call$table, keys, fcn_call$output_key)
    }

            

Re-query which keys are not cached, since someone else could have populated them in parallel (if another user requested the same IDs).

                  uncached_keys <- get_new_key(fcn_call$con, fcn_call$table, keys, fcn_call$output_key)
    intercepted_keys$keys <- c(intercepted_keys$keys, setdiff(keys, uncached_keys))
    keys <- uncached_keys
    if (!length(keys)) return(data.frame())
    uncached_data <- compute_uncached_data(fcn_call, keys)
    write_data_safely(fcn_call$con, fcn_call$table,
      uncached_data, fcn_call$output_key, safe_columns)
    uncached_data
  }

  if (fcn_call$force) {
    uncached_keys <- keys
  } else {
    uncached_keys <- get_new_key(fcn_call$con, fcn_call$table, keys, fcn_call$output_key)
  }

  if (length(uncached_keys) > fcn_call$batch_size &&
      requireNamespace("batchman", quietly = TRUE)) {
    batched_fn <- batchman::batch(
      compute_and_cache_data, "keys",
      size = fcn_call$batch_size,
      combination_strategy = plyr::rbind.fill,
      batchman.verbose = verbose(),
      retry = 3,
      stop = TRUE
    )
    uncached_data <- batched_fn(uncached_keys, fcn_call$force)
  } else {
    uncached_data <- compute_and_cache_data(uncached_keys, fcn_call$force)
  }

            

Since computing and caching data may take a long time and some of the keys may have been populated by a different R process (in case of parallel) cache requests, we need to query now which keys are cached.

                cached_keys <- Reduce(setdiff, list(keys, uncached_keys, intercepted_keys$keys))

            

Actually compute for the uncached keys

                cached_data <- compute_cached_data(fcn_call, cached_keys)

  data <- unique(plyr::rbind.fill(uncached_data, cached_data))
  if (fcn_call$force) {
            

restore column names using existing cache columns

                  old_columns <- get_column_names_from_table(fcn_call)
    tmp_df <- setNames(data.frame(matrix(ncol = length(old_columns), nrow = 0)), old_columns)
            

rbind.fill with a 0-row dataframe will set the missing columns to NA, just what we want

                  data <- plyr::rbind.fill(data, tmp_df)
  }
            

This seems to cause a bug. Have to sort to conform with order of keys.

                data[order(match(data[[fcn_call$output_key]], keys), na.last = NA), , drop = FALSE]
}

debug_info <- function(fcn_call, keys) {
  uncached_keys <- get_new_key(fcn_call$con, fcn_call$table, keys, fcn_call$output_key)
  cached_keys <- setdiff(keys, uncached_keys)

  shard_names <- get_shards_for_table(fcn_call$con, fcn_call$table)
  shard_info <- lapply(shard_names, function(name) {
    if (DBI::dbExistsTable(fcn_call$con, name)) {
      num_rows <- DBI::dbGetQuery(fcn_call$con, paste0("SELECT count(*) from ", name))[1, 1]
      query <- paste0("select count(column_name) from information_schema.columns where table_name='", name, "'")
      num_cols <- DBI::dbGetQuery(fcn_call$con, query)[1, 1]
      paste0('  ', name, ': ', num_rows, ' rows * ', num_cols, ' columns')
    } else {
      paste0('  ', name, ': new shard')
    }
  })

  if (isTRUE(getOption('cachemeifyoucan.debug'))) {
    msg <- paste(
      c(
        paste0("Using table name: ", fcn_call$table),
        "Shard dimensions:",
        shard_info,
        paste0(length(cached_keys)  , " cached keys"),
        paste0(length(uncached_keys), " uncached keys")
      ),
      collapse = "\n"
    )
    message(msg)
  }

  list(
    cached_keys = cached_keys,
    uncached_keys = uncached_keys,
    shard_names = shard_names,
    table_name = fcn_call$table
  )
}

compute_uncached_data <- function(fcn_call, uncached_keys) {
  error_fn(data_injector(fcn_call, uncached_keys, FALSE))
}

get_column_names_from_table <- function(fcn_call) {
            

Fetch one row from each corresponding shard omitting the id column and return a vector of column names

                shards <- get_shards_for_table(fcn_call$con, fcn_call$table)
  lst <- lapply(shards, function(shard) {
    df <- if (DBI::dbExistsTable(fcn_call$con, shard))
      DBI::dbGetQuery(fcn_call$con, paste0("SELECT * from ", shard, " LIMIT 1"))
    else data.frame()
    as.character(setdiff(colnames(df), fcn_call$output_key))
  })
            

We don't really have to unique, but better safe than sorry!

                unique(c(fcn_call$output_key, translate_column_names(unlist(lst), fcn_call$con)))
}

compute_cached_data <- function(fcn_call, cached_keys) {
  error_fn(data_injector(fcn_call, cached_keys, TRUE))
}

cached_function_call <- function(fn, call, context, table, key, con, force, batch_size) {
  # TODO: (RK) Handle keys of length more than 1
  if (is.null(names(key))) {
    output_key <- key
  } else {
    output_key <- unname(key)
    key <- names(key)
  }
  structure(list(fn = fn, call = call, context = context, table = table, key = key,
                 output_key = output_key, con = con, force = force, batch_size = batch_size),
    class = 'cached_function_call')
}

data_injector <- function(fcn_call, keys, cached) {
  if (length(keys) == 0) {
    return(data.frame())
  } else if (isTRUE(cached)) {
    data_injector_cached(fcn_call, keys)
  } else {
    data_injector_uncached(fcn_call, keys)
  }
}

data_injector_uncached <- function(fcn_call, keys) {
  fcn_call$call[[fcn_call$key]] <- keys
  eval(as.call(append(fcn_call$fn, fcn_call$call)), envir = fcn_call$context)
}

data_injector_cached <- function(fcn_call, keys) {
            

Find all the shards that correspond to this function call Read data from all the shards into a list. We will thus obtain a list of data frames. Now we have to merge all the data.frames in the list into one and return. Notice that these data frames have different column names (that's the whole point of our columnar sharding), except for the key by which we query.

                shards <- get_shards_for_table(fcn_call$con, fcn_call$table)
  lst <- lapply(shards, function(shard) read_df_from_a_shard(fcn_call, keys, shard))
  if (length(unique(vapply(lst, NROW, integer(1)))) > 1) {
    warning("cachemeifyoucan detected an integrity error: All shards should ",
      "have the same number of rows. If this is ",
      "not an error you understand, please report it to the ",
      "cachemeifyoucan developers at github.com/robertzk/cachemeifyoucan",
      call. = FALSE)
  }
  merge2(lst, fcn_call$output_key)
}

read_df_from_a_shard <- function(fcn_call, keys, shard) {
  sql <- paste("SELECT * FROM", shard, "WHERE", fcn_call$output_key, "IN (",
               paste(sanitize_sql(keys), collapse = ', '), ")")
  db2df(dbGetQuery(fcn_call$con, sql),
        fcn_call$con, fcn_call$output_key)
}

            

Gotta love some method dispatch here in cachemeifyoucan

              sanitize_sql <- function(x) { UseMethod("sanitize_sql") }
sanitize_sql.numeric <- function(x) { x }
sanitize_sql.character <- function(x) {
  paste0("'", gsub("'", "\\'", x, fixed = TRUE), "'")
}

#' Stop on given errors and print corresponding error message.
#'
#' @name error_fn
#' @param data data.frame.
error_fn <- function(data) {
  if (!is.data.frame(data)) {
    stop("Function cached with cachemeifyoucan ",
         "package must return data.frame outputs", call. = FALSE)
  }
  data
}
            

cachemeifyoucan-package.R

              
            
              #' One of the most frustrating parts about being a data scientist
#' is waiting for data or other large downloads. This package offers a caching
#' layer for arbitrary functions that relies on a database backend.
#'
#' @name cachemeifyoucan
#' @docType package
#' @import berdie RPostgres DBI digest plyr dbtest
NULL
            

db.R

              
            
              #' Database table name for a given prefix and salt.
#'
#' @param prefix character. Prefix.
#' @param salt list. Salt for the table name.
#' @return the table name. This will just be \code{"prefix_"}
#'   appended with the MD5 hash of the digest of the \code{salt}.
table_name <- function(prefix, salt) {
  tolower(paste0(prefix, "_", digest::digest(salt)))
}

#' Fetch the map of column names.
#'
#' @param dbconn SQLConnection. A database connection.
column_names_map <- function(dbconn) {
  DBI::dbGetQuery(dbconn, "SELECT * FROM column_names")
}

#' Fetch all the shards for the given table name.
#'
#' @param dbconn SQLConnection. A database connection.
#' @param tbl_name character. The calculated table name for the function.
#' @return one or many names of the shard tables.
get_shards_for_table <- function(dbconn, tbl_name) {
  if (!DBI::dbExistsTable(dbconn, 'table_shard_map')) create_shards_table(dbconn, 'table_shard_map')
  DBI::dbGetQuery(dbconn, paste0("SELECT shard_name FROM table_shard_map where table_name='", tbl_name, "'"))$shard_name
}

#' Create the table <=> shards map.
#'
#' @rdname create_table
#' @param dbconn SQLConnection. A database connection.
#' @param tblname character.The table to be created
create_shards_table <- function(dbconn, tblname) {
  if (DBI::dbExistsTable(dbconn, tblname)) return(TRUE)
  sql <- paste0("CREATE TABLE ", tblname, " (table_name varchar(255) NOT NULL, shard_name varchar(255) NOT NULL);")
  DBI::dbGetQuery(dbconn, sql)
  TRUE
}

#' MD5 digest of column names.
#'
#' @param raw_names character. A character vector of column names.
#' @return the character vector of hashed names.
get_hashed_names <- function(raw_names) {
  paste0('c', vapply(raw_names, digest::digest, character(1)))
}

#' Translate column names using the column_names table from MD5 to raw.
#'
#' @param names character. A character vector of column names.
#' @param dbconn SQLConnection. A database connection.
translate_column_names <- function(names, dbconn) {
  name_map <- column_names_map(dbconn)
  name_map <- setNames(as.list(name_map$raw_name), name_map$hashed_name)
  vapply(names, function(name) name_map[[name]] %||% name, character(1))
}

#' Convert the raw fetched database table to a readable data frame.
#'
#' @param df. Raw fetched database table.
#' @param dbconn SQLConnection. A database connection.
#' @param key. Identifier of database table.
db2df <- function(df, dbconn, key) {
  df[[key]] <- NULL
  colnames(df) <- translate_column_names(colnames(df), dbconn)
  df
}

#' Create index on a table
#'
#' @name db2df
#' @param df. Raw fetched database table.
#' @param dbconn SQLConnection. A database connection.
#' @param key. Identifier of database table.
add_index <- function(dbconn, tblname, key, idx_name) {
  if (!tolower(substring(idx_name, 1, 1)) %in% letters) {
    stop(sprintf("Invalid index name '%s': must begin with an alphabetic character",
                 idx_name))
  }
  DBI::dbGetQuery(dbconn, paste0('CREATE INDEX ', idx_name, ' ON ', tblname, '(', key, ')'))
  TRUE
}

#' Try and check dbWriteTable until success
#'
#' @param dbconn SQLConnection. A database connection.
#' @param tblname character. Database table name.
#' @param df data frame. The data frame to insert.
dbWriteTableUntilSuccess <- function(dbconn, tblname, df, append = FALSE, row.names = NA) {
  if (DBI::dbExistsTable(dbconn, tblname) && !isTRUE(append)) {
    DBI::dbRemoveTable(dbconn, tblname)
  }
  if (any(is.na(df))) {
    df[, vapply(df, function(x) all(is.na(x)), logical(1))] <- as.character(NA)
  }

  repeat {
    class_map <- list(integer = 'bigint', numeric = 'double precision', factor = 'text',
                      double = 'double precision', character = 'text', logical = 'text')
    field_types <- sapply(sapply(df, class), function(klass) class_map[[klass]])
    DBI::dbWriteTable(dbconn, tblname, df, append = append,
                      row.names = row.names, field.types = field_types)
    #TODO(kirill): repeat maximum of N times
    if (!isTRUE(append)) {
      num_rows <- DBI::dbGetQuery(dbconn, paste0('SELECT COUNT(*) FROM ', tblname))
      if (num_rows == nrow(df)) break
    } else break
  }
}

            

Helper utility for safe IO of a data.frame to a database connection.

This function will be mindful of three problems: non-existent columns, long column names, and sharding *data.frame*s with too many columns.

Since this is meant to be used as a helper function for caching data, we must take a few precautions. If certain variables are not available for older data but are introduced for newer data, we must be careful to create those columns first.

Furthermore, certain column names may be longer than PostgreSQL supports. To circumvent this problem, this helper function stores an MD5 digest of each column name and maps them using the column_names helper table.

By default, this function assumes any data to be written is not already present in the table and should be appended. If the table does not exist, it will be created.

              #' Write data.frames to DB addressing pitfalls
#'
#' @param dbconn PostgreSQLConnection. The database connection.
#' @param tblname character. The table name to write the data into.
#' @param df data.frame. The data to write.
#' @param key character. The identifier column name.
#' @inheritParams cache
write_data_safely <- function(dbconn, tblname, df, key, safe_columns) {
  if (is.null(df)) return(FALSE)
  if (!is.data.frame(df)) return(FALSE)
  if (nrow(df) == 0) return(FALSE)

  if (missing(key)) {
    id_cols <- grep('(_|^)id$', colnames(df), value = TRUE)
    if (length(id_cols) == 0)
      stop("The data you are writing to the database must contain at least one ",
           "column ending with '_id'")
  } else {
    id_cols <- key
    if (!is.integer(df[[key]]) && is.numeric(df[[key]])) {
      # TODO: (RK) Check if coercion is possible.
      df[[key]] <- as.integer(df[[key]])
    }
  }

  write_column_names_map <- function(raw_names) {
    hashed_names <- get_hashed_names(raw_names)
    column_map <- data.frame(raw_name = raw_names, hashed_name = hashed_names)
    column_map <- column_map[!duplicated(column_map), ]

            

If we don't do this, we will get really weird bugs with numeric things stored as character For example, a row with ID 100000 will be stored as 10e+5, which is wrong.

                  old_options <- options(scipen = 20, digits = 20)
    on.exit(options(old_options))

            

Store the map of raw to MD5'ed column names in the column_names table.

                  if (!DBI::dbExistsTable(dbconn, 'column_names'))
      dbWriteTableUntilSuccess(dbconn, 'column_names', column_map, append = FALSE)
    else {
      raw_names <- DBI::dbGetQuery(dbconn, 'SELECT raw_name FROM column_names')[[1]]
      column_map <- column_map[!is.element(column_map$raw_name, raw_names), ]
      if (NROW(column_map) > 0) {
        if (isTRUE(safe_columns)) {
          stop("Safe Columns Error: Your function call is adding additional ",
            "columns to a cache that already has pre-existing columns. This ",
            "would suggest your cache is invalid and you should wipe the cache ",
            "and start over.")
        } else if (is.function(safe_columns)) {
          safe_columns()
        } else { # Write additional columns
          dbWriteTable(dbconn, 'column_names', column_map, append = TRUE, row.names = FALSE)
        }
      }
    }
    TRUE
  }

            

Input: table name and calculated shard names

                write_table_shard_map <- function(tblname, shard_names) {
            

Example:

table_name shard_name
1 tblname_1 shard_1
2 tblname_1 shard_2
3 tblname_2 shard_3
                  table_shard_map <- data.frame(table_name = rep(tblname, length(shard_names)), shard_name = shard_names)
            

If we don't do this, we will get really weird bugs with numeric things stored as character For example, a row with ID 100000 will be stored as 10e+5, which is wrong.

                  old_options <- options(scipen = 20, digits = 20)
    on.exit(options(old_options))

            

Store the map of logical table names to physical shards in the table_shard_map table.

                  if (!DBI::dbExistsTable(dbconn, 'table_shard_map')) {
      dbWriteTableUntilSuccess(dbconn, 'table_shard_map', table_shard_map, append = FALSE)
    } else {
      shards <- get_shards_for_table(dbconn, tblname)
      if (length(shards) > 0) {
        table_shard_map <- table_shard_map[table_shard_map$shard_name %nin% shards, ]
      }
      if (NROW(table_shard_map) > 0) {
        DBI::dbWriteTable(dbconn, 'table_shard_map', table_shard_map, append = TRUE, row.names = FALSE)
      }
    }
    TRUE
  }

  get_shard_names <- function(df, tblname) {
            

Two cases: the shards already exist - or they don't

Fetch existing shards

                  shards <- get_shards_for_table(dbconn, tblname)

            

come up with new shards if needed

                  numcols <- NCOL(df)
    if (numcols == 0) return(NULL)
    numshards <- ceiling(numcols / MAX_COLUMNS_PER_SHARD)
            

All data-containing tables will start with prefix shard#{n}_

                  newshards <- paste0("shard", seq(numshards), "_", digest::digest(tblname))
    if (length(shards) > 0) {
            

only generate new shard names for shards that don't exist!

                    unique(c(shards, newshards[-seq(length(shards))]))
    } else newshards
  }

  df2shards <- function(dbconn, df, shard_names, key) {
            

Here comes the hard part. Sharding strategies!

Here is how we're going to do it. We sort the shardnames, to ensure that the first shard is the biggest in size This way appending to a shard is trivial: if we have any columns in the dataframe that are not yet stored in the cache - just append them to the last shard! Since we've done the calculation of number of shards beforehand we don't even have to worry about creating new shards if something won't fit.

Because it will.

              
            

Make sure we don't store key in the used_columns! Need it in every dataframe

                  used_columns <- c()

            

We want to sort our shards prior to writing. Unfortunately, sort(1:11) == c(1, 10, 11, 2, 3, ...) which is not what we want That's why we're using a slightly more ghetto solution

                  suffix <- strsplit(shard_names[1], '_')[[1]][2]
    lapply(paste0('shard', seq(length(shard_names)), '_', suffix), function (shard, last, key) {
            

We need to create a map in the form of list(df = dataframe, shard_name = shard_names), where the dataframe is a subset of the original dataframe that contains less columns than MAX_COLUMNS_PER_SHARD. This is what we should do for each shard:

  1. Determine which columns are already being stored in the shard
  2. Take the subset of the dataframe that has these columns, assign it to a shard
  3. See which columns are left unsaved, and add those to the last shard
                    if (shard == last) {
            

Write out the rest of the dataframe into the last shard

                      list(df = df[setdiff(colnames(df), used_columns)], shard_name = shard)
      } else {
            

If the response is empty, write the first N columns of the dataframe Otherwise, only write out those columns that already exist in this shard

                      shard_exists <- DBI::dbExistsTable(dbconn, shard)
        if (isTRUE(shard_exists)) {
          one_row <- DBI::dbGetQuery(dbconn, paste0("SELECT * FROM ", shard, " LIMIT 1"))
        } else one_row <- NULL
            

Here we abuse the fact that NROW(NULL) == 0

                      if (NROW(one_row) == 0 || NCOL(one_row) == 2) {
            

This is very hacky… If we see only two columns in a shard, it means that we only stored the id and the hashed id. So basically this shard is useless! In this case we should drop it, and pretend this table doesn't exist

                        if (NCOL(one_row) == 2) {
            DBI::dbGetQuery(dbconn, paste0("DROP TABLE ", shard))
          }
          columns <- colnames(df)
          columns <- columns[columns != key]
          columns <- setdiff(columns, used_columns)
          columns <- c(columns[1:MAX_COLUMNS_PER_SHARD - 1], key)
          used_columns <<- append(used_columns, columns[columns != key])
          list(df = df[columns], shard_name = shard)
        } else {
          columns <- unique(translate_column_names(colnames(one_row), dbconn))
          used_columns <<- append(used_columns, columns[columns != key])
          list(df = df[colnames(df) %in% columns], shard_name = shard)
        }
      }
    }, last = shard_names[length(shard_names)], key = key)
  }

  write_column_hashed_data <- function(df, tblname, append = TRUE) {
            

Create the mapping between original column names and their MD5 companions

                  write_column_names_map(colnames(df))

            

Store a copy of the ID columns (ending with '_id')

                  id_cols_ix <- which(is.element(colnames(df), id_cols))
    colnames(df) <- get_hashed_names(colnames(df))
    df[, id_cols] <- df[, id_cols_ix]

            

Convert some types to character so they go in the DB properly.

                  to_chars <- unname(vapply(df, function(x) is.factor(x) || is.ordered(x) || is.logical(x), logical(1)))
    df[to_chars] <- lapply(df[to_chars], as.character)

            

Write out to postgres

                  dbWriteTableUntilSuccess(dbconn, tblname, df, row.names = FALSE, append = append)
  }

            

Use transactions!

                DBI::dbGetQuery(dbconn, 'BEGIN')
  tryCatch({
            

Find the appropriate shards for this dataframe and tablename

                  shard_names <- get_shard_names(df, tblname)
            

Create references for these shards if needed

                  write_table_shard_map(tblname, shard_names)
            

Split the dataframe into the appropriate shards

                  df_shard_map <- df2shards(dbconn, df, shard_names, key)

            

Actually write the data to the database

                  lapply(df_shard_map, function(lst) {
      tblname <- lst$shard_name
      df <- lst$df
      if (!DBI::dbExistsTable(dbconn, tblname)) {
            

The shard doesn't exist yet. Let's create it and index it by key!

                      write_column_hashed_data(df, tblname, append = FALSE)
        add_index(dbconn, tblname, key, paste0("idx_", digest::digest(tblname)))
        return(invisible(TRUE))
      }

      one_row <- if (DBI::dbExistsTable(dbconn, tblname)) {
        DBI::dbGetQuery(dbconn, paste("SELECT * FROM ", tblname, " LIMIT 1"))
      } else NULL
      if (NROW(one_row) == 0) {
            

The shard is empty! Delete it and write to it, finally Also, it's a great opportunity to enforce indexes on this table!

                      if (DBI::dbExistsTable(dbconn, tblname))
          DBI::dbRemoveTable(dbconn, tblname)
        write_column_hashed_data(df, tblname, append = FALSE)
        add_index(dbconn, tblname, key, paste0("i", digest::digest(tblname)))
        return(invisible(TRUE))
      }

            

Columns that are missing in database need to be created

                    new_names <- get_hashed_names(colnames(df))
            

We also keep non-hashed versions of ID columns around for convenience.

                    new_names <- c(new_names, id_cols)
      missing_cols <- !is.element(new_names, colnames(one_row))
      # TODO: (RK) Check reverse, that we're not missing any already-present columns
      class_map <- list(integer = 'bigint', numeric = 'double precision', factor = 'text',
                        double = 'double precision', character = 'text', logical = 'text')
      removes <- integer(0)
      for (index in which(missing_cols)) {
        col <- new_names[index]
        if (!all(vapply(col, nchar, integer(1)) > 0))
          stop("Failed to retrieve MD5 hashed column names in write_data_safely")
        # TODO: (RK) Figure out how to filter all NA columns without wrecking
        # the tables.
        if (index > length(df)) index <- col
        sql <- paste0("ALTER TABLE ", tblname, " ADD COLUMN ",
                         col, " ", class_map[[class(df[[index]])[1]]])
        suppressWarnings(DBI::dbGetQuery(dbconn, sql))
      }

            

Columns that are missing in data need to be set to NA

                    missing_cols <- !is.element(colnames(one_row), new_names)
      if (sum(missing_cols) > 0) {
        raw_names <- translate_column_names(colnames(one_row)[missing_cols], dbconn)
        stopifnot(is.character(raw_names))
        df[, raw_names] <- lapply(sapply(one_row[, missing_cols], class), as, object = NA)
      }

      write_column_hashed_data(df, tblname, append = TRUE)
      DBI::dbGetQuery(dbconn, 'COMMIT')
    })
    },
    warning = function(w) {
      message("An warning occured:", w)
      message("Rollback!")
      DBI::dbRollback(dbconn)
      DBI::dbGetQuery(dbconn, 'COMMIT')
    },
    error = function(e) {
      if (grepl("Safe Columns Error", conditionMessage(e))) {
        stop(e)
      } else {
        message("An error occured:", e)
        message("Rollback!")
        DBI::dbRollback(dbconn)
        DBI::dbGetQuery(dbconn, 'COMMIT')
      }
    }
  )
  invisible(TRUE)
}

#' setdiff current ids with those in the table of the database.
#'
#' @param dbconn SQLConnection. The database connection.
#' @param tbl_name character. Database table name.
#' @param ids vector. A vector of ids.
#' @param key character. Identifier of database table.
get_new_key <- function(dbconn, tbl_name, ids, key) {
  if (length(ids) == 0) return(integer(0))
  shards <- get_shards_for_table(dbconn, tbl_name)
            

If there are no existing shards - then nothing is cached yet

                if (length(shards) == 0) return(ids)

  if (!DBI::dbExistsTable(dbconn, shards[1])) return(ids)
  id_column_name <- get_hashed_names(key)
            

We can check only the first shard because all shards have the same keys

                present_ids <- DBI::dbGetQuery(dbconn, paste0(
    "SELECT ", id_column_name, " FROM ", shards[1]))
            

If the table is empty, a 0-by-0 dataframe will be returned, so we must be careful.

                present_ids <- if (NROW(present_ids)) present_ids[[1]] else integer(0)
  setdiff(ids, present_ids)
}

#' remove old keys to maintain uniqueness of "id" for the sake of force pushing
#'
#' @param dbconn SQLConnection. The database connection.
#' @param tbl_name character. Database table name.
#' @param ids vector. A vector of ids.
#' @param key character. Identifier of database table.
remove_old_key <- function(dbconn, tbl_name, ids, key) {
  if (length(ids) == 0) return(invisible(NULL))
  id_column_name <- get_hashed_names(key)
  shards <- get_shards_for_table(dbconn, tbl_name)
  if (length(shards) == 0) return(invisible(NULL))
            

In this case though, we need to delete from all shards to keep them consistent

                sapply(shards, function(shard) {
    DBI::dbGetQuery(dbconn, paste0(
      "DELETE FROM ", shard, " WHERE ", id_column_name, " IN (",
      paste(ids, collapse = ","), ")"))
  })
  invisible(NULL)
}
            

migrations.R

              
            
              #' Table migrate
#'
#' @param table_list list. A list on form of
#'   \code{list("old_table_name_1" = "new_table_name_1", ..., "old_table_name_n" = "new_table_name_n")}
#' @param cached_fn cached_function.
#'
#' @export
table_migrate <- function(cached_fn, table_list) {
  stopifnot(is(cached_fn, 'cached_function'))
            

Imagine in the future you add one more parameter to your function that you want to make part of your salt. Unfortunately, this would mean that the cache table name would change and your old cache would not be used.

Enter migrations. Since everything is stored in shards anyway, all we need to do is point the new table to access the old shards. Current migration implementation requires you to fetch the correct table names by hand i.e. doing the following:

debugonce(cachemeifyoucan:::execute)
old_cached_function(...)  # call your old cached function
fcn_call$table  # this would be the old table name

And repeating the same step for the new function. then the inputs to this migrator should be the cached function itself (for now only used for getting the right db connection) and a list in the form of list("old_table_name_1" = "new_table_name_1", ..., "old_table_name_n" = "new_table_name_n")

              
  dbconn <- environment(cached_fn)[['_con']]
  if (is.null(dbconn)) {
    stop('Please execute the cached function at least once in the current R session to initialize the cache db connection')
  }
  stopifnot(DBI::dbExistsTable(dbconn, 'table_shard_map'))  # Gotta have some shards first
            

For every record in the table list…

                for (tblname in names(table_list)) {
            

Get shards corresponsing to the old table

                  shards <- DBI::dbGetQuery(dbconn, paste0("SELECT shard_name FROM table_shard_map WHERE table_name='", tblname,"'"))
            

Do stuff only if some shards exist

                  if (NROW(shards) > 0) {
      shards <- shards[[1]]
            

For every corresponsing shard…

                    for (shard in shards) {
            

Insert a mapping from the new table to the shard

                      DBI::dbGetQuery(dbconn,
          paste0("INSERT INTO table_shard_map values ('", table_list[[tblname]], "', '", shard, "')"))
      }
    }
  }
}
            

utils.R

              
            
              `%||%` <- function(x, y) if (is.null(x)) y else x

`%nin%` <- Negate(`%in%`)

slice <- function(x, n) split(x, as.integer((seq_along(x) - 1) / n))

verbose <- function() { isTRUE(getOption("cachemeifyoucan.verbose", FALSE)) }

            

This constant determines how many columns a table can have, thus affecting sharding

              MAX_COLUMNS_PER_SHARD <- 550

merge2 <- function(list_of_dataframes, id_name) {
  list_of_dataframes <- list_of_dataframes[
    order(vapply(list_of_dataframes, NROW, integer(1)), decreasing = TRUE)
  ]
  Reduce(function(x, y) {
    merge(x[c(id_name, setdiff(colnames(x), colnames(y)))], y, by = id_name, all.x = TRUE)
  }, list_of_dataframes)
}