context("ModelMultiplexer")


test_that("makeModelMultiplexerParamSet works", {
  bls = list(
    makeLearner("classif.ksvm"),
    makeLearner("classif.randomForest")
  )

  lrn = makeModelMultiplexer(bls)

  ps1 = makeModelMultiplexerParamSet(lrn,
    makeNumericParam("sigma", lower = -10, upper = 10, trafo = function(x) 2^x),
    makeIntegerParam("ntree", lower = 1L, upper = 500L)
  )

  ps2 = makeModelMultiplexerParamSet(lrn,
    classif.ksvm = makeParamSet(makeNumericParam("sigma", lower = -10, upper = 10, trafo = function(x) 2^x)),
    classif.randomForest = makeParamSet(makeIntegerParam("ntree", lower = 1L, upper = 500L))
  )

  ps3 = makeParamSet(
    makeDiscreteParam("selected.learner", values = extractSubList(bls, "id")),
    makeNumericParam("classif.ksvm.sigma", lower=-10, upper = 10, trafo = function(x) 2^x,
      requires = quote(selected.learner == "classif.ksvm")),
    makeIntegerParam("classif.randomForest.ntree", lower = 1L, upper = 500L,
      requires = quote(selected.learner == "classif.randomForest"))
  )

  expect_equal(ps1, ps2)
  expect_equal(ps2, ps3)
  expect_equal(ps1, ps3)
})

# this is more or less a test for BaseEnsemble, that hyperpars work and so on
test_that("ModelMultiplexer basic stuff works", {
  lrn = makeModelMultiplexer(c("classif.lda", "classif.rpart"))
  expect_equal(class(lrn), c("ModelMultiplexer", "BaseEnsemble", "Learner"))

  # check hyper par setting and so on
  lrn2 = setHyperPars(lrn, selected.learner = "classif.rpart", classif.rpart.minsplit = 10000L)
  xs = getHyperPars(lrn2)
  expect_true(setequal(names(xs), c("selected.learner", "classif.rpart.minsplit", "classif.rpart.xval")))
  expect_equal(xs$classif.rpart.minsplit, 10000L)
  mod = train(lrn2, task = binaryclass.task)
  expect_equal(mod$learner.model$learner.model$control$minsplit, 10000L)

  # check removal
  lrn3 = removeHyperPars(lrn2, "classif.rpart.minsplit")
  xs = getHyperPars(lrn3)
  expect_true(setequal(names(xs), c("selected.learner", "classif.rpart.xval")))

  # check predict.type
  lrn2 = setPredictType(lrn, "prob")
  mod = train(lrn2, task = binaryclass.task)
  p = predict(mod, task = binaryclass.task)
  getProbabilities(p)
})


test_that("ModelMultiplexer tuning", {
  lrn = makeModelMultiplexer(c("classif.knn", "classif.rpart"))

  rdesc = makeResampleDesc("CV", iters = 2L)

  tune.ps = makeModelMultiplexerParamSet(lrn,
    makeIntegerParam("minsplit", lower = 1, upper = 50))
  # tune with random
  ctrl = makeTuneControlRandom(maxit = 4L)
  res = tuneParams(lrn, binaryclass.task, rdesc, par.set = tune.ps, control = ctrl)
  expect_true(setequal(class(res), c("TuneResult", "OptResult")))
  y = getOptPathY(res$opt.path)
  expect_true(!is.na(y) && is.finite(y))
  # tune with irace
  task = subsetTask(binaryclass.task, subset = c(1:20, 150:170))
  ctrl = makeTuneControlIrace(maxExperiments = 40L, nbIterations = 2L, minNbSurvival = 1L)
  res = tuneParams(lrn, task, rdesc, par.set = tune.ps, control = ctrl)
  expect_true(setequal(class(res), c("TuneResult", "OptResult")))
  y = getOptPathY(res$opt.path)
  expect_true(!is.na(y) && is.finite(y))
})
