/*
 * Decompiled with CFR 0.152.
 */
package dr.evomodel.treedatalikelihood.continuous;

import dr.evolution.tree.NodeRef;
import dr.evolution.tree.Tree;
import dr.evolution.tree.TreeTrait;
import dr.evomodel.branchratemodel.ArbitraryBranchRates;
import dr.evomodel.branchratemodel.BranchRateModel;
import dr.evomodel.treedatalikelihood.TreeDataLikelihood;
import dr.evomodel.treedatalikelihood.continuous.ContinuousDataLikelihoodDelegate;
import dr.evomodel.treedatalikelihood.preorder.BranchConditionalDistributionDelegate;
import dr.evomodel.treedatalikelihood.preorder.BranchSufficientStatistics;
import dr.evomodel.treedatalikelihood.preorder.ConditionalPrecisionAndTransform2;
import dr.evomodel.treedatalikelihood.preorder.MatrixSufficientStatistics;
import dr.evomodel.treedatalikelihood.preorder.NormalSufficientStatistics;
import dr.inference.hmc.GradientWrtParameterProvider;
import dr.inference.hmc.HessianWrtParameterProvider;
import dr.inference.loggers.LogColumn;
import dr.inference.loggers.Loggable;
import dr.inference.model.Likelihood;
import dr.inference.model.Parameter;
import dr.inference.operators.hmc.NumericalHessianFromGradient;
import dr.math.MultivariateFunction;
import dr.math.NumericalDerivative;
import dr.math.matrixAlgebra.Vector;
import dr.math.matrixAlgebra.WrappedVector;
import dr.math.matrixAlgebra.missingData.MissingOps;
import dr.math.matrixAlgebra.missingData.PermutationIndices;
import dr.xml.Reportable;
import java.util.List;
import org.ejml.data.D1Matrix64F;
import org.ejml.data.DenseMatrix64F;
import org.ejml.ops.CommonOps;

