/*
 * Decompiled with CFR 0.152.
 */
package dr.evomodel.treedatalikelihood.continuous;

import dr.evolution.tree.BranchRates;
import dr.evolution.tree.NodeRef;
import dr.evolution.tree.Tree;
import dr.evolution.tree.TreeUtils;
import dr.evomodel.treedatalikelihood.continuous.DiffusionProcessDelegate;
import dr.evomodel.treedatalikelihood.continuous.cdi.ContinuousDiffusionIntegrator;
import dr.math.KroneckerOperation;
import dr.math.matrixAlgebra.IllegalDimension;
import dr.math.matrixAlgebra.Matrix;
import java.util.HashSet;

public class MultivariateTraitDebugUtilities {
    public static double getLengthToRoot(Tree tree, BranchRates branchRates, NodeRef nodeRef) {
        double d = 0.0;
        if (!tree.isRoot(nodeRef)) {
            NodeRef nodeRef2 = tree.getParent(nodeRef);
            double d2 = 1.0;
            if (branchRates != null) {
                d2 = branchRates.getBranchRate(tree, nodeRef);
            }
            d += d2 * tree.getBranchLength(nodeRef) + MultivariateTraitDebugUtilities.getLengthToRoot(tree, branchRates, nodeRef2);
        }
        return d;
    }

    private static NodeRef findMRCA(Tree tree, int n, int n2) {
        HashSet<String> hashSet = new HashSet<String>();
        hashSet.add(tree.getTaxonId(n));
        hashSet.add(tree.getTaxonId(n2));
        return TreeUtils.getCommonAncestorNode(tree, hashSet);
    }

    public static void insertPrecision(Tree tree, NodeRef nodeRef, NodeRef nodeRef2, double[][] dArray, double d) {
        double d2 = -1.0 / (tree.getBranchLength(nodeRef2) * d);
        dArray[nodeRef2.getNumber()][nodeRef.getNumber()] = d2;
        dArray[nodeRef.getNumber()][nodeRef2.getNumber()] = d2;
        MultivariateTraitDebugUtilities.recurseGraph(tree, nodeRef2, dArray, d);
    }

    public static void recurseGraph(Tree tree, NodeRef nodeRef, double[][] dArray, double d) {
        if (!tree.isExternal(nodeRef)) {
            MultivariateTraitDebugUtilities.insertPrecision(tree, nodeRef, tree.getChild(nodeRef, 0), dArray, d);
            MultivariateTraitDebugUtilities.insertPrecision(tree, nodeRef, tree.getChild(nodeRef, 1), dArray, d);
        }
    }

    public static double[][] getGraphVariance(Tree tree, BranchRates branchRates, double d, double d2) {
        int n = tree.getNodeCount();
        double[][] dArray = new double[n][n];
        for (int i = 0; i < n; ++i) {
            double d3;
            dArray[i][i] = d3 = MultivariateTraitDebugUtilities.getLengthToRoot(tree, branchRates, tree.getNode(i)) * d;
            for (int j = i + 1; j < n; ++j) {
                NodeRef nodeRef = TreeUtils.getCommonAncestorSafely(tree, tree.getNode(i), tree.getNode(j));
                dArray[i][j] = MultivariateTraitDebugUtilities.getLengthToRoot(tree, branchRates, nodeRef) * d;
            }
        }
        MultivariateTraitDebugUtilities.makeSymmetric(dArray);
        MultivariateTraitDebugUtilities.addPrior(dArray, d2);
        return dArray;
    }

    @Deprecated
    public static double[][] getTreeVarianceOld(Tree tree, BranchRates branchRates, double d, double d2) {
        int n = tree.getExternalNodeCount();
        double[][] dArray = new double[n][n];
        for (int i = 0; i < n; ++i) {
            double d3;
            dArray[i][i] = d3 = MultivariateTraitDebugUtilities.getLengthToRoot(tree, branchRates, tree.getExternalNode(i)) * d;
            for (int j = i + 1; j < n; ++j) {
                NodeRef nodeRef = MultivariateTraitDebugUtilities.findMRCA(tree, i, j);
                dArray[i][j] = MultivariateTraitDebugUtilities.getLengthToRoot(tree, branchRates, nodeRef) * d;
            }
        }
        MultivariateTraitDebugUtilities.makeSymmetric(dArray);
        MultivariateTraitDebugUtilities.addPrior(dArray, d2);
        return dArray;
    }

