package smile.classification;

import java.io.Serializable;
import java.util.Arrays;
import smile.math.Math;
import smile.math.matrix.ColumnMajorMatrix;
import smile.math.matrix.DenseMatrix;
import smile.math.matrix.EigenValueDecomposition;
import smile.projection.Projection;

/* loaded from: input_file:BOOT-INF/lib/libarx-3.8.0.jar:smile/classification/FLD.class */
public class FLD implements Classifier<double[]>, Projection<double[]>, Serializable {
    private static final long serialVersionUID = 1;
    private final int p;
    private final int k;
    private final double[] mean;
    private final double[][] mu;
    private final double[][] scaling;
    private final double[] smean;
    private final double[][] smu;

    /* loaded from: input_file:BOOT-INF/lib/libarx-3.8.0.jar:smile/classification/FLD$Trainer.class */
    public static class Trainer extends ClassifierTrainer<double[]> {
        private int L;
        private double tol;

        public Trainer(TrainingInterrupt trainingInterrupt) {
            super(trainingInterrupt);
            this.L = -1;
            this.tol = 1.0E-4d;
        }

        public Trainer setDimension(int i) {
            if (i < 1) {
                throw new IllegalArgumentException("Invalid mapping space dimension: " + i);
            }
            this.L = i;
            return this;
        }

        public Trainer setTolerance(double d) {
            if (d < 0.0d) {
                throw new IllegalArgumentException("Invalid tol: " + d);
            }
            this.tol = d;
            return this;
        }

        @Override // smile.classification.ClassifierTrainer
        public FLD train(double[][] dArr, int[] iArr) {
            return new FLD(dArr, iArr, this.L, this.tol);
        }
    }

    public FLD(double[][] dArr, int[] iArr) {
        this(dArr, iArr, -1);
    }

    public FLD(double[][] dArr, int[] iArr, int i) {
        this(dArr, iArr, i, 1.0E-4d);
    }

