/*
 * Decompiled with CFR 0.152.
 */
package weka.classifiers.functions.pace;

import java.util.Random;
import weka.classifiers.functions.pace.MixtureDistribution;
import weka.classifiers.functions.pace.NormalMixture;
import weka.classifiers.functions.pace.PaceMatrix;
import weka.core.RevisionUtils;
import weka.core.matrix.DoubleVector;
import weka.core.matrix.Maths;

public class ChisqMixture
extends MixtureDistribution {
    protected double separatingThreshold = 0.05;
    protected double trimingThreshold = 0.5;
    protected double supportThreshold = 0.5;
    protected int maxNumSupportPoints = 200;
    protected int fittingIntervalLength = 3;
    protected double fittingIntervalThreshold = 0.5;

    public double getSeparatingThreshold() {
        return this.separatingThreshold;
    }

    public void setSeparatingThreshold(double t) {
        this.separatingThreshold = t;
    }

    public double getTrimingThreshold() {
        return this.trimingThreshold;
    }

    public void setTrimingThreshold(double t) {
        this.trimingThreshold = t;
    }

    @Override
    public boolean separable(DoubleVector data, int i0, int i1, double x) {
        DoubleVector dataSqrt = data.sqrt();
        double xh = Math.sqrt(x);
        NormalMixture m = new NormalMixture();
        m.setSeparatingThreshold(this.separatingThreshold);
        return m.separable(dataSqrt, i0, i1, xh);
    }

    @Override
    public DoubleVector supportPoints(DoubleVector data, int ne) {
        DoubleVector sp = new DoubleVector();
        sp.setCapacity(data.size() + 1);
        if (data.get(0) < this.supportThreshold || ne != 0) {
            sp.addElement(0.0);
        }
        int i = 0;
        while (i < data.size()) {
            if (data.get(i) > this.supportThreshold) {
                sp.addElement(data.get(i));
            }
            ++i;
        }
        if (sp.size() > this.maxNumSupportPoints) {
            throw new IllegalArgumentException("Too many support points. ");
        }
        return sp;
    }

    @Override
    public PaceMatrix fittingIntervals(DoubleVector data) {
        double right;
        double left;
        PaceMatrix a = new PaceMatrix(data.size() * 2, 2);
        DoubleVector v = data.sqrt();
        int count = 0;
        int i = 0;
        while (i < data.size()) {
            left = v.get(i) - (double)this.fittingIntervalLength;
            if (left < this.fittingIntervalThreshold) {
                left = 0.0;
            }
            left *= left;
            right = data.get(i);
            if (right < this.fittingIntervalThreshold) {
                right = this.fittingIntervalThreshold;
            }
            a.set(count, 0, left);
            a.set(count, 1, right);
            ++count;
            ++i;
        }
        i = 0;
        while (i < data.size()) {
            left = data.get(i);
            if (left < this.fittingIntervalThreshold) {
                left = 0.0;
            }
            right = v.get(i) + this.fittingIntervalThreshold;
            right *= right;
            a.set(count, 0, left);
            a.set(count, 1, right);
            ++count;
            ++i;
        }
        a.setRowDimension(count);
        return a;
    }

    @Override
    public PaceMatrix probabilityMatrix(DoubleVector s, PaceMatrix intervals) {
        int ns = s.size();
        int nr = intervals.getRowDimension();
        PaceMatrix p = new PaceMatrix(nr, ns);
        int i = 0;
        while (i < nr) {
            int j = 0;
            while (j < ns) {
                p.set(i, j, Maths.pchisq(intervals.get(i, 1), s.get(j)) - Maths.pchisq(intervals.get(i, 0), s.get(j)));
                ++j;
            }
            ++i;
        }
        return p;
    }

    public double pace6(double x) {
        if (x > 100.0) {
            return x;
        }
        DoubleVector points = this.mixingDistribution.getPointValues();
        DoubleVector values = this.mixingDistribution.getFunctionValues();
        DoubleVector mean = points.sqrt();
        DoubleVector d = Maths.dchisqLog(x, points);
        d.minusEquals(d.max());
        d = d.map("java.lang.Math", "exp").timesEquals(values);
        double atilde = mean.innerProduct(d) / d.sum();
        return atilde * atilde;
    }

    public DoubleVector pace6(DoubleVector x) {
        DoubleVector pred = new DoubleVector(x.size());
        int i = 0;
        while (i < x.size()) {
            pred.set(i, this.pace6(x.get(i)));
            ++i;
        }
        this.trim(pred);
        return pred;
    }

    public DoubleVector pace2(DoubleVector x) {
        DoubleVector chf = new DoubleVector(x.size());
        int i = 0;
        while (i < x.size()) {
            chf.set(i, this.hf(x.get(i)));
            ++i;
        }
        chf.cumulateInPlace();
        int index = chf.indexOfMax();
        DoubleVector copy = x.copy();
        if (index < x.size() - 1) {
            copy.set(index + 1, x.size() - 1, 0.0);
        }
        this.trim(copy);
        return copy;
    }

    public DoubleVector pace4(DoubleVector x) {
        DoubleVector h = this.h(x);
        DoubleVector copy = x.copy();
        int i = 0;
        while (i < x.size()) {
            if (h.get(i) <= 0.0) {
                copy.set(i, 0.0);
            }
            ++i;
        }
        this.trim(copy);
        return copy;
    }

    public void trim(DoubleVector x) {
        int i = 0;
        while (i < x.size()) {
            if (x.get(i) <= this.trimingThreshold) {
                x.set(i, 0.0);
            }
            ++i;
        }
    }

    public double hf(double AHat) {
        DoubleVector points = this.mixingDistribution.getPointValues();
        DoubleVector values = this.mixingDistribution.getFunctionValues();
        double x = Math.sqrt(AHat);
        DoubleVector mean = points.sqrt();
        DoubleVector d1 = Maths.dnormLog(x, mean, 1.0);
        double d1max = d1.max();
        d1.minusEquals(d1max);
        DoubleVector d2 = Maths.dnormLog(-x, mean, 1.0);
        d2.minusEquals(d1max);
        d1 = d1.map("java.lang.Math", "exp");
        d1.timesEquals(values);
        d2 = d2.map("java.lang.Math", "exp");
        d2.timesEquals(values);
        return (points.minus(x / 2.0).innerProduct(d1) - points.plus(x / 2.0).innerProduct(d2)) / (d1.sum() + d2.sum());
    }

    public double h(double AHat) {
        if (AHat == 0.0) {
            return 0.0;
        }
        DoubleVector points = this.mixingDistribution.getPointValues();
        DoubleVector values = this.mixingDistribution.getFunctionValues();
        double aHat = Math.sqrt(AHat);
        DoubleVector aStar = points.sqrt();
        DoubleVector d1 = Maths.dnorm(aHat, aStar, 1.0).timesEquals(values);
        DoubleVector d2 = Maths.dnorm(-aHat, aStar, 1.0).timesEquals(values);
        return points.minus(aHat / 2.0).innerProduct(d1) - points.plus(aHat / 2.0).innerProduct(d2);
    }

    public DoubleVector h(DoubleVector AHat) {
        DoubleVector h = new DoubleVector(AHat.size());
        int i = 0;
        while (i < AHat.size()) {
            h.set(i, this.h(AHat.get(i)));
            ++i;
        }
        return h;
    }

    public double f(double x) {
        DoubleVector points = this.mixingDistribution.getPointValues();
        DoubleVector values = this.mixingDistribution.getFunctionValues();
        return Maths.dchisq(x, points).timesEquals(values).sum();
    }

    public DoubleVector f(DoubleVector x) {
        DoubleVector f = new DoubleVector(x.size());
        int i = 0;
        while (i < x.size()) {
            f.set(i, this.h(f.get(i)));
            ++i;
        }
        return f;
    }

    @Override
    public String toString() {
        return this.mixingDistribution.toString();
    }

    @Override
    public String getRevision() {
        return RevisionUtils.extract("$Revision: 1.5 $");
    }

    public static void main(String[] args) {
        int n1 = 50;
        int n2 = 50;
        double ncp1 = 0.0;
        double ncp2 = 10.0;
        double mu1 = Math.sqrt(ncp1);
        double mu2 = Math.sqrt(ncp2);
        DoubleVector a = Maths.rnorm(n1, mu1, 1.0, new Random());
        DoubleVector aNormal = a = a.cat(Maths.rnorm(n2, mu2, 1.0, new Random()));
        a = a.square();
        a.sort();
        DoubleVector means = new DoubleVector(n1, mu1).cat(new DoubleVector(n2, mu2));
        System.out.println("==========================================================");
        System.out.println("This is to test the estimation of the mixing\ndistribution of the mixture of non-central Chi-square\ndistributions. The example mixture used is of the form: \n\n   0.5 * Chi^2_1(ncp1) + 0.5 * Chi^2_1(ncp2)\n");
        System.out.println("It also tests the PACE estimators. Quadratic losses of the\nestimators are given, measuring their performance.");
        System.out.println("==========================================================");
        System.out.println("ncp1 = " + ncp1 + " ncp2 = " + ncp2 + "\n");
        System.out.println(String.valueOf(a.size()) + " observations are: \n\n" + a);
        System.out.println("\nQuadratic loss of the raw data (i.e., the MLE) = " + aNormal.sum2(means));
        System.out.println("==========================================================");
        ChisqMixture d = new ChisqMixture();
        d.fit(a, 1);
        System.out.println("The estimated mixing distribution is\n" + d);
        DoubleVector pred = d.pace2(a.rev()).rev();
        System.out.println("\nThe PACE2 Estimate = \n" + pred);
        System.out.println("Quadratic loss = " + pred.sqrt().times(aNormal.sign()).sum2(means));
        pred = d.pace4(a);
        System.out.println("\nThe PACE4 Estimate = \n" + pred);
        System.out.println("Quadratic loss = " + pred.sqrt().times(aNormal.sign()).sum2(means));
        pred = d.pace6(a);
        System.out.println("\nThe PACE6 Estimate = \n" + pred);
        System.out.println("Quadratic loss = " + pred.sqrt().times(aNormal.sign()).sum2(means));
    }
}