    public static double[][] getTreeVariance(Tree tree, BranchRates branchRates, double d, double d2) {
        int n = tree.getExternalNodeCount();
        double[][] dArray = new double[n][n];
        NodeRef nodeRef = tree.getRoot();
        MultivariateTraitDebugUtilities.recursiveTreeVariance(dArray, nodeRef, tree, branchRates, d, d2);
        MultivariateTraitDebugUtilities.makeSymmetric(dArray);
        MultivariateTraitDebugUtilities.addPrior(dArray, d2);
        return dArray;
    }

    public static PostOrderBranchStats recursiveTreeVariance(double[][] dArray, NodeRef nodeRef, Tree tree, BranchRates branchRates, double d, double d2) {
        if (tree.isExternal(nodeRef)) {
            return new PostOrderBranchStats(new int[]{nodeRef.getNumber()}, MultivariateTraitDebugUtilities.getScaledBranchLength(nodeRef, tree, branchRates, d));
        }
        PostOrderBranchStats postOrderBranchStats = MultivariateTraitDebugUtilities.recursiveTreeVariance(dArray, tree.getChild(nodeRef, 0), tree, branchRates, d, d2);
        PostOrderBranchStats postOrderBranchStats2 = MultivariateTraitDebugUtilities.recursiveTreeVariance(dArray, tree.getChild(nodeRef, 1), tree, branchRates, d, d2);
        MultivariateTraitDebugUtilities.accumulateBranchLengths(dArray, postOrderBranchStats);
        MultivariateTraitDebugUtilities.accumulateBranchLengths(dArray, postOrderBranchStats2);
        int n = postOrderBranchStats.dims.length;
        int n2 = postOrderBranchStats2.dims.length;
        int n3 = n + n2;
        int[] nArray = new int[n3];
        System.arraycopy(postOrderBranchStats.dims, 0, nArray, 0, n);
        System.arraycopy(postOrderBranchStats2.dims, 0, nArray, n, n2);
        double d3 = MultivariateTraitDebugUtilities.getScaledBranchLength(nodeRef, tree, branchRates, d);
        return new PostOrderBranchStats(nArray, d3);
    }

    private static double getScaledBranchLength(NodeRef nodeRef, Tree tree, BranchRates branchRates, double d) {
        double d2 = tree.getBranchLength(nodeRef);
        if (branchRates != null) {
            d2 *= branchRates.getBranchRate(tree, nodeRef);
        }
        return d2 *= d;
    }

    private static void accumulateBranchLengths(double[][] dArray, PostOrderBranchStats postOrderBranchStats) {
        for (int n : postOrderBranchStats.dims) {
            for (int n2 : postOrderBranchStats.dims) {
                double[] dArray2 = dArray[n];
                int n3 = n2;
                dArray2[n3] = dArray2[n3] + postOrderBranchStats.branchLength;
            }
        }
    }

    private static void makeSymmetric(double[][] dArray) {
        for (int i = 0; i < dArray.length; ++i) {
            for (int j = i + 1; j < dArray[i].length; ++j) {
                dArray[j][i] = dArray[i][j];
            }
        }
    }

    private static void addPrior(double[][] dArray, double d) {
        if (!Double.isInfinite(d)) {
            for (int i = 0; i < dArray.length; ++i) {
                int n = 0;
                while (n < dArray[i].length) {
                    double[] dArray2 = dArray[i];
                    int n2 = n++;
                    dArray2[n2] = dArray2[n2] + 1.0 / d;
                }
            }
        }
    }

