/*
 * Decompiled with CFR 0.152.
 */
package weka.classifiers.meta;

import java.io.Serializable;
import java.io.StreamTokenizer;
import java.io.StringReader;
import java.util.Enumeration;
import java.util.Random;
import java.util.Vector;
import weka.classifiers.Evaluation;
import weka.classifiers.RandomizableSingleClassifierEnhancer;
import weka.core.Capabilities;
import weka.core.Drawable;
import weka.core.FastVector;
import weka.core.Instance;
import weka.core.Instances;
import weka.core.Option;
import weka.core.OptionHandler;
import weka.core.Summarizable;
import weka.core.TechnicalInformation;
import weka.core.TechnicalInformationHandler;
import weka.core.Utils;

public class CVParameterSelection
extends RandomizableSingleClassifierEnhancer
implements Drawable,
Summarizable,
TechnicalInformationHandler {
    static final long serialVersionUID = -6529603380876641265L;
    protected String[] m_ClassifierOptions;
    protected String[] m_BestClassifierOptions;
    protected String[] m_InitOptions;
    protected double m_BestPerformance;
    protected FastVector m_CVParams = new FastVector();
    protected int m_NumAttributes;
    protected int m_TrainFoldSize;
    protected int m_NumFolds = 10;

    protected String[] createOptions() {
        String[] stringArray = new String[this.m_ClassifierOptions.length + 2 * this.m_CVParams.size()];
        int n = 0;
        int n2 = stringArray.length;
        for (int i = 0; i < this.m_CVParams.size(); ++i) {
            CVParameter cVParameter = (CVParameter)this.m_CVParams.elementAt(i);
            double d = cVParameter.m_ParamValue;
            if (cVParameter.m_RoundParam) {
                d = (int)(d + 0.5);
            }
            if (cVParameter.m_AddAtEnd) {
                stringArray[--n2] = "" + Utils.doubleToString(d, 4);
                stringArray[--n2] = "-" + cVParameter.m_ParamChar;
                continue;
            }
            stringArray[n++] = "-" + cVParameter.m_ParamChar;
            stringArray[n++] = "" + Utils.doubleToString(d, 4);
        }
        System.arraycopy(this.m_ClassifierOptions, 0, stringArray, n, this.m_ClassifierOptions.length);
        return stringArray;
    }

    protected void findParamsByCrossValidation(int n, Instances instances, Random random) throws Exception {
        if (n < this.m_CVParams.size()) {
            double d;
            CVParameter cVParameter = (CVParameter)this.m_CVParams.elementAt(n);
            switch ((int)(cVParameter.m_Lower - cVParameter.m_Upper + 0.5)) {
                case 1: {
                    d = this.m_NumAttributes;
                    break;
                }
                case 2: {
                    d = this.m_TrainFoldSize;
                    break;
                }
                default: {
                    d = cVParameter.m_Upper;
                }
            }
            double d2 = (d - cVParameter.m_Lower) / (cVParameter.m_Steps - 1.0);
            cVParameter.m_ParamValue = cVParameter.m_Lower;
            while (cVParameter.m_ParamValue <= d) {
                this.findParamsByCrossValidation(n + 1, instances, random);
                cVParameter.m_ParamValue += d2;
            }
        } else {
            int n2;
            Evaluation evaluation = new Evaluation(instances);
            String[] stringArray = this.createOptions();
            if (this.m_Debug) {
                System.err.print("Setting options for " + this.m_Classifier.getClass().getName() + ":");
                for (n2 = 0; n2 < stringArray.length; ++n2) {
                    System.err.print(" " + stringArray[n2]);
                }
                System.err.println("");
            }
            this.m_Classifier.setOptions(stringArray);
            for (n2 = 0; n2 < this.m_NumFolds; ++n2) {
                Instances instances2 = instances.trainCV(this.m_NumFolds, n2, new Random(1L));
                Instances instances3 = instances.testCV(this.m_NumFolds, n2);
                this.m_Classifier.buildClassifier(instances2);
                evaluation.setPriors(instances2);
                evaluation.evaluateModel(this.m_Classifier, instances3);
            }
            double d = evaluation.errorRate();
            if (this.m_Debug) {
                System.err.println("Cross-validated error rate: " + Utils.doubleToString(d, 6, 4));
            }
            if (this.m_BestPerformance == -99.0 || d < this.m_BestPerformance) {
                this.m_BestPerformance = d;
                this.m_BestClassifierOptions = this.createOptions();
            }
        }
    }

    public String globalInfo() {
        return "Class for performing parameter selection by cross-validation for any classifier.\n\nFor more information, see:\n\n" + this.getTechnicalInformation().toString();
    }

    public TechnicalInformation getTechnicalInformation() {
        TechnicalInformation technicalInformation = new TechnicalInformation(TechnicalInformation.Type.PHDTHESIS);
        technicalInformation.setValue(TechnicalInformation.Field.AUTHOR, "R. Kohavi");
        technicalInformation.setValue(TechnicalInformation.Field.YEAR, "1995");
        technicalInformation.setValue(TechnicalInformation.Field.TITLE, "Wrappers for Performance Enhancement and Oblivious Decision Graphs");
        technicalInformation.setValue(TechnicalInformation.Field.SCHOOL, "Stanford University");
        technicalInformation.setValue(TechnicalInformation.Field.ADDRESS, "Department of Computer Science, Stanford University");
        return technicalInformation;
    }

    public Enumeration listOptions() {
        Vector<Option> vector = new Vector<Option>(2);
        vector.addElement(new Option("\tNumber of folds used for cross validation (default 10).", "X", 1, "-X <number of folds>"));
        vector.addElement(new Option("\tClassifier parameter options.\n\teg: \"N 1 5 10\" Sets an optimisation parameter for the\n\tclassifier with name -N, with lower bound 1, upper bound\n\t5, and 10 optimisation steps. The upper bound may be the\n\tcharacter 'A' or 'I' to substitute the number of\n\tattributes or instances in the training data,\n\trespectively. This parameter may be supplied more than\n\tonce to optimise over several classifier options\n\tsimultaneously.", "P", 1, "-P <classifier parameter>"));
        Enumeration enumeration = super.listOptions();
        while (enumeration.hasMoreElements()) {
            vector.addElement((Option)enumeration.nextElement());
        }
        return vector.elements();
    }

    public void setOptions(String[] stringArray) throws Exception {
        String string;
        String string2 = Utils.getOption('X', stringArray);
        if (string2.length() != 0) {
            this.setNumFolds(Integer.parseInt(string2));
        } else {
            this.setNumFolds(10);
        }
        this.m_CVParams = new FastVector();
        do {
            if ((string = Utils.getOption('P', stringArray)).length() == 0) continue;
            this.addCVParameter(string);
        } while (string.length() != 0);
        super.setOptions(stringArray);
    }

    public String[] getOptions() {
        String[] stringArray;
        if (this.m_InitOptions != null) {
            try {
                this.m_Classifier.setOptions((String[])this.m_InitOptions.clone());
                stringArray = super.getOptions();
                this.m_Classifier.setOptions((String[])this.m_BestClassifierOptions.clone());
            }
            catch (Exception exception) {
                throw new RuntimeException("CVParameterSelection: could not set options in getOptions().");
            }
        } else {
            stringArray = super.getOptions();
        }
        String[] stringArray2 = new String[stringArray.length + this.m_CVParams.size() * 2 + 2];
        int n = 0;
        for (int i = 0; i < this.m_CVParams.size(); ++i) {
            stringArray2[n++] = "-P";
            stringArray2[n++] = "" + this.getCVParameter(i);
        }
        stringArray2[n++] = "-X";
        stringArray2[n++] = "" + this.getNumFolds();
        System.arraycopy(stringArray, 0, stringArray2, n, stringArray.length);
        return stringArray2;
    }

    public Capabilities getCapabilities() {
        Capabilities capabilities = super.getCapabilities();
        capabilities.setMinimumNumberInstances(this.m_NumFolds);
        return capabilities;
    }

    public void buildClassifier(Instances instances) throws Exception {
        this.getCapabilities().testWithFail(instances);
        Instances instances2 = new Instances(instances);
        instances2.deleteWithMissingClass();
        if (!(this.m_Classifier instanceof OptionHandler)) {
            throw new IllegalArgumentException("Base classifier should be OptionHandler.");
        }
        this.m_InitOptions = this.m_Classifier.getOptions();
        this.m_BestPerformance = -99.0;
        this.m_NumAttributes = instances2.numAttributes();
        Random random = new Random(this.m_Seed);
        instances2.randomize(random);
        this.m_TrainFoldSize = instances2.trainCV(this.m_NumFolds, 0).numInstances();
        if (this.m_CVParams.size() == 0) {
            this.m_Classifier.buildClassifier(instances2);
            this.m_BestClassifierOptions = this.m_InitOptions;
            return;
        }
        if (instances2.classAttribute().isNominal()) {
            instances2.stratify(this.m_NumFolds);
        }
        this.m_BestClassifierOptions = null;
        this.m_ClassifierOptions = this.m_Classifier.getOptions();
        for (int i = 0; i < this.m_CVParams.size(); ++i) {
            Utils.getOption(((CVParameter)this.m_CVParams.elementAt(i)).m_ParamChar, this.m_ClassifierOptions);
        }
        this.findParamsByCrossValidation(0, instances2, random);
        String[] stringArray = (String[])this.m_BestClassifierOptions.clone();
        this.m_Classifier.setOptions(stringArray);
        this.m_Classifier.buildClassifier(instances2);
    }

    public double[] distributionForInstance(Instance instance) throws Exception {
        return this.m_Classifier.distributionForInstance(instance);
    }

    public void addCVParameter(String string) throws Exception {
        CVParameter cVParameter = new CVParameter(string);
        this.m_CVParams.addElement(cVParameter);
    }

    public String getCVParameter(int n) {
        if (this.m_CVParams.size() <= n) {
            return "";
        }
        return ((CVParameter)this.m_CVParams.elementAt(n)).toString();
    }

    public String CVParametersTipText() {
        return "Sets the scheme parameters which are to be set by cross-validation.\nThe format for each string should be:\nparam_char lower_bound upper_bound number_of_steps\neg to search a parameter -P from 1 to 10 by increments of 1:\n    \"P 1 10 11\" ";
    }

    public Object[] getCVParameters() {
        Object[] objectArray = this.m_CVParams.toArray();
        Object[] objectArray2 = new String[objectArray.length];
        for (int i = 0; i < objectArray.length; ++i) {
            objectArray2[i] = objectArray[i].toString();
        }
        return objectArray2;
    }

    public void setCVParameters(Object[] objectArray) throws Exception {
        FastVector fastVector = this.m_CVParams;
        this.m_CVParams = new FastVector();
        for (int i = 0; i < objectArray.length; ++i) {
            try {
                this.addCVParameter((String)objectArray[i]);
                continue;
            }
            catch (Exception exception) {
                this.m_CVParams = fastVector;
                throw exception;
            }
        }
    }

    public String numFoldsTipText() {
        return "Get the number of folds used for cross-validation.";
    }

    public int getNumFolds() {
        return this.m_NumFolds;
    }

    public void setNumFolds(int n) throws Exception {
        if (n < 0) {
            throw new IllegalArgumentException("Stacking: Number of cross-validation folds must be positive.");
        }
        this.m_NumFolds = n;
    }

    public int graphType() {
        if (this.m_Classifier instanceof Drawable) {
            return ((Drawable)((Object)this.m_Classifier)).graphType();
        }
        return 0;
    }

    public String graph() throws Exception {
        if (this.m_Classifier instanceof Drawable) {
            return ((Drawable)((Object)this.m_Classifier)).graph();
        }
        throw new Exception("Classifier: " + this.m_Classifier.getClass().getName() + " " + Utils.joinOptions(this.m_BestClassifierOptions) + " cannot be graphed");
    }

    public String toString() {
        if (this.m_InitOptions == null) {
            return "CVParameterSelection: No model built yet.";
        }
        String string = "Cross-validated Parameter selection.\nClassifier: " + this.m_Classifier.getClass().getName() + "\n";
        try {
            for (int i = 0; i < this.m_CVParams.size(); ++i) {
                CVParameter cVParameter = (CVParameter)this.m_CVParams.elementAt(i);
                string = string + "Cross-validation Parameter: '-" + cVParameter.m_ParamChar + "'" + " ranged from " + cVParameter.m_Lower + " to ";
                switch ((int)(cVParameter.m_Lower - cVParameter.m_Upper + 0.5)) {
                    case 1: {
                        string = string + this.m_NumAttributes;
                        break;
                    }
                    case 2: {
                        string = string + this.m_TrainFoldSize;
                        break;
                    }
                    default: {
                        string = string + cVParameter.m_Upper;
                    }
                }
                string = string + " with " + cVParameter.m_Steps + " steps\n";
            }
        }
        catch (Exception exception) {
            string = string + exception.getMessage();
        }
        string = string + "Classifier Options: " + Utils.joinOptions(this.m_BestClassifierOptions) + "\n\n" + this.m_Classifier.toString();
        return string;
    }

    public String toSummaryString() {
        String string = "Selected values: " + Utils.joinOptions(this.m_BestClassifierOptions);
        return string + '\n';
    }

    public static void main(String[] stringArray) {
        CVParameterSelection.runClassifier(new CVParameterSelection(), stringArray);
    }

    protected class CVParameter
    implements Serializable {
        static final long serialVersionUID = -4668812017709421953L;
        private char m_ParamChar;
        private double m_Lower;
        private double m_Upper;
        private double m_Steps;
        private double m_ParamValue;
        private boolean m_AddAtEnd;
        private boolean m_RoundParam;

        /*
         * Enabled force condition propagation
         * Lifted jumps to return sites
         */
        public CVParameter(String string) throws Exception {
            StreamTokenizer streamTokenizer = new StreamTokenizer(new StringReader(string));
            if (streamTokenizer.nextToken() != -3) {
                throw new Exception("CVParameter " + string + ": Character parameter identifier expected");
            }
            this.m_ParamChar = streamTokenizer.sval.charAt(0);
            if (streamTokenizer.nextToken() != -2) {
                throw new Exception("CVParameter " + string + ": Numeric lower bound expected");
            }
            this.m_Lower = streamTokenizer.nval;
            if (streamTokenizer.nextToken() == -2) {
                this.m_Upper = streamTokenizer.nval;
                if (this.m_Upper < this.m_Lower) {
                    throw new Exception("CVParameter " + string + ": Upper bound is less than lower bound");
                }
            } else {
                if (streamTokenizer.ttype != -3) throw new Exception("CVParameter " + string + ": Upper bound must be numeric, or 'A' or 'N'");
                if (streamTokenizer.sval.toUpperCase().charAt(0) == 'A') {
                    this.m_Upper = this.m_Lower - 1.0;
                } else {
                    if (streamTokenizer.sval.toUpperCase().charAt(0) != 'I') throw new Exception("CVParameter " + string + ": Upper bound must be numeric, or 'A' or 'N'");
                    this.m_Upper = this.m_Lower - 2.0;
                }
            }
            if (streamTokenizer.nextToken() != -2) {
                throw new Exception("CVParameter " + string + ": Numeric number of steps expected");
            }
            this.m_Steps = streamTokenizer.nval;
            if (streamTokenizer.nextToken() != -3 || streamTokenizer.sval.toUpperCase().charAt(0) != 'R') return;
            this.m_RoundParam = true;
        }

        public String toString() {
            String string = this.m_ParamChar + " " + this.m_Lower + " ";
            switch ((int)(this.m_Lower - this.m_Upper + 0.5)) {
                case 1: {
                    string = string + "A";
                    break;
                }
                case 2: {
                    string = string + "I";
                    break;
                }
                default: {
                    string = string + this.m_Upper;
                }
            }
            string = string + " " + this.m_Steps;
            if (this.m_RoundParam) {
                string = string + " R";
            }
            return string;
        }
    }
}

