/*
 * Decompiled with CFR 0.152.
 */
package dr.evomodel.continuous.hmc;

import dr.evolution.tree.Tree;
import dr.evolution.tree.TreeTrait;
import dr.evomodel.continuous.hmc.LinearOrderTreePrecisionTraitProductProvider;
import dr.evomodel.treedatalikelihood.TreeDataLikelihood;
import dr.evomodel.treedatalikelihood.continuous.ContinuousDataLikelihoodDelegate;
import dr.evomodel.treedatalikelihood.continuous.ContinuousTraitPartialsProvider;
import dr.evomodel.treedatalikelihood.continuous.IntegratedFactorAnalysisLikelihood;
import dr.evomodel.treedatalikelihood.continuous.cdi.PrecisionType;
import dr.evomodel.treedatalikelihood.preorder.WrappedNormalSufficientStatistics;
import dr.evomodel.treedatalikelihood.preorder.WrappedTipFullConditionalDistributionDelegate;
import dr.inference.hmc.GradientWrtParameterProvider;
import dr.inference.model.CompoundLikelihood;
import dr.inference.model.CompoundParameter;
import dr.inference.model.Likelihood;
import dr.inference.model.Parameter;
import dr.inference.model.Variable;
import dr.inference.model.VariableListener;
import dr.math.matrixAlgebra.ReadableMatrix;
import dr.math.matrixAlgebra.ReadableVector;
import dr.math.matrixAlgebra.WrappedMatrix;
import dr.math.matrixAlgebra.WrappedVector;
import dr.math.matrixAlgebra.missingData.MissingOps;
import dr.util.StopWatch;
import dr.util.TaskPool;
import dr.xml.Reportable;
import java.util.ArrayList;
import java.util.List;
import org.ejml.data.DenseMatrix64F;

