/*
 * Decompiled with CFR 0.152.
 */
package weka.classifiers.functions;

import java.util.ArrayList;
import java.util.Enumeration;
import java.util.Random;
import java.util.Vector;
import weka.classifiers.RandomizableClassifier;
import weka.classifiers.UpdateableClassifier;
import weka.core.Aggregateable;
import weka.core.Capabilities;
import weka.core.Instance;
import weka.core.Instances;
import weka.core.Option;
import weka.core.OptionHandler;
import weka.core.RevisionUtils;
import weka.core.SelectedTag;
import weka.core.Tag;
import weka.core.Utils;
import weka.filters.Filter;
import weka.filters.unsupervised.attribute.NominalToBinary;
import weka.filters.unsupervised.attribute.Normalize;
import weka.filters.unsupervised.attribute.ReplaceMissingValues;

/*
 * This class specifies class file version 49.0 but uses Java 6 signatures.  Assumed Java 6.
 */
public class SGD
extends RandomizableClassifier
implements UpdateableClassifier,
OptionHandler,
Aggregateable<SGD> {
    private static final long serialVersionUID = -3732968666673530290L;
    protected ReplaceMissingValues m_replaceMissing;
    protected Filter m_nominalToBinary;
    protected Normalize m_normalize;
    protected double m_lambda = 1.0E-4;
    protected double m_learningRate = 0.01;
    protected double[] m_weights;
    protected double m_epsilon = 0.001;
    protected double m_t;
    protected double m_numInstances;
    protected int m_epochs = 500;
    protected boolean m_dontNormalize = false;
    protected boolean m_dontReplaceMissing = false;
    protected Instances m_data;
    public static final int HINGE = 0;
    public static final int LOGLOSS = 1;
    public static final int SQUAREDLOSS = 2;
    public static final int EPSILON_INSENSITIVE = 3;
    public static final int HUBER = 4;
    protected int m_loss = 0;
    public static final Tag[] TAGS_SELECTION = new Tag[]{new Tag(0, "Hinge loss (SVM)"), new Tag(1, "Log loss (logistic regression)"), new Tag(2, "Squared loss (regression)"), new Tag(3, "Epsilon-insensitive loss (SVM regression)"), new Tag(4, "Huber loss (robust regression)")};
    protected int m_numModels = 0;

    @Override
    public Capabilities getCapabilities() {
        Capabilities result = super.getCapabilities();
        result.disableAll();
        result.enable(Capabilities.Capability.NOMINAL_ATTRIBUTES);
        result.enable(Capabilities.Capability.NUMERIC_ATTRIBUTES);
        result.enable(Capabilities.Capability.MISSING_VALUES);
        if (this.m_loss == 2 || this.m_loss == 3 || this.m_loss == 4) {
            result.enable(Capabilities.Capability.NUMERIC_CLASS);
        } else {
            result.enable(Capabilities.Capability.BINARY_CLASS);
        }
        result.enable(Capabilities.Capability.MISSING_CLASS_VALUES);
        result.setMinimumNumberInstances(0);
        return result;
    }

    public String epsilonTipText() {
        return "The epsilon threshold for epsilon insensitive and Huber loss. An error with absolute value less that this threshold has loss of 0 for epsilon insensitive loss. For Huber loss this is the boundary between the quadratic and linear parts of the loss function.";
    }

    public void setEpsilon(double e) {
        this.m_epsilon = e;
    }

    public double getEpsilon() {
        return this.m_epsilon;
    }

    public String lambdaTipText() {
        return "The regularization constant. (default = 0.0001)";
    }

    public void setLambda(double lambda) {
        this.m_lambda = lambda;
    }

    public double getLambda() {
        return this.m_lambda;
    }

    public void setLearningRate(double lr) {
        this.m_learningRate = lr;
    }

    public double getLearningRate() {
        return this.m_learningRate;
    }

    public String learningRateTipText() {
        return "The learning rate. If normalization is turned off (as it is automatically for streaming data), thenthe default learning rate will need to be reduced (try 0.0001).";
    }

    public String epochsTipText() {
        return "The number of epochs to perform (batch learning). The total number of iterations is epochs * num instances.";
    }

    public void setEpochs(int e) {
        this.m_epochs = e;
    }

    public int getEpochs() {
        return this.m_epochs;
    }

    public void setDontNormalize(boolean m) {
        this.m_dontNormalize = m;
    }

    public boolean getDontNormalize() {
        return this.m_dontNormalize;
    }

    public String dontNormalizeTipText() {
        return "Turn normalization off";
    }

    public void setDontReplaceMissing(boolean m) {
        this.m_dontReplaceMissing = m;
    }

    public boolean getDontReplaceMissing() {
        return this.m_dontReplaceMissing;
    }

    public String dontReplaceMissingTipText() {
        return "Turn off global replacement of missing values";
    }

    public void setLossFunction(SelectedTag function) {
        if (function.getTags() == TAGS_SELECTION) {
            this.m_loss = function.getSelectedTag().getID();
        }
    }

    public SelectedTag getLossFunction() {
        return new SelectedTag(this.m_loss, TAGS_SELECTION);
    }

    public String lossFunctionTipText() {
        return "The loss function to use. Hinge loss (SVM), log loss (logistic regression) or squared loss (regression).";
    }

    @Override
    public Enumeration<Option> listOptions() {
        Vector<Option> newVector = new Vector<Option>();
        newVector.add(new Option("\tSet the loss function to minimize.\n\t0 = hinge loss (SVM), 1 = log loss (logistic regression),\n\t2 = squared loss (regression), 3 = epsilon insensitive loss (regression),\n\t4 = Huber loss (regression).\n\t(default = 0)", "F", 1, "-F"));
        newVector.add(new Option("\tThe learning rate. If normalization is\n\tturned off (as it is automatically for streaming data), then the\n\tdefault learning rate will need to be reduced (try 0.0001).\n\t(default = 0.01).", "L", 1, "-L"));
        newVector.add(new Option("\tThe lambda regularization constant (default = 0.0001)", "R", 1, "-R <double>"));
        newVector.add(new Option("\tThe number of epochs to perform (batch learning only, default = 500)", "E", 1, "-E <integer>"));
        newVector.add(new Option("\tThe epsilon threshold (epsilon-insenstive and Huber loss only, default = 1e-3)", "C", 1, "-C <double>"));
        newVector.add(new Option("\tDon't normalize the data", "N", 0, "-N"));
        newVector.add(new Option("\tDon't replace missing values", "M", 0, "-M"));
        return newVector.elements();
    }

    @Override
    public void setOptions(String[] options) throws Exception {
        String epsilonString;
        String epochsString;
        String learningRateString;
        String lambdaString;
        this.reset();
        super.setOptions(options);
        String lossString = Utils.getOption('F', options);
        if (lossString.length() != 0) {
            this.setLossFunction(new SelectedTag(Integer.parseInt(lossString), TAGS_SELECTION));
        }
        if ((lambdaString = Utils.getOption('R', options)).length() > 0) {
            this.setLambda(Double.parseDouble(lambdaString));
        }
        if ((learningRateString = Utils.getOption('L', options)).length() > 0) {
            this.setLearningRate(Double.parseDouble(learningRateString));
        }
        if ((epochsString = Utils.getOption("E", options)).length() > 0) {
            this.setEpochs(Integer.parseInt(epochsString));
        }
        if ((epsilonString = Utils.getOption("C", options)).length() > 0) {
            this.setEpsilon(Double.parseDouble(epsilonString));
        }
        this.setDontNormalize(Utils.getFlag("N", options));
        this.setDontReplaceMissing(Utils.getFlag('M', options));
    }

    @Override
    public String[] getOptions() {
        ArrayList<String> options = new ArrayList<String>();
        options.add("-F");
        options.add("" + this.getLossFunction().getSelectedTag().getID());
        options.add("-L");
        options.add("" + this.getLearningRate());
        options.add("-R");
        options.add("" + this.getLambda());
        options.add("-E");
        options.add("" + this.getEpochs());
        options.add("-C");
        options.add("" + this.getEpsilon());
        if (this.getDontNormalize()) {
            options.add("-N");
        }
        if (this.getDontReplaceMissing()) {
            options.add("-M");
        }
        return options.toArray(new String[1]);
    }

    public String globalInfo() {
        return "Implements stochastic gradient descent for learning various linear models (binary class SVM, binary class logistic regression, squared loss, Huber loss and epsilon-insensitive loss linear regression). Globally replaces all missing values and transforms nominal attributes into binary ones. It also normalizes all attributes, so the coefficients in the output are based on the normalized data.\nFor numeric class attributes, the squared, Huber or epsilon-insensitve loss function must be used. Epsilon-insensitive and Huber loss may require a much higher learning rate.";
    }

    public void reset() {
        this.m_t = 1.0;
        this.m_weights = null;
    }

    @Override
    public void buildClassifier(Instances data) throws Exception {
        this.reset();
        this.getCapabilities().testWithFail(data);
        data = new Instances(data);
        data.deleteWithMissingClass();
        if (data.numInstances() > 0 && !this.m_dontReplaceMissing) {
            this.m_replaceMissing = new ReplaceMissingValues();
            this.m_replaceMissing.setInputFormat(data);
            data = Filter.useFilter(data, this.m_replaceMissing);
        }
        boolean onlyNumeric = true;
        for (int i = 0; i < data.numAttributes(); ++i) {
            if (i == data.classIndex() || data.attribute(i).isNumeric()) continue;
            onlyNumeric = false;
            break;
        }
        if (!onlyNumeric) {
            this.m_nominalToBinary = data.numInstances() > 0 ? new weka.filters.supervised.attribute.NominalToBinary() : new NominalToBinary();
            this.m_nominalToBinary.setInputFormat(data);
            data = Filter.useFilter(data, this.m_nominalToBinary);
        }
        if (!this.m_dontNormalize && data.numInstances() > 0) {
            this.m_normalize = new Normalize();
            this.m_normalize.setInputFormat(data);
            data = Filter.useFilter(data, this.m_normalize);
        }
        this.m_numInstances = data.numInstances();
        this.m_weights = new double[data.numAttributes() + 1];
        this.m_data = new Instances(data, 0);
        if (data.numInstances() > 0) {
            data.randomize(new Random(this.getSeed()));
            this.train(data);
        }
    }

    protected double dloss(double z) {
        if (this.m_loss == 0) {
            return z < 1.0 ? 1.0 : 0.0;
        }
        if (this.m_loss == 1) {
            if (z < 0.0) {
                return 1.0 / (Math.exp(z) + 1.0);
            }
            double t = Math.exp(-z);
            return t / (t + 1.0);
        }
        if (this.m_loss == 3) {
            if (z > this.m_epsilon) {
                return 1.0;
            }
            if (-z > this.m_epsilon) {
                return -1.0;
            }
            return 0.0;
        }
        if (this.m_loss == 4) {
            if (Math.abs(z) <= this.m_epsilon) {
                return z;
            }
            if (z > 0.0) {
                return this.m_epsilon;
            }
            return -this.m_epsilon;
        }
        return z;
    }

    private void train(Instances data) throws Exception {
        for (int e = 0; e < this.m_epochs; ++e) {
            for (int i = 0; i < data.numInstances(); ++i) {
                this.updateClassifier(data.instance(i), false);
            }
        }
    }

    protected static double dotProd(Instance inst1, double[] weights, int classIndex) {
        double result = 0.0;
        int n1 = inst1.numValues();
        int n2 = weights.length - 1;
        int p1 = 0;
        int p2 = 0;
        while (p1 < n1 && p2 < n2) {
            int ind2;
            int ind1 = inst1.index(p1);
            if (ind1 == (ind2 = p2++)) {
                if (ind1 != classIndex && !inst1.isMissingSparse(p1)) {
                    result += inst1.valueSparse(p1) * weights[p2];
                }
                ++p1;
                ++p2;
                continue;
            }
            if (ind1 > ind2) continue;
            ++p1;
        }
        return result;
    }

    protected void updateClassifier(Instance instance, boolean filter) throws Exception {
        if (!instance.classIsMissing()) {
            double z;
            double y;
            if (filter) {
                if (this.m_replaceMissing != null) {
                    this.m_replaceMissing.input(instance);
                    instance = this.m_replaceMissing.output();
                }
                if (this.m_nominalToBinary != null) {
                    this.m_nominalToBinary.input(instance);
                    instance = this.m_nominalToBinary.output();
                }
                if (this.m_normalize != null) {
                    this.m_normalize.input(instance);
                    instance = this.m_normalize.output();
                }
            }
            double wx = SGD.dotProd(instance, this.m_weights, instance.classIndex());
            if (instance.classAttribute().isNominal()) {
                y = instance.classValue() == 0.0 ? -1.0 : 1.0;
                z = y * (wx + this.m_weights[this.m_weights.length - 1]);
            } else {
                y = instance.classValue();
                z = y - (wx + this.m_weights[this.m_weights.length - 1]);
                y = 1.0;
            }
            double multiplier = 1.0;
            multiplier = this.m_numInstances == 0.0 ? 1.0 - this.m_learningRate * this.m_lambda / this.m_t : 1.0 - this.m_learningRate * this.m_lambda / this.m_numInstances;
            int i = 0;
            while (i < this.m_weights.length - 1) {
                int n = i++;
                this.m_weights[n] = this.m_weights[n] * multiplier;
            }
            if (this.m_loss == 2 || this.m_loss == 1 || this.m_loss == 4 || this.m_loss == 0 && z < 1.0 || this.m_loss == 3 && Math.abs(z) > this.m_epsilon) {
                double factor = this.m_learningRate * y * this.dloss(z);
                int n1 = instance.numValues();
                for (int p1 = 0; p1 < n1; ++p1) {
                    int indS = instance.index(p1);
                    if (indS == instance.classIndex() || instance.isMissingSparse(p1)) continue;
                    int n = indS;
                    this.m_weights[n] = this.m_weights[n] + factor * instance.valueSparse(p1);
                }
                int n = this.m_weights.length - 1;
                this.m_weights[n] = this.m_weights[n] + factor;
            }
            this.m_t += 1.0;
        }
    }

    @Override
    public void updateClassifier(Instance instance) throws Exception {
        this.updateClassifier(instance, true);
    }

    @Override
    public double[] distributionForInstance(Instance inst) throws Exception {
        double[] result;
        double[] dArray = result = inst.classAttribute().isNominal() ? new double[2] : new double[1];
        if (this.m_replaceMissing != null) {
            this.m_replaceMissing.input(inst);
            inst = this.m_replaceMissing.output();
        }
        if (this.m_nominalToBinary != null) {
            this.m_nominalToBinary.input(inst);
            inst = this.m_nominalToBinary.output();
        }
        if (this.m_normalize != null) {
            this.m_normalize.input(inst);
            inst = this.m_normalize.output();
        }
        double wx = SGD.dotProd(inst, this.m_weights, inst.classIndex());
        double z = wx + this.m_weights[this.m_weights.length - 1];
        if (inst.classAttribute().isNumeric()) {
            result[0] = z;
            return result;
        }
        if (z <= 0.0) {
            if (this.m_loss == 1) {
                result[0] = 1.0 / (1.0 + Math.exp(z));
                result[1] = 1.0 - result[0];
            } else {
                result[0] = 1.0;
            }
        } else if (this.m_loss == 1) {
            result[1] = 1.0 / (1.0 + Math.exp(-z));
            result[0] = 1.0 - result[1];
        } else {
            result[1] = 1.0;
        }
        return result;
    }

    public double[] getWeights() {
        return this.m_weights;
    }

    public String toString() {
        if (this.m_weights == null) {
            return "SGD: No model built yet.\n";
        }
        StringBuffer buff = new StringBuffer();
        buff.append("Loss function: ");
        if (this.m_loss == 0) {
            buff.append("Hinge loss (SVM)\n\n");
        } else if (this.m_loss == 1) {
            buff.append("Log loss (logistic regression)\n\n");
        } else {
            buff.append("Squared loss (linear regression)\n\n");
        }
        buff.append(this.m_data.classAttribute().name() + " = \n\n");
        int printed = 0;
        for (int i = 0; i < this.m_weights.length - 1; ++i) {
            if (i == this.m_data.classIndex()) continue;
            if (printed > 0) {
                buff.append(" + ");
            } else {
                buff.append("   ");
            }
            buff.append(Utils.doubleToString(this.m_weights[i], 12, 4) + " " + (this.m_normalize != null ? "(normalized) " : "") + this.m_data.attribute(i).name() + "\n");
            ++printed;
        }
        if (this.m_weights[this.m_weights.length - 1] > 0.0) {
            buff.append(" + " + Utils.doubleToString(this.m_weights[this.m_weights.length - 1], 12, 4));
        } else {
            buff.append(" - " + Utils.doubleToString(-this.m_weights[this.m_weights.length - 1], 12, 4));
        }
        return buff.toString();
    }

    @Override
    public String getRevision() {
        return RevisionUtils.extract("$Revision: 9785 $");
    }

    @Override
    public SGD aggregate(SGD toAggregate) throws Exception {
        if (this.m_weights == null) {
            throw new Exception("No model built yet, can't aggregate");
        }
        if (!this.m_data.equalHeaders(toAggregate.m_data)) {
            throw new Exception("Can't aggregate - data headers dont match: " + this.m_data.equalHeadersMsg(toAggregate.m_data));
        }
        if (this.m_weights.length != toAggregate.getWeights().length) {
            throw new Exception("Can't aggregate - SDG to aggregate has weight vector that differs in length from ours.");
        }
        for (int i = 0; i < this.m_weights.length; ++i) {
            int n = i;
            this.m_weights[n] = this.m_weights[n] + toAggregate.getWeights()[i];
        }
        ++this.m_numModels;
        return this;
    }

    @Override
    public void finalizeAggregation() throws Exception {
        if (this.m_numModels == 0) {
            throw new Exception("Unable to finalize aggregation - haven't seen any models to aggregate");
        }
        int i = 0;
        while (i < this.m_weights.length) {
            int n = i++;
            this.m_weights[n] = this.m_weights[n] / (double)(this.m_numModels + 1);
        }
        this.m_numModels = 0;
    }

    public static void main(String[] args) {
        SGD.runClassifier(new SGD(), args);
    }
}

