#' @template surv_learner
#' @templateVar title Gradient Boosting for Generalized Additive Models
#' @templateVar fullname LearnerSurvMboost
#' @templateVar caller [mboost::mboost()]
#' @templateVar distr by [mboost::survFit()] which assumes a PH fit with a Breslow estimator
#' @templateVar lp by [mboost::predict.mboost()]
#'
#' @template learner_boost
#' @description
#' The only difference between [LearnerSurvGamboost] and [LearnerSurvMboost] is that the latter function
#' allows one to specify default degrees of freedom for smooth effects specified via
#' \code{baselearner = "bbs"}. In all other cases, degrees of freedom need to be set manually via a
#' specific definition of the corresponding base-learner.
#'
#' @references
#' \cite{mlr3proba}{buehlmann_2003}
#'
#' \cite{mlr3proba}{buehlmann_2007}
#'
#' \cite{mlr3proba}{buehlmann_2007}
#'
#' \cite{mlr3proba}{kneib_2008}
#'
#' \cite{mlr3proba}{schmid_2008}
#'
#' \cite{mlr3proba}{hothorn_2010}
#'
#' \cite{mlr3proba}{hofner_2012}
#'
#' @export
#' @examples
#' library(mlr3)
#' task = tgen("simsurv")$generate(20)
#' learner = lrn("surv.mboost")
#' learner$param_set$values = mlr3misc::insert_named(
#'   learner$param_set$values,
#'   list(center = TRUE, baselearner = "bols"))
#' resampling = rsmp("cv", folds = 2)
#' resample(task, learner, resampling)
LearnerSurvMboost = R6Class("LearnerSurvMboost",
  inherit = LearnerSurv,
  public = list(
    #' @description
    #' Creates a new instance of this [R6][R6::R6Class] class.
    initialize = function() {
      ps = ParamSet$new(
        params = list(
          ParamFct$new(
            id = "family", default = "coxph",
            levels = c(
              "coxph", "weibull", "loglog", "lognormal", "gehan", "cindex",
              "custom"), tags = c("train", "family")),
          ParamUty$new(id = "custom.family", tags = c("train", "family")),
          ParamUty$new(id = "nuirange", default = c(0, 100), tags = c("train", "aft")),
          ParamUty$new(id = "offset", tags = "train"),
          ParamLgl$new(id = "center", default = TRUE, tags = "train"),
          ParamInt$new(id = "mstop", default = 100L, lower = 0L, tags = "train"),
          ParamDbl$new(id = "nu", default = 0.1, lower = 0, upper = 1, tags = "train"),
          ParamFct$new(id = "risk", levels = c("inbag", "oobag", "none"), tags = "train"),
          ParamLgl$new(id = "stopintern", default = FALSE, tags = "train"),
          ParamLgl$new(id = "trace", default = FALSE, tags = "train"),
          ParamUty$new(id = "oobweights", tags = "train"),
          ParamFct$new(
            id = "baselearner", default = "bbs",
            levels = c("bbs", "bols", "btree"), tags = "train"),
          ParamDbl$new(
            id = "sigma", default = 0.1, lower = 0, upper = 1,
            tags = c("train", "cindex")),
          ParamUty$new(id = "ipcw", default = 1, tags = c("train", "cindex"))
        )
      )

      ps$values = list(family = "coxph")
      ps$add_dep("sigma", "family", CondEqual$new("cindex"))
      ps$add_dep("ipcw", "family", CondEqual$new("cindex"))

      super$initialize(
        id = "surv.mboost",
        param_set = ps,
        feature_types = c("integer", "numeric", "factor", "logical"),
        predict_types = c("distr", "crank", "lp", "response"),
        properties = c("weights", "importance", "selected_features"),
        packages = c("mboost", "distr6", "survival")
      )
    },

    #' @description
    #' The importance scores are extracted with the function [mboost::varimp()] with the
    #' default arguments.
    #' @return Named `numeric()`.
    importance = function() {
      if (is.null(self$model)) {
        stopf("No model stored")
      }

      vimp = as.numeric(mboost::varimp(self$model))
      names(vimp) = unname(variable.names(self$model))

      sort(vimp, decreasing = TRUE)
    },

    #' @description
    #' Selected features are extracted with the function [mboost::variable.names.mboost()], with
    #' `used.only = TRUE`.
    #' @return `character()`.
    selected_features = function() {
      if (is.null(self$model)) {
        stopf("No model stored")
      }

      unname(variable.names(self$model, usedonly = TRUE))
    }
  ),

  private = list(
    .train = function(task) {

      pars = self$param_set$get_values(tags = "train")

      if ("weights" %in% task$properties) {
        pars$weights = task$weights$weight
      }

      # Save control settings and return on exit
      saved_ctrl = mboost::boost_control()
      on.exit(mlr3misc::invoke(mboost::boost_control, .args = saved_ctrl))
      is_ctrl_pars = (names(pars) %in% names(saved_ctrl))

      # ensure only relevant pars passed to fitted model
      if (any(is_ctrl_pars)) {
        pars$control = do.call(mboost::boost_control, pars[is_ctrl_pars])
        pars = pars[!is_ctrl_pars]
      }

      family = switch(pars$family,
        coxph = mboost::CoxPH(),
        weibull = mlr3misc::invoke(mboost::Weibull,
          .args = self$param_set$get_values(tags = "aft")),
        loglog = mlr3misc::invoke(mboost::Loglog,
          .args = self$param_set$get_values(tags = "aft")),
        lognormal = mlr3misc::invoke(mboost::Lognormal,
          .args = self$param_set$get_values(tags = "aft")),
        gehan = mboost::Gehan(),
        cindex = mlr3misc::invoke(mboost::Cindex,
          .args = self$param_set$get_values(tags = "cindex")),
        custom = pars$custom.family
      )

      # FIXME - until issue closes
      pars = pars[!(pars %in% self$param_set$get_values(tags = c("aft")))]
      pars = pars[!(pars %in% self$param_set$get_values(tags = c("cindex")))]
      pars = pars[!(pars %in% self$param_set$get_values(tags = c("family")))]

      with_package("mboost", {
        mlr3misc::invoke(mboost::mboost,
          formula = task$formula(task$feature_names),
          data = task$data(), family = family, .args = pars)
      })
    },

    .predict = function(task) {

      newdata = task$data(cols = task$feature_names)
      # predict linear predictor
      lp = as.numeric(mlr3misc::invoke(predict, self$model, newdata = newdata, type = "link"))

      # predict survival
      surv = mlr3misc::invoke(mboost::survFit, self$model, newdata = newdata)
      surv$cdf = 1 - surv$surv

      # define WeightedDiscrete distr6 object from predicted survival function
      x = rep(list(data = data.frame(x = surv$time, cdf = 0)), task$nrow)
      for (i in 1:task$nrow) {
        x[[i]]$cdf = surv$cdf[, i]
      }

      distr = distr6::VectorDistribution$new(
        distribution = "WeightedDiscrete", params = x,
        decorators = c("CoreStatistics", "ExoticStatistics"))

      response = NULL
      if (!is.null(self$param_set$values$family)) {
        if (self$param_set$values$family %in% c("weibull", "loglog", "lognormal", "gehan")) {
          response = exp(lp)
        }
      }

      PredictionSurv$new(task = task, crank = lp, distr = distr, lp = lp, response = response)
    }
  )
)
