/*
 * Decompiled with CFR 0.152.
 */
package dr.math.distributions;

import dr.math.GammaFunction;
import dr.math.distributions.MultivariateDistribution;

public class DirichletDistribution
implements MultivariateDistribution {
    public static final String TYPE = "dirichletDistribution";
    public static final boolean DEBUG = false;
    public static final double ACCURACY_THRESHOLD = 1.0E-12;
    private double[] counts;
    private double countSum = 0.0;
    private double countParameterSum;
    private int dim;
    private boolean sumToNumberOfElements;
    private double logNormalizingConstant;

    public DirichletDistribution(double[] dArray, boolean bl) {
        this.counts = dArray;
        this.sumToNumberOfElements = bl;
        this.countParameterSum = this.sumToNumberOfElements ? (double)dArray.length : 1.0;
        this.dim = dArray.length;
        for (int i = 0; i < this.dim; ++i) {
            this.countSum += dArray[i];
        }
        this.computeNormalizingConstant();
    }

    public DirichletDistribution(double[] dArray, double d) {
        this.counts = dArray;
        this.countParameterSum = d;
        this.dim = dArray.length;
        for (int i = 0; i < this.dim; ++i) {
            this.countSum += dArray[i];
        }
        this.computeNormalizingConstant();
    }

    private void computeNormalizingConstant() {
        this.logNormalizingConstant = GammaFunction.lnGamma(this.countSum);
        for (int i = 0; i < this.dim; ++i) {
            this.logNormalizingConstant -= GammaFunction.lnGamma(this.counts[i]);
        }
        this.logNormalizingConstant -= (double)this.dim * Math.log(this.countParameterSum);
    }

    @Override
    public double logPdf(double[] dArray) {
        if (dArray.length != this.dim) {
            throw new IllegalArgumentException("data array is of the wrong dimension");
        }
        double d = this.logNormalizingConstant;
        double d2 = 0.0;
        for (int i = 0; i < this.dim; ++i) {
            d += (this.counts[i] - 1.0) * (Math.log(dArray[i]) - Math.log(this.countParameterSum));
            d2 += dArray[i];
        }
        if (Math.abs(d2 - this.countParameterSum) > 1.0E-12) {
            d = Double.NEGATIVE_INFINITY;
        }
        return d;
    }

    @Override
    public double[][] getScaleMatrix() {
        return null;
    }

    @Override
    public double[] getMean() {
        double[] dArray = new double[this.dim];
        for (int i = 0; i < this.dim; ++i) {
            dArray[i] = this.counts[i] / this.countSum;
        }
        return dArray;
    }

    @Override
    public String getType() {
        return TYPE;
    }

    public static void main(String[] stringArray) {
        System.out.println("Test Dirichlet distribution for the standard n-simplex");
        double[] dArray = new double[]{1.0, 2.0, 3.0};
        DirichletDistribution dirichletDistribution = new DirichletDistribution(dArray, false);
        double[] dArray2 = new double[]{0.5, 0.2, 0.3};
        System.out.println(dirichletDistribution.logPdf(dArray2));
        System.out.println("Test Scaled Dirichlet distribution");
        dirichletDistribution = new DirichletDistribution(dArray, true);
        dArray2[0] = 1.5;
        dArray2[1] = 0.6;
        dArray2[2] = 0.9;
        System.out.println(dirichletDistribution.logPdf(dArray2));
        dArray2[0] = 1.0;
        dArray2[1] = 1.0;
        dArray2[2] = 1.0;
        System.out.println(dirichletDistribution.logPdf(dArray2));
        dArray = new double[]{1.0, 1.0, 1.0, 1.0};
        dirichletDistribution = new DirichletDistribution(dArray, true);
        dArray2 = new double[]{0.5, 1.2, 1.3, 1.0};
        System.out.println(dirichletDistribution.logPdf(dArray2));
        dArray2[0] = 1.0;
        dArray2[1] = 1.0;
        dArray2[2] = 1.0;
        dArray2[3] = 1.0;
        System.out.println(dirichletDistribution.logPdf(dArray2));
    }
}