public class BranchRateGradient
implements GradientWrtParameterProvider,
HessianWrtParameterProvider,
Reportable,
Loggable {
    private final TreeDataLikelihood treeDataLikelihood;
    private final TreeTrait<List<BranchSufficientStatistics>> treeTraitProvider;
    private final Tree tree;
    private final int nTraits;
    private final Parameter rateParameter;
    private final ArbitraryBranchRates branchRateModel;
    private final ContinuousTraitGradientForBranch branchProvider;
    private MultivariateFunction numeric1 = new MultivariateFunction(){

        @Override
        public double evaluate(double[] dArray) {
            for (int i = 0; i < dArray.length; ++i) {
                BranchRateGradient.this.rateParameter.setParameterValue(i, dArray[i]);
            }
            BranchRateGradient.this.treeDataLikelihood.makeDirty();
            return BranchRateGradient.this.treeDataLikelihood.getLogLikelihood();
        }

        @Override
        public int getNumArguments() {
            return BranchRateGradient.this.rateParameter.getDimension();
        }

        @Override
        public double getLowerBound(int n) {
            return 0.0;
        }

        @Override
        public double getUpperBound(int n) {
            return Double.POSITIVE_INFINITY;
        }
    };
    private static final boolean DEBUG = false;

    public BranchRateGradient(String string, TreeDataLikelihood treeDataLikelihood, ContinuousDataLikelihoodDelegate continuousDataLikelihoodDelegate, Parameter parameter) {
        TreeTrait treeTrait;
        assert (treeDataLikelihood != null);
        this.treeDataLikelihood = treeDataLikelihood;
        this.tree = treeDataLikelihood.getTree();
        this.rateParameter = parameter;
        BranchRateModel branchRateModel = treeDataLikelihood.getBranchRateModel();
        this.branchRateModel = branchRateModel instanceof ArbitraryBranchRates ? (ArbitraryBranchRates)branchRateModel : null;
        String string2 = BranchConditionalDistributionDelegate.getName(string);
        if (treeDataLikelihood.getTreeTrait(string2) == null) {
            continuousDataLikelihoodDelegate.addBranchConditionalDensityTrait(string);
        }
        this.treeTraitProvider = treeTrait = treeDataLikelihood.getTreeTrait(string2);
        assert (this.treeTraitProvider != null);
        this.nTraits = treeDataLikelihood.getDataLikelihoodDelegate().getTraitCount();
        if (this.nTraits != 1) {
            throw new RuntimeException("Not yet implemented for >1 traits");
        }
        int n = treeDataLikelihood.getDataLikelihoodDelegate().getTraitDim();
        this.branchProvider = new ContinuousTraitGradientForBranch.Default(n);
    }

    @Override
    public Likelihood getLikelihood() {
        return this.treeDataLikelihood;
    }

    @Override
    public Parameter getParameter() {
        return this.rateParameter;
    }

    @Override
    public int getDimension() {
        return this.getParameter().getDimension();
    }

    @Override
    public double[] getGradientLogDensity() {
        this.treeDataLikelihood.makeDirty();
        double[] dArray = new double[this.rateParameter.getDimension()];
        for (int i = 0; i < this.tree.getNodeCount(); ++i) {
            int n;
            NodeRef nodeRef = this.tree.getNode(i);
            if (this.tree.isRoot(nodeRef)) continue;
            List<BranchSufficientStatistics> list = this.treeTraitProvider.getTrait(this.tree, nodeRef);
            assert (list.size() == this.nTraits);
            double d = this.branchRateModel.getBranchRate(this.tree, nodeRef);
            double d2 = this.branchRateModel.getBranchRateDifferential(this.tree, nodeRef);
            double d3 = d2 / d;
            double d4 = 0.0;
            for (n = 0; n < this.nTraits; ++n) {
                d4 += this.branchProvider.getGradientForBranch(list.get(n), d3);
            }
            n = this.getParameterIndexFromNode(nodeRef);
            assert (n != -1);
            dArray[n] = d4;
        }
        return dArray;
    }

    private int getParameterIndexFromNode(NodeRef nodeRef) {
        return this.branchRateModel == null ? nodeRef.getNumber() : this.branchRateModel.getParameterIndexFromNode(nodeRef);
    }

    @Override
    public double[] getDiagonalHessianLogDensity() {
        double[] dArray = new double[this.rateParameter.getDimension()];
        for (int i = 0; i < this.tree.getNodeCount(); ++i) {
            int n;
            NodeRef nodeRef = this.tree.getNode(i);
            if (this.tree.isRoot(nodeRef)) continue;
            List<BranchSufficientStatistics> list = this.treeTraitProvider.getTrait(this.tree, nodeRef);
            assert (list.size() == this.nTraits);
            double d = this.branchRateModel.getBranchRate(this.tree, nodeRef);
            double d2 = this.branchRateModel.getBranchRateDifferential(this.tree, nodeRef);
            double d3 = d2 / d;
            double d4 = this.branchRateModel.getBranchRateSecondDifferential(this.tree, nodeRef);
            double d5 = d4 / d;
            double d6 = 0.0;
            for (n = 0; n < this.nTraits; ++n) {
                d6 += this.getDiagonalHessianLogDensity(list.get(n), d3, d5);
            }
            n = this.getParameterIndexFromNode(nodeRef);
            assert (n != -1);
            dArray[n] = d6;
        }
        return dArray;
    }

    private double getDiagonalHessianLogDensity(BranchSufficientStatistics branchSufficientStatistics, double d, double d2) {
        int n = this.treeDataLikelihood.getDataLikelihoodDelegate().getTraitDim();
        DenseMatrix64F denseMatrix64F = new DenseMatrix64F(n, n);
        DenseMatrix64F denseMatrix64F2 = new DenseMatrix64F(n, n);
        DenseMatrix64F denseMatrix64F3 = new DenseMatrix64F(n, n);
        DenseMatrix64F denseMatrix64F4 = new DenseMatrix64F(n, n);
        NormalSufficientStatistics normalSufficientStatistics = branchSufficientStatistics.getBelow();
        MatrixSufficientStatistics matrixSufficientStatistics = branchSufficientStatistics.getBranch();
        NormalSufficientStatistics normalSufficientStatistics2 = branchSufficientStatistics.getAbove();
        ContinuousTraitGradientForBranch.Default cfr_ignored_0 = (ContinuousTraitGradientForBranch.Default)this.branchProvider;
        NormalSufficientStatistics normalSufficientStatistics3 = ContinuousTraitGradientForBranch.Default.computeJointStatistics(normalSufficientStatistics, normalSufficientStatistics2, n);
        ((ContinuousTraitGradientForBranch.Default)this.branchProvider).makeDeltaVector(denseMatrix64F4, normalSufficientStatistics3, normalSufficientStatistics2);
        DenseMatrix64F denseMatrix64F5 = denseMatrix64F4;
        DenseMatrix64F denseMatrix64F6 = normalSufficientStatistics2.getRawPrecision();
        DenseMatrix64F denseMatrix64F7 = normalSufficientStatistics3.getRawVariance();
        ((ContinuousTraitGradientForBranch.Default)this.branchProvider).makeGradientMatrices0(denseMatrix64F2, denseMatrix64F, branchSufficientStatistics, d);
        DenseMatrix64F denseMatrix64F8 = denseMatrix64F;
        DenseMatrix64F denseMatrix64F9 = denseMatrix64F2;
        DenseMatrix64F denseMatrix64F10 = denseMatrix64F3;
        CommonOps.mult(denseMatrix64F8, denseMatrix64F8, denseMatrix64F10);
        double d3 = 0.0;
        for (int i = 0; i < n; ++i) {
            d3 += 0.5 * denseMatrix64F10.unsafe_get(i, i);
        }
        DenseMatrix64F denseMatrix64F11 = denseMatrix64F;
        CommonOps.mult(denseMatrix64F10, denseMatrix64F6, denseMatrix64F11);
        DenseMatrix64F denseMatrix64F12 = denseMatrix64F3;
        CommonOps.mult(denseMatrix64F7, denseMatrix64F11, denseMatrix64F12);
        double d4 = 0.0;
        for (int i = 0; i < n; ++i) {
            for (int j = 0; j < n; ++j) {
                d4 -= denseMatrix64F5.unsafe_get(i, 0) * denseMatrix64F11.unsafe_get(i, j) * denseMatrix64F5.unsafe_get(j, 0);
            }
            d4 -= denseMatrix64F12.unsafe_get(i, i);
        }
        ((ContinuousTraitGradientForBranch.Default)this.branchProvider).makeGradientMatrices1(denseMatrix64F3, denseMatrix64F9, normalSufficientStatistics3);
        DenseMatrix64F denseMatrix64F13 = denseMatrix64F3;
        DenseMatrix64F denseMatrix64F14 = denseMatrix64F;
        CommonOps.mult(denseMatrix64F13, denseMatrix64F9, denseMatrix64F14);
        DenseMatrix64F denseMatrix64F15 = denseMatrix64F3;
        CommonOps.mult(denseMatrix64F14, denseMatrix64F7, denseMatrix64F15);
        double d5 = 0.0;
        for (int i = 0; i < n; ++i) {
            for (int j = 0; j < n; ++j) {
                d5 += denseMatrix64F5.unsafe_get(i, 0) * denseMatrix64F14.unsafe_get(i, j) * denseMatrix64F5.unsafe_get(j, 0);
            }
            d5 += 0.5 * denseMatrix64F15.unsafe_get(i, i);
        }
        DenseMatrix64F denseMatrix64F16 = denseMatrix64F;
        CommonOps.scale(d2, matrixSufficientStatistics.getRawVariance(), denseMatrix64F16);
        DenseMatrix64F denseMatrix64F17 = denseMatrix64F3;
        CommonOps.mult(denseMatrix64F6, denseMatrix64F16, denseMatrix64F17);
        DenseMatrix64F denseMatrix64F18 = denseMatrix64F;
        CommonOps.mult(denseMatrix64F17, denseMatrix64F6, denseMatrix64F18);
        DenseMatrix64F denseMatrix64F19 = denseMatrix64F2;
        CommonOps.mult(denseMatrix64F7, denseMatrix64F18, denseMatrix64F19);
        double d6 = 0.0;
        for (int i = 0; i < n; ++i) {
            for (int j = 0; j < n; ++j) {
                d6 += denseMatrix64F5.unsafe_get(i, 0) * denseMatrix64F18.unsafe_get(i, j) * denseMatrix64F5.unsafe_get(j, 0);
            }
            d6 -= 0.5 * denseMatrix64F17.unsafe_get(i, i);
            d6 += 0.5 * denseMatrix64F19.unsafe_get(i, i);
        }
        return d3 + d4 + d5 + d6;
    }

    @Override
    public double[][] getHessianLogDensity() {
        throw new RuntimeException("Not yet implemented!");
    }

    public double[] getNumericalGradient() {
        double[] dArray = this.rateParameter.getParameterValues();
        double[] dArray2 = NumericalDerivative.gradient(this.numeric1, this.rateParameter.getParameterValues());
        for (int i = 0; i < dArray.length; ++i) {
            this.rateParameter.setParameterValue(i, dArray[i]);
        }
        return dArray2;
    }

    @Override
    public String getReport() {
        double[] dArray = this.getNumericalGradient();
        NumericalHessianFromGradient numericalHessianFromGradient = new NumericalHessianFromGradient(this);
        StringBuilder stringBuilder = new StringBuilder();
        stringBuilder.append("Peeling: ").append(new Vector(this.getGradientLogDensity()));
        stringBuilder.append("\n");
        stringBuilder.append("numeric: ").append(new Vector(dArray));
        stringBuilder.append("\n");
        stringBuilder.append("Peeling diagonal hessian: ").append(new Vector(this.getDiagonalHessianLogDensity()));
        stringBuilder.append("\n");
        stringBuilder.append("numeric diagonal hessian: ").append(new Vector(NumericalDerivative.diagonalHessian(this.numeric1, this.getParameter().getParameterValues())));
        stringBuilder.append("\n");
        stringBuilder.append("Another numeric diagonal hessian: ").append(new Vector(numericalHessianFromGradient.getDiagonalHessianLogDensity()));
        stringBuilder.append("\n");
        return stringBuilder.toString();
    }

    @Override
    public LogColumn[] getColumns() {
        LogColumn[] logColumnArray = new LogColumn[]{new LogColumn.Default("gradient report", new Object(){

            public String toString() {
                return "\n" + BranchRateGradient.this.getReport();
            }
        })};
        return logColumnArray;
    }

    static interface ContinuousTraitGradientForBranch {
        public double getGradientForBranch(BranchSufficientStatistics var1, double var2);

        public static class Default
        implements ContinuousTraitGradientForBranch {
            private final DenseMatrix64F matrix0;
            private final DenseMatrix64F matrix1;
            private final DenseMatrix64F vector0;
            private final int dim;

            public Default(int n) {
                this.dim = n;
                this.matrix0 = new DenseMatrix64F(n, n);
                this.matrix1 = new DenseMatrix64F(n, n);
                this.vector0 = new DenseMatrix64F(n, 1);
            }

            @Override
            public double getGradientForBranch(BranchSufficientStatistics branchSufficientStatistics, double d) {
                NormalSufficientStatistics normalSufficientStatistics = branchSufficientStatistics.getBelow();
                MatrixSufficientStatistics matrixSufficientStatistics = branchSufficientStatistics.getBranch();
                NormalSufficientStatistics normalSufficientStatistics2 = branchSufficientStatistics.getAbove();
                DenseMatrix64F denseMatrix64F = normalSufficientStatistics2.getRawPrecision();
                DenseMatrix64F denseMatrix64F2 = this.matrix0;
                DenseMatrix64F denseMatrix64F3 = this.matrix1;
                this.makeGradientMatrices0(denseMatrix64F3, denseMatrix64F2, branchSufficientStatistics, d);
                double d2 = 0.0;
                for (int i = 0; i < this.dim; ++i) {
                    d2 -= 0.5 * denseMatrix64F2.unsafe_get(i, i);
                }
                NormalSufficientStatistics normalSufficientStatistics3 = Default.computeJointStatistics(normalSufficientStatistics, normalSufficientStatistics2, this.dim);
                DenseMatrix64F denseMatrix64F4 = this.matrix0;
                this.makeGradientMatrices1(denseMatrix64F4, denseMatrix64F3, normalSufficientStatistics3);
                DenseMatrix64F denseMatrix64F5 = this.vector0;
                this.makeDeltaVector(denseMatrix64F5, normalSufficientStatistics3, normalSufficientStatistics2);
                double d3 = 0.0;
                for (int i = 0; i < this.dim; ++i) {
                    for (int j = 0; j < this.dim; ++j) {
                        d3 += 0.5 * denseMatrix64F5.unsafe_get(i, 0) * denseMatrix64F3.unsafe_get(i, j) * denseMatrix64F5.unsafe_get(j, 0);
                    }
                    d3 += 0.5 * denseMatrix64F4.unsafe_get(i, i);
                }
                DenseMatrix64F denseMatrix64F6 = new DenseMatrix64F(this.dim, 1);
                CommonOps.scale(d, matrixSufficientStatistics.getRawMean(), denseMatrix64F6);
                double d4 = 0.0;
                for (int i = 0; i < this.dim; ++i) {
                    for (int j = 0; j < this.dim; ++j) {
                        d4 += denseMatrix64F5.unsafe_get(i, 0) * denseMatrix64F.unsafe_get(i, j) * denseMatrix64F6.unsafe_get(j, 0);
                    }
                }
                return d2 + d3 + d4;
            }

            public static NormalSufficientStatistics computeJointStatistics(NormalSufficientStatistics normalSufficientStatistics, NormalSufficientStatistics normalSufficientStatistics2, int n) {
                PermutationIndices permutationIndices = new PermutationIndices(normalSufficientStatistics.getRawPrecision());
                if (permutationIndices.getNumberOfInfiniteDiagonals() == n) {
                    return Default.computeJointFullyObserved(normalSufficientStatistics, n);
                }
                if (permutationIndices.getNumberOfZeroDiagonals() == n) {
                    return Default.computeJointFullyMissing(normalSufficientStatistics2, n);
                }
                if (permutationIndices.getNumberOfZeroDiagonals() == 0 || permutationIndices.getNumberOfInfiniteDiagonals() == 0) {
                    return Default.computeJointLatent(normalSufficientStatistics, normalSufficientStatistics2, n);
                }
                return Default.computeJointPartiallyMissing(normalSufficientStatistics, normalSufficientStatistics2, permutationIndices, n);
            }

            private static NormalSufficientStatistics computeJointFullyObserved(NormalSufficientStatistics normalSufficientStatistics, int n) {
                return new NormalSufficientStatistics(normalSufficientStatistics.getRawMean(), normalSufficientStatistics.getRawPrecision(), new DenseMatrix64F(n, n));
            }

            private static NormalSufficientStatistics computeJointFullyMissing(NormalSufficientStatistics normalSufficientStatistics, int n) {
                return new NormalSufficientStatistics(normalSufficientStatistics.getRawMean(), normalSufficientStatistics.getRawPrecision(), normalSufficientStatistics.getRawVariance());
            }

            private static NormalSufficientStatistics computeJointLatent(NormalSufficientStatistics normalSufficientStatistics, NormalSufficientStatistics normalSufficientStatistics2, int n) {
                DenseMatrix64F denseMatrix64F = new DenseMatrix64F(n, 1);
                DenseMatrix64F denseMatrix64F2 = new DenseMatrix64F(n, n);
                DenseMatrix64F denseMatrix64F3 = new DenseMatrix64F(n, n);
                CommonOps.add((D1Matrix64F)normalSufficientStatistics.getRawPrecision(), normalSufficientStatistics2.getRawPrecision(), (D1Matrix64F)denseMatrix64F2);
                MissingOps.safeInvert2(denseMatrix64F2, denseMatrix64F3, false);
                MissingOps.safeWeightedAverage(new WrappedVector.Raw(normalSufficientStatistics.getRawMean().getData(), 0, n), normalSufficientStatistics.getRawPrecision(), new WrappedVector.Raw(normalSufficientStatistics2.getRawMean().getData(), 0, n), normalSufficientStatistics2.getRawPrecision(), new WrappedVector.Raw(denseMatrix64F.getData(), 0, n), denseMatrix64F3, n);
                return new NormalSufficientStatistics(denseMatrix64F, denseMatrix64F2, denseMatrix64F3);
            }

            private static NormalSufficientStatistics computeJointPartiallyMissing(NormalSufficientStatistics normalSufficientStatistics, NormalSufficientStatistics normalSufficientStatistics2, PermutationIndices permutationIndices, int n) {
                DenseMatrix64F denseMatrix64F = new DenseMatrix64F(n, 1);
                DenseMatrix64F denseMatrix64F2 = new DenseMatrix64F(n, n);
                DenseMatrix64F denseMatrix64F3 = new DenseMatrix64F(n, n);
                if (permutationIndices.getNumberOfNonZeroFiniteDiagonals() != 0) {
                    throw new RuntimeException("Unsure if this works for latent trait below");
                }
                ConditionalPrecisionAndTransform2 conditionalPrecisionAndTransform2 = new ConditionalPrecisionAndTransform2(normalSufficientStatistics2.getRawPrecision(), permutationIndices.getZeroIndices(), permutationIndices.getInfiniteIndices());
                double[] dArray = conditionalPrecisionAndTransform2.getConditionalMean(normalSufficientStatistics.getRawMean().getData(), 0, normalSufficientStatistics2.getRawMean().getData(), 0);
                MissingOps.copyRowsAndColumns(normalSufficientStatistics2.getRawPrecision(), denseMatrix64F2, permutationIndices.getZeroIndices(), permutationIndices.getZeroIndices(), false);
                MissingOps.scatterRowsAndColumns(conditionalPrecisionAndTransform2.getConditionalVariance(), denseMatrix64F3, permutationIndices.getZeroIndices(), permutationIndices.getZeroIndices(), false);
                int n2 = 0;
                for (int n3 : permutationIndices.getZeroIndices()) {
                    denseMatrix64F.unsafe_set(n3, 0, dArray[n2++]);
                }
                for (int n3 : permutationIndices.getInfiniteIndices()) {
                    denseMatrix64F.unsafe_set(n3, 0, normalSufficientStatistics.getMean(n3));
                    denseMatrix64F2.unsafe_set(n3, n3, Double.POSITIVE_INFINITY);
                }
                return new NormalSufficientStatistics(denseMatrix64F, denseMatrix64F2, denseMatrix64F3);
            }

            public void makeGradientMatrices0(DenseMatrix64F denseMatrix64F, DenseMatrix64F denseMatrix64F2, BranchSufficientStatistics branchSufficientStatistics, double d) {
                NormalSufficientStatistics normalSufficientStatistics = branchSufficientStatistics.getAbove();
                MatrixSufficientStatistics matrixSufficientStatistics = branchSufficientStatistics.getBranch();
                DenseMatrix64F denseMatrix64F3 = normalSufficientStatistics.getRawPrecision();
                CommonOps.scale(d, matrixSufficientStatistics.getRawVariance(), denseMatrix64F);
                CommonOps.mult(denseMatrix64F3, denseMatrix64F, denseMatrix64F2);
                CommonOps.mult(denseMatrix64F2, denseMatrix64F3, denseMatrix64F);
            }

            public void makeGradientMatrices1(DenseMatrix64F denseMatrix64F, DenseMatrix64F denseMatrix64F2, NormalSufficientStatistics normalSufficientStatistics) {
                CommonOps.mult(denseMatrix64F2, normalSufficientStatistics.getRawVariance(), denseMatrix64F);
            }

            public void makeDeltaVector(DenseMatrix64F denseMatrix64F, NormalSufficientStatistics normalSufficientStatistics, NormalSufficientStatistics normalSufficientStatistics2) {
                for (int i = 0; i < this.dim; ++i) {
                    denseMatrix64F.unsafe_set(i, 0, normalSufficientStatistics.getRawMean().unsafe_get(i, 0) - normalSufficientStatistics2.getMean(i));
                }
            }
        }
    }
}

