#' Predict choices
#'
#' @description
#' This function predicts the discrete choice behaviour.
#'
#' @details
#' Predictions are made based on the maximum predicted probability for each
#' choice alternative.
#'
#' See [the vignette on choice prediction](https://loelschlaeger.de/RprobitB/articles/v05_choice_prediction.html)
#' for a demonstration on how to visualize the model's sensitivity and
#' specificity by means of a receiver operating characteristic (ROC) curve.
#'
#' @param object
#' An object of class \code{RprobitB_fit}.
#'
#' @param data
#' Either
#' \itemize{
#'   \item \code{NULL}, using the data in \code{object},
#'   \item an object of class \code{RprobitB_data}, for example the test part
#'         generated by \code{\link{train_test}},
#'   \item or a data frame of custom choice characteristics. It must have the
#'         same structure as `choice_data` used in \code{\link{prepare_data}}.
#'         Missing columns or \code{NA} values are set to 0.
#' }
#'
#' @param overview \[`logical(1)`\]\cr
#' Summarize the prediction in a confusion matrix?
#'
#' @param digits \[`integer(1)`\]\cr
#' The number of digits of the returned choice probabilities.
#'
#' @param ...
#' Currently not used.
#'
#' @return
#' Either a table if \code{overview = TRUE} or a data frame otherwise.
#'
#' @examples
#' set.seed(1)
#' data <- simulate_choices(form = choice ~ cov, N = 10, T = 10, J = 2)
#' data <- train_test(data, test_proportion = 0.5)
#' model <- fit_model(data$train)
#'
#' predict(model)
#' predict(model, overview = FALSE)
#' predict(model, data = data$test)
#' predict(
#'   model,
#'   data = data.frame("cov_A" = c(1, 1, NA, NA), "cov_B" = c(1, NA, 1, NA)),
#'   overview = FALSE
#' )
#'
#' @export

predict.RprobitB_fit <- function(
    object, data = NULL, overview = TRUE, digits = 2, ...
  ) {

  ### choose data
  if (is.null(data)) {
    data <- object$data
  } else if (is.data.frame(data)) {
    cov <- object$data$res_var_names$cov
    data_build <- matrix(NA_real_, nrow = nrow(data), ncol = 1 + length(cov))
    colnames(data_build) <- c("id", cov)
    data_build[, "id"] <- 1:nrow(data)
    for (col in colnames(data)) {
      if (col %in% colnames(data_build)) {
        data_build[, col] <- data[, col]
      }
    }
    data <- prepare_data(
      form = object$data$form,
      choice_data = as.data.frame(data_build),
      re = object$data$re,
      alternatives = object$data$alternatives,
      id = "id",
      idc = NULL,
      standardize = NULL,
      impute = "zero"
    )
  }
  oeli::input_check_response(
    check = checkmate::check_class(data, "RprobitB_data"),
    var_name = "data"
  )

  ### compute choice probabilities
  choice_probs <- as.data.frame(choice_probabilities(object, data = data))

  ### round choice probabilities
  choice_probs[data$alternatives] <- round(
    choice_probs[data$alternatives], digits = digits
  )

  ### check if true choices are available
  if (data$choice_available) {
    true_choices <- data$choice_data[[data$res_var_names$choice]]
    true_choices <- factor(true_choices, labels = data$alternatives)
  }

  ### predict
  prediction <- data$alternatives[apply(choice_probs[data$alternatives], 1, which.max)]
  prediction <- factor(prediction, levels = data$alternatives)

  ### create and return output
  if (overview) {
    if (data$choice_available) {
      out <- table(true_choices, prediction, dnn = c("true", "predicted"))
    } else {
      out <- table(prediction, dnn = c("prediction"))
    }
  } else {
    if (data$choice_available) {
      out <- cbind(
        choice_probs,
        "true" = true_choices, "predicted" = prediction,
        "correct" = (true_choices == prediction)
      )
    } else {
      out <- cbind(choice_probs, "prediction" = prediction)
    }
  }
  return(out)
}
