library(distfreereg)

all.equal.distfreereg <- distfreereg:::all.equal.distfreereg
test_dfr_functions <- distfreereg:::test_dfr_functions

n <- 1e2
func <- function(X, theta) theta[1] + theta[2]*X[,1] + theta[3]*X[,2]
set.seed(20250516)
Sig <- rWishart(1, df = n, Sigma = diag(n))[,,1]
X <- matrix(rexp(2*n, rate = 1), nrow = n)
theta <- c(2,5,1)
Y <- distfreereg:::f2ftheta(f = func, X)(theta) +
  as.vector(distfreereg:::rmvnorm(n = n, reps = 1, mean = rep(0,n), SqrtSigma = distfreereg:::matsqrt(Sig)))

df_lm <- as.data.frame(cbind(Y, X, rep(1:10, 10)))
names(df_lm) <- c("z", "x", "y", "g")
set.seed(20250516)
wt <- rexp(n) + 1
set.seed(20250516)
dfr_form_lm <- distfreereg(test_mean = z ~ x + y, data = df_lm,
                           method_args = list(weights = wt),
                           method = "lm", verbose = FALSE,
                           control = list(return_on_error = FALSE))

dfr_form_lm

set.seed(20250516)
dfr_form_lm_no_weights <- distfreereg(test_mean = z ~ x + y, data = df_lm,
                                      method = "lm")

newdata_lm <- data.frame(a = rnorm(10), b = rnorm(10))
test_dfr_functions(dfr_form_lm, newdata = newdata_lm)


m <- lm(z ~ x + y, data = df_lm, weights = wt)

set.seed(20250516)
dfr_lm <- distfreereg(test_mean = m, verbose = FALSE,
                      control = list(return_on_error = FALSE))
set.seed(20250516)
dfr_lm_verbose <- distfreereg(test_mean = m,
                              control = list(return_on_error = FALSE),
                              override = list(J = dfr_lm[["J"]],
                                              fitted_values = dfr_lm[["fitted_values"]]))

dfr_lm
test_dfr_functions(dfr_lm, newdata = newdata_lm)

stopifnot(all.equal(dfr_lm, dfr_form_lm))
stopifnot(all.equal(dfr_lm, dfr_lm_verbose))

cdfr_form_lm <- asymptotics(dfr_form_lm, reps = 5)
cdfr_lm <- asymptotics(dfr_lm, reps = 5)

signif(rejection(cdfr_form_lm, alpha = c(0.1, 0.5))[,2:3], digits = 3)
signif(rejection(cdfr_lm, alpha = c(0.1, 0.5))[,2:3], digits = 3)


dfr_lm_but_not_lm <- dfr_lm
class(dfr_lm_but_not_lm[["test_mean"]]) <- "wrong"
tryCatch(asymptotics(dfr_lm_but_not_lm, reps = 5),
         error = function(e) warning(e))

vcov(dfr_lm, jacobian_args = list("ignored"))

# Orderings

set.seed(20250516)
dfr_lm_asis <- update(dfr_lm, ordering = "asis")
set.seed(20250516)
dfr_form_lm_asis <- update(dfr_form_lm, ordering = "asis")
stopifnot(all.equal(dfr_lm_asis, dfr_form_lm_asis))

set.seed(20250516)
dfr_lm_optimal <- update(dfr_lm, ordering = "optimal")
set.seed(20250516)
dfr_form_lm_optimal <- update(dfr_form_lm, ordering = "optimal")
stopifnot(all.equal(dfr_lm_optimal, dfr_form_lm_optimal))

set.seed(20250516)
dfr_lm_natural <- update(dfr_lm, ordering = "natural")
set.seed(20250516)
dfr_form_lm_natural <- update(dfr_form_lm, ordering = "natural")
stopifnot(all.equal(dfr_lm_natural, dfr_form_lm_natural))

dfr_lm_natural_grouped <- update(dfr_lm, ordering = "natural", group = TRUE,
                                 verbose = TRUE)

set.seed(20250516)
dfr_lm_g_character <- update(dfr_lm, ordering = list("g"))

dfr_lm_g_character

set.seed(20250516)
dfr_form_lm_g <- update(dfr_form_lm, ordering = list("g"))
stopifnot(all.equal(dfr_lm_g_character, dfr_form_lm_g))

df_lm[dfr_lm_g_character[["res_order"]],][["g"]]
df_lm[dfr_form_lm_g[["res_order"]],][["g"]]

set.seed(20250516)
dfr_lm_g_character_grouped <- update(dfr_lm_g_character, group = TRUE)
set.seed(20250516)
dfr_form_lm_g_grouped <- update(dfr_form_lm_g, group = TRUE)
stopifnot(all.equal(dfr_lm_g_character_grouped, dfr_form_lm_g_grouped))



### Failures

tryCatch(update(dfr_lm, theta_init = 1), error = function(e) warning(e))
tryCatch(update(dfr_lm, ordering = list("h")), error = function(e) warning(e))
tryCatch(update(dfr_lm, ordering = list(1, "g")), error = function(e) warning(e))
tryCatch(update(dfr_lm, ordering = c(1)), error = function(e) warning(e))
tryCatch(update(dfr_lm, ordering = list(1:10)), error = function(e) warning(e))

dfr_lm_fail <- update(dfr_lm, override = list(r = matrix(rnorm(length(dfr_lm[["J"]])), nrow = n)),
                      control = list(return_on_error = TRUE))

tryCatch(update(dfr_form_lm, test_mean = z + y ~ x), error = function(e) warning(e))

tryCatch(update(dfr_form_lm, group = "hello"), error = function(e) warning(e))