    public static double[][] getTreeDrift(Tree tree, double[] dArray, ContinuousDiffusionIntegrator continuousDiffusionIntegrator, DiffusionProcessDelegate diffusionProcessDelegate) {
        int n = continuousDiffusionIntegrator.getDimTrait();
        double[][] dArray2 = new double[tree.getExternalNodeCount()][n];
        for (int i = 0; i < tree.getExternalNodeCount(); ++i) {
            dArray2[i] = diffusionProcessDelegate.getAccumulativeDrift(tree.getExternalNode(i), dArray, continuousDiffusionIntegrator, n);
        }
        return dArray2;
    }

    public static double[][] getGraphDrift(Tree tree, ContinuousDiffusionIntegrator continuousDiffusionIntegrator, DiffusionProcessDelegate diffusionProcessDelegate) {
        int n = continuousDiffusionIntegrator.getDimTrait();
        double[][] dArray = new double[tree.getNodeCount()][n];
        double[] dArray2 = new double[n];
        for (int i = 0; i < tree.getNodeCount(); ++i) {
            dArray[i] = diffusionProcessDelegate.getAccumulativeDrift(tree.getNode(i), dArray2, continuousDiffusionIntegrator, n);
        }
        return dArray;
    }

    public static Matrix getJointVarianceFactor(double d, double[][] dArray, double[][] dArray2, double[][] dArray3, double[][] dArray4, DiffusionProcessDelegate diffusionProcessDelegate, Matrix matrix) {
        if (!diffusionProcessDelegate.hasActualization()) {
            double[][] dArray5 = diffusionProcessDelegate.getJointVariance(d, dArray, dArray, dArray3);
            Matrix matrix2 = new Matrix(dArray5);
            return matrix2;
        }
        double[][] dArray6 = diffusionProcessDelegate.getJointVariance(d, dArray, dArray2, dArray4);
        Matrix matrix3 = new Matrix(dArray6);
        double[][] dArray7 = KroneckerOperation.makeIdentityMatrixArray(dArray2[0].length);
        Matrix matrix4 = new Matrix(KroneckerOperation.product(dArray7, matrix.toComponents()));
        Matrix matrix5 = null;
        try {
            matrix5 = matrix4.product(matrix3.product(matrix4.transpose()));
        }
        catch (IllegalDimension illegalDimension) {
            illegalDimension.printStackTrace();
        }
        return matrix5;
    }

    public static double getVarianceOffDiagonalSum(Tree tree, BranchRates branchRates, double d) {
        Accumulator.BranchCumulant branchCumulant = Accumulator.OFF_DIAGONAL.postOrderAccumulation(tree, tree.getRoot(), branchRates);
        return branchCumulant.sharedLength * d;
    }

    public static double getVarianceDiagonalSum(Tree tree, BranchRates branchRates, double d) {
        Accumulator.BranchCumulant branchCumulant = Accumulator.DIAGONAL.postOrderAccumulation(tree, tree.getRoot(), branchRates);
        return branchCumulant.sharedLength * d;
    }

    public static double[] getTreeDepths(Tree tree, BranchRates branchRates, double d) {
        int n = tree.getExternalNodeCount();
        double[] dArray = new double[n];
        NodeRef nodeRef = tree.getRoot();
        MultivariateTraitDebugUtilities.recursiveTreeDepth(dArray, nodeRef, tree, branchRates, d);
        return dArray;
    }

