/*
 * Decompiled with CFR 0.152.
 */
package dr.inference.operators.factorAnalysis;

import dr.inference.distribution.DistributionLikelihood;
import dr.inference.distribution.NormalDistributionModel;
import dr.inference.distribution.NormalStatisticsProvider;
import dr.inference.operators.GibbsOperator;
import dr.inference.operators.SimpleMCMCOperator;
import dr.inference.operators.factorAnalysis.FactorAnalysisOperatorAdaptor;
import dr.inference.operators.factorAnalysis.FactorAnalysisStatisticsProvider;
import dr.inference.operators.factorAnalysis.LoadingsSamplerConstraints;
import dr.math.MathUtils;
import dr.math.distributions.MultivariateNormalDistribution;
import dr.math.matrixAlgebra.CholeskyDecomposition;
import dr.math.matrixAlgebra.IllegalDimension;
import dr.math.matrixAlgebra.Matrix;
import dr.math.matrixAlgebra.SymmetricMatrix;
import dr.math.matrixAlgebra.Vector;
import dr.xml.Reportable;
import java.util.ArrayList;
import java.util.Iterator;
import java.util.List;
import java.util.concurrent.Callable;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;

public class NewLoadingsGibbsOperator
extends SimpleMCMCOperator
implements GibbsOperator,
Reportable {
    private final FactorAnalysisStatisticsProvider statisticsProvider;
    private NormalDistributionModel workingPrior;
    private final ArrayList<double[][]> precisionArray;
    private final ArrayList<double[]> meanMidArray;
    private final ArrayList<double[]> meanArray;
    private final boolean randomScan;
    private double pathParameter = 1.0;
    private final NormalStatisticsProvider prior;
    private final double priorPrecisionWorking;
    private final FactorAnalysisOperatorAdaptor adaptor;
    private final ConstrainedSampler constrainedSampler;
    private final LoadingsSamplerConstraints columnDimProvider;
    private static boolean DEBUG = false;
    private final List<Callable<Double>> drawCallers = new ArrayList<Callable<Double>>();
    private final ExecutorService pool;

    public NewLoadingsGibbsOperator(FactorAnalysisStatisticsProvider factorAnalysisStatisticsProvider, NormalStatisticsProvider normalStatisticsProvider, double d, boolean bl, DistributionLikelihood distributionLikelihood, boolean bl2, int n, ConstrainedSampler constrainedSampler, LoadingsSamplerConstraints loadingsSamplerConstraints) {
        this.setWeight(d);
        this.statisticsProvider = factorAnalysisStatisticsProvider;
        this.adaptor = factorAnalysisStatisticsProvider.getAdaptor();
        this.prior = normalStatisticsProvider;
        if (distributionLikelihood != null) {
            this.workingPrior = (NormalDistributionModel)distributionLikelihood.getDistribution();
        }
        this.precisionArray = new ArrayList();
        this.meanMidArray = new ArrayList();
        this.meanArray = new ArrayList();
        this.randomScan = bl;
        this.constrainedSampler = constrainedSampler;
        this.columnDimProvider = loadingsSamplerConstraints;
        this.priorPrecisionWorking = distributionLikelihood == null ? this.getPrecision(normalStatisticsProvider, 0) : 1.0 / (this.workingPrior.getStdev() * this.workingPrior.getStdev());
        if (bl2) {
            for (int i = 0; i < this.adaptor.getNumberOfTraits(); ++i) {
                int n2 = loadingsSamplerConstraints.getColumnDim(i, this.adaptor.getNumberOfFactors());
                this.drawCallers.add(new DrawCaller(i, new double[n2][n2], new double[n2], new double[n2]));
            }
            this.pool = Executors.newFixedThreadPool(n);
        } else {
            this.pool = null;
            loadingsSamplerConstraints.allocateStorage(this.precisionArray, this.meanMidArray, this.meanArray, this.adaptor.getNumberOfFactors());
        }
        if (factorAnalysisStatisticsProvider.useCache() && bl2 && n > 1) {
            throw new IllegalArgumentException("Cannot currently parallelize cached precisions");
        }
    }

    public FactorAnalysisOperatorAdaptor getAdaptor() {
        return this.adaptor;
    }

    private double getPrecision(NormalStatisticsProvider normalStatisticsProvider, int n) {
        double d = normalStatisticsProvider.getNormalSD(n);
        return 1.0 / (d * d);
    }

    private void getPrecisionOfTruncated(FactorAnalysisOperatorAdaptor factorAnalysisOperatorAdaptor, int n, int n2, double[][] dArray) {
        this.statisticsProvider.getFactorInnerProduct(n2, n, dArray);
        for (int i = 0; i < n; ++i) {
            for (int j = i; j < n; ++j) {
                double[] dArray2 = dArray[i];
                int n3 = j;
                dArray2[n3] = dArray2[n3] * this.adaptor.getColumnPrecision(n2);
                if (i == j) {
                    dArray[i][j] = dArray[i][j] * this.pathParameter + this.getAdjustedPriorPrecision(factorAnalysisOperatorAdaptor.getNumberOfTraits() * i + n2);
                    continue;
                }
                double[] dArray3 = dArray[i];
                int n4 = j;
                dArray3[n4] = dArray3[n4] * this.pathParameter;
                dArray[j][i] = dArray[i][j];
            }
        }
    }

    private void getTruncatedMean(int n, int n2, double[][] dArray, double[] dArray2, double[] dArray3) {
        this.statisticsProvider.getFactorTraitProduct(n2, n, dArray2);
        int n3 = 0;
        while (n3 < n) {
            int n4 = this.adaptor.getNumberOfTraits() * n3 + n2;
            int n5 = n3;
            dArray2[n5] = dArray2[n5] * this.adaptor.getColumnPrecision(n2);
            int n6 = n3++;
            dArray2[n6] = dArray2[n6] + this.prior.getNormalMean(n4) * this.getPrecision(this.prior, n4);
        }
        for (n3 = 0; n3 < n; ++n3) {
            double d = 0.0;
            for (int i = 0; i < n; ++i) {
                d += dArray[n3][i] * dArray2[i];
            }
            dArray3[n3] = d;
        }
    }

    private void getPrecision(int n, double[][] dArray) {
        int n2 = this.adaptor.getNumberOfFactors();
        this.getPrecisionOfTruncated(this.adaptor, this.columnDimProvider.getColumnDim(n, n2), n, dArray);
    }

    private void getMean(int n, double[][] dArray, double[] dArray2, double[] dArray3) {
        int n2 = this.adaptor.getNumberOfFactors();
        this.getTruncatedMean(this.columnDimProvider.getColumnDim(n, n2), n, dArray, dArray2, dArray3);
        int n3 = 0;
        while (n3 < dArray3.length) {
            int n4 = n3++;
            dArray3[n4] = dArray3[n4] * this.pathParameter;
        }
    }

    private void drawI(int n, double[][] dArray, double[] dArray2, double[] dArray3) {
        this.getPrecision(n, dArray);
        double[][] dArray4 = new SymmetricMatrix(dArray).inverse().toComponents();
        double[][] dArray5 = null;
        try {
            dArray5 = new CholeskyDecomposition(dArray4).getL();
        }
        catch (IllegalDimension illegalDimension) {
            illegalDimension.printStackTrace();
        }
        this.getMean(n, dArray4, dArray2, dArray3);
        double[] dArray6 = MultivariateNormalDistribution.nextMultivariateNormalCholesky(dArray3, dArray5);
        this.adaptor.setLoadingsForTraitQuietly(n, dArray6);
        if (DEBUG) {
            System.err.println("draw: " + new Vector(dArray6));
        }
    }

    private void drawI(int n) {
        if (this.columnDimProvider.getColumnDim(n, this.adaptor.getNumberOfFactors()) > 0) {
            int n2 = this.columnDimProvider.getArrayIndex(n, this.adaptor.getNumberOfFactors());
            this.drawI(n, this.precisionArray.get(n2), this.meanMidArray.get(n2), this.meanArray.get(n2));
        }
    }

    @Override
    public String getOperatorName() {
        return "newLoadingsGibbsOperator";
    }

    @Override
    public double doOperation() {
        if (DEBUG) {
            System.err.println("Start doOp");
        }
        this.adaptor.drawFactors();
        int n = this.adaptor.getNumberOfTraits();
        if (this.pool != null) {
            if (DEBUG) {
                System.err.println("!= poll");
            }
            try {
                this.pool.invokeAll(this.drawCallers);
                this.adaptor.fireLoadingsChanged();
            }
            catch (InterruptedException interruptedException) {
                interruptedException.printStackTrace();
            }
        } else {
            int n2;
            if (DEBUG) {
                System.err.println("inner");
            }
            if (!this.randomScan) {
                for (n2 = 0; n2 < n; ++n2) {
                    this.drawI(n2);
                }
            } else {
                n2 = MathUtils.nextInt(this.adaptor.getNumberOfTraits());
                this.drawI(n2);
            }
            this.constrainedSampler.applyConstraint(this.adaptor);
            this.adaptor.fireLoadingsChanged();
        }
        if (DEBUG) {
            for (Object object : this.meanArray) {
                System.err.println(new Vector((double[])object));
            }
            for (Object object : this.meanMidArray) {
                System.err.println(new Vector((double[])object));
            }
            Iterator<double[]> iterator = this.precisionArray.iterator();
            while (iterator.hasNext()) {
                Object object;
                object = (double[][])iterator.next();
                System.err.println(new Matrix((double[][])object));
            }
            System.err.println("End doOp");
        }
        return 0.0;
    }

    @Override
    public void setPathParameter(double d) {
        this.pathParameter = d;
    }

    private double getAdjustedPriorPrecision(int n) {
        return this.getPrecision(this.prior, n) * this.pathParameter + (1.0 - this.pathParameter) * this.priorPrecisionWorking;
    }

    @Override
    public String getReport() {
        int n;
        int n2;
        StringBuilder stringBuilder = new StringBuilder();
        stringBuilder.append(this.adaptor.getReport());
        stringBuilder.append("\n\n");
        int n3 = 20000;
        int n4 = this.adaptor.getNumberOfFactors();
        int n5 = this.adaptor.getNumberOfTraits();
        int n6 = n5 * n4;
        double[] dArray = new double[n6];
        double[][] dArray2 = new double[n6][n6];
        double[] dArray3 = new double[n6];
        for (n2 = 0; n2 < n6; ++n2) {
            dArray3[n2] = this.adaptor.getLoadingsValue(n2);
        }
        for (n2 = 0; n2 < n3; ++n2) {
            this.doOperation();
            for (n = 0; n < n6; ++n) {
                int n7 = n;
                dArray[n7] = dArray[n7] + this.adaptor.getLoadingsValue(n);
                for (int i = n; i < n6; ++i) {
                    double[] dArray4 = dArray2[n];
                    int n8 = i;
                    dArray4[n8] = dArray4[n8] + this.adaptor.getLoadingsValue(n) * this.adaptor.getLoadingsValue(i);
                }
            }
            this.adaptor.fireLoadingsChanged();
        }
        this.restoreLoadings(dArray3);
        this.adaptor.fireLoadingsChanged();
        for (n2 = 0; n2 < n6; ++n2) {
            int n9 = n2;
            dArray[n9] = dArray[n9] / (double)n3;
            n = n2;
            while (n < n6) {
                double[] dArray5 = dArray2[n2];
                int n10 = n++;
                dArray5[n10] = dArray5[n10] / (double)n3;
            }
        }
        for (n2 = 0; n2 < n6; ++n2) {
            for (n = n2; n < n6; ++n) {
                dArray2[n2][n] = dArray2[n2][n] - dArray[n2] * dArray[n];
                dArray2[n][n2] = dArray2[n2][n];
            }
        }
        stringBuilder.append(this.getOperatorName() + "Report:\n");
        stringBuilder.append("Loadings mean:\n");
        stringBuilder.append(new Vector(dArray));
        stringBuilder.append("\n\n");
        stringBuilder.append("Loadings covariance:\n");
        stringBuilder.append(new Matrix(dArray2));
        stringBuilder.append("\n\n");
        return stringBuilder.toString();
    }

    private void restoreLoadings(double[] dArray) {
        int n = this.adaptor.getNumberOfTraits();
        int n2 = this.adaptor.getNumberOfFactors();
        double[] dArray2 = new double[n2];
        for (int i = 0; i < n; ++i) {
            for (int j = 0; j < n2; ++j) {
                dArray2[j] = dArray[j * n + i];
            }
            this.adaptor.setLoadingsForTraitQuietly(i, dArray2);
        }
    }

    public static enum ConstrainedSampler {
        NONE("none"){

            @Override
            void applyConstraint(FactorAnalysisOperatorAdaptor factorAnalysisOperatorAdaptor) {
            }
        }
        ,
        REFLECTION("reflection"){

            @Override
            void applyConstraint(FactorAnalysisOperatorAdaptor factorAnalysisOperatorAdaptor) {
                for (int i = 0; i < factorAnalysisOperatorAdaptor.getNumberOfFactors(); ++i) {
                    factorAnalysisOperatorAdaptor.reflectLoadingsForFactor(i);
                }
            }
        };

        private String name;

        private ConstrainedSampler(String string2) {
            this.name = string2;
        }

        public String getName() {
            return this.name;
        }

        public static ConstrainedSampler parse(String string) {
            string = string.toLowerCase();
            for (ConstrainedSampler constrainedSampler : ConstrainedSampler.values()) {
                if (string.compareTo(constrainedSampler.getName()) != 0) continue;
                return constrainedSampler;
            }
            throw new IllegalArgumentException("Unknown sampler type");
        }

        abstract void applyConstraint(FactorAnalysisOperatorAdaptor var1);
    }

    class DrawCaller
    implements Callable<Double> {
        int i;
        double[][] precision;
        double[] midMean;
        double[] mean;
        private static final boolean DEBUG_PARALLEL_EVALUATION = false;

        DrawCaller(int n, double[][] dArray, double[] dArray2, double[] dArray3) {
            this.i = n;
            this.precision = dArray;
            this.midMean = dArray2;
            this.mean = dArray3;
        }

        @Override
        public Double call() throws Exception {
            NewLoadingsGibbsOperator.this.drawI(this.i, this.precision, this.midMean, this.mean);
            return null;
        }
    }
}

