/*
 * Decompiled with CFR 0.152.
 */
package dr.evomodel.continuous;

import dr.evolution.tree.MutableTreeModel;
import dr.evolution.tree.NodeRef;
import dr.evolution.tree.TreeUtils;
import dr.evomodel.continuous.FullyConjugateMultivariateTraitLikelihood;
import dr.math.KroneckerOperation;
import dr.math.matrixAlgebra.SymmetricMatrix;
import java.util.HashSet;

public class MultivariateTraitUtils {
    private static final boolean DEBUG = false;

    public static NodeRef findMRCA(FullyConjugateMultivariateTraitLikelihood fullyConjugateMultivariateTraitLikelihood, int n, int n2) {
        MutableTreeModel mutableTreeModel = fullyConjugateMultivariateTraitLikelihood.getTreeModel();
        HashSet<String> hashSet = new HashSet<String>();
        hashSet.add(mutableTreeModel.getTaxonId(n));
        hashSet.add(mutableTreeModel.getTaxonId(n2));
        return TreeUtils.getCommonAncestorNode(mutableTreeModel, hashSet);
    }

    public static double[][] computeTreePrecision(FullyConjugateMultivariateTraitLikelihood fullyConjugateMultivariateTraitLikelihood, boolean bl) {
        if (fullyConjugateMultivariateTraitLikelihood.strengthOfSelection != null) {
            return new SymmetricMatrix(MultivariateTraitUtils.computeTreeVarianceOU(fullyConjugateMultivariateTraitLikelihood, bl)).inverse().toComponents();
        }
        return new SymmetricMatrix(MultivariateTraitUtils.computeTreeVariance(fullyConjugateMultivariateTraitLikelihood, bl)).inverse().toComponents();
    }

    public static double[][] computeTreeTraitPrecision(FullyConjugateMultivariateTraitLikelihood fullyConjugateMultivariateTraitLikelihood, boolean bl) {
        double[][] dArray = MultivariateTraitUtils.computeTreePrecision(fullyConjugateMultivariateTraitLikelihood, bl);
        double[][] dArray2 = fullyConjugateMultivariateTraitLikelihood.getDiffusionModel().getPrecisionmatrix();
        return MultivariateTraitUtils.productKronecker(dArray, dArray2);
    }

    private static double[][] productKronecker(double[][] dArray, double[][] dArray2) {
        if (dArray2.length > 1) {
            dArray = KroneckerOperation.product(dArray, dArray2);
        } else {
            double d = dArray2[0][0];
            for (int i = 0; i < dArray.length; ++i) {
                int n = 0;
                while (n < dArray[i].length) {
                    double[] dArray3 = dArray[i];
                    int n2 = n++;
                    dArray3[n2] = dArray3[n2] * d;
                }
            }
        }
        return dArray;
    }

    private static double[][] productMatrices(double[][] dArray, double[][] dArray2) {
        double[][] dArray3 = new double[dArray.length][dArray2[0].length];
        for (int i = 0; i < dArray.length; ++i) {
            for (int j = 0; j < dArray2[0].length; ++j) {
                for (int k = 0; k < dArray[0].length; ++k) {
                    dArray3[i][j] = dArray3[i][j] + dArray[i][k] * dArray2[k][j];
                }
            }
        }
        return dArray3;
    }

    private static double[][] transposeMatrix(double[][] dArray) {
        double[][] dArray2 = new double[dArray[0].length][dArray.length];
        for (int i = 0; i < dArray.length; ++i) {
            for (int j = 0; j < dArray[0].length; ++j) {
                dArray2[j][i] = dArray[i][j];
            }
        }
        return dArray2;
    }

    private static double[][] computeLinCombMatrix(FullyConjugateMultivariateTraitLikelihood fullyConjugateMultivariateTraitLikelihood) {
        MutableTreeModel mutableTreeModel = fullyConjugateMultivariateTraitLikelihood.getTreeModel();
        int n = mutableTreeModel.getExternalNodeCount();
        int n2 = 2 * n - 2;
        double[][] dArray = new double[n][n2];
        for (int i = 0; i < n; ++i) {
            NodeRef nodeRef = mutableTreeModel.getExternalNode(i);
            double d = 1.0;
            int n3 = i;
            for (int j = 0; j < n2; ++j) {
                if (j == n3) {
                    dArray[i][j] = d;
                    d *= Math.exp(-fullyConjugateMultivariateTraitLikelihood.getTimeScaledSelection(nodeRef));
                    nodeRef = mutableTreeModel.getParent(nodeRef);
                    n3 = nodeRef.getNumber();
                    continue;
                }
                dArray[i][j] = 0.0;
            }
        }
        return dArray;
    }

