#' Define a Markov Model
#' 
#' Combine information on parameters, transition matrix and 
#' states defined through \code{\link{define_parameters}}, 
#' \code{\link{define_matrix}} and 
#' \code{\link{define_state}} respectively.
#' 
#' This function checks whether the objects are compatible 
#' in the same model (same state names...).
#' 
#' 
#' @param parameters Optional. An object generated by 
#'   \code{\link{define_parameters}}.
#' @param transition_matrix An object generated by 
#'   \code{\link{define_matrix}}.
#' @param ... Object generated by 
#'   \code{\link{define_state}}.
#' @param states List of states, only used by 
#'   \code{define_model_} to avoid using \code{...}.
#'   
#' @return An object of class \code{uneval_model} (a list 
#'   containing the unevaluated parameters, matrix and 
#'   states).
#'   
#' @export
#' 
#' @example inst/examples/example_define_model.R
define_model <- function(...,
                         parameters = define_parameters(),
                         transition_matrix = define_matrix()) {
  
  states <- define_state_list_(list(...))
  
  define_model_(
    parameters = parameters,
    transition_matrix = transition_matrix,
    states = states
  )
}

#' @export
#' @rdname define_model
define_model_ <- function(parameters, transition_matrix, states) {
  
  stopifnot(
    get_state_number(states) == 0 |
      get_state_number(states) == get_matrix_order(transition_matrix),
    length(
      intersect(
        get_parameter_names(parameters),
        get_state_value_names(states)
      )
    ) == 0,
    
    identical(
      sort(get_state_names(states)),
      sort(get_state_names(transition_matrix))
    )
  )
  
  structure(
    list(
      parameters = parameters,
      transition_matrix = transition_matrix,
      states = states
    ), class = "uneval_model")
}

#' @export
print.uneval_model <- function(x, ...) {
  n_parm <- length(get_parameter_names(get_parameters(x)))
  n_states <- get_state_number(get_states(x))
  n_state_values <- length(get_state_value_names(get_states(x)))
  
  cat(sprintf(
    "An unevaluated Markov model:

    %i parameter%s,
    %i state%s,
    %i state value%s\n",
    n_parm,
    plur(n_parm),
    n_states,
    plur(n_states),
    n_state_values,
    plur(n_state_values)
  ))
}

#' Get Markov Model Parameters
#' 
#' Works on both unevaluated and evaluated models.
#' 
#' For internal use.
#' 
#' @param x An \code{uneval_model} or \code{eval_model}
#'   object.
#'   
#' @return An \code{uneval_parameters} or
#'   \code{eval_parameters} object.
get_parameters <- function(x){
  UseMethod("get_parameters")
}

get_parameters.default <- function(x){
  x$parameters
}

#' Get Markov Model Transition Matrix
#' 
#' Works on both unevaluated and evaluated models.
#' 
#' For internal use.
#' 
#' @param x An \code{uneval_model} or \code{eval_model}
#'   object.
#'   
#' @return An \code{uneval_matrix} or \code{uneval_matrix}
#'   object.
get_matrix <- function(x){
  UseMethod("get_matrix")
}

get_matrix.default <- function(x){
  x$transition_matrix
}

get_states <- function(x){
  UseMethod("get_states")
}

get_states.default <- function(x){
  x$states
}

get_counts <- function(x){
  UseMethod("get_counts")
}

get_counts.eval_model <- function(x){
  x$counts
}

#' Evaluate Markov Model
#' 
#' Given an unevaluated Markov Model, an initial number of 
#' individual and a number of cycle to compute, returns the 
#' evaluated version of the objects and the count of 
#' individual per state per model cycle.
#' 
#' \code{init} need not be integer. E.g. specifying a vector
#' of type c(Q = 1, B = 0, C = 0, ...) returns the 
#' probabilities for an individual starting in state A to be
#' in each state, per cycle.
#' 
#' @param model An \code{uneval_model} object.
#' @param cycles positive integer. Number of Markov Cycles 
#'   to compute.
#' @param init numeric vector, same length as number of 
#'   model states. Number of individuals in each model state
#'   at the beginning.
#' @param method Counting method.
#'   
#' @return An \code{eval_model} object (actually a list of 
#'   evaluated parameters, matrix, states and cycles 
#'   counts).
#' 
#' @examples
#' 
#' \dontrun{
#' param <- define_parameters(
#'   a = markov_cycle + 1 * 2
#' )
#' 
#' mat <- define_matrix(
#'   1-1/a, 1/a,
#'   .1,    .9
#' )
#' 
#' sta <- define_state_list(
#'   A = define_state(cost = 10),
#'   B = define_state(cost = 2)
#' )
#' 
#' mod <- define_model(
#'   parameters = param,
#'   transition_matrix = mat,
#'   states = sta
#' )
#' 
#' eval_model(
#'   mod,
#'   init = c(10, 5),
#'   cycles = 5
#' )
#' }
#' 
eval_model <- function(model, cycles, 
                       init, method) {
  
  stopifnot(
    cycles > 0,
    length(cycles) == 1,
    all(init >= 0)
  )
  
  parameters <- eval_parameters(get_parameters(model),
                                cycles = cycles)
  transition_matrix <- eval_matrix(get_matrix(model),
                                   parameters)
  states <- eval_state_list(get_states(model), parameters)
  
  count_table <- compute_counts(
    transition_matrix = transition_matrix,
    init = init,
    method = method
  )
  
  values <- compute_values(states, count_table)
  
  structure(
    list(
      parameters = parameters,
      transition_matrix = transition_matrix,
      states = states,
      counts = count_table,
      values = values
    ),
    class = c("eval_model"),
    init = init,
    cycles = cycles)
}

