#include <RcppArmadillo.h>
using namespace Rcpp;
//[[Rcpp::depends(RcppArmadillo)]]
//[[Rcpp::export]]
List find_best_split_cox_bone(
    arma::vec time, 
    arma::uvec event, 
    arma::mat xx_numeric,
    arma::mat xx_factor,
    arma::vec weights,
    double min_weights=50.0,
    int cut_type=0) {
  
  arma::vec alltime=time.elem(arma::find(event));
  alltime=arma::sort(arma::unique(alltime));
  int ntime=alltime.n_elem;
  
  arma::uvec idx_sort=arma::sort_index(time);
  
  arma::vec time_sorted=time.elem(idx_sort);
  arma::uvec event_sorted=event.elem(idx_sort);
  arma::mat xx_numeric_sorted=xx_numeric.rows(idx_sort);
  arma::mat xx_factor_sorted=xx_factor.rows(idx_sort);
  arma::vec weights_sorted=weights.elem(idx_sort);
  
  int nind=xx_numeric.n_rows;
  int ndim_numeric=xx_numeric.n_cols;
  int ndim_factor=xx_factor.n_cols;
  arma::vec xx_numeric_column;
  arma::vec xx_factor_column;
  
  arma::vec time_sub;
  arma::uvec event_sub;
  arma::vec weights_sub;
  arma::vec xx_numeric_sub;
  arma::vec xx_factor_sub;
  
  // arma::uvec bool_to_split(nind);
  arma::uvec idx_can_split(nind);
  arma::uvec to_unsure(nind);
  arma::uvec to_left(nind);
  arma::uvec to_right(nind);
  
  arma::vec weights_left,weights_right;
  double sum_weights_left,sum_weights_right;
  
  arma::vec xx_numeric_unique;
  arma::vec xx_factor_unique;
  
  int ii,jj,kk;
  int ttidx,nsplit,nsub,nfactor,splitidx;
  double current_rcumsum_rr,current_rcumsum_rr1,sum_dd1;
  arma::vec dd,frac;
  double zscore,a_split;
  double best_zscore=0.0,best_pvalue=1.0,best_chisq=0.0,best_split_numeric=0.0;
  int best_jj=-1;
  double expected_zscore,variance_zscore;
  arma::umat factor_combinations;
  arma::uvec row_factor_combinations;
  arma::vec factor_left;
  arma::vec best_split_factor_left;
  
  arma::vec zscore_numeric(ndim_numeric,arma::fill::zeros);
  arma::vec zscore_factor(ndim_factor,arma::fill::zeros);
  
  for(jj=0;jj<ndim_numeric;jj++){
    xx_numeric_column=xx_numeric_sorted.col(jj);
    idx_can_split=arma::find(xx_numeric_column<=arma::datum::inf);
    
    xx_numeric_sub=xx_numeric_column.elem(idx_can_split);
    time_sub=time_sorted.elem(idx_can_split);
    event_sub=event_sorted.elem(idx_can_split);
    weights_sub=weights_sorted.elem(idx_can_split);
    
    nsub=xx_numeric_sub.n_elem;
    
    xx_numeric_unique=arma::sort(arma::unique(xx_numeric_sub));
    nsplit=xx_numeric_unique.n_elem;
    
    for(splitidx=0;splitidx<nsplit-1;splitidx++){
      if(cut_type==0){
        a_split=xx_numeric_unique(splitidx);
      }else if(cut_type==1){
        a_split=0.5*(xx_numeric_unique(splitidx)+xx_numeric_unique(splitidx+1));
      }else if(cut_type==2){
        a_split=R::runif(xx_numeric_unique(splitidx),xx_numeric_unique(splitidx+1));
      }else{
        a_split=0.5*(xx_numeric_unique(splitidx)+xx_numeric_unique(splitidx+1));
      }
      
      to_left=xx_numeric_sub<=a_split;
      to_right=xx_numeric_sub>a_split;
      
      weights_left=weights_sub.elem(arma::find(to_left));
      sum_weights_left=arma::sum(weights_left);
      weights_right=weights_sub.elem(arma::find(to_right));
      sum_weights_right=arma::sum(weights_right);
      if(sum_weights_left<min_weights)continue;
      if(sum_weights_right<min_weights)continue;
      
      sum_dd1=0.0;
      dd.zeros(ntime);
      frac.zeros(ntime);
      ii=nsub-1;
      current_rcumsum_rr=0.0;
      current_rcumsum_rr1=0.0;
      for(ttidx=ntime-1;ttidx>=0;ttidx--){
        while(ii>=0){
          if(time_sub(ii)>=alltime(ttidx)){
            current_rcumsum_rr+=weights_sub(ii);
            if(to_right(ii))current_rcumsum_rr1+=weights_sub(ii);
            if(event_sub(ii))dd(ttidx)+=weights_sub(ii);
            if(event_sub(ii)&&to_right(ii))sum_dd1+=weights_sub(ii);
            ii--;
          }else{
            break;
          }
        }
        frac(ttidx)=current_rcumsum_rr==0.0?0.0:current_rcumsum_rr1/current_rcumsum_rr;
      }
      expected_zscore=sum_dd1-arma::sum(dd%frac);
      variance_zscore=sum(dd%(frac-frac%frac));
      zscore=expected_zscore/std::sqrt(variance_zscore);
      if(std::abs(zscore)>std::abs(zscore_numeric(jj)))zscore_numeric(jj)=zscore;
      if(std::abs(zscore)>std::abs(best_zscore)){
        best_zscore=zscore;
        best_split_numeric=a_split;
        best_jj=jj;
      }
    }
  }
  
  for(jj=0;jj<ndim_factor;jj++){
    xx_factor_column=xx_factor_sorted.col(jj);
    idx_can_split=arma::find(xx_factor_column<=arma::datum::inf);
    
    xx_factor_sub=xx_factor_column.elem(idx_can_split);
    time_sub=time_sorted(idx_can_split);
    event_sub=event_sorted(idx_can_split);
    weights_sub=weights_sorted(idx_can_split);
    
    nsub=xx_factor_sub.n_elem;
    
    xx_factor_unique=arma::sort(arma::unique(xx_factor_sub));
    nfactor=xx_factor_unique.n_elem;
    if(nfactor<=1)continue;
    nsplit=std::pow(2,nfactor-1);
    factor_combinations=arma::zeros<arma::umat>(nfactor,nsplit);
    for(kk=0;kk<nfactor;kk++){
      for(ii=0;ii<nsplit;ii++){
        factor_combinations(kk,ii)=(ii&(1<<kk))!=0;
      }
    }
    for(splitidx=1;splitidx<nsplit;splitidx++){
      row_factor_combinations=factor_combinations.col(splitidx);
      factor_left=xx_factor_unique.elem(arma::find(row_factor_combinations));
      to_left=arma::uvec(nsub);
      for(ii=0;ii<nsub;ii++){
        to_left(ii)=0;
        for(kk=0;kk<factor_left.n_elem;kk++){
          if(factor_left(kk)==xx_factor_sub(ii)){
            to_left(ii)=1;
            break;
          }
        }
      }
      to_right=to_left==0;
      
      weights_left=weights_sub(arma::find(to_left));
      sum_weights_left=arma::sum(weights_left);
      weights_right=weights_sub(arma::find(to_right));
      sum_weights_right=arma::sum(weights_right);
      if(sum_weights_left<min_weights)continue;
      if(sum_weights_right<min_weights)continue;
      
      sum_dd1=0.0;
      dd.zeros(ntime);
      frac.zeros(ntime);
      ii=nsub-1;
      current_rcumsum_rr=0.0;
      current_rcumsum_rr1=0.0;
      for(ttidx=ntime-1;ttidx>=0;ttidx--){
        while(ii>=0){
          if(time_sub(ii)>=alltime(ttidx)){
            current_rcumsum_rr+=weights_sub(ii);
            if(to_right(ii))current_rcumsum_rr1+=weights_sub(ii);
            if(event_sub(ii))dd(ttidx)+=weights_sub(ii);
            if(event_sub(ii)&&to_right(ii))sum_dd1+=weights_sub(ii);
            ii--;
          }else{
            break;
          }
        }
        frac(ttidx)=current_rcumsum_rr==0.0?0.0:current_rcumsum_rr1/current_rcumsum_rr;
      }
      expected_zscore=sum_dd1-arma::sum(dd%frac);
      variance_zscore=sum(dd%(frac-frac%frac));
      zscore=expected_zscore/std::sqrt(variance_zscore);
      if(std::abs(zscore)>std::abs(zscore_factor(jj)))zscore_factor(jj)=zscore;
      if(std::abs(zscore)>std::abs(best_zscore)){
        best_zscore=zscore;
        best_split_factor_left=factor_left;
        best_jj=ndim_numeric+jj;
      }
    }
  }
  
  LogicalVector to_left_rcpp(nind);
  LogicalVector to_right_rcpp(nind);
  LogicalVector to_unsure_rcpp(nind);
  
  if(best_jj<0){
    
    best_split_numeric=NumericVector::get_na();
    best_split_factor_left=NumericVector::get_na();
    to_left_rcpp=NumericVector::get_na();
    to_right_rcpp=NumericVector::get_na();
    to_unsure_rcpp=NumericVector::get_na();
    weights_left=NumericVector::get_na();
    weights_right=NumericVector::get_na();
    
  }else if(best_jj<ndim_numeric){
    xx_numeric_column=xx_numeric.col(best_jj);
    idx_can_split=xx_numeric_column<=arma::datum::inf;
    to_unsure=idx_can_split==0;
    to_left=idx_can_split&&(xx_numeric_column<=best_split_numeric);
    to_right=idx_can_split&&(xx_numeric_column>best_split_numeric);
    
    weights_left=weights.elem(arma::find(to_left));
    sum_weights_left=arma::sum(weights_left);
    weights_right=weights.elem(arma::find(to_right));
    sum_weights_right=arma::sum(weights_right);
    
    best_chisq=best_zscore*best_zscore;
    best_pvalue=R::pchisq(best_chisq,1.0,false,false);
    
    to_left_rcpp=LogicalVector(to_left.begin(),to_left.end());
    to_right_rcpp=LogicalVector(to_right.begin(),to_right.end());
    to_unsure_rcpp=LogicalVector(to_unsure.begin(),to_unsure.end());
    best_split_factor_left=NumericVector::get_na();
  }else{
    xx_factor_column=xx_factor.col(best_jj-ndim_numeric);
    idx_can_split=xx_factor_column<=arma::datum::inf;
    to_unsure=idx_can_split==0;
    
    to_left.zeros(nind);
    for(ii=0;ii<nind;ii++){
      if(to_unsure(ii))continue;
      for(kk=0;kk<best_split_factor_left.n_elem;kk++){
        if(best_split_factor_left(kk)==xx_factor_column(ii)){
          to_left(ii)=true;
          break;
        }
      }
    }
    to_right=idx_can_split&&(to_left==0);
    
    weights_left=weights(arma::find(to_left));
    sum_weights_left=arma::sum(weights_left);
    weights_right=weights(arma::find(to_right));
    sum_weights_right=arma::sum(weights_right);
    
    best_chisq=best_zscore*best_zscore;
    best_pvalue=R::pchisq(best_chisq,1.0,false,false);
    
    best_split_numeric=NumericVector::get_na();
    
    to_left_rcpp=LogicalVector(to_left.begin(),to_left.end());
    to_right_rcpp=LogicalVector(to_right.begin(),to_right.end());
    to_unsure_rcpp=LogicalVector(to_unsure.begin(),to_unsure.end());
  }
  
  arma::vec all_zscore=arma::join_cols(zscore_numeric,zscore_factor);
  arma::vec all_chisq=all_zscore%all_zscore;
  arma::vec all_pvalue(ndim_numeric+ndim_factor);
  for(jj=0;jj<ndim_numeric+ndim_factor;jj++){
    all_pvalue(jj)=R::pchisq(all_chisq(jj),1.0,false,false);
  }
  
  List result=Rcpp::List::create(
    Rcpp::Named("best_zscore") = best_zscore,
    Rcpp::Named("best_chisq") = best_zscore*best_zscore,
    Rcpp::Named("best_pvalue") = best_pvalue,
    Rcpp::Named("best_split_numeric") = best_split_numeric,
    Rcpp::Named("best_split_factor_left") = NumericVector(best_split_factor_left.begin(),best_split_factor_left.end()),
    Rcpp::Named("best_jj") = best_jj+1,
    Rcpp::Named("to_left") = to_left_rcpp,
    Rcpp::Named("to_right") = to_right_rcpp,
    Rcpp::Named("to_unsure") = to_unsure_rcpp,
    Rcpp::Named("sum_weights_left") = sum_weights_left,
    Rcpp::Named("sum_weights_right") = sum_weights_right,
    Rcpp::Named("all_zscore") = NumericVector(all_zscore.begin(),all_zscore.end()),
    Rcpp::Named("all_pvalue") = NumericVector(all_pvalue.begin(),all_pvalue.end()));
  return(result);
}

