/*
 * Decompiled with CFR 0.152.
 */
package dr.evomodel.continuous;

import dr.evolution.tree.MutableTreeModel;
import dr.evolution.tree.NodeRef;
import dr.evomodel.branchratemodel.BranchRateModel;
import dr.evomodel.continuous.IntegratedMultivariateTraitLikelihood;
import dr.evomodel.continuous.MultivariateDiffusionModel;
import dr.evomodel.continuous.RestrictedPartials;
import dr.inference.model.CompoundParameter;
import dr.inference.model.Model;
import dr.math.distributions.MultivariateNormalDistribution;
import dr.math.matrixAlgebra.IllegalDimension;
import dr.math.matrixAlgebra.Matrix;
import java.util.List;

public class SemiConjugateMultivariateTraitLikelihood
extends IntegratedMultivariateTraitLikelihood {
    protected double[] rootPriorMean;
    protected double[][] rootPriorPrecision;
    protected double logRootPriorPrecisionDeterminant;
    protected double[] Bz;
    private double zBz;

    public SemiConjugateMultivariateTraitLikelihood(String string, MutableTreeModel mutableTreeModel, MultivariateDiffusionModel multivariateDiffusionModel, CompoundParameter compoundParameter, List<Integer> list, boolean bl, boolean bl2, boolean bl3, BranchRateModel branchRateModel, Model model, boolean bl4, MultivariateNormalDistribution multivariateNormalDistribution, boolean bl5, List<RestrictedPartials> list2) {
        super(string, mutableTreeModel, multivariateDiffusionModel, compoundParameter, null, list, bl, bl2, bl3, branchRateModel, null, null, null, model, list2, bl4, bl5);
        this.setRootPrior(multivariateNormalDistribution);
    }

    @Override
    public boolean getComputeWishartSufficientStatistics() {
        return false;
    }

    @Override
    protected double calculateAscertainmentCorrection(int n) {
        throw new RuntimeException("Ascertainment correction not yet implemented for semi-conjugate trait likelihoods");
    }

    public double getRescaledLengthToRoot(NodeRef nodeRef) {
        double d = 0.0;
        NodeRef nodeRef2 = this.treeModel.getRoot();
        while (nodeRef != nodeRef2) {
            d += this.getRescaledBranchLengthForPrecision(nodeRef);
            nodeRef = this.treeModel.getParent(nodeRef);
        }
        return d;
    }

    @Override
    protected double integrateLogLikelihoodAtRoot(double[] dArray, double[] dArray2, double[][] dArray3, double[][] dArray4, double d) {
        double d2 = 0.0;
        double d3 = 0.0;
        if (this.dimTrait > 1) {
            for (int i = 0; i < this.dimTrait; ++i) {
                int n = i;
                dArray2[n] = dArray2[n] + this.Bz[i];
                for (int j = 0; j < this.dimTrait; ++j) {
                    dArray3[i][j] = dArray4[i][j] * d + this.rootPriorPrecision[i][j];
                }
            }
            Matrix matrix = new Matrix(dArray3);
            try {
                d2 = matrix.determinant();
            }
            catch (IllegalDimension illegalDimension) {
                illegalDimension.printStackTrace();
            }
            double[][] dArray5 = matrix.inverse().toComponents();
            for (int i = 0; i < this.dimTrait; ++i) {
                for (int j = 0; j < this.dimTrait; ++j) {
                    d3 += dArray2[i] * dArray5[i][j] * dArray2[j];
                }
            }
        } else {
            d2 = dArray4[0][0] * d + this.rootPriorPrecision[0][0];
            dArray2[0] = dArray2[0] + this.Bz[0];
            d3 = dArray2[0] * dArray2[0] / d2;
        }
        double d4 = 0.5 * (this.logRootPriorPrecisionDeterminant - Math.log(d2) - this.zBz + d3);
        if (DEBUG) {
            System.err.println("(Ay+Bz)(A+B)^{-1}(Ay+Bz) = " + d3);
            System.err.println("density = " + d4);
            System.err.println("zBz = " + this.zBz);
        }
        return d4;
    }

    private void setRootPriorSumOfSquares() {
        this.Bz = new double[this.dimTrait];
        this.zBz = SemiConjugateMultivariateTraitLikelihood.computeWeightedAverageAndSumOfSquares(this.rootPriorMean, this.Bz, this.rootPriorPrecision, this.dimTrait, 1.0);
    }

    private void setRootPrior(MultivariateNormalDistribution multivariateNormalDistribution) {
        this.rootPriorMean = multivariateNormalDistribution.getMean();
        this.rootPriorPrecision = multivariateNormalDistribution.getScaleMatrix();
        try {
            this.logRootPriorPrecisionDeterminant = Math.log(new Matrix(this.rootPriorPrecision).determinant());
        }
        catch (IllegalDimension illegalDimension) {
            illegalDimension.printStackTrace();
        }
        this.setRootPriorSumOfSquares();
    }

    @Override
    protected double[][] computeMarginalRootMeanAndVariance(double[] dArray, double[][] dArray2, double[][] dArray3, double d) {
        SemiConjugateMultivariateTraitLikelihood.computeWeightedAverageAndSumOfSquares(dArray, this.Ay, dArray2, this.dimTrait, d);
        double[][] dArray4 = this.tmpM;
        for (int i = 0; i < this.dimTrait; ++i) {
            int n = i;
            this.Ay[n] = this.Ay[n] + this.Bz[i];
            for (int j = 0; j < this.dimTrait; ++j) {
                dArray4[i][j] = dArray2[i][j] * d + this.rootPriorPrecision[i][j];
            }
        }
        Matrix matrix = new Matrix(dArray4);
        double[][] dArray5 = matrix.inverse().toComponents();
        for (int i = 0; i < this.dimTrait; ++i) {
            dArray[i] = 0.0;
            for (int j = 0; j < this.dimTrait; ++j) {
                int n = i;
                dArray[n] = dArray[n] + dArray5[i][j] * this.Ay[j];
            }
        }
        return dArray5;
    }
}

