package smile.regression;

import java.io.Serializable;
import smile.math.Math;
import smile.math.kernel.MercerKernel;
import smile.math.matrix.CholeskyDecomposition;
import smile.math.matrix.ColumnMajorMatrix;
import smile.math.matrix.DenseMatrix;
import smile.math.matrix.EigenValueDecomposition;
import smile.math.matrix.LUDecomposition;

/* loaded from: input_file:BOOT-INF/lib/libarx-3.8.0.jar:smile/regression/GaussianProcessRegression.class */
public class GaussianProcessRegression<T> implements Regression<T>, Serializable {
    private static final long serialVersionUID = 1;
    private T[] knots;
    private double[] w;
    private MercerKernel<T> kernel;
    private double lambda;

    /* loaded from: input_file:BOOT-INF/lib/libarx-3.8.0.jar:smile/regression/GaussianProcessRegression$Trainer.class */
    public static class Trainer<T> extends RegressionTrainer<T> {
        private MercerKernel<T> kernel;
        private double lambda;

        public Trainer(MercerKernel<T> mercerKernel, double d) {
            this.kernel = mercerKernel;
            this.lambda = d;
        }

        @Override // smile.regression.RegressionTrainer
        public GaussianProcessRegression<T> train(T[] tArr, double[] dArr) {
            return new GaussianProcessRegression<>(tArr, dArr, this.kernel, this.lambda);
        }

        public GaussianProcessRegression<T> train(T[] tArr, double[] dArr, T[] tArr2) {
            return new GaussianProcessRegression<>(tArr, dArr, tArr2, this.kernel, this.lambda);
        }
    }

    public GaussianProcessRegression(T[] tArr, double[] dArr, MercerKernel<T> mercerKernel, double d) {
        if (tArr.length != dArr.length) {
            throw new IllegalArgumentException(String.format("The sizes of X and Y don't match: %d != %d", Integer.valueOf(tArr.length), Integer.valueOf(dArr.length)));
        }
        if (d < 0.0d) {
            throw new IllegalArgumentException("Invalid regularization parameter lambda = " + d);
        }
        this.kernel = mercerKernel;
        this.lambda = d;
        this.knots = tArr;
        int length = tArr.length;
        double[][] dArr2 = new double[length][length];
        this.w = new double[length];
        for (int i = 0; i < length; i++) {
            for (int i2 = 0; i2 <= i; i2++) {
                dArr2[i][i2] = mercerKernel.k(tArr[i], tArr[i2]);
                dArr2[i2][i] = dArr2[i][i2];
            }
            double[] dArr3 = dArr2[i];
            int i3 = i;
            dArr3[i3] = dArr3[i3] + d;
        }
        new CholeskyDecomposition(dArr2).solve(dArr, this.w);
    }

    public GaussianProcessRegression(T[] tArr, double[] dArr, T[] tArr2, MercerKernel<T> mercerKernel, double d) {
        if (tArr.length != dArr.length) {
            throw new IllegalArgumentException(String.format("The sizes of X and Y don't match: %d != %d", Integer.valueOf(tArr.length), Integer.valueOf(dArr.length)));
        }
        if (d < 0.0d) {
            throw new IllegalArgumentException("Invalid regularization parameter lambda = " + d);
        }
        this.kernel = mercerKernel;
        this.lambda = d;
        this.knots = tArr2;
        int length = tArr.length;
        int length2 = tArr2.length;
        double[][] dArr2 = new double[length][length2];
        for (int i = 0; i < length; i++) {
            for (int i2 = 0; i2 < length2; i2++) {
                dArr2[i][i2] = mercerKernel.k(tArr[i], tArr2[i2]);
            }
        }
        double[][] atamm = Math.atamm(dArr2);
        for (int i3 = 0; i3 < length2; i3++) {
            for (int i4 = 0; i4 <= i3; i4++) {
                double[] dArr3 = atamm[i3];
                int i5 = i4;
                dArr3[i5] = dArr3[i5] + (d * mercerKernel.k(tArr2[i3], tArr2[i4]));
                atamm[i4][i3] = atamm[i3][i4];
            }
        }
        double[] dArr4 = new double[length2];
        this.w = new double[length2];
        Math.atx(dArr2, dArr, dArr4);
        new LUDecomposition(atamm).solve(dArr4, this.w);
    }

    GaussianProcessRegression(T[] tArr, double[] dArr, T[] tArr2, MercerKernel<T> mercerKernel, double d, boolean z) {
        if (tArr.length != dArr.length) {
            throw new IllegalArgumentException(String.format("The sizes of X and Y don't match: %d != %d", Integer.valueOf(tArr.length), Integer.valueOf(dArr.length)));
        }
        if (d < 0.0d) {
            throw new IllegalArgumentException("Invalid regularization parameter lambda = " + d);
        }
        this.kernel = mercerKernel;
        this.lambda = d;
        this.knots = tArr;
        int length = tArr.length;
        int length2 = tArr2.length;
        ColumnMajorMatrix columnMajorMatrix = new ColumnMajorMatrix(length, length2);
        for (int i = 0; i < length; i++) {
            for (int i2 = 0; i2 < length2; i2++) {
                columnMajorMatrix.set(i, i2, mercerKernel.k(tArr[i], tArr2[i2]));
            }
        }
        ColumnMajorMatrix columnMajorMatrix2 = new ColumnMajorMatrix(length2, length2);
        for (int i3 = 0; i3 < length2; i3++) {
            for (int i4 = 0; i4 <= i3; i4++) {
                double k = mercerKernel.k(tArr2[i3], tArr2[i4]);
                columnMajorMatrix2.set(i3, i4, k);
                columnMajorMatrix2.set(i4, i3, k);
            }
        }
        EigenValueDecomposition eigenValueDecomposition = new EigenValueDecomposition((DenseMatrix) columnMajorMatrix2, true);
        DenseMatrix eigenVectors = eigenValueDecomposition.getEigenVectors();
        DenseMatrix d2 = eigenValueDecomposition.getD();
        for (int i5 = 0; i5 < length2; i5++) {
            d2.set(i5, i5, 1.0d / Math.sqrt(d2.get(i5, i5)));
        }
        DenseMatrix abmm = columnMajorMatrix.abmm((ColumnMajorMatrix) eigenVectors.abmm(d2).abtmm(eigenVectors));
        DenseMatrix ata = abmm.ata();
        for (int i6 = 0; i6 < length2; i6++) {
            ata.add(i6, i6, d);
        }
        DenseMatrix abtmm = abmm.abmm(new CholeskyDecomposition(ata).inverse()).abtmm(abmm);
        this.w = new double[length];
        abtmm.atx(dArr, this.w);
        for (int i7 = 0; i7 < length; i7++) {
            this.w[i7] = (dArr[i7] - this.w[i7]) / d;
        }
    }

    public double[] coefficients() {
        return this.w;
    }

    public double shrinkage() {
        return this.lambda;
    }

    @Override // smile.regression.Regression
    public double predict(T t) {
        double d = 0.0d;
        for (int i = 0; i < this.knots.length; i++) {
            d += this.w[i] * this.kernel.k(t, this.knots[i]);
        }
        return d;
    }
}