get_state_values <- function(x) {
  x$values
}

#' Compute Count of Individual in Each State per Cycle
#' 
#' Given an initial number of individual and an evaluated 
#' transition matrix, returns the number of individual per 
#' state per cycle.
#' 
#' Use the \code{method} argument to specify if transitions
#' are supposed to happen at the beginning or the end of
#' each cycle. Alternatively linear interpolation between 
#' cycles can be performed.
#' 
#' @param transition_matrix An \code{eval_matrix} object.
#' @param init numeric vector, same length as number of 
#'   model states. Number of individuals in each model state
#'   at the beginning.
#' @param method Counting method.
#'   
#' @return A \code{cycle_counts} object.
#' 
compute_counts <- function(
  transition_matrix, init,
  method
) {
  
  stopifnot(
    length(init) == get_matrix_order(transition_matrix)
  )
  
  list_counts <- Reduce(
    "%*%",
    transition_matrix,
    init,
    accumulate = TRUE
  )
  
  res <- dplyr::as.tbl(
    as.data.frame(
      matrix(
        unlist(list_counts),
        byrow = TRUE,
        ncol = get_matrix_order(transition_matrix)
      )
    )
  )
  
  colnames(res) <- get_state_names(transition_matrix)
  
  n0 <- res[- nrow(res), ]
  n1 <- res[-1, ]
  
  switch(
    method,
    "beginning" = {
      out <- n1
    },
    "end" = {
      out <- n0
    },
    "cycle-tree" = {
      stop("Unimplemented")
    },
    "half-cycle" = {
      out <- n1
      out[1, ] <- out[1, ] + init / 2
      out[nrow(out), ] <- out[nrow(out), ] + out[nrow(out), ] / 2
    },
    "spread-half-cycle" = {
      to_add <- (init + n1[nrow(n1), ]) / 2
      weights <- prop.table(as.matrix(n1), 2)
      out <- n1
      for (i in seq_len(ncol(n1))) {
        out[, i] <- n1[, i] + weights[, i] * to_add[, i]
      }
    },
    "life-table" = {
      out <- (n0 + n1) / 2
    },
    {
      stop(sprintf("Unknown counting method, '%s'.", method))
    }
  )
  
  structure(out, class = c("cycle_counts", class(out)))
  
}


#' Compute State Values per Cycle
#' 
#' Given states and counts, computes the total state values
#' per cycle.
#' 
#' @param states An object of class \code{eval_state_list}.
#' @param counts An object of class \code{cycle_counts}.
#'   
#' @return A data.frame of state values, one column per
#'   state value and one row per cycle.
#'   
compute_values <- function(states, counts) {
  
  states_names <- get_state_names(states)
  state_values_names <- get_state_value_names(states)
  
  res <- data.frame(
    markov_cycle = states[[1]]$markov_cycle
  )
  # bottleneck!
  for (state_value in state_values_names) {
    res[state_value] <- 0
    
    for (state in states_names) {
      res[state_value] <-
        res[state_value] +
        counts[, state] * 
        states[[state]][, state_value]
    }
  }
  res
}

get_state_value_names.uneval_model <- function(x) {
  get_state_value_names(get_states(x))
}

get_state_names.uneval_model <- function(x, ...) {
  get_state_names(get_states(x))
}

#' Plot Results of Markov Model
#' 
#' Various plots for Markov models.
#' 
#' \code{type = "counts"} represents state memberships 
#' (corrected) by cycle, \code{type = "ce"} plots models on
#' the cost-efficiency plane with the efficiency frontier.
#' 
#' @param x Result from \code{\link{run_models}}.
#' @param type Type of plot, see details.
#' @param model Name or position of model of interest.
#' @param ... Additional arguments passed to \code{plot}.
#'   
#' @return A \code{ggplot2} object.
#' @export
#' 
plot.eval_model_list <- function(x, type = c("counts", "ce"), model = 1, ...) {
  type <- match.arg(type)
  
  switch(
    type,
    counts = {
      tab_counts <- dplyr::mutate(
        get_counts(attr(x, "eval_model_list")[[model]]),
        markov_cycle = row_number()
      )
      pos_cycle <- pretty(seq_len(nrow(tab_counts)), n = min(nrow(tab_counts), 10))
      tab_counts <- tidyr::gather(data = tab_counts, ... = - markov_cycle)
      
      y_max <- max(attr(x, "init"), tab_counts$value)
      ggplot2::ggplot(tab_counts, ggplot2::aes(markov_cycle, value, colour = key)) +
        ggplot2::geom_line() +
        ggplot2::geom_point() +
        ggplot2::scale_x_continuous(breaks = pos_cycle) +
        ggplot2::xlab("Markov cycle") +
        ggplot2::ylab("Count") +
        ggplot2::scale_colour_hue(name = "State") +
        ggplot2::ylim(0, y_max)
    },
    ce = {
      tab_ce <- normalize_ce(x)
      ef <- get_frontier(x)
      
      ggplot2::ggplot(tab_ce,
                      ggplot2::aes(x = .effect, y = .cost, label = .model_names)) +
        ggplot2::geom_line(data = tab_ce[tab_ce$.model_names %in% ef, ]) +
        ggplot2::geom_point() +
        ggplot2::geom_text(hjust = 1) +
        ggplot2::xlab("Effect") +
        ggplot2::ylab("Cost")
    },
    stop(sprintf("Unknown type: '%s'.", type))
  )
}
if(getRversion() >= "2.15.1")
  utils::globalVariables(c("row_number", "markov_cycle", "value", "key",
                           ".cost", ".effect", ".model_names"))
