/*
 * Decompiled with CFR 0.152.
 */
package dr.evomodel.substmodel;

import dr.evolution.datatype.Codons;
import dr.evolution.tree.NodeRef;
import dr.evolution.tree.Tree;
import dr.evolution.tree.TreeTrait;
import dr.evolution.tree.TreeTraitProvider;
import dr.evomodel.branchratemodel.BranchRateModel;
import dr.evomodel.siteratemodel.SiteRateModel;
import dr.evomodel.substmodel.CodonLabeling;
import dr.evomodel.substmodel.MarkovJumpsSubstitutionModel;
import dr.evomodel.substmodel.ProductChainSubstitutionModel;
import dr.evomodel.substmodel.StratifiedTraitOutputFormat;
import dr.evomodel.substmodel.SubstitutionModel;
import dr.evomodel.substmodel.UniformizedSubstitutionModel;
import dr.evomodel.tree.TreeModel;
import dr.evomodel.treelikelihood.AncestralStateBeagleTreeLikelihood;
import dr.evomodel.treelikelihood.utilities.TreeTraitLogger;
import dr.inference.loggers.LogColumn;
import dr.inference.loggers.Loggable;
import dr.inference.markovjumps.StateHistory;
import dr.inference.model.AbstractModel;
import dr.inference.model.Model;
import dr.inference.model.Variable;
import dr.math.MathUtils;
import dr.util.Citable;
import dr.util.Citation;
import dr.util.CommonCitations;
import java.util.ArrayList;
import java.util.List;

