/*
 * Decompiled with CFR 0.152.
 */
package weka.classifiers.trees.lmt;

import weka.classifiers.Classifier;
import weka.classifiers.Evaluation;
import weka.classifiers.functions.SimpleLinearRegression;
import weka.core.Attribute;
import weka.core.Instance;
import weka.core.Instances;
import weka.core.Utils;
import weka.core.WeightedInstancesHandler;

public class LogisticBase
extends Classifier
implements WeightedInstancesHandler {
    static final long serialVersionUID = 168765678097825064L;
    protected Instances m_numericDataHeader;
    protected Instances m_numericData;
    protected Instances m_train;
    protected boolean m_useCrossValidation;
    protected boolean m_errorOnProbabilities;
    protected int m_fixedNumIterations;
    protected int m_heuristicStop = 50;
    protected int m_numRegressions = 0;
    protected int m_maxIterations;
    protected int m_numClasses;
    protected SimpleLinearRegression[][] m_regressions;
    protected static int m_numFoldsBoosting = 5;
    protected static final double Z_MAX = 3.0;
    private boolean m_useAIC = false;
    protected double m_numParameters = 0.0;
    protected double m_weightTrimBeta = 0.0;

    public LogisticBase() {
        this.m_fixedNumIterations = -1;
        this.m_useCrossValidation = true;
        this.m_errorOnProbabilities = false;
        this.m_maxIterations = 500;
        this.m_useAIC = false;
        this.m_numParameters = 0.0;
    }

    public LogisticBase(int n, boolean bl, boolean bl2) {
        this.m_fixedNumIterations = n;
        this.m_useCrossValidation = bl;
        this.m_errorOnProbabilities = bl2;
        this.m_maxIterations = 500;
        this.m_useAIC = false;
        this.m_numParameters = 0.0;
    }

    public void buildClassifier(Instances instances) throws Exception {
        this.m_train = new Instances(instances);
        this.m_numClasses = this.m_train.numClasses();
        this.m_regressions = this.initRegressions();
        this.m_numRegressions = 0;
        this.m_numericData = this.getNumericData(this.m_train);
        this.m_numericDataHeader = new Instances(this.m_numericData, 0);
        if (this.m_fixedNumIterations > 0) {
            this.performBoosting(this.m_fixedNumIterations);
        } else if (this.m_useAIC) {
            this.performBoostingInfCriterion();
        } else if (this.m_useCrossValidation) {
            this.performBoostingCV();
        } else {
            this.performBoosting();
        }
        this.m_regressions = this.selectRegressions(this.m_regressions);
    }

    protected void performBoostingCV() throws Exception {
        int n;
        int n2 = this.m_maxIterations;
        Instances instances = new Instances(this.m_train);
        instances.stratify(m_numFoldsBoosting);
        double[] dArray = new double[this.m_maxIterations + 1];
        for (n = 0; n < m_numFoldsBoosting; ++n) {
            Instances instances2 = instances.trainCV(m_numFoldsBoosting, n);
            Instances instances3 = instances.testCV(m_numFoldsBoosting, n);
            this.m_numRegressions = 0;
            this.m_regressions = this.initRegressions();
            int n3 = this.performBoosting(instances2, instances3, dArray, n2);
            if (n3 >= n2) continue;
            n2 = n3;
        }
        n = this.getBestIteration(dArray, n2);
        this.m_numRegressions = 0;
        this.performBoosting(n);
    }

    protected void performBoostingInfCriterion() throws Exception {
        boolean bl;
        double d = 0.0;
        double d2 = Double.MAX_VALUE;
        int n = 0;
        int n2 = 0;
        double d3 = Double.MAX_VALUE;
        double[][] dArray = this.getYs(this.m_train);
        double[][] dArray2 = this.getFs(this.m_numericData);
        double[][] dArray3 = this.getProbs(dArray2);
        boolean[][] blArray = new boolean[this.m_numClasses][this.m_numericDataHeader.numAttributes()];
        int n3 = 0;
        while (n3 < this.m_maxIterations && (bl = this.performIteration(n3, dArray, dArray2, dArray3, this.m_numericData))) {
            this.m_numRegressions = ++n3;
            double d4 = this.m_numParameters + (double)n3;
            d3 = 2.0 * this.negativeLogLikelihood(dArray, dArray3) + 2.0 * d4;
            if (n2 > this.m_heuristicStop) break;
            if (d3 < d2) {
                d2 = d3;
                n = n3;
                n2 = 0;
                continue;
            }
            ++n2;
        }
        this.m_numRegressions = 0;
        this.performBoosting(n);
    }

    protected int performBoosting(Instances instances, Instances instances2, double[] dArray, int n) throws Exception {
        boolean bl;
        Instances instances3 = this.getNumericData(instances);
        double[][] dArray2 = this.getYs(instances);
        double[][] dArray3 = this.getFs(instances3);
        double[][] dArray4 = this.getProbs(dArray3);
        int n2 = 0;
        int n3 = 0;
        double d = Double.MAX_VALUE;
        dArray[0] = this.m_errorOnProbabilities ? dArray[0] + this.getMeanAbsoluteError(instances2) : dArray[0] + this.getErrorRate(instances2);
        while (n2 < n && (bl = this.performIteration(n2, dArray2, dArray3, dArray4, instances3))) {
            this.m_numRegressions = ++n2;
            if (this.m_errorOnProbabilities) {
                int n4 = n2;
                dArray[n4] = dArray[n4] + this.getMeanAbsoluteError(instances2);
            } else {
                int n5 = n2;
                dArray[n5] = dArray[n5] + this.getErrorRate(instances2);
            }
            if (n3 > this.m_heuristicStop) break;
            if (dArray[n2] < d) {
                d = dArray[n2];
                n3 = 0;
                continue;
            }
            ++n3;
        }
        return n2;
    }

    protected void performBoosting(int n) throws Exception {
        boolean bl;
        int n2;
        double[][] dArray = this.getYs(this.m_train);
        double[][] dArray2 = this.getFs(this.m_numericData);
        double[][] dArray3 = this.getProbs(dArray2);
        for (n2 = 0; n2 < n && (bl = this.performIteration(n2, dArray, dArray2, dArray3, this.m_numericData)); ++n2) {
        }
        this.m_numRegressions = n2;
    }

    protected void performBoosting() throws Exception {
        boolean bl;
        double[][] dArray = this.getYs(this.m_train);
        double[][] dArray2 = this.getFs(this.m_numericData);
        double[][] dArray3 = this.getProbs(dArray2);
        int n = 0;
        double[] dArray4 = new double[this.m_maxIterations + 1];
        dArray4[0] = this.getErrorRate(this.m_train);
        int n2 = 0;
        double d = Double.MAX_VALUE;
        while (n < this.m_maxIterations && (bl = this.performIteration(n, dArray, dArray2, dArray3, this.m_numericData))) {
            this.m_numRegressions = ++n;
            dArray4[n] = this.getErrorRate(this.m_train);
            if (n2 > this.m_heuristicStop) break;
            if (dArray4[n] < d) {
                d = dArray4[n];
                n2 = 0;
                continue;
            }
            ++n2;
        }
        this.m_numRegressions = this.getBestIteration(dArray4, n);
    }

    protected double getErrorRate(Instances instances) throws Exception {
        Evaluation evaluation = new Evaluation(instances);
        evaluation.evaluateModel(this, instances);
        return evaluation.errorRate();
    }

    protected double getMeanAbsoluteError(Instances instances) throws Exception {
        Evaluation evaluation = new Evaluation(instances);
        evaluation.evaluateModel(this, instances);
        return evaluation.meanAbsoluteError();
    }

    protected int getBestIteration(double[] dArray, int n) {
        double d = dArray[0];
        int n2 = 0;
        for (int i = 1; i <= n; ++i) {
            if (!(dArray[i] < d)) continue;
            d = dArray[i];
            n2 = i;
        }
        return n2;
    }

    protected boolean performIteration(int n, double[][] dArray, double[][] dArray2, double[][] dArray3, Instances instances) throws Exception {
        double d;
        double[] dArray4;
        int n2;
        for (n2 = 0; n2 < this.m_numClasses; ++n2) {
            double d2;
            dArray4 = new double[instances.numInstances()];
            d = 0.0;
            Instances instances2 = new Instances(instances);
            for (int i = 0; i < instances.numInstances(); ++i) {
                d2 = dArray3[i][n2];
                double d3 = dArray[i][n2];
                double d4 = this.getZ(d3, d2);
                double d5 = (d3 - d2) / d4;
                Instance instance = instances2.instance(i);
                instance.setValue(instances2.classIndex(), d4);
                instance.setWeight(instance.weight() * d5);
                dArray4[i] = instance.weight();
                d += instance.weight();
            }
            Instances instances3 = new Instances(instances2);
            if (d > 0.0) {
                if (this.m_weightTrimBeta > 0.0) {
                    d2 = 0.0;
                    int[] nArray = new int[instances.numInstances()];
                    nArray = Utils.sort(dArray4);
                    instances3.delete();
                    for (int i = nArray.length - 1; i >= 0 && d2 < 1.0 - this.m_weightTrimBeta; d2 += dArray4[nArray[i]] / d, --i) {
                        instances3.add(instances2.instance(nArray[i]));
                    }
                }
                d = instances3.sumOfWeights();
                for (int i = 0; i < instances3.numInstances(); ++i) {
                    Instance instance = instances3.instance(i);
                    instance.setWeight(instance.weight() * (double)instances3.numInstances() / d);
                }
            }
            this.m_regressions[n2][n].buildClassifier(instances3);
            boolean bl = this.m_regressions[n2][n].foundUsefulAttribute();
            if (bl) continue;
            return false;
        }
        for (n2 = 0; n2 < dArray2.length; ++n2) {
            int n3;
            dArray4 = new double[this.m_numClasses];
            d = 0.0;
            for (n3 = 0; n3 < this.m_numClasses; ++n3) {
                dArray4[n3] = this.m_regressions[n3][n].classifyInstance(instances.instance(n2));
                d += dArray4[n3];
            }
            d /= (double)this.m_numClasses;
            for (n3 = 0; n3 < this.m_numClasses; ++n3) {
                double[] dArray5 = dArray2[n2];
                int n4 = n3;
                dArray5[n4] = dArray5[n4] + (dArray4[n3] - d) * (double)(this.m_numClasses - 1) / (double)this.m_numClasses;
            }
        }
        for (n2 = 0; n2 < dArray.length; ++n2) {
            dArray3[n2] = this.probs(dArray2[n2]);
        }
        return true;
    }

    protected SimpleLinearRegression[][] initRegressions() {
        SimpleLinearRegression[][] simpleLinearRegressionArray = new SimpleLinearRegression[this.m_numClasses][this.m_maxIterations];
        for (int i = 0; i < this.m_numClasses; ++i) {
            for (int j = 0; j < this.m_maxIterations; ++j) {
                simpleLinearRegressionArray[i][j] = new SimpleLinearRegression();
                simpleLinearRegressionArray[i][j].setSuppressErrorMessage(true);
            }
        }
        return simpleLinearRegressionArray;
    }

    protected Instances getNumericData(Instances instances) throws Exception {
        Instances instances2 = new Instances(instances);
        int n = instances2.classIndex();
        instances2.setClassIndex(-1);
        instances2.deleteAttributeAt(n);
        instances2.insertAttributeAt(new Attribute("'pseudo class'"), n);
        instances2.setClassIndex(n);
        return instances2;
    }

    protected SimpleLinearRegression[][] selectRegressions(SimpleLinearRegression[][] simpleLinearRegressionArray) {
        SimpleLinearRegression[][] simpleLinearRegressionArray2 = new SimpleLinearRegression[this.m_numClasses][this.m_numRegressions];
        for (int i = 0; i < this.m_numClasses; ++i) {
            for (int j = 0; j < this.m_numRegressions; ++j) {
                simpleLinearRegressionArray2[i][j] = simpleLinearRegressionArray[i][j];
            }
        }
        return simpleLinearRegressionArray2;
    }

    protected double getZ(double d, double d2) {
        double d3;
        if (d == 1.0) {
            d3 = 1.0 / d2;
            if (d3 > 3.0) {
                d3 = 3.0;
            }
        } else {
            d3 = -1.0 / (1.0 - d2);
            if (d3 < -3.0) {
                d3 = -3.0;
            }
        }
        return d3;
    }

    protected double[][] getZs(double[][] dArray, double[][] dArray2) {
        double[][] dArray3 = new double[dArray.length][this.m_numClasses];
        for (int i = 0; i < this.m_numClasses; ++i) {
            for (int j = 0; j < dArray.length; ++j) {
                dArray3[j][i] = this.getZ(dArray2[j][i], dArray[j][i]);
            }
        }
        return dArray3;
    }

    protected double[][] getWs(double[][] dArray, double[][] dArray2) {
        double[][] dArray3 = new double[dArray.length][this.m_numClasses];
        for (int i = 0; i < this.m_numClasses; ++i) {
            for (int j = 0; j < dArray.length; ++j) {
                double d = this.getZ(dArray2[j][i], dArray[j][i]);
                dArray3[j][i] = (dArray2[j][i] - dArray[j][i]) / d;
            }
        }
        return dArray3;
    }

    protected double[] probs(double[] dArray) {
        double d = -1.7976931348623157E308;
        for (int i = 0; i < dArray.length; ++i) {
            if (!(dArray[i] > d)) continue;
            d = dArray[i];
        }
        double d2 = 0.0;
        double[] dArray2 = new double[dArray.length];
        for (int i = 0; i < dArray.length; ++i) {
            dArray2[i] = Math.exp(dArray[i] - d);
            d2 += dArray2[i];
        }
        Utils.normalize(dArray2, d2);
        return dArray2;
    }

    protected double[][] getYs(Instances instances) {
        double[][] dArray = new double[instances.numInstances()][this.m_numClasses];
        for (int i = 0; i < this.m_numClasses; ++i) {
            for (int j = 0; j < instances.numInstances(); ++j) {
                dArray[j][i] = instances.instance(j).classValue() == (double)i ? 1.0 : 0.0;
            }
        }
        return dArray;
    }

    protected double[] getFs(Instance instance) throws Exception {
        double[] dArray = new double[this.m_numClasses];
        double[] dArray2 = new double[this.m_numClasses];
        for (int i = 0; i < this.m_numRegressions; ++i) {
            int n;
            double d = 0.0;
            for (n = 0; n < this.m_numClasses; ++n) {
                dArray[n] = this.m_regressions[n][i].classifyInstance(instance);
                d += dArray[n];
            }
            d /= (double)this.m_numClasses;
            for (n = 0; n < this.m_numClasses; ++n) {
                int n2 = n;
                dArray2[n2] = dArray2[n2] + (dArray[n] - d) * (double)(this.m_numClasses - 1) / (double)this.m_numClasses;
            }
        }
        return dArray2;
    }

    protected double[][] getFs(Instances instances) throws Exception {
        double[][] dArrayArray = new double[instances.numInstances()][];
        for (int i = 0; i < instances.numInstances(); ++i) {
            dArrayArray[i] = this.getFs(instances.instance(i));
        }
        return dArrayArray;
    }

    protected double[][] getProbs(double[][] dArray) {
        int n = dArray.length;
        double[][] dArrayArray = new double[n][];
        for (int i = 0; i < n; ++i) {
            dArrayArray[i] = this.probs(dArray[i]);
        }
        return dArrayArray;
    }

    protected double negativeLogLikelihood(double[][] dArray, double[][] dArray2) {
        double d = 0.0;
        for (int i = 0; i < dArray.length; ++i) {
            for (int j = 0; j < this.m_numClasses; ++j) {
                if (dArray[i][j] != 1.0) continue;
                d -= Math.log(dArray2[i][j]);
            }
        }
        return d;
    }

    public int[][] getUsedAttributes() {
        int[][] nArrayArray = new int[this.m_numClasses][];
        double[][] dArray = this.getCoefficients();
        for (int i = 0; i < this.m_numClasses; ++i) {
            int n;
            boolean[] blArray = new boolean[this.m_numericDataHeader.numAttributes()];
            for (n = 0; n < blArray.length; ++n) {
                if (Utils.eq(dArray[i][n + 1], 0.0)) continue;
                blArray[n] = true;
            }
            n = 0;
            for (int j = 0; j < this.m_numericDataHeader.numAttributes(); ++j) {
                if (!blArray[j]) continue;
                ++n;
            }
            int[] nArray = new int[n];
            int n2 = 0;
            for (int j = 0; j < this.m_numericDataHeader.numAttributes(); ++j) {
                if (!blArray[j]) continue;
                nArray[n2] = j;
                ++n2;
            }
            nArrayArray[i] = nArray;
        }
        return nArrayArray;
    }

    public int getNumRegressions() {
        return this.m_numRegressions;
    }

    public double getWeightTrimBeta() {
        return this.m_weightTrimBeta;
    }

    public boolean getUseAIC() {
        return this.m_useAIC;
    }

    public void setMaxIterations(int n) {
        this.m_maxIterations = n;
    }

    public void setHeuristicStop(int n) {
        this.m_heuristicStop = n;
    }

    public void setWeightTrimBeta(double d) {
        this.m_weightTrimBeta = d;
    }

    public void setUseAIC(boolean bl) {
        this.m_useAIC = bl;
    }

    public int getMaxIterations() {
        return this.m_maxIterations;
    }

    protected double[][] getCoefficients() {
        double[][] dArray = new double[this.m_numClasses][this.m_numericDataHeader.numAttributes() + 1];
        for (int i = 0; i < this.m_numClasses; ++i) {
            for (int j = 0; j < this.m_numRegressions; ++j) {
                double d = this.m_regressions[i][j].getSlope();
                double d2 = this.m_regressions[i][j].getIntercept();
                int n = this.m_regressions[i][j].getAttributeIndex();
                double[] dArray2 = dArray[i];
                dArray2[0] = dArray2[0] + d2;
                double[] dArray3 = dArray[i];
                int n2 = n + 1;
                dArray3[n2] = dArray3[n2] + d;
            }
        }
        return dArray;
    }

    public double percentAttributesUsed() {
        boolean[] blArray = new boolean[this.m_numericDataHeader.numAttributes()];
        double[][] dArray = this.getCoefficients();
        for (int i = 0; i < this.m_numClasses; ++i) {
            for (int j = 1; j < this.m_numericDataHeader.numAttributes() + 1; ++j) {
                if (Utils.eq(dArray[i][j], 0.0)) continue;
                blArray[j - 1] = true;
            }
        }
        double d = 0.0;
        for (int i = 0; i < blArray.length; ++i) {
            if (!blArray[i]) continue;
            d += 1.0;
        }
        return d / (double)(this.m_numericDataHeader.numAttributes() - 1) * 100.0;
    }

    public String toString() {
        StringBuffer stringBuffer = new StringBuffer();
        int[][] nArray = this.getUsedAttributes();
        double[][] dArray = this.getCoefficients();
        for (int i = 0; i < this.m_numClasses; ++i) {
            stringBuffer.append("\nClass " + i + " :\n");
            stringBuffer.append(Utils.doubleToString(dArray[i][0], 4, 2) + " + \n");
            for (int j = 0; j < nArray[i].length; ++j) {
                stringBuffer.append("[" + this.m_numericDataHeader.attribute(nArray[i][j]).name() + "]");
                stringBuffer.append(" * " + Utils.doubleToString(dArray[i][nArray[i][j] + 1], 4, 2));
                if (j != nArray[i].length - 1) {
                    stringBuffer.append(" +");
                }
                stringBuffer.append("\n");
            }
        }
        return new String(stringBuffer);
    }

    public double[] distributionForInstance(Instance instance) throws Exception {
        instance = (Instance)instance.copy();
        instance.setDataset(this.m_numericDataHeader);
        return this.probs(this.getFs(instance));
    }

    public void cleanup() {
        this.m_train = new Instances(this.m_train, 0);
        this.m_numericData = null;
    }
}