    private static double[] computeRootMultipliers(FullyConjugateMultivariateTraitLikelihood fullyConjugateMultivariateTraitLikelihood) {
        MutableTreeModel mutableTreeModel = fullyConjugateMultivariateTraitLikelihood.getTreeModel();
        int n = mutableTreeModel.getExternalNodeCount();
        double[] dArray = new double[n];
        for (int i = 0; i < n; ++i) {
            NodeRef nodeRef = mutableTreeModel.getExternalNode(i);
            dArray[i] = Math.exp(-fullyConjugateMultivariateTraitLikelihood.getTimeScaledSelection(nodeRef));
            nodeRef = mutableTreeModel.getParent(nodeRef);
            while (!mutableTreeModel.isRoot(nodeRef)) {
                dArray[i] = dArray[i] * Math.exp(-fullyConjugateMultivariateTraitLikelihood.getTimeScaledSelection(nodeRef));
                nodeRef = mutableTreeModel.getParent(nodeRef);
            }
        }
        return dArray;
    }

    private static double[] getShiftContributionToMean(NodeRef nodeRef, FullyConjugateMultivariateTraitLikelihood fullyConjugateMultivariateTraitLikelihood) {
        MutableTreeModel mutableTreeModel = fullyConjugateMultivariateTraitLikelihood.getTreeModel();
        double[] dArray = new double[fullyConjugateMultivariateTraitLikelihood.dimTrait];
        if (!mutableTreeModel.isRoot(nodeRef)) {
            NodeRef nodeRef2 = mutableTreeModel.getParent(nodeRef);
            double[] dArray2 = MultivariateTraitUtils.getShiftContributionToMean(nodeRef2, fullyConjugateMultivariateTraitLikelihood);
            for (int i = 0; i < dArray.length; ++i) {
                dArray[i] = fullyConjugateMultivariateTraitLikelihood.getShiftForBranchLength(nodeRef)[i] + dArray2[i];
            }
        }
        return dArray;
    }

    public static double[] computeTreeTraitMean(FullyConjugateMultivariateTraitLikelihood fullyConjugateMultivariateTraitLikelihood, double[] dArray, boolean bl) {
        double[] dArray2 = fullyConjugateMultivariateTraitLikelihood.getPriorMean();
        if (bl) {
            System.err.println("WARNING: Not yet fully implemented (conditioning on root in simulator)");
            dArray2 = dArray;
        }
        int n = fullyConjugateMultivariateTraitLikelihood.getTreeModel().getExternalNodeCount();
        double[] dArray3 = new double[dArray2.length * n];
        for (int i = 0; i < n; ++i) {
            System.arraycopy(dArray2, 0, dArray3, i * dArray2.length, dArray2.length);
        }
        if (fullyConjugateMultivariateTraitLikelihood.driftModels != null) {
            MutableTreeModel mutableTreeModel = fullyConjugateMultivariateTraitLikelihood.getTreeModel();
            for (int i = 0; i < n; ++i) {
                double[] dArray4 = MultivariateTraitUtils.getShiftContributionToMean(mutableTreeModel.getExternalNode(i), fullyConjugateMultivariateTraitLikelihood);
                for (int j = 0; j < fullyConjugateMultivariateTraitLikelihood.dimTrait; ++j) {
                    dArray3[i * fullyConjugateMultivariateTraitLikelihood.dimTrait + j] = dArray3[i * fullyConjugateMultivariateTraitLikelihood.dimTrait + j] + dArray4[j];
                }
            }
        }
        return dArray3;
    }

    public static double[] computeTreeTraitMeanOU(FullyConjugateMultivariateTraitLikelihood fullyConjugateMultivariateTraitLikelihood, double[] dArray, boolean bl) {
        int n;
        int n2;
        double[] dArray2 = fullyConjugateMultivariateTraitLikelihood.getPriorMean();
        MutableTreeModel mutableTreeModel = fullyConjugateMultivariateTraitLikelihood.getTreeModel();
        double[][] dArray3 = MultivariateTraitUtils.computeLinCombMatrix(fullyConjugateMultivariateTraitLikelihood);
        double[] dArray4 = MultivariateTraitUtils.computeRootMultipliers(fullyConjugateMultivariateTraitLikelihood);
        if (bl) {
            dArray2 = dArray;
        }
        int n3 = mutableTreeModel.getExternalNodeCount();
        int n4 = 2 * n3 - 2;
        int n5 = fullyConjugateMultivariateTraitLikelihood.dimTrait;
        double[] dArray5 = new double[dArray2.length * n3];
        double[] dArray6 = new double[n4 * n5];
        double[] dArray7 = new double[n3 * n5];
        for (n2 = 0; n2 < n4; ++n2) {
            NodeRef nodeRef = mutableTreeModel.getNode(n2);
            for (n = 0; n < n5; ++n) {
                dArray6[n2 * n5 + n] = (1.0 - Math.exp(-fullyConjugateMultivariateTraitLikelihood.getTimeScaledSelection(nodeRef))) * fullyConjugateMultivariateTraitLikelihood.getOptimalValue(nodeRef)[n];
            }
        }
        for (n2 = 0; n2 < n3; ++n2) {
            for (n = 0; n < n4; ++n) {
                for (int i = 0; i < n5; ++i) {
                    dArray7[n2 * n5 + i] = dArray7[n2 * n5 + i] + dArray3[n2][n] * dArray6[n * n5 + i];
                }
            }
        }
        for (n2 = 0; n2 < n3; ++n2) {
            System.arraycopy(dArray2, 0, dArray5, n2 * dArray2.length, dArray2.length);
            for (n = 0; n < n5; ++n) {
                dArray5[n2 * n5 + n] = dArray5[n2 * n5 + n] * dArray4[n2] + dArray7[n2 * n5 + n];
            }
        }
        return dArray5;
    }

