/*
 * Decompiled with CFR 0.152.
 */
package dr.evomodel.treedatalikelihood.continuous;

import dr.evolution.tree.NodeRef;
import dr.evolution.tree.Tree;
import dr.evolution.tree.TreeTrait;
import dr.evomodel.treedatalikelihood.continuous.ContinuousDataLikelihoodDelegate;
import dr.evomodel.treedatalikelihood.continuous.ContinuousTraitDataModel;
import dr.evomodel.treedatalikelihood.continuous.ContinuousTraitPartialsProvider;
import dr.evomodel.treedatalikelihood.continuous.FullPrecisionContinuousTraitPartialsProvider;
import dr.evomodel.treedatalikelihood.continuous.cdi.PrecisionType;
import dr.evomodel.treedatalikelihood.preorder.ContinuousExtensionDelegate;
import dr.evomodel.treedatalikelihood.preorder.ModelExtensionProvider;
import dr.inference.model.CompoundParameter;
import dr.inference.model.MatrixParameterInterface;
import dr.inference.model.Parameter;
import dr.inference.model.Variable;
import dr.math.matrixAlgebra.Matrix;
import dr.math.matrixAlgebra.WrappedVector;
import dr.math.matrixAlgebra.missingData.InversionResult;
import dr.math.matrixAlgebra.missingData.MissingOps;
import java.util.ArrayList;
import java.util.Arrays;
import org.ejml.data.D1Matrix64F;
import org.ejml.data.DenseMatrix64F;
import org.ejml.ops.CommonOps;