    public static PostOrderBranchStats recursiveTreeDepth(double[] dArray, NodeRef nodeRef, Tree tree, BranchRates branchRates, double d) {
        if (tree.isExternal(nodeRef)) {
            return new PostOrderBranchStats(new int[]{nodeRef.getNumber()}, MultivariateTraitDebugUtilities.getScaledBranchLength(nodeRef, tree, branchRates, d));
        }
        PostOrderBranchStats postOrderBranchStats = MultivariateTraitDebugUtilities.recursiveTreeDepth(dArray, tree.getChild(nodeRef, 0), tree, branchRates, d);
        PostOrderBranchStats postOrderBranchStats2 = MultivariateTraitDebugUtilities.recursiveTreeDepth(dArray, tree.getChild(nodeRef, 1), tree, branchRates, d);
        MultivariateTraitDebugUtilities.accumulateBranchLengths(dArray, postOrderBranchStats);
        MultivariateTraitDebugUtilities.accumulateBranchLengths(dArray, postOrderBranchStats2);
        int n = postOrderBranchStats.dims.length;
        int n2 = postOrderBranchStats2.dims.length;
        int n3 = n + n2;
        int[] nArray = new int[n3];
        System.arraycopy(postOrderBranchStats.dims, 0, nArray, 0, n);
        System.arraycopy(postOrderBranchStats2.dims, 0, nArray, n, n2);
        double d2 = MultivariateTraitDebugUtilities.getScaledBranchLength(nodeRef, tree, branchRates, d);
        return new PostOrderBranchStats(nArray, d2);
    }

    private static void accumulateBranchLengths(double[] dArray, PostOrderBranchStats postOrderBranchStats) {
        int[] nArray = postOrderBranchStats.dims;
        int n = nArray.length;
        for (int i = 0; i < n; ++i) {
            int n2;
            int n3 = n2 = nArray[i];
            dArray[n3] = dArray[n3] + postOrderBranchStats.branchLength;
        }
    }

    private static class PostOrderBranchStats {
        private final int[] dims;
        private final double branchLength;

        PostOrderBranchStats(int[] nArray, double d) {
            this.dims = nArray;
            this.branchLength = d;
        }
    }

    private static enum Accumulator {
        OFF_DIAGONAL{

            @Override
            BranchCumulant accumulate(BranchCumulant branchCumulant, double d, BranchCumulant branchCumulant2, double d2) {
                return new BranchCumulant(branchCumulant.nTaxa + branchCumulant2.nTaxa, branchCumulant.sharedLength + branchCumulant2.sharedLength + (double)((branchCumulant.nTaxa - 1) * branchCumulant.nTaxa) * d + (double)((branchCumulant2.nTaxa - 1) * branchCumulant2.nTaxa) * d2);
            }
        }
        ,
        DIAGONAL{

            @Override
            BranchCumulant accumulate(BranchCumulant branchCumulant, double d, BranchCumulant branchCumulant2, double d2) {
                return new BranchCumulant(branchCumulant.nTaxa + branchCumulant2.nTaxa, branchCumulant.sharedLength + branchCumulant2.sharedLength + (double)branchCumulant.nTaxa * d + (double)branchCumulant2.nTaxa * d2);
            }
        };


        abstract BranchCumulant accumulate(BranchCumulant var1, double var2, BranchCumulant var4, double var5);

        private BranchCumulant postOrderAccumulation(Tree tree, NodeRef nodeRef, BranchRates branchRates) {
            if (tree.isExternal(nodeRef)) {
                return new BranchCumulant(1, 0.0);
            }
            NodeRef nodeRef2 = tree.getChild(nodeRef, 0);
            NodeRef nodeRef3 = tree.getChild(nodeRef, 1);
            BranchCumulant branchCumulant = this.postOrderAccumulation(tree, nodeRef2, branchRates);
            BranchCumulant branchCumulant2 = this.postOrderAccumulation(tree, nodeRef3, branchRates);
            double d = tree.getBranchLength(nodeRef2);
            double d2 = tree.getBranchLength(nodeRef3);
            if (branchRates != null) {
                d *= branchRates.getBranchRate(tree, nodeRef2);
                d2 *= branchRates.getBranchRate(tree, nodeRef3);
            }
            return this.accumulate(branchCumulant, d, branchCumulant2, d2);
        }

        private class BranchCumulant {
            final int nTaxa;
            final double sharedLength;

            BranchCumulant(int n, double d) {
                this.nTaxa = n;
                this.sharedLength = d;
            }
        }
    }
}

