/*
 * Decompiled with CFR 0.152.
 */
package dr.inference.operators.hmc;

import dr.evomodel.operators.NativeZigZagOptions;
import dr.evomodel.operators.NativeZigZagWrapper;
import dr.evomodel.treedatalikelihood.TreeDataLikelihood;
import dr.evomodel.treedatalikelihood.continuous.ContinuousDataLikelihoodDelegate;
import dr.inference.hmc.GradientWrtParameterProvider;
import dr.inference.hmc.PrecisionColumnProvider;
import dr.inference.hmc.PrecisionMatrixVectorProductProvider;
import dr.inference.model.Parameter;
import dr.inference.operators.GibbsOperator;
import dr.inference.operators.SimpleMCMCOperator;
import dr.inference.operators.hmc.MassPreconditionScheduler;
import dr.inference.operators.hmc.MassPreconditioner;
import dr.inference.operators.hmc.MassPreconditioningOptions;
import dr.math.MathUtils;
import dr.math.matrixAlgebra.ReadableVector;
import dr.math.matrixAlgebra.WrappedVector;
import dr.util.BenchmarkTimer;
import dr.xml.Reportable;
import java.util.Arrays;

public abstract class AbstractParticleOperator
extends SimpleMCMCOperator
implements GibbsOperator,
Reportable {
    private static final boolean CHECK_MATRIX_ILL_CONDITIONED = false;
    protected final GradientWrtParameterProvider gradientProvider;
    private final PrecisionMatrixVectorProductProvider productProvider;
    private final PrecisionColumnProvider columnProvider;
    protected final Parameter parameter;
    protected final Options runtimeOptions;
    protected boolean refreshVelocity;
    protected final NativeCodeOptions nativeCodeOptions;
    final Parameter mask;
    final double[] parameterSign;
    protected final double[] maskVector;
    int numEvents;
    int numBoundaryEvents;
    int numGradientEvents;
    protected WrappedVector storedVelocity;
    Preconditioning preconditioning;
    protected final MassPreconditioner massPreconditioning;
    protected final MassPreconditionScheduler preconditionScheduler;
    protected final double[] observedDataMask;
    private final double[] meanVector;
    protected final int[] categoryClasses;
    static final boolean TIMING = true;
    BenchmarkTimer timer = new BenchmarkTimer();
    NativeZigZagWrapper nativeZigZag;

    AbstractParticleOperator(GradientWrtParameterProvider gradientWrtParameterProvider, PrecisionMatrixVectorProductProvider precisionMatrixVectorProductProvider, PrecisionColumnProvider precisionColumnProvider, double d, Options options, NativeCodeOptions nativeCodeOptions, boolean bl, Parameter parameter, Parameter parameter2, MassPreconditioner massPreconditioner, MassPreconditionScheduler.Type type) {
        this.gradientProvider = gradientWrtParameterProvider;
        this.productProvider = precisionMatrixVectorProductProvider;
        this.columnProvider = precisionColumnProvider;
        this.parameter = gradientWrtParameterProvider.getParameter();
        this.mask = parameter;
        this.maskVector = parameter != null ? parameter.getParameterValues() : null;
        this.parameterSign = this.setParameterSign(gradientWrtParameterProvider);
        this.runtimeOptions = options;
        this.nativeCodeOptions = nativeCodeOptions;
        this.refreshVelocity = bl;
        this.preconditioning = this.setupPreconditioning();
        this.meanVector = this.getMeanVector(gradientWrtParameterProvider);
        this.massPreconditioning = massPreconditioner;
        this.preconditionScheduler = type.factory(options, this);
        this.setWeight(d);
        this.observedDataMask = this.getObservedDataMask();
        this.categoryClasses = this.getCategoryClasses(parameter2);
        this.checkParameterBounds(this.parameter);
        long l = 128L;
        long l2 = MathUtils.nextLong();
        int n = 4;
        if (nativeCodeOptions.testNativeFindNextBounce || nativeCodeOptions.useNativeFindNextBounce || nativeCodeOptions.useNativeUpdateDynamics) {
            NativeZigZagOptions nativeZigZagOptions = new NativeZigZagOptions(l, l2, n);
            double[] dArray = this.getLowerBoundVector();
            double[] dArray2 = this.getUpperBoundVector();
            this.nativeZigZag = new NativeZigZagWrapper(this.parameter.getDimension(), nativeZigZagOptions, this.maskVector, this.getObservedDataMask(), this.parameterSign, dArray, dArray2);
        }
    }

    private double[] setParameterSign(GradientWrtParameterProvider gradientWrtParameterProvider) {
        double[] dArray = gradientWrtParameterProvider.getParameter().getParameterValues();
        double[] dArray2 = new double[dArray.length];
        for (int i = 0; i < dArray.length; ++i) {
            if (dArray[i] == 0.0 && (this.mask == null || this.mask.getParameterValue(i) == 1.0)) {
                throw new RuntimeException("Must start from either positive or negative value!");
            }
            dArray2[i] = dArray[i] > 0.0 ? 1.0 : -1.0;
        }
        return dArray2;
    }

    private double[] getObservedDataMask() {
        int n = this.parameter.getDimension();
        double[] dArray = new double[n];
        assert (n == this.parameter.getBounds().getBoundsDimension());
        for (int i = 0; i < n; ++i) {
            dArray[i] = this.parameter.getBounds().getUpperLimit(i) == Double.POSITIVE_INFINITY && this.parameter.getBounds().getLowerLimit(i) == Double.NEGATIVE_INFINITY ? 0.0 : 1.0;
        }
        return dArray;
    }

    private double[] getLowerBoundVector() {
        int n = this.parameter.getDimension();
        double[] dArray = new double[n];
        for (int i = 0; i < n; ++i) {
            dArray[i] = this.parameter.getBounds().getLowerLimit(i);
        }
        return dArray;
    }

    private double[] getUpperBoundVector() {
        int n = this.parameter.getDimension();
        double[] dArray = new double[n];
        for (int i = 0; i < n; ++i) {
            dArray[i] = this.parameter.getBounds().getUpperLimit(i);
        }
        return dArray;
    }

    private int[] getCategoryClasses(Parameter parameter) {
        int n = this.parameter.getDimension();
        int[] nArray = new int[n];
        if (parameter != null) {
            int n2;
            int[] nArray2 = new int[parameter.getDimension()];
            for (n2 = 0; n2 < nArray2.length; ++n2) {
                nArray2[n2] = (int)parameter.getParameterValues()[n2];
            }
            n2 = parameter.getDimension();
            int n3 = n / n2;
            for (int i = 0; i < n3; ++i) {
                System.arraycopy(nArray2, 0, nArray, i * n2, n2);
            }
        }
        return nArray;
    }

    @Override
    public double doOperation() {
        WrappedVector wrappedVector = this.getInitialPosition();
        WrappedVector wrappedVector2 = this.drawInitialMomentum();
        if (this.preconditionScheduler.shouldUpdatePreconditioning()) {
            this.updatePreconditioning(wrappedVector);
        }
        double d = this.integrateTrajectory(wrappedVector, wrappedVector2);
        ReadableVector.Utils.setParameter((ReadableVector)wrappedVector, this.parameter);
        if (false & this.getCount() % 100L == 0L) {
            this.productProvider.getTimeScaleEigen();
        }
        return d;
    }

    abstract double integrateTrajectory(WrappedVector var1, WrappedVector var2);

    WrappedVector drawInitialMomentum() {
        return new WrappedVector.Raw(null, 0, 0);
    }

    double drawTotalTravelTime() {
        double d = 1.0 + this.runtimeOptions.randomTimeWidth * (MathUtils.nextDouble() - 0.5);
        return this.preconditioning.totalTravelTime * d;
    }

    static void updateGradient(WrappedVector wrappedVector, double d, WrappedVector wrappedVector2) {
        double[] dArray = wrappedVector.getBuffer();
        double[] dArray2 = wrappedVector2.getBuffer();
        int n = dArray.length;
        for (int i = 0; i < n; ++i) {
            int n2 = i;
            dArray[n2] = dArray[n2] - d * dArray2[i];
        }
    }

    static void updatePosition(WrappedVector wrappedVector, WrappedVector wrappedVector2, double d) {
        double[] dArray = wrappedVector.getBuffer();
        double[] dArray2 = wrappedVector2.getBuffer();
        int n = dArray.length;
        for (int i = 0; i < n; ++i) {
            int n2 = i;
            dArray[n2] = dArray[n2] + d * dArray2[i];
        }
    }

    static void updatePosition(double[] dArray, double[] dArray2, double d) {
        int n = dArray.length;
        for (int i = 0; i < n; ++i) {
            int n2 = i;
            dArray[n2] = dArray[n2] + d * dArray2[i];
        }
    }

    static void updateMomentum(double[] dArray, double[] dArray2, double[] dArray3, double d) {
        double d2 = d * d / 2.0;
        int n = dArray3.length;
        for (int i = 0; i < n; ++i) {
            dArray3[i] = dArray3[i] + d * dArray2[i] - d2 * dArray[i];
        }
    }

    WrappedVector getInitialGradient() {
        double[] dArray = this.gradientProvider.getGradientLogDensity();
        if (this.mask != null) {
            this.applyMask(dArray);
        }
        return new WrappedVector.Raw(dArray);
    }

    void applyMask(WrappedVector wrappedVector) {
        this.applyMask(wrappedVector.getBuffer());
    }

    void applyMask(double[] dArray) {
        this.timer.startTimer("applyMask");
        assert (dArray.length == this.mask.getDimension());
        int n = dArray.length;
        for (int i = 0; i < n; ++i) {
            int n2 = i;
            dArray[n2] = dArray[n2] * this.maskVector[i];
        }
        this.timer.stopTimer("applyMask");
    }

    WrappedVector getPrecisionProduct(ReadableVector readableVector) {
        WrappedVector.Raw raw = new WrappedVector.Raw(new double[readableVector.getDim()]);
        for (int i = 0; i < raw.getDim(); ++i) {
            raw.set(i, readableVector.get(i) + this.meanVector[i]);
        }
        ReadableVector.Utils.setParameter((ReadableVector)raw, this.parameter);
        double[] dArray = this.productProvider.getProduct(this.parameter);
        if (this.mask != null) {
            this.applyMask(dArray);
        }
        return new WrappedVector.Raw(dArray);
    }

    WrappedVector getPrecisionColumn(int n) {
        this.timer.startTimer("getColumn");
        double[] dArray = this.columnProvider.getColumn(n);
        this.timer.stopTimer("getColumn");
        if (this.mask != null) {
            this.applyMask(dArray);
        }
        return new WrappedVector.Raw(dArray);
    }

    void updateAction(WrappedVector wrappedVector, ReadableVector readableVector, int n) {
        WrappedVector wrappedVector2 = this.getPrecisionColumn(n);
        this.timer.startTimer("updateAction");
        double[] dArray = wrappedVector.getBuffer();
        double[] dArray2 = wrappedVector2.getBuffer();
        double d = 2.0 * readableVector.get(n);
        int n2 = dArray.length;
        for (int i = 0; i < n2; ++i) {
            int n3 = i;
            dArray[n3] = dArray[n3] + d * dArray2[i];
        }
        this.timer.stopTimer("updateAction");
        if (this.mask != null) {
            this.applyMask(dArray);
        }
    }

    boolean headingTowardsBinaryBoundary(double d, int n) {
        return this.observedDataMask[n] * this.parameterSign[n] * d < 0.0;
    }

    private WrappedVector getInitialPosition() {
        return new WrappedVector.Raw(this.parameter.getParameterValues());
    }

    private void checkParameterBounds(Parameter parameter) {
        int n = parameter.getDimension();
        for (int i = 0; i < n; ++i) {
            double d = parameter.getParameterValue(i);
            if (!(d < parameter.getBounds().getLowerLimit(i)) && !(d > parameter.getBounds().getUpperLimit(i))) continue;
            throw new IllegalArgumentException("Parameter '" + parameter.getId() + "' is out-of-bounds");
        }
    }

    private Preconditioning setupPreconditioning() {
        double[] dArray = new double[this.parameter.getDimension()];
        Arrays.fill(dArray, 1.0);
        this.productProvider.getMassVector();
        double d = this.productProvider.getTimeScale();
        return new Preconditioning(new WrappedVector.Raw(dArray), d);
    }

    void updatePreconditioning(WrappedVector wrappedVector) {
        this.massPreconditioning.updateVariance(wrappedVector);
        this.massPreconditioning.updateMass();
        this.preconditioning.mass = this.massPreconditioning.getMass();
    }

    void initializeNumEvent() {
        this.numEvents = 0;
        this.numBoundaryEvents = 0;
        this.numGradientEvents = 0;
    }

    void recordOneMoreEvent() {
        ++this.numEvents;
    }

    void recordEvents(Type type) {
        ++this.numEvents;
        if (type == Type.BINARY_BOUNDARY || type == Type.CATE_BOUNDARY) {
            ++this.numBoundaryEvents;
        } else if (type == Type.GRADIENT) {
            ++this.numGradientEvents;
        }
    }

    void storeVelocity(WrappedVector wrappedVector) {
        this.storedVelocity = wrappedVector;
    }

    double[] getMeanVector(GradientWrtParameterProvider gradientWrtParameterProvider) {
        double[] dArray = new double[this.parameter.getDimension()];
        if (gradientWrtParameterProvider.getLikelihood() instanceof TreeDataLikelihood) {
            TreeDataLikelihood treeDataLikelihood = (TreeDataLikelihood)gradientWrtParameterProvider.getLikelihood();
            ContinuousDataLikelihoodDelegate continuousDataLikelihoodDelegate = (ContinuousDataLikelihoodDelegate)treeDataLikelihood.getDataLikelihoodDelegate();
            double[] dArray2 = continuousDataLikelihoodDelegate.getRootPrior().getMean();
            int n = continuousDataLikelihoodDelegate.getTraitDim();
            int n2 = this.parameter.getDimension() / n;
            int n3 = 0;
            for (int i = 0; i < n2; ++i) {
                for (int j = 0; j < n; ++j) {
                    dArray[n3 + j] = dArray2[j];
                }
                n3 += n;
            }
        }
        return dArray;
    }

    @Override
    public String getReport() {
        return this.timer.toString();
    }

    public static class Options
    implements MassPreconditioningOptions {
        final double randomTimeWidth;
        final int preconditioningUpdateFrequency;
        final int preconditioningMaxUpdate;
        final int preconditioningDelay;
        final int updateSampleCovFrequency;
        final int updateSampleCovDelay;

        public Options(double d, int n, int n2, int n3, int n4, int n5) {
            this.randomTimeWidth = d;
            this.preconditioningUpdateFrequency = n;
            this.preconditioningMaxUpdate = n2;
            this.preconditioningDelay = n3;
            this.updateSampleCovFrequency = n4;
            this.updateSampleCovDelay = n5;
        }

        @Override
        public int preconditioningUpdateFrequency() {
            return this.preconditioningUpdateFrequency;
        }

        @Override
        public int preconditioningDelay() {
            return this.preconditioningDelay;
        }

        @Override
        public int preconditioningMaxUpdate() {
            return this.preconditioningMaxUpdate;
        }

        @Override
        public int preconditioningMemory() {
            return 0;
        }

        @Override
        public Parameter preconditioningEigenLowerBound() {
            throw new RuntimeException("Not yet implemented.");
        }

        @Override
        public Parameter preconditioningEigenUpperBound() {
            throw new RuntimeException("Not yet implemented.");
        }
    }

    public static class NativeCodeOptions {
        final boolean testNativeFindNextBounce;
        final boolean useNativeFindNextBounce;
        final boolean useNativeUpdateDynamics;

        public NativeCodeOptions(boolean bl, boolean bl2, boolean bl3) {
            this.testNativeFindNextBounce = bl;
            this.useNativeFindNextBounce = bl2;
            this.useNativeUpdateDynamics = bl3;
        }
    }

    protected class Preconditioning {
        WrappedVector mass;
        double totalTravelTime;

        private Preconditioning(WrappedVector wrappedVector, double d) {
            this.mass = wrappedVector;
            this.totalTravelTime = d;
        }
    }

    static enum Type {
        NONE,
        BINARY_BOUNDARY,
        CATE_BOUNDARY,
        GRADIENT,
        REFRESHMENT;


        public static Type castFromInt(int n) {
            if (n == 0) {
                return NONE;
            }
            if (n == 1) {
                return BINARY_BOUNDARY;
            }
            if (n == 2) {
                return BINARY_BOUNDARY;
            }
            if (n == 3) {
                return GRADIENT;
            }
            throw new RuntimeException("Unknown type");
        }
    }

    class BounceState {
        final Type type;
        final int index;
        final double remainingTime;

        BounceState(Type type, int n, double d) {
            this.type = type;
            this.index = n;
            this.remainingTime = d;
        }

        BounceState(double d) {
            this.type = Type.NONE;
            this.index = -1;
            this.remainingTime = d;
        }

        boolean isTimeRemaining() {
            return this.remainingTime > 0.0;
        }

        public String toString() {
            return "remainingTime : " + this.remainingTime + " lastBounceType: " + (Object)((Object)this.type) + " in dim: " + this.index;
        }
    }
}

