/*
 * Decompiled with CFR 0.152.
 */
package dr.inference.operators.hmc;

import dr.evomodel.operators.NativeZigZag;
import dr.evomodel.operators.NativeZigZagWrapper;
import dr.evomodel.treedatalikelihood.TreeDataLikelihood;
import dr.evomodel.treedatalikelihood.continuous.ContinuousDataLikelihoodDelegate;
import dr.inference.hmc.GradientWrtParameterProvider;
import dr.inference.hmc.PrecisionColumnProvider;
import dr.inference.hmc.PrecisionMatrixVectorProductProvider;
import dr.inference.model.Parameter;
import dr.inference.operators.GibbsOperator;
import dr.inference.operators.SimpleMCMCOperator;
import dr.math.MathUtils;
import dr.math.matrixAlgebra.ReadableVector;
import dr.math.matrixAlgebra.WrappedVector;
import dr.util.BenchmarkTimer;
import dr.xml.Reportable;
import java.util.Arrays;

public abstract class AbstractParticleOperator
extends SimpleMCMCOperator
implements GibbsOperator,
Reportable {
    private static final boolean CHECK_MATRIX_ILL_CONDITIONED = false;
    protected final GradientWrtParameterProvider gradientProvider;
    private final PrecisionMatrixVectorProductProvider productProvider;
    private final PrecisionColumnProvider columnProvider;
    protected final Parameter parameter;
    private final Options runtimeOptions;
    final Parameter mask;
    private final double[] maskVector;
    int numEvents;
    Preconditioning preconditioning;
    private final boolean[] missingDataMask;
    private final double[] meanVector;
    static final boolean TIMING = true;
    BenchmarkTimer timer = new BenchmarkTimer();
    private static final boolean TEST_NATIVE_OPERATOR = false;
    static final boolean TEST_NATIVE_BOUNCE = false;
    static final boolean TEST_NATIVE_INNER_BOUNCE = false;
    NativeZigZagWrapper nativeZigZag;

    AbstractParticleOperator(GradientWrtParameterProvider gradientWrtParameterProvider, PrecisionMatrixVectorProductProvider precisionMatrixVectorProductProvider, PrecisionColumnProvider precisionColumnProvider, double d, Options options, Parameter parameter) {
        this.gradientProvider = gradientWrtParameterProvider;
        this.productProvider = precisionMatrixVectorProductProvider;
        this.columnProvider = precisionColumnProvider;
        this.parameter = gradientWrtParameterProvider.getParameter();
        this.mask = parameter;
        this.maskVector = parameter != null ? parameter.getParameterValues() : null;
        this.runtimeOptions = options;
        this.preconditioning = this.setupPreconditioning();
        this.meanVector = this.getMeanVector(gradientWrtParameterProvider);
        this.setWeight(d);
        this.missingDataMask = this.getMissingDataMask();
        this.checkParameterBounds(this.parameter);
        long l = NativeZigZag.Flag.PRECISION_DOUBLE.getMask() | NativeZigZag.Flag.FRAMEWORK_TBB.getMask();
        long l2 = MathUtils.nextLong();
        int n = 4;
    }

    private boolean[] getMissingDataMask() {
        int n = this.parameter.getDimension();
        boolean[] blArray = new boolean[n];
        assert (n == this.parameter.getBounds().getBoundsDimension());
        for (int i = 0; i < n; ++i) {
            blArray[i] = this.parameter.getBounds().getUpperLimit(i) == Double.POSITIVE_INFINITY && this.parameter.getBounds().getLowerLimit(i) == Double.NEGATIVE_INFINITY;
        }
        return blArray;
    }

    private double[] getObservedDataMask() {
        int n = this.parameter.getDimension();
        double[] dArray = new double[n];
        assert (n == this.parameter.getBounds().getBoundsDimension());
        for (int i = 0; i < n; ++i) {
            dArray[i] = this.parameter.getBounds().getUpperLimit(i) == Double.POSITIVE_INFINITY && this.parameter.getBounds().getLowerLimit(i) == Double.NEGATIVE_INFINITY ? 0.0 : 1.0;
        }
        return dArray;
    }

    @Override
    public double doOperation() {
        if (this.shouldUpdatePreconditioning()) {
            this.preconditioning = this.setupPreconditioning();
        }
        WrappedVector wrappedVector = this.getInitialPosition();
        WrappedVector wrappedVector2 = this.drawInitialMomentum();
        double d = this.integrateTrajectory(wrappedVector, wrappedVector2);
        ReadableVector.Utils.setParameter((ReadableVector)wrappedVector, this.parameter);
        if (false & this.getCount() % 100L == 0L) {
            this.productProvider.getTimeScaleEigen();
        }
        return d;
    }

    abstract double integrateTrajectory(WrappedVector var1, WrappedVector var2);

    WrappedVector drawInitialMomentum() {
        return new WrappedVector.Raw(null, 0, 0);
    }

    double drawTotalTravelTime() {
        double d = 1.0 + this.runtimeOptions.randomTimeWidth * (MathUtils.nextDouble() - 0.5);
        return this.preconditioning.totalTravelTime * d;
    }

    static void updateGradient(WrappedVector wrappedVector, double d, WrappedVector wrappedVector2) {
        double[] dArray = wrappedVector.getBuffer();
        double[] dArray2 = wrappedVector2.getBuffer();
        int n = dArray.length;
        for (int i = 0; i < n; ++i) {
            int n2 = i;
            dArray[n2] = dArray[n2] - d * dArray2[i];
        }
    }

    static void updatePosition(WrappedVector wrappedVector, WrappedVector wrappedVector2, double d) {
        double[] dArray = wrappedVector.getBuffer();
        double[] dArray2 = wrappedVector2.getBuffer();
        int n = dArray.length;
        for (int i = 0; i < n; ++i) {
            int n2 = i;
            dArray[n2] = dArray[n2] + d * dArray2[i];
        }
    }

    static void updatePosition(double[] dArray, double[] dArray2, double d) {
        int n = dArray.length;
        for (int i = 0; i < n; ++i) {
            int n2 = i;
            dArray[n2] = dArray[n2] + d * dArray2[i];
        }
    }

    static void updateMomentum(double[] dArray, double[] dArray2, double[] dArray3, double d) {
        double d2 = d * d / 2.0;
        int n = dArray3.length;
        for (int i = 0; i < n; ++i) {
            dArray3[i] = dArray3[i] + d * dArray2[i] - d2 * dArray[i];
        }
    }

    WrappedVector getInitialGradient() {
        double[] dArray = this.gradientProvider.getGradientLogDensity();
        if (this.mask != null) {
            this.applyMask(dArray);
        }
        return new WrappedVector.Raw(dArray);
    }

    void applyMask(WrappedVector wrappedVector) {
        this.applyMask(wrappedVector.getBuffer());
    }

    void applyMask(double[] dArray) {
        this.timer.startTimer("applyMask");
        assert (dArray.length == this.mask.getDimension());
        int n = dArray.length;
        for (int i = 0; i < n; ++i) {
            int n2 = i;
            dArray[n2] = dArray[n2] * this.maskVector[i];
        }
        this.timer.stopTimer("applyMask");
    }

    WrappedVector getPrecisionProduct(ReadableVector readableVector) {
        WrappedVector.Raw raw = new WrappedVector.Raw(new double[readableVector.getDim()]);
        for (int i = 0; i < raw.getDim(); ++i) {
            raw.set(i, readableVector.get(i) + this.meanVector[i]);
        }
        ReadableVector.Utils.setParameter((ReadableVector)raw, this.parameter);
        double[] dArray = this.productProvider.getProduct(this.parameter);
        if (this.mask != null) {
            this.applyMask(dArray);
        }
        return new WrappedVector.Raw(dArray);
    }

    WrappedVector getPrecisionColumn(int n) {
        this.timer.startTimer("getColumn");
        double[] dArray = this.columnProvider.getColumn(n);
        this.timer.stopTimer("getColumn");
        if (this.mask != null) {
            this.applyMask(dArray);
        }
        return new WrappedVector.Raw(dArray);
    }

    void updateAction(WrappedVector wrappedVector, ReadableVector readableVector, int n) {
        WrappedVector wrappedVector2 = this.getPrecisionColumn(n);
        this.timer.startTimer("updateAction");
        double[] dArray = wrappedVector.getBuffer();
        double[] dArray2 = wrappedVector2.getBuffer();
        double d = 2.0 * readableVector.get(n);
        int n2 = dArray.length;
        for (int i = 0; i < n2; ++i) {
            int n3 = i;
            dArray[n3] = dArray[n3] + d * dArray2[i];
        }
        this.timer.stopTimer("updateAction");
        if (this.mask != null) {
            this.applyMask(dArray);
        }
    }

    boolean headingTowardsBoundary(double d, double d2, int n) {
        if (this.missingDataMask[n]) {
            return false;
        }
        return d * d2 < 0.0;
    }

    private WrappedVector getInitialPosition() {
        return new WrappedVector.Raw(this.parameter.getParameterValues());
    }

    private void checkParameterBounds(Parameter parameter) {
        int n = parameter.getDimension();
        for (int i = 0; i < n; ++i) {
            double d = parameter.getParameterValue(i);
            if (!(d < parameter.getBounds().getLowerLimit(i)) && !(d > parameter.getBounds().getUpperLimit(i))) continue;
            throw new IllegalArgumentException("Parameter '" + parameter.getId() + "' is out-of-bounds");
        }
    }

    private Preconditioning setupPreconditioning() {
        double[] dArray = new double[this.parameter.getDimension()];
        Arrays.fill(dArray, 1.0);
        this.productProvider.getMassVector();
        double d = this.productProvider.getTimeScale();
        return new Preconditioning(new WrappedVector.Raw(dArray), d);
    }

    private boolean shouldUpdatePreconditioning() {
        return this.runtimeOptions.preconditioningUpdateFrequency > 0 && this.getCount() % (long)this.runtimeOptions.preconditioningUpdateFrequency == 0L;
    }

    void initializeNumEvent() {
        this.numEvents = 0;
    }

    void recordOneMoreEvent() {
        ++this.numEvents;
    }

    double[] getMeanVector(GradientWrtParameterProvider gradientWrtParameterProvider) {
        double[] dArray = new double[this.parameter.getDimension()];
        if (gradientWrtParameterProvider.getLikelihood() instanceof TreeDataLikelihood) {
            TreeDataLikelihood treeDataLikelihood = (TreeDataLikelihood)gradientWrtParameterProvider.getLikelihood();
            ContinuousDataLikelihoodDelegate continuousDataLikelihoodDelegate = (ContinuousDataLikelihoodDelegate)treeDataLikelihood.getDataLikelihoodDelegate();
            double[] dArray2 = continuousDataLikelihoodDelegate.getRootPrior().getMean();
            int n = continuousDataLikelihoodDelegate.getTraitDim();
            int n2 = this.parameter.getDimension() / n;
            int n3 = 0;
            for (int i = 0; i < n2; ++i) {
                for (int j = 0; j < n; ++j) {
                    dArray[n3 + j] = dArray2[j];
                }
                n3 += n;
            }
        }
        return dArray;
    }

    @Override
    public String getReport() {
        return this.timer.toString();
    }

    static enum Type {
        NONE,
        BOUNDARY,
        GRADIENT,
        REFRESHMENT;


        public static Type castFromInt(int n) {
            if (n == 0) {
                return NONE;
            }
            if (n == 1) {
                return BOUNDARY;
            }
            if (n == 2) {
                return GRADIENT;
            }
            throw new RuntimeException("Unknown type");
        }
    }

    class BounceState {
        final Type type;
        final int index;
        final double remainingTime;

        BounceState(Type type, int n, double d) {
            this.type = type;
            this.index = n;
            this.remainingTime = d;
        }

        BounceState(double d) {
            this.type = Type.NONE;
            this.index = -1;
            this.remainingTime = d;
        }

        boolean isTimeRemaining() {
            return this.remainingTime > 0.0;
        }

        public String toString() {
            return "remainingTime : " + this.remainingTime + " lastBounceType: " + (Object)((Object)this.type) + " in dim: " + this.index;
        }
    }

    protected class Preconditioning {
        final WrappedVector mass;
        double totalTravelTime;

        private Preconditioning(WrappedVector wrappedVector, double d) {
            this.mass = wrappedVector;
            this.totalTravelTime = d;
        }
    }

    public static class Options {
        final double randomTimeWidth;
        final int preconditioningUpdateFrequency;

        public Options(double d, int n) {
            this.randomTimeWidth = d;
            this.preconditioningUpdateFrequency = n;
        }
    }
}

