package smile.classification;

import java.io.Serializable;
import java.util.Arrays;
import smile.math.Math;
import smile.math.matrix.ColumnMajorMatrix;
import smile.math.matrix.DenseMatrix;
import smile.math.matrix.EigenValueDecomposition;

/* loaded from: input_file:BOOT-INF/lib/libarx-3.8.0.jar:smile/classification/QDA.class */
public class QDA extends SoftClassifier<double[]> implements Serializable {
    private static final long serialVersionUID = 1;
    private final int p;
    private final int k;
    private final double[] ct;
    private final double[] priori;
    private final double[][] mu;
    private final DenseMatrix[] scaling;
    private final double[][] ev;

    /* loaded from: input_file:BOOT-INF/lib/libarx-3.8.0.jar:smile/classification/QDA$Trainer.class */
    public static class Trainer extends ClassifierTrainer<double[]> {
        private double[] priori;
        private double tol;

        public Trainer(TrainingInterrupt trainingInterrupt) {
            super(trainingInterrupt);
            this.tol = 1.0E-4d;
        }

        public Trainer setPriori(double[] dArr) {
            this.priori = dArr;
            return this;
        }

        public Trainer setTolerance(double d) {
            if (d < 0.0d) {
                throw new IllegalArgumentException("Invalid tol: " + d);
            }
            this.tol = d;
            return this;
        }

        @Override // smile.classification.ClassifierTrainer
        public QDA train(double[][] dArr, int[] iArr) {
            return new QDA(dArr, iArr, this.priori, this.tol, this.interrupt);
        }
    }

    public QDA(double[][] dArr, int[] iArr, double d, TrainingInterrupt trainingInterrupt) {
        this(dArr, iArr, null, d, trainingInterrupt);
    }

    /* JADX WARN: Type inference failed for: r1v25, types: [double[], double[][]] */
    public QDA(double[][] dArr, int[] iArr, double[] dArr2, double d, TrainingInterrupt trainingInterrupt) {
        super(trainingInterrupt);
        if (dArr.length != iArr.length) {
            throw new IllegalArgumentException(String.format("The sizes of X and Y don't match: %d != %d", Integer.valueOf(dArr.length), Integer.valueOf(iArr.length)));
        }
        if (dArr2 != null) {
            if (dArr2.length < 2) {
                throw new IllegalArgumentException("Invalid number of priori probabilities: " + dArr2.length);
            }
            double d2 = 0.0d;
            for (double d3 : dArr2) {
                if (d3 <= 0.0d || d3 >= 1.0d) {
                    throw new IllegalArgumentException("Invalid priori probability: " + d3);
                }
                d2 += d3;
            }
            if (Math.abs(d2 - 1.0d) > 1.0E-10d) {
                throw new IllegalArgumentException("The sum of priori probabilities is not one: " + d2);
            }
        }
        int[] unique = Math.unique(iArr);
        Arrays.sort(unique);
        for (int i = 0; i < unique.length; i++) {
            if (unique[i] < 0) {
                throw new IllegalArgumentException("Negative class label: " + unique[i]);
            }
            if (i > 0 && unique[i] - unique[i - 1] > 1) {
                throw new IllegalArgumentException("Missing class: " + unique[i] + 1);
            }
        }
        this.k = unique.length;
        if (this.k < 2) {
            throw new IllegalArgumentException("Only one class.");
        }
        if (dArr2 != null && this.k != dArr2.length) {
            throw new IllegalArgumentException("The number of classes and the number of priori probabilities don't match.");
        }
        if (d < 0.0d) {
            throw new IllegalArgumentException("Invalid tol: " + d);
        }
        int length = dArr.length;
        if (length <= this.k) {
            throw new IllegalArgumentException(String.format("Sample size is too small: %d <= %d", Integer.valueOf(length), Integer.valueOf(this.k)));
        }
        this.p = dArr[0].length;
        int[] iArr2 = new int[this.k];
        this.mu = new double[this.k][this.p];
        DenseMatrix[] denseMatrixArr = new DenseMatrix[this.k];
        for (int i2 = 0; i2 < length; i2++) {
            int i3 = iArr[i2];
            iArr2[i3] = iArr2[i3] + 1;
            for (int i4 = 0; i4 < this.p; i4++) {
                double[] dArr3 = this.mu[i3];
                int i5 = i4;
                dArr3[i5] = dArr3[i5] + dArr[i2][i4];
            }
        }
        for (int i6 = 0; i6 < this.k; i6++) {
            if (iArr2[i6] <= 1) {
                throw new IllegalArgumentException(String.format("Class %d has only one sample.", Integer.valueOf(i6)));
            }
            denseMatrixArr[i6] = new ColumnMajorMatrix(this.p, this.p);
            for (int i7 = 0; i7 < this.p; i7++) {
                double[] dArr4 = this.mu[i6];
                int i8 = i7;
                dArr4[i8] = dArr4[i8] / iArr2[i6];
            }
        }
        if (dArr2 == null) {
            dArr2 = new double[this.k];
            for (int i9 = 0; i9 < this.k; i9++) {
                dArr2[i9] = iArr2[i9] / length;
            }
        }
        this.priori = dArr2;
        for (int i10 = 0; i10 < length; i10++) {
            int i11 = iArr[i10];
            for (int i12 = 0; i12 < this.p; i12++) {
                for (int i13 = 0; i13 <= i12; i13++) {
                    denseMatrixArr[i11].add(i12, i13, (dArr[i10][i12] - this.mu[i11][i12]) * (dArr[i10][i13] - this.mu[i11][i13]));
                }
            }
        }
        double d4 = d * d;
        this.ev = new double[this.k];
        for (int i14 = 0; i14 < this.k; i14++) {
            for (int i15 = 0; i15 < this.p; i15++) {
                for (int i16 = 0; i16 <= i15; i16++) {
                    denseMatrixArr[i14].div(i15, i16, iArr2[i14] - 1);
                    denseMatrixArr[i14].set(i16, i15, denseMatrixArr[i14].get(i15, i16));
                }
                if (denseMatrixArr[i14].get(i15, i15) < d4) {
                    throw new IllegalArgumentException(String.format("Class %d covariance matrix (variable %d) is close to singular.", Integer.valueOf(i14), Integer.valueOf(i15)));
                }
            }
            EigenValueDecomposition eigenValueDecomposition = new EigenValueDecomposition(denseMatrixArr[i14], true);
            for (double d5 : eigenValueDecomposition.getEigenValues()) {
                if (d5 < d4) {
                    throw new IllegalArgumentException(String.format("Class %d covariance matrix is close to singular.", Integer.valueOf(i14)));
                }
            }
            this.ev[i14] = eigenValueDecomposition.getEigenValues();
            denseMatrixArr[i14] = eigenValueDecomposition.getEigenVectors();
        }
        this.scaling = denseMatrixArr;
        this.ct = new double[this.k];
        for (int i17 = 0; i17 < this.k; i17++) {
            double d6 = 0.0d;
            for (int i18 = 0; i18 < this.p; i18++) {
                d6 += Math.log(this.ev[i17][i18]);
            }
            this.ct[i17] = Math.log(dArr2[i17]) - (0.5d * d6);
        }
    }