    public FLD(double[][] dArr, int[] iArr, int i, double d) {
        if (dArr.length != iArr.length) {
            throw new IllegalArgumentException(String.format("The sizes of X and Y don't match: %d != %d", Integer.valueOf(dArr.length), Integer.valueOf(iArr.length)));
        }
        int[] unique = Math.unique(iArr);
        Arrays.sort(unique);
        for (int i2 = 0; i2 < unique.length; i2++) {
            if (unique[i2] < 0) {
                throw new IllegalArgumentException("Negative class label: " + unique[i2]);
            }
            if (i2 > 0 && unique[i2] - unique[i2 - 1] > 1) {
                throw new IllegalArgumentException("Missing class: " + unique[i2] + 1);
            }
        }
        this.k = unique.length;
        if (this.k < 2) {
            throw new IllegalArgumentException("Only one class.");
        }
        if (d < 0.0d) {
            throw new IllegalArgumentException("Invalid tol: " + d);
        }
        if (dArr.length <= this.k) {
            throw new IllegalArgumentException(String.format("Sample size is too small: %d <= %d", Integer.valueOf(dArr.length), Integer.valueOf(this.k)));
        }
        if (i >= this.k) {
            throw new IllegalArgumentException(String.format("The dimensionality of mapped space is too high: %d >= %d", Integer.valueOf(i), Integer.valueOf(this.k)));
        }
        i = i <= 0 ? this.k - 1 : i;
        int length = dArr.length;
        this.p = dArr[0].length;
        int[] iArr2 = new int[this.k];
        this.mean = Math.colMean(dArr);
        ColumnMajorMatrix columnMajorMatrix = new ColumnMajorMatrix(this.p, this.p);
        this.mu = new double[this.k][this.p];
        for (int i3 = 0; i3 < length; i3++) {
            int i4 = iArr[i3];
            iArr2[i4] = iArr2[i4] + 1;
            for (int i5 = 0; i5 < this.p; i5++) {
                double[] dArr2 = this.mu[i4];
                int i6 = i5;
                dArr2[i6] = dArr2[i6] + dArr[i3][i5];
            }
        }
        for (int i7 = 0; i7 < this.k; i7++) {
            for (int i8 = 0; i8 < this.p; i8++) {
                this.mu[i7][i8] = (this.mu[i7][i8] / iArr2[i7]) - this.mean[i8];
            }
        }
        for (int i9 = 0; i9 < length; i9++) {
            for (int i10 = 0; i10 < this.p; i10++) {
                for (int i11 = 0; i11 <= i10; i11++) {
                    columnMajorMatrix.add(i10, i11, (dArr[i9][i10] - this.mean[i10]) * (dArr[i9][i11] - this.mean[i11]));
                }
            }
        }
        for (int i12 = 0; i12 < this.p; i12++) {
            for (int i13 = 0; i13 <= i12; i13++) {
                columnMajorMatrix.div(i12, i13, length);
                columnMajorMatrix.set(i13, i12, columnMajorMatrix.get(i12, i13));
            }
        }
        ColumnMajorMatrix columnMajorMatrix2 = new ColumnMajorMatrix(this.p, this.p);
        for (int i14 = 0; i14 < this.k; i14++) {
            for (int i15 = 0; i15 < this.p; i15++) {
                for (int i16 = 0; i16 <= i15; i16++) {
                    columnMajorMatrix2.add(i15, i16, this.mu[i14][i15] * this.mu[i14][i16]);
                }
            }
        }
        for (int i17 = 0; i17 < this.p; i17++) {
            for (int i18 = 0; i18 <= i17; i18++) {
                columnMajorMatrix2.div(i17, i18, this.k);
                columnMajorMatrix2.set(i18, i17, columnMajorMatrix2.get(i17, i18));
            }
        }
        EigenValueDecomposition eigenValueDecomposition = new EigenValueDecomposition((DenseMatrix) columnMajorMatrix, true);
        double d2 = d * d;
        double[] eigenValues = eigenValueDecomposition.getEigenValues();
        for (int i19 = 0; i19 < eigenValues.length; i19++) {
            if (eigenValues[i19] < d2) {
                throw new IllegalArgumentException("The covariance matrix is close to singular.");
            }
            eigenValues[i19] = 1.0d / eigenValues[i19];
        }
        DenseMatrix eigenVectors = eigenValueDecomposition.getEigenVectors();
        DenseMatrix atbmm = eigenVectors.atbmm(columnMajorMatrix2);
        for (int i20 = 0; i20 < this.k; i20++) {
            for (int i21 = 0; i21 < this.p; i21++) {
                atbmm.mul(i20, i21, eigenValues[i21]);
            }
        }
        DenseMatrix eigenVectors2 = new EigenValueDecomposition(eigenVectors.abmm(atbmm), true).getEigenVectors();
        this.scaling = new double[this.p][i];
        for (int i22 = 0; i22 < this.p; i22++) {
            for (int i23 = 0; i23 < i; i23++) {
                this.scaling[i22][i23] = eigenVectors2.get(i22, i23);
            }
        }
        this.smean = new double[i];
        Math.atx(this.scaling, this.mean, this.smean);
        this.smu = Math.abmm(this.mu, this.scaling);
    }

    public double[][] getProjection() {
        return this.scaling;
    }

    @Override // smile.classification.Classifier
    public int predict(double[] dArr) {
        if (dArr.length != this.p) {
            throw new IllegalArgumentException(String.format("Invalid input vector size: %d, expected: %d", Integer.valueOf(dArr.length), Integer.valueOf(this.p)));
        }
        double[] project = project(dArr);
        int i = 0;
        double d = Double.POSITIVE_INFINITY;
        for (int i2 = 0; i2 < this.k; i2++) {
            double distance = Math.distance(project, this.smu[i2]);
            if (distance < d) {
                d = distance;
                i = i2;
            }
        }
        return i;
    }

    @Override // smile.projection.Projection
    public double[] project(double[] dArr) {
        if (dArr.length != this.p) {
            throw new IllegalArgumentException(String.format("Invalid input vector size: %d, expected: %d", Integer.valueOf(dArr.length), Integer.valueOf(this.p)));
        }
        double[] dArr2 = new double[this.scaling[0].length];
        Math.atx(this.scaling, dArr, dArr2);
        Math.minus(dArr2, this.smean);
        return dArr2;
    }

    @Override // smile.projection.Projection
    public double[][] project(double[][] dArr) {
        double[][] dArr2 = new double[dArr.length][this.scaling[0].length];
        for (int i = 0; i < dArr.length; i++) {
            if (dArr[i].length != this.p) {
                throw new IllegalArgumentException(String.format("Invalid input vector size: %d, expected: %d", Integer.valueOf(dArr[i].length), Integer.valueOf(this.p)));
            }
            Math.atx(this.scaling, dArr[i], dArr2[i]);
            Math.minus(dArr2[i], this.smean);
        }
        return dArr2;
    }
}
