/*
 * Decompiled with CFR 0.152.
 */
package dr.inference.operators;

import cern.colt.matrix.impl.DenseDoubleMatrix2D;
import cern.colt.matrix.linalg.SingularValueDecomposition;
import dr.inference.model.CompoundParameter;
import dr.inference.model.MatrixParameter;
import dr.inference.model.Parameter;
import dr.inference.operators.AbstractAdaptableOperator;
import dr.inference.operators.AdaptationMode;
import dr.inference.operators.MCMCOperator;
import dr.math.MathUtils;
import dr.math.matrixAlgebra.CholeskyDecomposition;
import dr.math.matrixAlgebra.IllegalDimension;
import dr.math.matrixAlgebra.SymmetricMatrix;
import dr.util.Author;
import dr.util.Citable;
import dr.util.Citation;
import dr.util.Transform;
import dr.xml.AbstractXMLObjectParser;
import dr.xml.AttributeRule;
import dr.xml.ElementRule;
import dr.xml.XMLObject;
import dr.xml.XMLObjectParser;
import dr.xml.XMLParseException;
import dr.xml.XMLSyntaxRule;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;

public class AdaptableVarianceMultivariateNormalOperator
extends AbstractAdaptableOperator
implements Citable {
    public static final String AVMVN_OPERATOR = "adaptableVarianceMultivariateNormalOperator";
    public static final String SCALE_FACTOR = "scaleFactor";
    public static final String BETA = "beta";
    public static final String INITIAL = "initial";
    public static final String BURNIN = "burnin";
    public static final String UPDATE_EVERY = "updateEvery";
    public static final String FORM_XTX = "formXtXInverse";
    public static final String COEFFICIENT = "coefficient";
    public static final String SKIP_RANK_CHECK = "skipRankCheck";
    public static final String TRANSFORM = "transform";
    public static final String TYPE = "type";
    public static final boolean DEBUG = false;
    public static final boolean PRINT_FULL_MATRIX = false;
    private double scaleFactor;
    private double beta;
    private int iterations;
    private int updates;
    private int initial;
    private int burnin;
    private int every;
    private final Parameter parameter;
    private final Transform[] transformations;
    private final int[] transformationSizes;
    private final double[] transformationSums;
    private final int dim;
    private double[] oldMeans;
    private double[] newMeans;
    final double[][] matrix;
    private double[][] empirical;
    private double[][] cholesky;
    private double[] epsilon;
    private double[][] proposal;
    public static final boolean MULTI = true;
    public static XMLObjectParser PARSER = new AbstractXMLObjectParser(){
        private final XMLSyntaxRule[] rules = new XMLSyntaxRule[]{AttributeRule.newDoubleRule("scaleFactor"), AttributeRule.newDoubleRule("weight"), AttributeRule.newDoubleRule("beta"), AttributeRule.newDoubleRule("coefficient"), AttributeRule.newIntegerRule("initial"), AttributeRule.newIntegerRule("burnin", true), AttributeRule.newIntegerRule("updateEvery", true), AttributeRule.newBooleanRule("autoOptimize", true), AttributeRule.newBooleanRule("formXtXInverse", true), AttributeRule.newBooleanRule("skipRankCheck", true), new ElementRule(Parameter.class, 0, Integer.MAX_VALUE), new ElementRule(Transform.ParsedTransform.class, 0, Integer.MAX_VALUE)};

        @Override
        public String getParserName() {
            return AdaptableVarianceMultivariateNormalOperator.AVMVN_OPERATOR;
        }

        @Override
        public Object parseXMLObject(XMLObject xMLObject) throws XMLParseException {
            int n;
            int n2;
            Object object;
            Object object2;
            Transform[] transformArray;
            Parameter parameter;
            Object object3;
            Object object4;
            Object object5;
            AdaptationMode adaptationMode = AdaptationMode.parseMode(xMLObject);
            double d = xMLObject.getDoubleAttribute("weight");
            double d2 = xMLObject.getDoubleAttribute(AdaptableVarianceMultivariateNormalOperator.BETA);
            int n3 = xMLObject.getIntegerAttribute(AdaptableVarianceMultivariateNormalOperator.INITIAL);
            double d3 = xMLObject.getDoubleAttribute(AdaptableVarianceMultivariateNormalOperator.SCALE_FACTOR);
            double d4 = xMLObject.getDoubleAttribute(AdaptableVarianceMultivariateNormalOperator.COEFFICIENT);
            int n4 = 0;
            int n5 = 1;
            if (xMLObject.hasAttribute(AdaptableVarianceMultivariateNormalOperator.BURNIN)) {
                n4 = xMLObject.getIntegerAttribute(AdaptableVarianceMultivariateNormalOperator.BURNIN);
            }
            if (n4 > n3 || n4 < 0) {
                throw new XMLParseException("Burn-in must be smaller than the initial period.");
            }
            if (xMLObject.hasAttribute(AdaptableVarianceMultivariateNormalOperator.UPDATE_EVERY)) {
                n5 = xMLObject.getIntegerAttribute(AdaptableVarianceMultivariateNormalOperator.UPDATE_EVERY);
            }
            if (n5 <= 0) {
                throw new XMLParseException("Covariance matrix needs to be updated at least every single iteration.");
            }
            if (d3 <= 0.0) {
                throw new XMLParseException("ScaleFactor must be greater than zero.");
            }
            boolean bl = xMLObject.getAttribute(AdaptableVarianceMultivariateNormalOperator.FORM_XTX, false);
            Transform.ParsedTransform parsedTransform = (Transform.ParsedTransform)xMLObject.getChild(Transform.ParsedTransform.class);
            if (parsedTransform == null) {
                throw new XMLParseException("No valid transformations have been provided in the XML file.");
            }
            boolean bl2 = parsedTransform.parameters == null;
            int n6 = 0;
            if (!bl2) {
                object5 = new CompoundParameter("allParameters");
                object4 = new ArrayList();
                object3 = new ArrayList();
                ArrayList<Double> arrayList = new ArrayList<Double>();
                for (Object object6 : xMLObject.getChildren()) {
                    Object object7;
                    if (object6 instanceof Parameter) {
                        object4.add(Transform.NONE);
                        object7 = (Parameter)object6;
                        ((CompoundParameter)object5).addParameter((Parameter)object7);
                        object3.add(object7.getDimension());
                        arrayList.add(0.0);
                        continue;
                    }
                    if (object6 instanceof Transform.ParsedTransform) {
                        object7 = (Transform.ParsedTransform)object6;
                        object4.add(((Transform.ParsedTransform)object7).transform);
                        int n7 = 0;
                        for (Parameter parameter2 : ((Transform.ParsedTransform)object7).parameters) {
                            ((CompoundParameter)object5).addParameter(parameter2);
                            n7 += parameter2.getDimension();
                        }
                        object3.add(n7);
                        arrayList.add(((Transform.ParsedTransform)object7).fixedSum);
                        continue;
                    }
                    throw new XMLParseException("Unknown element in adaptableVarianceMultivariateNormalOperator");
                }
                parameter = object5;
                transformArray = new Transform[parameter.getDimension()];
                object2 = new int[parameter.getDimension()];
                object = new double[parameter.getDimension()];
                int n8 = 0;
                for (n2 = 0; n2 < object3.size(); ++n2) {
                    if (!((Transform)object4.get(n2)).getTransformName().equals(Transform.LOG_CONSTRAINED_SUM.getTransformName())) {
                        for (int i = 0; i < (Integer)object3.get(n2); ++i) {
                            transformArray[n8] = (Transform)object4.get(n2);
                            object2[n8] = 1;
                            object[n8] = (Double)arrayList.get(n2);
                            ++n8;
                            ++n6;
                        }
                        continue;
                    }
                    transformArray[n8] = (Transform)object4.get(n2);
                    object2[n8] = (Integer)object3.get(n2);
                    object[n8] = (Double)arrayList.get(n2);
                    ++n8;
                    ++n6;
                }
            } else {
                parameter = (Parameter)xMLObject.getChild(Parameter.class);
                transformArray = new Transform[parameter.getDimension()];
                object2 = new int[parameter.getDimension()];
                object = new double[parameter.getDimension()];
                for (int i = 0; i < xMLObject.getChildCount(); ++i) {
                    object4 = xMLObject.getChild(i);
                    if (!(object4 instanceof Transform.ParsedTransform)) continue;
                    object3 = (Transform.ParsedTransform)object4;
                    if (((Transform.ParsedTransform)object3).transform.getTransformName().equals(Transform.LOG_CONSTRAINED_SUM.getTransformName())) {
                        transformArray[n6] = ((Transform.ParsedTransform)object3).transform;
                        object2[n6] = ((Transform.ParsedTransform)object3).end - ((Transform.ParsedTransform)object3).start;
                        object[n6] = ((Transform.ParsedTransform)object3).end - ((Transform.ParsedTransform)object3).start;
                        ++n6;
                        continue;
                    }
                    for (int j = ((Transform.ParsedTransform)object3).start; j < ((Transform.ParsedTransform)object3).end; ++j) {
                        transformArray[n6] = ((Transform.ParsedTransform)object3).transform;
                        object2[n6] = 1;
                        object[n6] = ((Transform.ParsedTransform)object3).fixedSum;
                        ++n6;
                    }
                }
            }
            object5 = new int[n6];
            object4 = new Transform[n6];
            object3 = new double[n6];
            for (n = 0; n < ((Object)object5).length; ++n) {
                object5[n] = object2[n];
                object4[n] = transformArray[n];
                object3[n] = object[n];
                if (object2[n] != 0 && object5[n] != false) continue;
                throw new XMLParseException("Transformation size 0 encountered");
            }
            object2 = object5;
            transformArray = object4;
            object = object3;
            n = parameter.getDimension();
            if (n3 <= 2 * n) {
                n3 = 2 * n;
            }
            Parameter[] parameterArray = new Parameter[n];
            for (n2 = 0; n2 < n; ++n2) {
                parameterArray[n2] = new Parameter.Default(n, 0.0);
            }
            for (n2 = 0; n2 < n; ++n2) {
                parameterArray[n2].setParameterValue(n2, Math.pow(d4, 2.0) / (double)n);
            }
            MatrixParameter matrixParameter = new MatrixParameter(null, parameterArray);
            if (!bl && matrixParameter.getColumnDimension() != matrixParameter.getRowDimension()) {
                throw new XMLParseException("The variance matrix is not square");
            }
            if (matrixParameter.getColumnDimension() != parameter.getDimension()) {
                throw new XMLParseException("The parameter and variance matrix have differing dimensions");
            }
            boolean bl3 = xMLObject.getAttribute(AdaptableVarianceMultivariateNormalOperator.SKIP_RANK_CHECK, false);
            return new AdaptableVarianceMultivariateNormalOperator(parameter, transformArray, (int[])object2, (double[])object, d3, matrixParameter, d, d2, n3, n4, n5, adaptationMode, !bl, bl3);
        }

        @Override
        public String getParserDescription() {
            return "This element returns an adaptable variance multivariate normal operator on a given parameter.";
        }

        @Override
        public Class getReturnType() {
            return MCMCOperator.class;
        }

        @Override
        public XMLSyntaxRule[] getSyntaxRules() {
            return this.rules;
        }
    };

    public AdaptableVarianceMultivariateNormalOperator(Parameter parameter, Transform[] transformArray, int[] nArray, double[] dArray, double d, double[][] dArray2, double d2, double d3, int n, int n2, int n3, AdaptationMode adaptationMode, boolean bl, boolean bl2) {
        super(adaptationMode);
        SingularValueDecomposition singularValueDecomposition;
        this.scaleFactor = d;
        this.parameter = parameter;
        this.transformations = transformArray;
        this.transformationSizes = nArray;
        this.transformationSums = dArray;
        this.beta = d3;
        this.iterations = 0;
        this.updates = 0;
        this.setWeight(d2);
        this.dim = parameter.getDimension();
        this.initial = n;
        this.burnin = n2;
        this.every = n3;
        this.empirical = new double[this.dim][this.dim];
        this.oldMeans = new double[this.dim];
        this.newMeans = new double[this.dim];
        this.epsilon = new double[this.dim];
        this.proposal = new double[this.dim][this.dim];
        if (!bl2 && dArray2[0].length != (singularValueDecomposition = new SingularValueDecomposition(new DenseDoubleMatrix2D(dArray2))).rank()) {
            throw new RuntimeException("Variance matrix in AdaptableVarianceMultivariateNormalOperator is not of full rank");
        }
        this.matrix = bl ? dArray2 : this.formXtXInverse(dArray2);
        try {
            this.cholesky = new CholeskyDecomposition(this.matrix).getL();
        }
        catch (IllegalDimension illegalDimension) {
            throw new RuntimeException("Unable to decompose matrix in AdaptableVarianceMultivariateNormalOperator");
        }
    }

    public AdaptableVarianceMultivariateNormalOperator(Parameter parameter, Transform[] transformArray, int[] nArray, double[] dArray, double d, MatrixParameter matrixParameter, double d2, double d3, int n, int n2, int n3, AdaptationMode adaptationMode, boolean bl, boolean bl2) {
        this(parameter, transformArray, nArray, dArray, d, matrixParameter.getParameterAsMatrix(), d2, d3, n, n2, n3, adaptationMode, bl, bl2);
    }

    private double[][] formXtXInverse(double[][] dArray) {
        int n = dArray.length;
        int n2 = dArray[0].length;
        double[][] dArray2 = new double[n2][n2];
        for (int i = 0; i < n2; ++i) {
            for (int j = 0; j < n2; ++j) {
                int n3 = 0;
                for (int k = 0; k < n; ++k) {
                    n3 = (int)((double)n3 + dArray[k][i] * dArray[k][j]);
                }
                dArray2[i][j] = n3;
            }
        }
        dArray2 = new SymmetricMatrix(dArray2).inverse().toComponents();
        return dArray2;
    }

    private double calculateCovariance(int n, double d, double[] dArray, int n2, int n3) {
        double d2 = d * (double)(n - 2);
        d2 += dArray[n2] * dArray[n3];
        d2 += (double)(n - 1) * this.oldMeans[n2] * this.oldMeans[n3] - (double)n * this.newMeans[n2] * this.newMeans[n3];
        return d2 /= (double)(n - 1);
    }

    @Override
    public double doOperation() {
        int n;
        int n2;
        ++this.iterations;
        double[] dArray = this.parameter.getParameterValues();
        double[] dArray2 = new double[this.dim];
        int n3 = 0;
        for (int i = 0; i < this.transformationSizes.length; ++i) {
            if (this.transformationSizes[i] > 1) {
                System.arraycopy(this.transformations[i].transform(dArray, n3, n3 + this.transformationSizes[i] - 1), 0, dArray2, n3, this.transformationSizes[i]);
            } else {
                dArray2[n3] = this.transformations[i].transform(dArray[n3]);
            }
            n3 += this.transformationSizes[i];
        }
        double d = 0.0;
        if (this.iterations > 1 && this.iterations > this.burnin) {
            if (this.iterations > this.burnin + 1) {
                if (this.iterations % this.every == 0) {
                    ++this.updates;
                    for (n2 = 0; n2 < this.dim; ++n2) {
                        this.newMeans[n2] = (this.oldMeans[n2] * (double)(this.updates - 1) + dArray2[n2]) / (double)this.updates;
                    }
                    if (this.updates > 1) {
                        for (n2 = 0; n2 < this.dim; ++n2) {
                            for (n = n2; n < this.dim; ++n) {
                                this.empirical[n2][n] = this.calculateCovariance(this.updates, this.empirical[n2][n], dArray2, n2, n);
                                this.empirical[n][n2] = this.empirical[n2][n];
                            }
                        }
                    }
                }
            } else if (this.iterations == this.burnin + 1) {
                for (n2 = 0; n2 < this.dim; ++n2) {
                    this.oldMeans[n2] = 0.0;
                    this.newMeans[n2] = 0.0;
                }
                for (n2 = 0; n2 < this.dim; ++n2) {
                    for (n = 0; n < this.dim; ++n) {
                        this.empirical[n2][n] = 0.0;
                    }
                }
            }
        } else if (this.iterations == 1) {
            for (n2 = 0; n2 < this.dim; ++n2) {
                this.oldMeans[n2] = 0.0;
                this.newMeans[n2] = 0.0;
            }
            for (n2 = 0; n2 < this.dim; ++n2) {
                for (n = 0; n < this.dim; ++n) {
                    this.empirical[n2][n] = 0.0;
                    this.proposal[n2][n] = this.matrix[n2][n];
                }
            }
        }
        for (n2 = 0; n2 < this.dim; ++n2) {
            this.epsilon[n2] = this.scaleFactor * MathUtils.nextGaussian();
        }
        if (this.iterations > this.initial && this.iterations % this.every == 0) {
            for (n2 = 0; n2 < this.dim; ++n2) {
                for (n = n2; n < this.dim; ++n) {
                    double d2 = (1.0 - this.beta) * this.empirical[n2][n] + this.beta * this.matrix[n2][n];
                    this.proposal[n2][n] = d2;
                    this.proposal[n][n2] = d2;
                }
            }
            try {
                this.cholesky = new CholeskyDecomposition(this.proposal).getL();
            }
            catch (IllegalDimension illegalDimension) {
                throw new RuntimeException("Unable to decompose matrix in AdaptableVarianceMultivariateNormalOperator");
            }
        }
        for (n2 = 0; n2 < this.dim; ++n2) {
            for (n = n2; n < this.dim; ++n) {
                int n4 = n2;
                dArray2[n4] = dArray2[n4] + this.cholesky[n][n2] * this.epsilon[n];
            }
        }
        n3 = 0;
        for (n2 = 0; n2 < this.transformationSizes.length; ++n2) {
            if (this.transformationSizes[n2] > 1) {
                double[] dArray3 = this.transformations[n2].inverse(dArray2, n3, n3 + this.transformationSizes[n2] - 1, this.transformationSums[n2]);
                for (int i = 0; i < dArray3.length; ++i) {
                    this.parameter.setParameterValueQuietly(n3 + i, dArray3[i]);
                }
                d += this.transformations[n2].getLogJacobian(dArray, n3, n3 + this.transformationSizes[n2] - 1) - this.transformations[n2].getLogJacobian(dArray3, 0, this.transformationSizes[n2] - 1);
            } else {
                this.parameter.setParameterValueQuietly(n3, this.transformations[n2].inverse(dArray2[n3]));
                d += this.transformations[n2].getLogJacobian(dArray[n3]) - this.transformations[n2].getLogJacobian(this.parameter.getParameterValue(n3));
            }
            n3 += this.transformationSizes[n2];
        }
        this.parameter.fireParameterChangedEvent();
        if (this.iterations % this.every == 0) {
            double[] dArray4 = this.oldMeans;
            this.oldMeans = this.newMeans;
            this.newMeans = dArray4;
        }
        return d;
    }

    public String toString() {
        return "adaptableVarianceMultivariateNormalOperator(" + this.parameter.getParameterName() + ")";
    }

    public Parameter getParameter() {
        return this.parameter;
    }

    public void provideSamples(ArrayList<ArrayList<Double>> arrayList) {
        int n;
        int n2;
        if (this.parameter.getDimension() != arrayList.size()) {
            throw new RuntimeException("Dimension mismatch in AVMVN Operator: inconsistent parameter dimensions");
        }
        int n3 = arrayList.get(0).size();
        for (n2 = 0; n2 < arrayList.size(); ++n2) {
            if (arrayList.get(n2).size() >= n3) continue;
            n3 = arrayList.get(n2).size();
        }
        this.iterations = n3;
        this.updates = n3;
        this.beta = 0.0;
        n2 = 0;
        while (n2 < arrayList.size()) {
            for (n = 0; n < n3; ++n) {
                int n4 = n2;
                this.newMeans[n4] = this.newMeans[n4] + this.transformations[n2].transform(arrayList.get(n2).get(n));
            }
            int n5 = n2++;
            this.newMeans[n5] = this.newMeans[n5] / (double)n3;
        }
        for (n2 = 0; n2 < this.dim; ++n2) {
            for (n = n2; n < this.dim; ++n) {
                for (int i = 0; i < n3; ++i) {
                    double[] dArray = this.empirical[n2];
                    int n6 = n;
                    dArray[n6] = dArray[n6] + this.transformations[n2].transform(arrayList.get(n2).get(i)) * this.transformations[n2].transform(arrayList.get(n).get(i));
                }
                double[] dArray = this.empirical[n2];
                int n7 = n;
                dArray[n7] = dArray[n7] / (double)n3;
                double[] dArray2 = this.empirical[n2];
                int n8 = n;
                dArray2[n8] = dArray2[n8] - this.newMeans[n2] * this.newMeans[n];
                this.empirical[n][n2] = this.empirical[n2][n];
            }
        }
    }

    @Override
    public final String getOperatorName() {
        String string = "adaptableVarianceMultivariateNormal(" + this.parameter.getParameterName() + ")";
        return string;
    }

    @Override
    protected double getAdaptableParameterValue() {
        return Math.log(this.scaleFactor);
    }

    @Override
    public void setAdaptableParameterValue(double d) {
        this.scaleFactor = Math.exp(d);
    }

    @Override
    public double getRawParameter() {
        return this.scaleFactor;
    }

    public double getScaleFactor() {
        return this.scaleFactor;
    }

    @Override
    public String getAdaptableParameterName() {
        return SCALE_FACTOR;
    }

    @Override
    public Citation.Category getCategory() {
        return Citation.Category.FRAMEWORK;
    }

    @Override
    public String getDescription() {
        return "Adaptive MCMC estimation method of continuous parameters";
    }

    @Override
    public List<Citation> getCitations() {
        return Collections.singletonList(new Citation(new Author[]{new Author("G", "Baele"), new Author("P", "Lemey"), new Author("A", "Rambaut"), new Author("MA", "Suchard")}, "Adaptive MCMC in Bayesian phylogenetics: an application to analyzing partitioned data in BEAST", 2017, "Bioinformatics", 33, 1798, 1805, Citation.Status.PUBLISHED));
    }
}