public class RepeatedMeasuresTraitDataModel
extends ContinuousTraitDataModel
implements FullPrecisionContinuousTraitPartialsProvider,
ModelExtensionProvider.NormalExtensionProvider {
    private final String traitName;
    private final MatrixParameterInterface samplingPrecisionParameter;
    private boolean diagonalOnly = false;
    private boolean variableChanged = true;
    private boolean varianceKnown = false;
    private Matrix samplingPrecision;
    private Matrix samplingVariance;
    private Matrix storedSamplingPrecision;
    private Matrix storedSamplingVariance;
    private boolean storedVarianceKnown = false;
    private boolean storedVariableChanged = true;
    private boolean[] missingTraitIndicators = null;
    private ContinuousTraitPartialsProvider childModel;
    private final int nRepeats;
    private ArrayList<Integer>[] relevantRepeats;
    private final int nObservedTips;
    private static final double LOG2PI = Math.log(Math.PI * 2);
    private static final boolean DEBUG = false;

    public RepeatedMeasuresTraitDataModel(String string, ContinuousTraitPartialsProvider continuousTraitPartialsProvider, CompoundParameter compoundParameter, boolean[] blArray, boolean bl, int n, int n2, MatrixParameterInterface matrixParameterInterface, PrecisionType precisionType) {
        super(string, compoundParameter, blArray, bl, n, n2, precisionType);
        if (n2 > 1) {
            throw new RuntimeException("not currently implemented");
        }
        this.childModel = continuousTraitPartialsProvider;
        this.traitName = string;
        this.samplingPrecisionParameter = matrixParameterInterface;
        this.nRepeats = continuousTraitPartialsProvider.getTraitCount() / n2;
        this.addVariable(matrixParameterInterface);
        this.calculatePrecisionInfo();
        this.samplingVariance = null;
        this.samplingPrecisionParameter.addBounds(new Parameter.DefaultBounds(Double.POSITIVE_INFINITY, 0.0, matrixParameterInterface.getDimension()));
        int n3 = precisionType.getPartialsDimension(this.dimTrait);
        int n4 = this.getParameter().getParameterCount();
        int n5 = 0;
        this.relevantRepeats = new ArrayList[n4];
        for (int i = 0; i < n4; ++i) {
            int n6 = precisionType.getPrecisionOffset(this.dimTrait);
            this.relevantRepeats[i] = new ArrayList();
            double[] dArray = continuousTraitPartialsProvider.getTipPartial(i, false);
            for (int j = 0; j < this.nRepeats; ++j) {
                boolean bl2 = false;
                DenseMatrix64F denseMatrix64F = MissingOps.wrap(dArray, n6, this.dimTrait, this.dimTrait);
                for (int k = 0; k < this.dimTrait; ++k) {
                    if (!(denseMatrix64F.get(k, k) > 0.0)) continue;
                    bl2 = true;
                    break;
                }
                if (bl2) {
                    this.relevantRepeats[i].add(j);
                    ++n5;
                }
                n6 += n3;
            }
        }
        this.nObservedTips = n5;
    }

    @Override
    public double[] getTipPartial(int n, boolean bl) {
        assert (this.numTraits == 1);
        assert (this.samplingPrecision.rows() == this.dimTrait && this.samplingPrecision.columns() == this.dimTrait);
        this.recomputeVariance();
        if (bl) {
            throw new RuntimeException("Incompatible with this model.");
        }
        double[] dArray = this.childModel.getTipPartial(n, bl);
        if (this.nRepeats == 1) {
            int n2;
            if (this.precisionType == PrecisionType.SCALAR) {
                return dArray;
            }
            DenseMatrix64F denseMatrix64F = MissingOps.wrap(dArray, this.dimTrait + this.dimTrait * this.dimTrait, this.dimTrait, this.dimTrait);
            if (this.diagonalOnly) {
                for (n2 = 0; n2 < this.dimTrait; ++n2) {
                    denseMatrix64F.set(n2, n2, denseMatrix64F.get(n2, n2) + 1.0 / this.samplingPrecision.component(n2, n2));
                }
            } else {
                for (n2 = 0; n2 < this.dimTrait; ++n2) {
                    for (int i = 0; i < this.dimTrait; ++i) {
                        denseMatrix64F.set(n2, i, denseMatrix64F.get(n2, i) + this.samplingVariance.component(n2, i));
                    }
                }
            }
            DenseMatrix64F denseMatrix64F2 = new DenseMatrix64F(this.dimTrait, this.dimTrait);
            MissingOps.safeInvert2(denseMatrix64F, denseMatrix64F2, false);
            MissingOps.unwrap(denseMatrix64F2, dArray, this.dimTrait);
            MissingOps.unwrap(denseMatrix64F, dArray, this.dimTrait + this.dimTrait * this.dimTrait);
            return dArray;
        }
        int n3 = this.precisionType.getPartialsDimension(this.dimTrait);
        int n4 = this.precisionType.getVarianceOffset(this.dimTrait);
        int n5 = this.precisionType.getMeanOffset(this.dimTrait);
        int n6 = this.precisionType.getVarianceLength(this.dimTrait);
        int n7 = this.precisionType.getRemainderOffset(this.dimTrait);
        DenseMatrix64F denseMatrix64F = new DenseMatrix64F(this.dimTrait, this.dimTrait);
        DenseMatrix64F denseMatrix64F3 = new DenseMatrix64F(this.dimTrait, this.dimTrait);
        DenseMatrix64F denseMatrix64F4 = new DenseMatrix64F(this.dimTrait, this.dimTrait);
        DenseMatrix64F denseMatrix64F5 = new DenseMatrix64F(this.dimTrait, this.dimTrait);
        DenseMatrix64F denseMatrix64F6 = new DenseMatrix64F(this.dimTrait, 1);
        DenseMatrix64F denseMatrix64F7 = new DenseMatrix64F(this.dimTrait, 1);
        double d = 0.0;
        Object object = this.relevantRepeats[n].iterator();
        while (object.hasNext()) {
            int n8 = object.next();
            System.arraycopy(dArray, n3 * n8 + n4, denseMatrix64F3.data, 0, n6);
            for (int i = 0; i < this.dimTrait; ++i) {
                if (!(denseMatrix64F3.get(i, i) < Double.POSITIVE_INFINITY)) continue;
                denseMatrix64F3.set(i, i, denseMatrix64F3.get(i, i) + this.samplingVariance.component(i, i));
                for (int j = 0; j < i; ++j) {
                    if (!(denseMatrix64F3.get(j, j) < Double.POSITIVE_INFINITY)) continue;
                    denseMatrix64F3.set(i, j, denseMatrix64F3.get(i, j) + this.samplingVariance.component(i, j));
                    denseMatrix64F3.set(j, i, denseMatrix64F3.get(i, j));
                }
            }
            InversionResult inversionResult = MissingOps.safeInvert2(denseMatrix64F3, denseMatrix64F, true);
            CommonOps.addEquals(denseMatrix64F4, denseMatrix64F);
            double d2 = 0.0;
            for (int i = 0; i < this.dimTrait; ++i) {
                int n9 = n3 * n8 + n5;
                double d3 = dArray[n9 + i];
                double d4 = 0.0;
                for (int j = 0; j < this.dimTrait; ++j) {
                    double d5 = dArray[n9 + j];
                    double d6 = denseMatrix64F.get(i, j) * d5;
                    d4 += d6;
                    d2 += d6 * d3;
                }
                denseMatrix64F6.add(i, 0, d4);
            }
            d += dArray[n3 * n8 + n7];
            d -= (double)inversionResult.getEffectiveDimension() * LOG2PI + d2 + inversionResult.getLogDeterminant();
        }
        MissingOps.safeSolve(denseMatrix64F4, denseMatrix64F6, denseMatrix64F7, false);
        object = MissingOps.safeInvertPrecision(denseMatrix64F4, denseMatrix64F5, true);
        if (((InversionResult)object).getReturnCode() == InversionResult.Code.NOT_OBSERVED) {
            d = 0.0;
        } else {
            double d7 = 0.0;
            for (int i = 0; i < this.dimTrait; ++i) {
                for (int j = 0; j < this.dimTrait; ++j) {
                    d7 += denseMatrix64F7.get(i, 0) * denseMatrix64F7.get(j, 0) * denseMatrix64F4.get(i, j);
                }
            }
            d += (double)((InversionResult)object).getEffectiveDimension() * LOG2PI + d7 - ((InversionResult)object).getLogDeterminant();
        }
        dArray = new double[n3];
        System.arraycopy(denseMatrix64F7.data, 0, dArray, this.precisionType.getMeanOffset(this.dimTrait), this.dimTrait);
        System.arraycopy(denseMatrix64F4.data, 0, dArray, this.precisionType.getPrecisionOffset(this.dimTrait), n6);
        System.arraycopy(denseMatrix64F5.data, 0, dArray, this.precisionType.getVarianceOffset(this.dimTrait), n6);
        this.precisionType.fillRemainderInPartials(dArray, 0, 0.5 * d, this.dimTrait);
        return dArray;
    }

    @Override
    public boolean[] getTraitMissingIndicators() {
        if (this.getDataMissingIndicators() == null) {
            return null;
        }
        if (this.missingTraitIndicators == null) {
            this.missingTraitIndicators = new boolean[this.getParameter().getDimension()];
            Arrays.fill(this.missingTraitIndicators, true);
        }
        return this.missingTraitIndicators;
    }

    private void recomputeVariance() {
        this.checkVariableChanged();
        if (!this.varianceKnown) {
            this.samplingVariance = this.samplingPrecision.inverse();
            this.varianceKnown = true;
        }
    }

    public Matrix getSamplingVariance() {
        this.recomputeVariance();
        return this.samplingVariance;
    }

    public String getTraitName() {
        return this.traitName;
    }

    @Override
    protected void handleVariableChangedEvent(Variable variable, int n, Variable.ChangeType changeType) {
        super.handleVariableChangedEvent(variable, n, changeType);
        if (variable == this.samplingPrecisionParameter) {
            this.variableChanged = true;
            this.varianceKnown = false;
            this.fireModelChanged();
        }
    }

    private void calculatePrecisionInfo() {
        this.samplingPrecision = new Matrix(this.samplingPrecisionParameter.getParameterAsMatrix());
    }

    private void checkVariableChanged() {
        if (this.variableChanged) {
            this.calculatePrecisionInfo();
            this.variableChanged = false;
            this.varianceKnown = false;
        }
    }

    @Override
    protected void storeState() {
        this.storedSamplingPrecision = this.samplingPrecision.clone();
        this.storedSamplingVariance = this.samplingVariance.clone();
        this.storedVarianceKnown = this.varianceKnown;
        this.storedVariableChanged = this.variableChanged;
    }

    @Override
    protected void restoreState() {
        Matrix matrix = this.samplingPrecision;
        this.samplingPrecision = this.storedSamplingPrecision;
        this.storedSamplingPrecision = matrix;
        matrix = this.samplingVariance;
        this.samplingVariance = this.storedSamplingVariance;
        this.storedSamplingVariance = matrix;
        this.varianceKnown = this.storedVarianceKnown;
        this.variableChanged = this.storedVariableChanged;
    }

    @Override
    public ContinuousExtensionDelegate getExtensionDelegate(ContinuousDataLikelihoodDelegate continuousDataLikelihoodDelegate, TreeTrait treeTrait, Tree tree) {
        this.checkVariableChanged();
        return new ContinuousExtensionDelegate.MultivariateNormalExtensionDelegate(continuousDataLikelihoodDelegate, treeTrait, this, tree);
    }

    @Override
    public boolean diagonalVariance() {
        return false;
    }

    @Override
    public DenseMatrix64F getExtensionVariance() {
        this.recomputeVariance();
        double[] dArray = this.samplingVariance.toArrayComponents();
        return DenseMatrix64F.wrap(this.dimTrait, this.dimTrait, dArray);
    }

    @Override
    public DenseMatrix64F getExtensionVariance(NodeRef nodeRef) {
        return this.getExtensionVariance();
    }

    @Override
    public MatrixParameterInterface getExtensionPrecision() {
        return this.getExtensionPrecisionParameter();
    }

    public void getMeanTipVariances(DenseMatrix64F denseMatrix64F, DenseMatrix64F denseMatrix64F2) {
        CommonOps.scale(1.0, denseMatrix64F, denseMatrix64F2);
    }

    @Override
    public MatrixParameterInterface getExtensionPrecisionParameter() {
        this.checkVariableChanged();
        return this.samplingPrecisionParameter;
    }

    @Override
    public int getDataDimension() {
        return this.dimTrait;
    }

    @Override
    public boolean suppliesWishartStatistics() {
        return false;
    }

    @Override
    public void chainRuleWrtVariance(double[] dArray, NodeRef nodeRef) {
    }

    @Override
    public ContinuousTraitPartialsProvider[] getChildModels() {
        return new ContinuousTraitPartialsProvider[]{this.childModel};
    }

    @Override
    public double[] drawTraitsBelowConditionalOnDataAndTraitsAbove(double[] dArray) {
        if (this.numTraits > 1) {
            throw new RuntimeException("not yet implemented");
        }
        double[] dArray2 = new double[this.nObservedTips * this.dimTrait];
        int n = this.getParameter().getParameterCount();
        DenseMatrix64F denseMatrix64F = DenseMatrix64F.wrap(this.dimTrait, this.dimTrait, this.samplingPrecisionParameter.getParameterValues());
        DenseMatrix64F denseMatrix64F2 = new DenseMatrix64F(this.dimTrait, this.dimTrait);
        DenseMatrix64F denseMatrix64F3 = new DenseMatrix64F(this.dimTrait, this.dimTrait);
        DenseMatrix64F denseMatrix64F4 = new DenseMatrix64F(this.dimTrait, this.dimTrait);
        int[] nArray = new int[this.dimTrait];
        for (int i = 0; i < this.dimTrait; ++i) {
            nArray[i] = i;
        }
        WrappedVector.Raw raw = new WrappedVector.Raw(new double[this.dimTrait]);
        int n2 = this.precisionType.getPrecisionOffset(this.dimTrait);
        int n3 = this.precisionType.getMeanOffset(this.dimTrait);
        int n4 = this.precisionType.getPartialsDimension(this.dimTrait);
        int n5 = this.precisionType.getPrecisionLength(this.dimTrait);
        int n6 = 0;
        int n7 = 0;
        for (int i = 0; i < n; ++i) {
            double[] dArray3 = this.childModel.getTipPartial(i, false);
            WrappedVector.Indexed indexed = new WrappedVector.Indexed(dArray, n6, nArray);
            for (int n8 : this.relevantRepeats[i]) {
                int n9;
                System.arraycopy(dArray3, n8 * n4 + n2, denseMatrix64F4.data, 0, n5);
                WrappedVector.Indexed indexed2 = new WrappedVector.Indexed(dArray3, n8 * n4 + n3, nArray);
                boolean bl = true;
                for (n9 = 0; n9 < this.dimTrait; ++n9) {
                    if (!(denseMatrix64F4.get(n9, n9) < Double.POSITIVE_INFINITY)) continue;
                    bl = false;
                    break;
                }
                if (bl) {
                    for (n9 = 0; n9 < this.dimTrait; ++n9) {
                        dArray2[n7 + n9] = indexed2.get(n9);
                    }
                } else {
                    CommonOps.add((D1Matrix64F)denseMatrix64F4, denseMatrix64F, (D1Matrix64F)denseMatrix64F2);
                    MissingOps.safeInvert2(denseMatrix64F2, denseMatrix64F3, false);
                    MissingOps.safeWeightedAverage(indexed2, denseMatrix64F4, indexed, denseMatrix64F, raw, denseMatrix64F3, this.dimTrait);
                    double[] dArray4 = MissingOps.nextPossiblyDegenerateNormal(raw, denseMatrix64F3);
                    System.arraycopy(dArray4, 0, dArray2, n7, this.dimTrait);
                }
                n7 += this.dimTrait;
            }
            n6 += this.dimTrait;
        }
        return dArray2;
    }

    @Override
    public double[] transformTreeTraits(double[] dArray) {
        double[] dArray2 = new double[this.dimTrait * this.nObservedTips];
        int n = 0;
        int n2 = 0;
        for (ArrayList<Integer> arrayList : this.relevantRepeats) {
            for (int n3 : arrayList) {
                System.arraycopy(dArray, n, dArray2, n2, this.dimTrait);
                n2 += this.dimTrait;
            }
            n += this.dimTrait;
        }
        return dArray2;
    }

    @Override
    public void updateTipDataGradient(DenseMatrix64F denseMatrix64F, DenseMatrix64F denseMatrix64F2, NodeRef nodeRef, int n, int n2) {
        ModelExtensionProvider.NormalExtensionProvider.extendTipDataGradient(this, denseMatrix64F, denseMatrix64F2, nodeRef, n, n2);
    }

    @Override
    public boolean needToUpdateTipDataGradient(int n, int n2) {
        return true;
    }
}

