/*
 * Decompiled with CFR 0.152.
 */
package org.extratrees;

import java.util.ArrayList;
import java.util.HashSet;
import java.util.Set;
import org.extratrees.AbstractTrees;
import org.extratrees.Aggregator;
import org.extratrees.BinaryTree;
import org.extratrees.TaskCutResult;
import org.extratrees.data.Array2D;

public class ExtraTrees
extends AbstractTrees<BinaryTree, Double> {
    double[] output;
    double[] outputSq;

    public ExtraTrees(Array2D array2D, double[] dArray) {
        this(array2D, dArray, null);
    }

    public ExtraTrees(Array2D array2D, double[] dArray, int[] nArray) {
        if (array2D.nrows() != dArray.length) {
            throw new IllegalArgumentException("Input and output do not have same length.");
        }
        if (nArray != null && array2D.nrows() != nArray.length) {
            throw new IllegalArgumentException("Input and tasks do not have the same number of data points.");
        }
        this.setInput(array2D);
        this.output = dArray;
        this.outputSq = new double[dArray.length];
        for (int i = 0; i < dArray.length; ++i) {
            this.outputSq[i] = this.output[i] * this.output[i];
        }
        this.setTasks(nArray);
    }

    public ExtraTrees selectTrees(boolean[] blArray) {
        ExtraTrees extraTrees = new ExtraTrees(this.input, this.output);
        extraTrees.trees = new ArrayList();
        for (int i = 0; i < blArray.length; ++i) {
            if (!blArray[i]) continue;
            extraTrees.trees.add(this.trees.get(i));
        }
        return extraTrees;
    }

    @Override
    Aggregator<Double> getNewAggregator() {
        return new ArithmeticMean();
    }

    @Override
    double convertToDouble(Double d) {
        return d;
    }

    protected static double[] listToDArray(ArrayList<Double> arrayList) {
        double[] dArray = new double[arrayList.size()];
        for (int i = 0; i < dArray.length; ++i) {
            dArray[i] = arrayList.get(i);
        }
        return dArray;
    }

    public double[] getValues(Array2D array2D) {
        return ExtraTrees.listToDArray(this.getValuesD(array2D));
    }

    public double[] getValuesMT(Array2D array2D, int[] nArray) {
        return ExtraTrees.listToDArray(this.getValuesMTD(array2D, nArray));
    }

    @Override
    protected BinaryTree makeFilledTree(BinaryTree binaryTree, BinaryTree binaryTree2, int n, double d, int n2) {
        BinaryTree binaryTree3 = new BinaryTree();
        binaryTree3.column = n;
        binaryTree3.threshold = d;
        binaryTree3.nSuccessors = n2;
        binaryTree3.left = binaryTree;
        binaryTree3.right = binaryTree2;
        binaryTree3.value = (Double)((BinaryTree)binaryTree3.left).value * (double)((BinaryTree)binaryTree3.left).nSuccessors + (Double)((BinaryTree)binaryTree3.right).value * (double)((BinaryTree)binaryTree3.right).nSuccessors;
        BinaryTree binaryTree4 = binaryTree3;
        binaryTree4.value = (Double)binaryTree4.value / (double)binaryTree3.nSuccessors;
        return binaryTree3;
    }

    @Override
    protected double get1NaNScore(int[] nArray) {
        double d = 0.0;
        double d2 = 0.0;
        double d3 = 0.0;
        for (int i = 0; i < nArray.length; ++i) {
            int n = nArray[i];
            double d4 = this.useWeights ? this.weights[n] : 1.0;
            d3 += d4;
            d += this.output[n] * d4;
            d2 += this.outputSq[n] * d4;
        }
        double d5 = d2 / d3 - d / d3 * (d / d3);
        return d5;
    }

    @Override
    protected void calculateCutScore(int[] nArray, int n, double d, AbstractTrees.CutResult cutResult) {
        double d2 = 0.0;
        double d3 = 0.0;
        double d4 = 0.0;
        double d5 = 0.0;
        double d6 = 0.0;
        double d7 = 0.0;
        for (int i = 0; i < nArray.length; ++i) {
            int n2 = nArray[i];
            double d8 = this.useWeights ? this.weights[n2] : 1.0;
            double d9 = this.input.get(n2, n);
            if (this.hasNaN && Double.isNaN(d9)) {
                cutResult.nanWeigth += d8;
                continue;
            }
            if (d9 < d) {
                ++cutResult.countLeft;
                d6 += d8;
                d2 += this.output[n2] * d8;
                d4 += this.outputSq[n2] * d8;
                continue;
            }
            ++cutResult.countRight;
            d7 += d8;
            d3 += this.output[n2] * d8;
            d5 += this.outputSq[n2] * d8;
        }
        this.cutResultFromSums(cutResult, d2, d3, d4, d5, d6, d7);
    }

    private void cutResultFromSums(AbstractTrees.CutResult cutResult, double d, double d2, double d3, double d4, double d5, double d6) {
        double d7 = d3 / d5 - d / d5 * (d / d5);
        double d8 = d4 / d6 - d2 / d6 * (d2 / d6);
        cutResult.score = d5 * d7 + d6 * d8;
        cutResult.leftConst = d7 < 9.999999999999998E-15;
        cutResult.rightConst = d8 < 9.999999999999998E-15;
    }

    @Override
    protected TaskCutResult getTaskCut(int[] nArray, Set<Integer> set, double d) {
        if (set.size() <= 1) {
            return null;
        }
        double d2 = this.getOutputMean(nArray);
        int[] nArray2 = new int[this.nTasks];
        double[] dArray = new double[this.nTasks];
        double[] dArray2 = new double[this.nTasks];
        double[] dArray3 = new double[this.nTasks];
        double[] dArray4 = this.getTaskScores(nArray, d2, set, nArray2, dArray, dArray2, dArray3);
        if (!this.hasAtLeast2Tasks(nArray)) {
            return null;
        }
        double[] dArray5 = ExtraTrees.getRange(dArray4);
        TaskCutResult taskCutResult = null;
        for (int i = 0; i < this.numRandomTaskCuts; ++i) {
            double d3 = this.getRandom(dArray5[0], dArray5[1]);
            TaskCutResult taskCutResult2 = new TaskCutResult();
            this.calculateTaskCutScore(dArray4, nArray2, dArray, dArray2, dArray3, d2, d3, taskCutResult2, set);
            if (!(taskCutResult2.score < d)) continue;
            taskCutResult = taskCutResult2;
            d = taskCutResult2.score;
        }
        return taskCutResult;
    }

    private void calculateTaskCutScore(double[] dArray, int[] nArray, double[] dArray2, double[] dArray3, double[] dArray4, double d, double d2, TaskCutResult taskCutResult, Set<Integer> set) {
        double d3 = 0.0;
        double d4 = 0.0;
        double d5 = 0.0;
        double d6 = 0.0;
        double d7 = 0.0;
        double d8 = 0.0;
        taskCutResult.leftTasks = new HashSet<Integer>();
        taskCutResult.rightTasks = new HashSet<Integer>();
        taskCutResult.countLeft = 0;
        taskCutResult.countRight = 0;
        for (int n : set) {
            if (dArray[n] < d2) {
                taskCutResult.leftTasks.add(n);
                taskCutResult.countLeft += nArray[n];
                d7 += dArray2[n];
                d3 += dArray3[n];
                d5 += dArray4[n];
                continue;
            }
            taskCutResult.rightTasks.add(n);
            taskCutResult.countRight += nArray[n];
            d8 += dArray2[n];
            d4 += dArray3[n];
            d6 += dArray4[n];
        }
        this.cutResultFromSums(taskCutResult, d3, d4, d5, d6, d7, d8);
    }

    private boolean hasAtLeast2Tasks(int[] nArray) {
        int n = this.tasks[nArray[0]];
        for (int i = 1; i < nArray.length; ++i) {
            if (n == this.tasks[nArray[i]]) continue;
            return true;
        }
        return false;
    }

    private double[] getTaskScores(int[] nArray, double d, Set<Integer> set, int[] nArray2, double[] dArray, double[] dArray2, double[] dArray3) {
        double d2 = 1.0;
        double[] dArray4 = new double[this.nTasks];
        for (int i = 0; i < nArray.length; ++i) {
            int n = nArray[i];
            int n2 = this.tasks[n];
            nArray2[n2] = nArray2[n2] + 1;
            int n3 = this.tasks[n];
            dArray2[n3] = dArray2[n3] + this.output[n];
            int n4 = this.tasks[n];
            dArray3[n4] = dArray3[n4] + this.outputSq[n];
        }
        for (int n : set) {
            dArray[n] = nArray2[n];
            dArray4[n] = (dArray2[n] + d * d2) / ((double)nArray2[n] + d2);
        }
        return dArray4;
    }

    private double getOutputMean(int[] nArray) {
        double d = 0.0;
        for (int i = 0; i < nArray.length; ++i) {
            d += this.output[nArray[i]];
        }
        return d /= (double)nArray.length;
    }

    @Override
    public BinaryTree makeLeaf(int[] nArray, Set<Integer> set) {
        BinaryTree binaryTree = new BinaryTree();
        binaryTree.value = 0.0;
        binaryTree.nSuccessors = nArray.length;
        binaryTree.tasks = set;
        double d = 0.0;
        for (int i = 0; i < nArray.length; ++i) {
            double d2 = this.useWeights ? this.weights[nArray[i]] : 1.0;
            BinaryTree binaryTree2 = binaryTree;
            Double.valueOf((Double)binaryTree2.value + this.output[nArray[i]] * d2);
            binaryTree2.value = binaryTree2.value;
            d += d2;
        }
        BinaryTree binaryTree3 = binaryTree;
        binaryTree3.value = (Double)binaryTree3.value / d;
        return binaryTree;
    }

    public class ArithmeticMean
    implements Aggregator<Double> {
        double sum = 0.0;
        int count;

        @Override
        public void processLeaf(Double d) {
            this.sum += d.doubleValue();
            ++this.count;
        }

        @Override
        public Double getResult() {
            if (this.count == 0) {
                return Double.NaN;
            }
            return this.sum / (double)this.count;
        }
    }
}