    public QDA(double[][] dArr, int[] iArr, double[] dArr2, TrainingInterrupt trainingInterrupt) {
        this(dArr, iArr, dArr2, 1.0E-4d, trainingInterrupt);
    }

    public QDA(double[][] dArr, int[] iArr, TrainingInterrupt trainingInterrupt) {
        this(dArr, iArr, (double[]) null, trainingInterrupt);
    }

    public double[] getPriori() {
        return this.priori;
    }

    @Override // smile.classification.Classifier
    public int predict(double[] dArr) {
        return predict(dArr, (double[]) null);
    }

    @Override // smile.classification.SoftClassifier
    public int predict(double[] dArr, double[] dArr2) {
        if (dArr.length != this.p) {
            throw new IllegalArgumentException(String.format("Invalid input vector size: %d, expected: %d", Integer.valueOf(dArr.length), Integer.valueOf(this.p)));
        }
        if (dArr2 != null && dArr2.length != this.k) {
            throw new IllegalArgumentException(String.format("Invalid posteriori vector size: %d, expected: %d", Integer.valueOf(dArr2.length), Integer.valueOf(this.k)));
        }
        int i = 0;
        double d = Double.NEGATIVE_INFINITY;
        double[] dArr3 = new double[this.p];
        double[] dArr4 = new double[this.p];
        for (int i2 = 0; i2 < this.k; i2++) {
            for (int i3 = 0; i3 < this.p; i3++) {
                dArr3[i3] = dArr[i3] - this.mu[i2][i3];
            }
            this.scaling[i2].atx(dArr3, dArr4);
            double d2 = 0.0d;
            for (int i4 = 0; i4 < this.p; i4++) {
                d2 += (dArr4[i4] * dArr4[i4]) / this.ev[i2][i4];
            }
            double d3 = this.ct[i2] - (0.5d * d2);
            if (d < d3) {
                d = d3;
                i = i2;
            }
            if (dArr2 != null) {
                dArr2[i2] = d3;
            }
        }
        if (dArr2 != null) {
            double d4 = 0.0d;
            for (int i5 = 0; i5 < this.k; i5++) {
                dArr2[i5] = Math.exp(dArr2[i5] - d);
                d4 += dArr2[i5];
            }
            for (int i6 = 0; i6 < this.k; i6++) {
                int i7 = i6;
                dArr2[i7] = dArr2[i7] / d4;
            }
        }
        return i;
    }
}