    public static double[][] computeTreeTraitVariance(FullyConjugateMultivariateTraitLikelihood fullyConjugateMultivariateTraitLikelihood, boolean bl) {
        double[][] dArray = MultivariateTraitUtils.computeTreeVariance(fullyConjugateMultivariateTraitLikelihood, bl);
        double[][] dArray2 = new SymmetricMatrix(fullyConjugateMultivariateTraitLikelihood.getDiffusionModel().getPrecisionmatrix()).inverse().toComponents();
        return MultivariateTraitUtils.productKronecker(dArray, dArray2);
    }

    public static double[][] computeTreeVariance(FullyConjugateMultivariateTraitLikelihood fullyConjugateMultivariateTraitLikelihood, boolean bl) {
        int n;
        int n2;
        MutableTreeModel mutableTreeModel = fullyConjugateMultivariateTraitLikelihood.getTreeModel();
        int n3 = mutableTreeModel.getExternalNodeCount();
        double[][] dArray = new double[n3][n3];
        for (n2 = 0; n2 < n3; ++n2) {
            double d;
            dArray[n2][n2] = d = fullyConjugateMultivariateTraitLikelihood.getRescaledLengthToRoot(mutableTreeModel.getExternalNode(n2));
            for (n = n2 + 1; n < n3; ++n) {
                double d2;
                NodeRef nodeRef = MultivariateTraitUtils.findMRCA(fullyConjugateMultivariateTraitLikelihood, n2, n);
                dArray[n2][n] = d2 = fullyConjugateMultivariateTraitLikelihood.getRescaledLengthToRoot(nodeRef);
            }
        }
        for (n2 = 0; n2 < n3; ++n2) {
            for (int i = n2 + 1; i < n3; ++i) {
                dArray[i][n2] = dArray[n2][i];
            }
        }
        if (!bl) {
            double d = fullyConjugateMultivariateTraitLikelihood.getPriorSampleSize();
            for (int i = 0; i < n3; ++i) {
                n = 0;
                while (n < n3) {
                    double[] dArray2 = dArray[i];
                    int n4 = n++;
                    dArray2[n4] = dArray2[n4] + 1.0 / d;
                }
            }
        }
        return dArray;
    }

    public static double[][] computeTreeVarianceOU(FullyConjugateMultivariateTraitLikelihood fullyConjugateMultivariateTraitLikelihood, boolean bl) {
        MutableTreeModel mutableTreeModel = fullyConjugateMultivariateTraitLikelihood.getTreeModel();
        int n = mutableTreeModel.getExternalNodeCount();
        int n2 = 2 * n - 2;
        double[][] dArray = new double[n][n];
        double[][] dArray2 = new double[n][n2];
        double[][] dArray3 = new double[n2][n2];
        dArray2 = MultivariateTraitUtils.computeLinCombMatrix(fullyConjugateMultivariateTraitLikelihood);
        for (int i = 0; i < n2; ++i) {
            dArray3[i][i] = (1.0 - Math.exp(-2.0 * fullyConjugateMultivariateTraitLikelihood.getTimeScaledSelection(mutableTreeModel.getNode(i)))) / (2.0 * fullyConjugateMultivariateTraitLikelihood.strengthOfSelection.getBranchRate(mutableTreeModel, mutableTreeModel.getNode(i)));
        }
        dArray = MultivariateTraitUtils.productMatrices(MultivariateTraitUtils.productMatrices(dArray2, dArray3), MultivariateTraitUtils.transposeMatrix(dArray2));
        return dArray;
    }
}

