/*
 * Decompiled with CFR 0.152.
 */
package dr.evomodel.bigfasttree.thorney;

import dr.evolution.tree.MutableTreeModel;
import dr.evolution.tree.NodeRef;
import dr.evomodel.bigfasttree.thorney.MutationList;
import dr.evomodel.bigfasttree.thorney.ThorneyDataLikelihoodDelegate;
import dr.evomodel.branchratemodel.BranchRateModel;
import dr.evomodel.tree.TreeModel;
import dr.evomodel.tree.TreeParameterModel;
import dr.evomodel.treedatalikelihood.TreeDataLikelihood;
import dr.evomodel.treedatalikelihood.discrete.NodeHeightProxyParameter;
import dr.inference.hmc.GradientWrtParameterProvider;
import dr.inference.model.Likelihood;
import dr.inference.model.Parameter;
import dr.xml.Reportable;

public class ThorneyTreeGradient
implements GradientWrtParameterProvider,
Reportable {
    private TreeDataLikelihood likelihood;
    private final NodeHeightProxyParameter nodeHeightProxyParameter;
    private final TreeModel tree;
    private final TreeParameterModel indexHelper;
    private final double[] branchGradient;
    private final ThorneyDataLikelihoodDelegate dataLikelihoodDelegate;
    private final BranchRateModel branchRateModel;
    private final double tolerance = 0.001;

    public ThorneyTreeGradient(TreeDataLikelihood treeDataLikelihood) {
        this.likelihood = treeDataLikelihood;
        this.tree = (TreeModel)treeDataLikelihood.getTree();
        this.dataLikelihoodDelegate = (ThorneyDataLikelihoodDelegate)treeDataLikelihood.getDataLikelihoodDelegate();
        this.nodeHeightProxyParameter = new NodeHeightProxyParameter("ThorneyTreeGradient.NodeHeightProxyParameter", this.tree, true);
        this.branchGradient = new double[this.tree.getNodeCount() - 1];
        this.indexHelper = new TreeParameterModel((MutableTreeModel)this.tree, (Parameter)new Parameter.Default(this.branchGradient), false);
        this.branchRateModel = treeDataLikelihood.getBranchRateModel();
    }

    @Override
    public Likelihood getLikelihood() {
        return this.likelihood;
    }

    @Override
    public Parameter getParameter() {
        return this.nodeHeightProxyParameter;
    }

    @Override
    public int getDimension() {
        return this.nodeHeightProxyParameter.getDimension();
    }

    private void calculateBranchGradient() {
        for (int i = 0; i < this.tree.getNodeCount() - 1; ++i) {
            NodeRef nodeRef = this.tree.getNode(this.indexHelper.getNodeNumberFromParameterIndex(i));
            double d = this.tree.getBranchLength(nodeRef);
            MutationList mutationList = this.dataLikelihoodDelegate.getMutationMap().getMutations(nodeRef);
            double d2 = this.branchRateModel.getBranchRate(this.tree, nodeRef);
            this.branchGradient[i] = this.dataLikelihoodDelegate.getBranchLengthLikelihoodDelegate().getGradientWrtTime(mutationList, d, d2);
        }
    }

    @Override
    public double[] getGradientLogDensity() {
        this.calculateBranchGradient();
        double[] dArray = new double[this.tree.getInternalNodeCount()];
        for (int i = 0; i < this.tree.getInternalNodeCount(); ++i) {
            int n;
            NodeRef nodeRef = this.tree.getNode(i + this.tree.getExternalNodeCount());
            for (n = 0; n < this.tree.getChildCount(nodeRef); ++n) {
                NodeRef nodeRef2 = this.tree.getChild(nodeRef, n);
                int n2 = this.indexHelper.getParameterIndexFromNodeNumber(nodeRef2.getNumber());
                int n3 = i;
                dArray[n3] = dArray[n3] + this.branchGradient[n2];
            }
            if (this.tree.isRoot(nodeRef)) continue;
            n = this.indexHelper.getParameterIndexFromNodeNumber(nodeRef.getNumber());
            int n4 = i;
            dArray[n4] = dArray[n4] - this.branchGradient[n];
        }
        return dArray;
    }

    @Override
    public String getReport() {
        return GradientWrtParameterProvider.getReportAndCheckForError(this, 0.0, Double.POSITIVE_INFINITY, 0.001);
    }
}