public class CodonPartitionedRobustCounting
extends AbstractModel
implements TreeTraitProvider,
Loggable,
Citable {
    private static final boolean DEBUG = false;
    public static final String UNCONDITIONED_PREFIX = "u_";
    public static final String SITE_SPECIFIC_PREFIX = "c_";
    public static final String TOTAL_PREFIX = "total_";
    public static final String UNCONDITIONED_TOTAL_PREFIX = "utotal_";
    public static final String BASE_TRAIT_PREFIX = "base_";
    public static final String COMPLETE_HISTORY_PREFIX = "all_";
    public static final String UNCONDITIONED_PER_BRANCH_PREFIX = "b_u_";
    private final AncestralStateBeagleTreeLikelihood[] partition;
    private final MarkovJumpsSubstitutionModel markovJumps;
    private MarkovJumpsSubstitutionModel averagedMarkovJumps = null;
    private final boolean forceUnconditionalAverageRate;
    private final boolean useUniformization;
    private final BranchRateModel branchRateModel;
    private final ProductChainSubstitutionModel productChainModel;
    private ProductChainSubstitutionModel averagedProductChainModel = null;
    private final CodonLabeling codonLabeling;
    private final Tree tree;
    private final String prefix;
    private final StratifiedTraitOutputFormat branchFormat;
    private final StratifiedTraitOutputFormat logFormat;
    private final double[] condMeanMatrix;
    private int numCodons;
    private boolean countsKnown = false;
    private boolean unconditionsKnown = false;
    private boolean unconditionsPerBranchKnown = false;
    private double[] unconditionedCounts;
    private double[][] unconditionedCountsPerBranch;
    private double[][] computedCounts;
    private String[][] completeHistoryPerNode;
    protected TreeTraitProvider.Helper treeTraits = new TreeTraitProvider.Helper();
    protected TreeTraitLogger treeTraitLogger;
    private final boolean includeExternalBranches;
    private final boolean includeInternalBranches;
    private final boolean doUnconditionedPerBranch;
    private static final boolean TRIAL = true;
    private boolean saveCompleteHistory = false;
    private boolean tryNewNeutralModel = false;

    public CodonPartitionedRobustCounting(String string, TreeModel treeModel, AncestralStateBeagleTreeLikelihood[] ancestralStateBeagleTreeLikelihoodArray, Codons codons, CodonLabeling codonLabeling, boolean bl, boolean bl2, boolean bl3, boolean bl4, boolean bl5, boolean bl6, StratifiedTraitOutputFormat stratifiedTraitOutputFormat, StratifiedTraitOutputFormat stratifiedTraitOutputFormat2, String string2) {
        this(string, treeModel, ancestralStateBeagleTreeLikelihoodArray, codons, codonLabeling, bl, bl2, bl3, bl4, bl5, false, bl6, stratifiedTraitOutputFormat, stratifiedTraitOutputFormat2, string2);
    }

    public CodonPartitionedRobustCounting(String string, TreeModel treeModel, AncestralStateBeagleTreeLikelihood[] ancestralStateBeagleTreeLikelihoodArray, Codons codons, CodonLabeling codonLabeling, boolean bl, boolean bl2, boolean bl3, boolean bl4, boolean bl5, boolean bl6, boolean bl7, StratifiedTraitOutputFormat stratifiedTraitOutputFormat, StratifiedTraitOutputFormat stratifiedTraitOutputFormat2, String string2) {
        super(string);
        this.tree = treeModel;
        this.addModel(treeModel);
        if (ancestralStateBeagleTreeLikelihoodArray.length != 3) {
            throw new RuntimeException("CodonPartition models require 3 partitions");
        }
        this.partition = ancestralStateBeagleTreeLikelihoodArray;
        this.codonLabeling = codonLabeling;
        this.branchRateModel = ancestralStateBeagleTreeLikelihoodArray[0].getBranchRateModel();
        this.addModel(this.branchRateModel);
        ArrayList<SubstitutionModel> arrayList = new ArrayList<SubstitutionModel>(3);
        ArrayList<SiteRateModel> arrayList2 = new ArrayList<SiteRateModel>(3);
        this.numCodons = ancestralStateBeagleTreeLikelihoodArray[0].getPatternWeights().length;
        for (int i = 0; i < 3; ++i) {
            arrayList.add(ancestralStateBeagleTreeLikelihoodArray[i].getBranchModel().getRootSubstitutionModel());
            arrayList2.add(ancestralStateBeagleTreeLikelihoodArray[i].getSiteRateModel());
            if (ancestralStateBeagleTreeLikelihoodArray[i].getPatternWeights().length == this.numCodons) continue;
            throw new RuntimeException("All sequence lengths must be equal in CodonPartitionedRobustCounting");
        }
        this.saveCompleteHistory = bl5;
        this.productChainModel = new ProductChainSubstitutionModel("codonLabeling", arrayList, arrayList2, false);
        this.addModel(this.productChainModel);
        this.forceUnconditionalAverageRate = bl6;
        if (bl6) {
            this.averagedProductChainModel = new ProductChainSubstitutionModel("codonLabeling", arrayList, arrayList2, true);
            this.addModel(this.averagedProductChainModel);
        }
        this.useUniformization = bl;
        if (bl) {
            this.markovJumps = new UniformizedSubstitutionModel(this.productChainModel);
            ((UniformizedSubstitutionModel)this.markovJumps).setSaveCompleteHistory(bl5);
            if (bl6) {
                this.averagedMarkovJumps = new UniformizedSubstitutionModel(this.averagedProductChainModel);
                ((UniformizedSubstitutionModel)this.averagedMarkovJumps).setSaveCompleteHistory(bl5);
            }
        } else {
            this.markovJumps = new MarkovJumpsSubstitutionModel(this.productChainModel);
            if (bl6) {
                this.averagedMarkovJumps = new MarkovJumpsSubstitutionModel(this.averagedProductChainModel);
            }
        }
        double[] dArray = CodonLabeling.getRegisterMatrix(codonLabeling, codons, true);
        this.markovJumps.setRegistration(dArray);
        this.condMeanMatrix = new double[4096];
        this.branchFormat = stratifiedTraitOutputFormat;
        this.logFormat = stratifiedTraitOutputFormat2;
        this.computedCounts = new double[treeModel.getNodeCount()][];
        this.includeExternalBranches = bl2;
        this.includeInternalBranches = bl3;
        this.doUnconditionedPerBranch = bl4;
        this.tryNewNeutralModel = bl7;
        this.prefix = string2;
        this.setupTraits();
    }

    public double[] getUnconditionalCountsForBranch(NodeRef nodeRef) {
        if (!this.unconditionsPerBranchKnown) {
            this.computeAllUnconditionalCountsPerBranch();
            this.unconditionsPerBranchKnown = true;
        }
        return this.unconditionedCountsPerBranch[nodeRef.getNumber()];
    }

    public double[] getExpectedCountsForBranch(NodeRef nodeRef) {
        if (!this.countsKnown) {
            this.computeAllExpectedCounts();
        }
        return this.computedCounts[nodeRef.getNumber()];
    }

    private void computeAllExpectedCounts() {
        for (int i = 0; i < this.tree.getNodeCount(); ++i) {
            NodeRef nodeRef = this.tree.getNode(i);
            if (this.tree.isRoot(nodeRef)) continue;
            this.computedCounts[nodeRef.getNumber()] = this.computeExpectedCountsForBranch(nodeRef);
        }
        this.countsKnown = true;
    }

    private double[] computeExpectedCountsForBranch(NodeRef nodeRef) {
        int[] nArray = this.partition[0].getStatesForNode(this.tree, nodeRef);
        int[] nArray2 = this.partition[1].getStatesForNode(this.tree, nodeRef);
        int[] nArray3 = this.partition[2].getStatesForNode(this.tree, nodeRef);
        NodeRef nodeRef2 = this.tree.getParent(nodeRef);
        int[] nArray4 = this.partition[0].getStatesForNode(this.tree, nodeRef2);
        int[] nArray5 = this.partition[1].getStatesForNode(this.tree, nodeRef2);
        int[] nArray6 = this.partition[2].getStatesForNode(this.tree, nodeRef2);
        double d = this.branchRateModel.getBranchRate(this.tree, nodeRef) * this.tree.getBranchLength(nodeRef);
        double[] dArray = new double[this.numCodons];
        if (!this.useUniformization) {
            this.markovJumps.computeCondStatMarkovJumps(d, this.condMeanMatrix);
        } else {
            this.markovJumps.getSubstitutionModel().getTransitionProbabilities(d, this.condMeanMatrix);
        }
        for (int i = 0; i < this.numCodons; ++i) {
            int n = this.getCanonicalState(nArray[i], nArray2[i], nArray3[i]);
            int n2 = this.getCanonicalState(nArray4[i], nArray5[i], nArray6[i]);
            double d2 = !this.useUniformization ? this.condMeanMatrix[n2 * 64 + n] : ((UniformizedSubstitutionModel)this.markovJumps).computeCondStatMarkovJumps(n2, n, d, this.condMeanMatrix[n2 * 64 + n]);
            if (this.useUniformization && this.saveCompleteHistory) {
                UniformizedSubstitutionModel uniformizedSubstitutionModel = (UniformizedSubstitutionModel)this.markovJumps;
                if (this.completeHistoryPerNode == null) {
                    this.completeHistoryPerNode = new String[this.tree.getNodeCount()][this.numCodons];
                }
                StateHistory stateHistory = uniformizedSubstitutionModel.getStateHistory();
                double[] dArray2 = uniformizedSubstitutionModel.getRegistration();
                int n3 = (stateHistory = stateHistory.filterChanges(dArray2)).getNumberOfJumps();
                if (n3 > 0) {
                    String string;
                    double d3 = this.tree.getNodeHeight(this.tree.getParent(nodeRef));
                    double d4 = this.tree.getNodeHeight(nodeRef);
                    stateHistory.rescaleTimesOfEvents(d3, d4);
                    int n4 = stateHistory.getNumberOfJumps();
                    this.completeHistoryPerNode[nodeRef.getNumber()][i] = string = stateHistory.toStringChanges(i + 1, uniformizedSubstitutionModel.dataType, false);
                } else {
                    this.completeHistoryPerNode[nodeRef.getNumber()][i] = null;
                }
            }
            dArray[i] = d2;
        }
        return dArray;
    }

    private void setupTraits() {
        TreeTrait.DefaultBehavior defaultBehavior;
        TreeTrait.DA dA = new TreeTrait.DA(){

            @Override
            public String getTraitName() {
                return CodonPartitionedRobustCounting.BASE_TRAIT_PREFIX + CodonPartitionedRobustCounting.this.codonLabeling.getText();
            }

            @Override
            public TreeTrait.Intent getIntent() {
                return TreeTrait.Intent.BRANCH;
            }

            @Override
            public double[] getTrait(Tree tree, NodeRef nodeRef) {
                return CodonPartitionedRobustCounting.this.getExpectedCountsForBranch(nodeRef);
            }

            @Override
            public boolean getLoggable() {
                return false;
            }
        };
        if (this.saveCompleteHistory) {
            defaultBehavior = new TreeTrait.SA(){

                @Override
                public String getTraitName() {
                    return CodonPartitionedRobustCounting.COMPLETE_HISTORY_PREFIX + CodonPartitionedRobustCounting.this.codonLabeling.getText();
                }

                @Override
                public TreeTrait.Intent getIntent() {
                    return TreeTrait.Intent.BRANCH;
                }

                @Override
                public boolean getFormatAsArray() {
                    return true;
                }

                @Override
                public String[] getTrait(Tree tree, NodeRef nodeRef) {
                    double[] dArray = CodonPartitionedRobustCounting.this.getExpectedCountsForBranch(nodeRef);
                    ArrayList<String> arrayList = new ArrayList<String>();
                    for (int i = 0; i < CodonPartitionedRobustCounting.this.numCodons; ++i) {
                        String string = CodonPartitionedRobustCounting.this.completeHistoryPerNode[nodeRef.getNumber()][i];
                        if (string == null) continue;
                        if (string.contains("},{")) {
                            String[] stringArray;
                            for (String string2 : stringArray = string.split("(?<=\\}),(?=\\{)")) {
                                arrayList.add(string2);
                            }
                            continue;
                        }
                        arrayList.add(string);
                    }
                    String[] stringArray = new String[arrayList.size()];
                    arrayList.toArray(stringArray);
                    return stringArray;
                }

                @Override
                public boolean getLoggable() {
                    return true;
                }
            };
            this.treeTraits.addTrait((TreeTrait)((Object)defaultBehavior));
        }
        defaultBehavior = new TreeTrait.DA(){

            @Override
            public String getTraitName() {
                return CodonPartitionedRobustCounting.UNCONDITIONED_PREFIX + CodonPartitionedRobustCounting.this.codonLabeling.getText();
            }

            @Override
            public TreeTrait.Intent getIntent() {
                return TreeTrait.Intent.WHOLE_TREE;
            }

            @Override
            public double[] getTrait(Tree tree, NodeRef nodeRef) {
                return CodonPartitionedRobustCounting.this.getUnconditionedTraitValues();
            }

            @Override
            public boolean getLoggable() {
                return false;
            }
        };
        TreeTrait.SumOverTreeDA sumOverTreeDA = new TreeTrait.SumOverTreeDA(SITE_SPECIFIC_PREFIX + this.codonLabeling.getText(), dA, this.includeExternalBranches, this.includeInternalBranches){

            @Override
            public boolean getLoggable() {
                return false;
            }
        };
        TreeTrait.SumAcrossArrayD sumAcrossArrayD = new TreeTrait.SumAcrossArrayD(this.codonLabeling.getText(), dA){

            @Override
            public boolean getLoggable() {
                return true;
            }
        };
        String string = this.prefix != null ? this.prefix + TOTAL_PREFIX + this.codonLabeling.getText() : TOTAL_PREFIX + this.codonLabeling.getText();
        TreeTrait.SumOverTreeD sumOverTreeD = new TreeTrait.SumOverTreeD(string, sumAcrossArrayD, this.includeExternalBranches, this.includeInternalBranches){

            @Override
            public boolean getLoggable() {
                return true;
            }
        };
        this.treeTraitLogger = new TreeTraitLogger(this.tree, new TreeTrait[]{sumOverTreeD});
        this.treeTraits.addTrait(dA);
        this.treeTraits.addTrait((TreeTrait)((Object)defaultBehavior));
        this.treeTraits.addTrait(sumAcrossArrayD);
        this.treeTraits.addTrait(sumOverTreeDA);
        this.treeTraits.addTrait(sumOverTreeD);
        if (this.doUnconditionedPerBranch) {
            TreeTrait.DA dA2 = new TreeTrait.DA(){

                @Override
                public String getTraitName() {
                    return CodonPartitionedRobustCounting.UNCONDITIONED_PER_BRANCH_PREFIX + CodonPartitionedRobustCounting.this.codonLabeling.getText();
                }

                @Override
                public TreeTrait.Intent getIntent() {
                    return TreeTrait.Intent.BRANCH;
                }

                @Override
                public double[] getTrait(Tree tree, NodeRef nodeRef) {
                    return CodonPartitionedRobustCounting.this.getUnconditionalCountsForBranch(nodeRef);
                }

                @Override
                public boolean getLoggable() {
                    return false;
                }
            };
            TreeTrait.SumAcrossArrayD sumAcrossArrayD2 = new TreeTrait.SumAcrossArrayD(UNCONDITIONED_PER_BRANCH_PREFIX + this.codonLabeling.getText(), dA2){

                @Override
                public boolean getLoggable() {
                    return true;
                }
            };
            String string2 = this.prefix != null ? this.prefix + UNCONDITIONED_TOTAL_PREFIX + this.codonLabeling.getText() : UNCONDITIONED_TOTAL_PREFIX + this.codonLabeling.getText();
            TreeTrait.SumOverTreeD sumOverTreeD2 = new TreeTrait.SumOverTreeD(string2, sumAcrossArrayD2, this.includeExternalBranches, this.includeInternalBranches){

                @Override
                public boolean getLoggable() {
                    return true;
                }
            };
            this.treeTraitLogger = new TreeTraitLogger(this.tree, new TreeTrait[]{sumOverTreeD, sumOverTreeD2});
            this.treeTraits.addTrait(dA2);
            this.treeTraits.addTrait(sumAcrossArrayD2);
        }
    }

    @Override
    public TreeTrait[] getTreeTraits() {
        return this.treeTraits.getTreeTraits();
    }

    @Override
    public TreeTrait getTreeTrait(String string) {
        return this.treeTraits.getTreeTrait(string);
    }

    private int getCanonicalState(int n, int n2, int n3) {
        return n * 16 + n2 * 4 + n3;
    }

    @Override
    public LogColumn[] getColumns() {
        return this.treeTraitLogger.getColumns();
    }

    public int getDimension() {
        return this.numCodons;
    }

    private void computeAllUnconditionalCountsPerBranch() {
        if (this.unconditionedCountsPerBranch == null) {
            this.unconditionedCountsPerBranch = new double[this.tree.getNodeCount()][this.numCodons];
        }
        double[] dArray = this.getUnconditionalRootDistribution();
        for (int i = 0; i < this.tree.getNodeCount(); ++i) {
            NodeRef nodeRef = this.tree.getNode(i);
            if (this.tree.isRoot(nodeRef)) continue;
            double d = this.getExpectedBranchLength(nodeRef);
            this.fillInUnconditionalTraitValues(d, dArray, this.unconditionedCountsPerBranch[nodeRef.getNumber()]);
        }
    }

    private void computeUnconditionedTraitValues() {
        if (this.unconditionedCounts == null) {
            this.unconditionedCounts = new double[this.numCodons];
        }
        double d = this.getExpectedTreeLength();
        double[] dArray = this.getUnconditionalRootDistribution();
        this.fillInUnconditionalTraitValues(d, dArray, this.unconditionedCounts);
    }

    private double[] getUnconditionalRootDistribution() {
        if (this.forceUnconditionalAverageRate) {
            return this.averagedProductChainModel.getFrequencyModel().getFrequencies();
        }
        return this.productChainModel.getFrequencyModel().getFrequencies();
    }

    private void fillInUnconditionalQMatrix(double[] dArray) {
        if (this.forceUnconditionalAverageRate) {
            this.averagedProductChainModel.getInfinitesimalMatrix(dArray);
        } else {
            this.productChainModel.getInfinitesimalMatrix(dArray);
        }
    }

    private void fillInUnconditionalTraitValues(double d, double[] dArray, double[] dArray2) {
        double[] dArray3 = new double[4096];
        this.fillInUnconditionalQMatrix(dArray3);
        for (int i = 0; i < this.numCodons; ++i) {
            int n = MathUtils.randomChoicePDF(dArray);
            StateHistory stateHistory = StateHistory.simulateUnconditionalOnEndingState(0.0, n, d, dArray3, 64);
            dArray2[i] = this.markovJumps.getProcessForSimulant(stateHistory);
        }
    }

    private double[] getUnconditionedTraitValues() {
        if (!this.unconditionsKnown) {
            this.computeUnconditionedTraitValues();
            this.unconditionsKnown = true;
        }
        return this.unconditionedCounts;
    }

    public Double getUnconditionedTraitValue() {
        double d = this.getExpectedTreeLength();
        double[] dArray = this.getUnconditionalRootDistribution();
        int n = MathUtils.randomChoicePDF(dArray);
        double[] dArray2 = new double[4096];
        this.fillInUnconditionalQMatrix(dArray2);
        StateHistory stateHistory = StateHistory.simulateUnconditionalOnEndingState(0.0, n, d, dArray2, 64);
        return this.markovJumps.getProcessForSimulant(stateHistory);
    }

    private double getExpectedBranchLength(NodeRef nodeRef) {
        return this.branchRateModel.getBranchRate(this.tree, nodeRef) * this.tree.getBranchLength(nodeRef);
    }

    private double getExpectedTreeLength() {
        NodeRef nodeRef;
        int n;
        double d = 0.0;
        if (this.includeExternalBranches) {
            for (n = 0; n < this.tree.getExternalNodeCount(); ++n) {
                nodeRef = this.tree.getExternalNode(n);
                d += this.getExpectedBranchLength(nodeRef);
            }
        }
        if (this.includeInternalBranches) {
            for (n = 0; n < this.tree.getInternalNodeCount(); ++n) {
                nodeRef = this.tree.getInternalNode(n);
                if (this.tree.isRoot(nodeRef)) continue;
                d += this.getExpectedBranchLength(nodeRef);
            }
        }
        return d;
    }

    @Override
    protected void handleModelChangedEvent(Model model, Object object, int n) {
        this.countsKnown = false;
        this.unconditionsKnown = false;
        this.unconditionsPerBranchKnown = false;
    }

    @Override
    protected void handleVariableChangedEvent(Variable variable, int n, Variable.ChangeType changeType) {
        this.countsKnown = false;
        this.unconditionsKnown = false;
    }

    @Override
    protected void storeState() {
    }

    @Override
    protected void restoreState() {
        this.countsKnown = false;
        this.unconditionsKnown = false;
        this.unconditionsPerBranchKnown = false;
    }

    @Override
    protected void acceptState() {
    }

    @Override
    public Citation.Category getCategory() {
        return Citation.Category.COUNTING_PROCESSES;
    }

    @Override
    public String getDescription() {
        StringBuilder stringBuilder = new StringBuilder("Using robust counting (first citation) for labeled distances between sequences to efficiently estimate site-specific dN/dS rate ratios (second citation)");
        if (this.saveCompleteHistory) {
            stringBuilder.append(" and inferring the complete transition history (third citation)");
        }
        return stringBuilder.toString();
    }

    @Override
    public List<Citation> getCitations() {
        ArrayList<Citation> arrayList = new ArrayList<Citation>();
        arrayList.add(CommonCitations.OBRIEN_2009_LEARNING);
        arrayList.add(CommonCitations.LEMEY_2012_RENAISSANCE);
        if (this.saveCompleteHistory) {
            arrayList.add(CommonCitations.BLOOM_2013_STABILITY);
        }
        return arrayList;
    }
}