public class IntegratedLoadingsGradient
implements GradientWrtParameterProvider,
VariableListener,
Reportable {
    private final TreeTrait<List<WrappedNormalSufficientStatistics>> fullConditionalDensity;
    protected final IntegratedFactorAnalysisLikelihood factorAnalysisLikelihood;
    private final ContinuousTraitPartialsProvider partialsProvider;
    protected final int dimTrait;
    protected final int dimFactors;
    protected final int dimPartials;
    private final Tree tree;
    private final Likelihood likelihood;
    protected final double[] data;
    private final boolean[] missing;
    private final ThreadUseProvider threadUseProvider;
    private final RemainderCompProvider remainderCompProvider;
    private final TaskPool taskPool;
    protected StopWatch[] stopWatches;
    protected static final boolean TIMING = false;
    private static final boolean DEBUG = false;

    public IntegratedLoadingsGradient(TreeDataLikelihood treeDataLikelihood, ContinuousDataLikelihoodDelegate continuousDataLikelihoodDelegate, IntegratedFactorAnalysisLikelihood integratedFactorAnalysisLikelihood, ContinuousTraitPartialsProvider continuousTraitPartialsProvider, TaskPool taskPool, ThreadUseProvider threadUseProvider, RemainderCompProvider remainderCompProvider) {
        this.factorAnalysisLikelihood = integratedFactorAnalysisLikelihood;
        this.partialsProvider = continuousTraitPartialsProvider;
        String string = integratedFactorAnalysisLikelihood.getModelName();
        String string2 = WrappedTipFullConditionalDistributionDelegate.getName(string);
        if (treeDataLikelihood.getTreeTrait(string2) == null) {
            continuousDataLikelihoodDelegate.addWrappedFullConditionalDensityTrait(string);
        }
        this.fullConditionalDensity = LinearOrderTreePrecisionTraitProductProvider.castTreeTrait(treeDataLikelihood.getTreeTrait(string2));
        this.tree = treeDataLikelihood.getTree();
        this.dimTrait = integratedFactorAnalysisLikelihood.getDataDimension();
        this.dimFactors = integratedFactorAnalysisLikelihood.getNumberOfFactors();
        this.dimPartials = continuousTraitPartialsProvider.getTraitDimension();
        CompoundParameter compoundParameter = integratedFactorAnalysisLikelihood.getParameter();
        this.data = compoundParameter.getParameterValues();
        compoundParameter.addVariableListener(this);
        this.missing = this.getMissing(integratedFactorAnalysisLikelihood.getMissingDataIndices(), compoundParameter.getDimension());
        ArrayList<Likelihood> arrayList = new ArrayList<Likelihood>();
        arrayList.add(treeDataLikelihood);
        arrayList.add(integratedFactorAnalysisLikelihood);
        this.likelihood = new CompoundLikelihood(arrayList);
        TaskPool taskPool2 = this.taskPool = taskPool != null ? taskPool : new TaskPool(this.tree.getExternalNodeCount(), 1);
        if (this.taskPool.getNumTaxon() != this.tree.getExternalNodeCount()) {
            throw new IllegalArgumentException("Incorrectly specified TaskPool");
        }
        this.threadUseProvider = threadUseProvider;
        this.remainderCompProvider = remainderCompProvider;
    }

    private boolean[] getMissing(List<Integer> list, int n) {
        boolean[] blArray = new boolean[n];
        for (int n2 : list) {
            blArray[n2] = true;
        }
        return blArray;
    }

    @Override
    public Likelihood getLikelihood() {
        return this.likelihood;
    }

    @Override
    public Parameter getParameter() {
        return this.factorAnalysisLikelihood.getLoadings();
    }

    @Override
    public int getDimension() {
        return this.dimFactors * this.dimTrait;
    }

    protected int getGradientDimension() {
        return this.dimFactors * this.dimTrait;
    }

    private ReadableMatrix shiftToSecondMoment(WrappedMatrix wrappedMatrix, ReadableVector readableVector) {
        assert (wrappedMatrix.getMajorDim() == wrappedMatrix.getMinorDim());
        assert (wrappedMatrix.getMajorDim() == readableVector.getDim());
        int n = wrappedMatrix.getMajorDim();
        for (int i = 0; i < n; ++i) {
            for (int j = 0; j < n; ++j) {
                wrappedMatrix.set(i, j, wrappedMatrix.get(i, j) + readableVector.get(i) * readableVector.get(j));
            }
        }
        return wrappedMatrix;
    }

    private static WrappedNormalSufficientStatistics getWeightedAverage(ReadableVector readableVector, ReadableMatrix readableMatrix, ReadableVector readableVector2, ReadableMatrix readableMatrix2) {
        assert (readableVector.getDim() == readableVector2.getDim());
        assert (readableMatrix.getDim() == readableMatrix2.getDim());
        assert (readableVector.getDim() == readableMatrix.getMinorDim());
        assert (readableVector.getDim() == readableMatrix.getMajorDim());
        int n = readableVector.getDim();
        WrappedVector.Raw raw = new WrappedVector.Raw(new double[readableVector.getDim()], 0, n);
        DenseMatrix64F denseMatrix64F = new DenseMatrix64F(n, n);
        DenseMatrix64F denseMatrix64F2 = new DenseMatrix64F(n, n);
        WrappedMatrix.WrappedDenseMatrix wrappedDenseMatrix = new WrappedMatrix.WrappedDenseMatrix(denseMatrix64F);
        WrappedMatrix.WrappedDenseMatrix wrappedDenseMatrix2 = new WrappedMatrix.WrappedDenseMatrix(denseMatrix64F2);
        MissingOps.add(readableMatrix, readableMatrix2, wrappedDenseMatrix);
        MissingOps.safeInvert2(denseMatrix64F, denseMatrix64F2, false);
        MissingOps.safeWeightedAverage(readableVector, MissingOps.copy(readableMatrix), readableVector2, MissingOps.copy(readableMatrix2), raw, denseMatrix64F2, n);
        return new WrappedNormalSufficientStatistics(raw, wrappedDenseMatrix, wrappedDenseMatrix2);
    }

    @Override
    public double[] getGradientLogDensity() {
        double[][] dArray = new double[this.taskPool.getNumThreads()][this.getGradientDimension()];
        WrappedVector.Parameter parameter = new WrappedVector.Parameter(this.factorAnalysisLikelihood.getPrecision());
        ReadableMatrix readableMatrix = ReadableMatrix.Utils.transposeProxy(new WrappedMatrix.MatrixParameter(this.factorAnalysisLikelihood.getLoadings()));
        double[] dArray2 = this.factorAnalysisLikelihood.getPrecision().getParameterValues();
        double[] dArray3 = ReadableMatrix.Utils.toArray(new WrappedMatrix.MatrixParameter(this.factorAnalysisLikelihood.getLoadings()));
        assert (parameter.getDim() == this.dimTrait);
        assert (readableMatrix.getMajorDim() == this.dimFactors);
        assert (readableMatrix.getMinorDim() == this.dimTrait);
        if (this.remainderCompProvider.computeRemainder()) {
            this.likelihood.getLogLikelihood();
        }
        List<WrappedNormalSufficientStatistics> list = this.fullConditionalDensity.getTrait(this.tree, null);
        assert (list.size() == this.tree.getExternalNodeCount());
        if (!this.threadUseProvider.usePool()) {
            int n3 = this.tree.getExternalNodeCount();
            for (int i = 0; i < n3; ++i) {
                this.computeGradientForOneTaxon(0, i, readableMatrix, dArray3, parameter, dArray2, list.get(i), dArray);
            }
        } else {
            this.taskPool.fork((n, n2) -> this.computeGradientForOneTaxon(n2, n, readableMatrix, dArray3, parameter, dArray2, (WrappedNormalSufficientStatistics)list.get(n), dArray));
        }
        return IntegratedLoadingsGradient.join(dArray);
    }

    protected MeanAndMoment getMeanAndMoment(int n, WrappedNormalSufficientStatistics wrappedNormalSufficientStatistics) {
        WrappedNormalSufficientStatistics wrappedNormalSufficientStatistics2 = this.getTipKernel(n);
        WrappedVector wrappedVector = wrappedNormalSufficientStatistics2.getMean();
        WrappedMatrix wrappedMatrix = wrappedNormalSufficientStatistics2.getPrecision();
        WrappedVector wrappedVector2 = wrappedNormalSufficientStatistics.getMean();
        WrappedMatrix wrappedMatrix2 = wrappedNormalSufficientStatistics.getPrecision();
        WrappedMatrix wrappedMatrix3 = wrappedNormalSufficientStatistics.getVariance();
        WrappedNormalSufficientStatistics wrappedNormalSufficientStatistics3 = IntegratedLoadingsGradient.getWeightedAverage(wrappedVector2, wrappedMatrix2, wrappedVector, wrappedMatrix);
        wrappedNormalSufficientStatistics3 = this.partialsProvider.partitionNormalStatistics(wrappedNormalSufficientStatistics3, this.factorAnalysisLikelihood);
        WrappedVector wrappedVector3 = wrappedNormalSufficientStatistics3.getMean();
        WrappedMatrix wrappedMatrix4 = wrappedNormalSufficientStatistics3.getVariance();
        ReadableMatrix readableMatrix = this.shiftToSecondMoment(wrappedMatrix4, wrappedVector3);
        double[] dArray = ReadableMatrix.Utils.toArray(readableMatrix);
        return new MeanAndMoment(wrappedVector3, dArray);
    }

    protected GradientComponents computeGradientComponents(int n, double[] dArray, MeanAndMoment meanAndMoment) {
        ReadableVector readableVector = meanAndMoment.mean;
        double[] dArray2 = meanAndMoment.moment;
        double[] dArray3 = new double[this.dimFactors * this.dimTrait];
        double[] dArray4 = new double[this.dimFactors * this.dimTrait];
        for (int i = 0; i < this.dimFactors; ++i) {
            double d = readableVector.get(i);
            for (int j = 0; j < this.dimTrait; ++j) {
                int n2;
                if (this.missing[n * this.dimTrait + j]) continue;
                double d2 = 0.0;
                for (n2 = 0; n2 < this.dimFactors; ++n2) {
                    d2 += dArray2[i * this.dimFactors + n2] * dArray[j * this.dimFactors + n2];
                }
                int n3 = n2 = i * this.dimTrait + j;
                dArray3[n3] = dArray3[n3] + d * this.data[n * this.dimTrait + j];
                int n4 = n2;
                dArray4[n4] = dArray4[n4] + d2;
            }
        }
        return new GradientComponents(dArray3, dArray4);
    }

    protected void computeLoadingsGradientForOneTaxon(int n, GradientComponents gradientComponents, double[] dArray, double[][] dArray2) {
        double[] dArray3 = gradientComponents.fty;
        double[] dArray4 = gradientComponents.ftfl;
        for (int i = 0; i < this.dimFactors; ++i) {
            for (int j = 0; j < this.dimTrait; ++j) {
                int n2 = i * this.dimTrait + j;
                double[] dArray5 = dArray2[n];
                int n3 = n2;
                dArray5[n3] = dArray5[n3] + (dArray3[n2] - dArray4[n2]) * dArray[j];
            }
        }
    }

    protected void computeGradientForOneTaxon(int n, int n2, ReadableMatrix readableMatrix, double[] dArray, ReadableVector readableVector, double[] dArray2, WrappedNormalSufficientStatistics wrappedNormalSufficientStatistics, double[][] dArray3) {
        MeanAndMoment meanAndMoment = this.getMeanAndMoment(n2, wrappedNormalSufficientStatistics);
        GradientComponents gradientComponents = this.computeGradientComponents(n2, dArray, meanAndMoment);
        this.computeLoadingsGradientForOneTaxon(n, gradientComponents, dArray2, dArray3);
    }

    private static double[] join(double[][] dArray) {
        int n = dArray.length;
        int n2 = dArray[0].length;
        double[] dArray2 = dArray[0];
        for (int i = 1; i < n; ++i) {
            double[] dArray3 = dArray[i];
            for (int j = 0; j < n2; ++j) {
                int n3 = j;
                dArray2[n3] = dArray2[n3] + dArray3[j];
            }
        }
        return dArray2;
    }

    private WrappedNormalSufficientStatistics getTipKernel(int n) {
        double[] dArray = this.partialsProvider.getTipPartial(n, false);
        return new WrappedNormalSufficientStatistics(dArray, 0, this.dimPartials, null, PrecisionType.FULL);
    }

    @Override
    public String getReport() {
        String string = "";
        string = string + GradientWrtParameterProvider.getReportAndCheckForError(this, Double.NEGATIVE_INFINITY, Double.POSITIVE_INFINITY, null);
        return string;
    }

    private String timingInfo() {
        StringBuilder stringBuilder = new StringBuilder("\nTiming in IntegratedLoadingsGradient\n");
        for (StopWatch stopWatch : this.stopWatches) {
            stringBuilder.append("\t").append(stopWatch.toString()).append("\n");
            stopWatch.reset();
        }
        return stringBuilder.toString();
    }

    @Override
    public void variableChangedEvent(Variable variable, int n, Variable.ChangeType changeType) {
        throw new RuntimeException("Trait data is not cached");
    }

    public static enum ThreadUseProvider {
        PARALLEL{

            @Override
            boolean usePool() {
                return true;
            }
        }
        ,
        SERIAL{

            @Override
            boolean usePool() {
                return false;
            }
        };


        abstract boolean usePool();
    }

    public static enum RemainderCompProvider {
        FULL{

            @Override
            boolean computeRemainder() {
                return true;
            }
        }
        ,
        SKIP{

            @Override
            boolean computeRemainder() {
                return false;
            }
        };


        abstract boolean computeRemainder();
    }

    protected class MeanAndMoment {
        public final ReadableVector mean;
        public final double[] moment;

        public MeanAndMoment(ReadableVector readableVector, double[] dArray) {
            this.mean = readableVector;
            this.moment = dArray;
        }
    }

    protected class GradientComponents {
        public final double[] fty;
        public final double[] ftfl;

        public GradientComponents(double[] dArray, double[] dArray2) {
            this.fty = dArray;
            this.ftfl = dArray2;
        }
    }
}

