package smile.stat.distribution;

import java.util.ArrayList;
import java.util.List;
import smile.math.Math;
import smile.stat.distribution.MultivariateMixture;

/* loaded from: input_file:BOOT-INF/lib/libarx-3.8.0.jar:smile/stat/distribution/MultivariateGaussianMixture.class */
public class MultivariateGaussianMixture extends MultivariateExponentialFamilyMixture {
    public MultivariateGaussianMixture(List<MultivariateMixture.Component> list) {
        super(list);
    }

    public MultivariateGaussianMixture(double[][] dArr, int i) {
        this(dArr, i, false);
    }

    public MultivariateGaussianMixture(double[][] dArr, int i, boolean z) {
        if (i < 2) {
            throw new IllegalArgumentException("Invalid number of components in the mixture.");
        }
        int length = dArr.length;
        int length2 = dArr[0].length;
        double[] dArr2 = new double[length2];
        double[][] dArr3 = new double[length2][length2];
        for (double[] dArr4 : dArr) {
            for (int i2 = 0; i2 < length2; i2++) {
                int i3 = i2;
                dArr2[i3] = dArr2[i3] + dArr4[i2];
            }
        }
        for (int i4 = 0; i4 < length2; i4++) {
            int i5 = i4;
            dArr2[i5] = dArr2[i5] / length;
        }
        if (z) {
            for (int i6 = 0; i6 < length; i6++) {
                for (int i7 = 0; i7 < length2; i7++) {
                    double[] dArr5 = dArr3[i7];
                    int i8 = i7;
                    dArr5[i8] = dArr5[i8] + ((dArr[i6][i7] - dArr2[i7]) * (dArr[i6][i7] - dArr2[i7]));
                }
            }
            for (int i9 = 0; i9 < length2; i9++) {
                double[] dArr6 = dArr3[i9];
                int i10 = i9;
                dArr6[i10] = dArr6[i10] / (length - 1);
            }
        } else {
            for (int i11 = 0; i11 < length; i11++) {
                for (int i12 = 0; i12 < length2; i12++) {
                    for (int i13 = 0; i13 <= i12; i13++) {
                        double[] dArr7 = dArr3[i12];
                        int i14 = i13;
                        dArr7[i14] = dArr7[i14] + ((dArr[i11][i12] - dArr2[i12]) * (dArr[i11][i13] - dArr2[i13]));
                    }
                }
            }
            for (int i15 = 0; i15 < length2; i15++) {
                for (int i16 = 0; i16 <= i15; i16++) {
                    double[] dArr8 = dArr3[i15];
                    int i17 = i16;
                    dArr8[i17] = dArr8[i17] / (length - 1);
                    dArr3[i16][i15] = dArr3[i15][i16];
                }
            }
        }
        double[] dArr9 = dArr[Math.randomInt(length)];
        MultivariateMixture.Component component = new MultivariateMixture.Component();
        component.priori = 1.0d / i;
        MultivariateGaussianDistribution multivariateGaussianDistribution = new MultivariateGaussianDistribution(dArr9, dArr3);
        multivariateGaussianDistribution.diagonal = z;
        component.distribution = multivariateGaussianDistribution;
        this.components.add(component);
        double[] dArr10 = new double[length];
        for (int i18 = 0; i18 < length; i18++) {
            dArr10[i18] = Double.MAX_VALUE;
        }
        for (int i19 = 1; i19 < i; i19++) {
            for (int i20 = 0; i20 < length; i20++) {
                double squaredDistance = Math.squaredDistance(dArr[i20], dArr9);
                if (squaredDistance < dArr10[i20]) {
                    dArr10[i20] = squaredDistance;
                }
            }
            double random = Math.random() * Math.sum(dArr10);
            double d = 0.0d;
            int i21 = 0;
            while (i21 < length) {
                d += dArr10[i21];
                if (d >= random) {
                    break;
                } else {
                    i21++;
                }
            }
            dArr9 = dArr[i21];
            MultivariateMixture.Component component2 = new MultivariateMixture.Component();
            component2.priori = 1.0d / i;
            MultivariateGaussianDistribution multivariateGaussianDistribution2 = new MultivariateGaussianDistribution(dArr9, dArr3);
            multivariateGaussianDistribution2.diagonal = z;
            component2.distribution = multivariateGaussianDistribution2;
            this.components.add(component2);
        }
        EM(this.components, dArr);
    }

    public MultivariateGaussianMixture(double[][] dArr) {
        this(dArr, false);
    }

    public MultivariateGaussianMixture(double[][] dArr, boolean z) {
        if (dArr.length < 20) {
            throw new IllegalArgumentException("Too few samples.");
        }
        ArrayList arrayList = new ArrayList();
        MultivariateMixture.Component component = new MultivariateMixture.Component();
        component.priori = 1.0d;
        component.distribution = new MultivariateGaussianDistribution(dArr, z);
        arrayList.add(component);
        int i = 0;
        for (int i2 = 0; i2 < arrayList.size(); i2++) {
            i += ((MultivariateMixture.Component) arrayList.get(i2)).distribution.npara();
        }
        double d = 0.0d;
        for (double[] dArr2 : dArr) {
            double p = component.distribution.p(dArr2);
            if (p > 0.0d) {
                d += Math.log(p);
            }
        }
        double log = d - ((0.5d * i) * Math.log(dArr.length));
        double d2 = Double.NEGATIVE_INFINITY;
        while (log > d2) {
            d2 = log;
            this.components = (ArrayList) arrayList.clone();
            split(arrayList);
            double EM = EM(arrayList, dArr);
            int i3 = 0;
            for (int i4 = 0; i4 < arrayList.size(); i4++) {
                i3 += ((MultivariateMixture.Component) arrayList.get(i4)).distribution.npara();
            }
            log = EM - ((0.5d * i3) * Math.log(dArr.length));
        }
    }

    private void split(List<MultivariateMixture.Component> list) {
        MultivariateMixture.Component component = null;
        double d = 0.0d;
        for (MultivariateMixture.Component component2 : list) {
            double scatter = ((MultivariateGaussianDistribution) component2.distribution).scatter();
            if (scatter > d) {
                d = scatter;
                component = component2;
            }
        }
        double[][] cov = ((MultivariateGaussianDistribution) component.distribution).cov();
        double[] mean = ((MultivariateGaussianDistribution) component.distribution).mean();
        MultivariateMixture.Component component3 = new MultivariateMixture.Component();
        component3.priori = component.priori / 2.0d;
        double[] dArr = new double[mean.length];
        for (int i = 0; i < mean.length; i++) {
            dArr[i] = mean[i] + (Math.sqrt(cov[i][i]) / 2.0d);
        }
        component3.distribution = new MultivariateGaussianDistribution(dArr, cov);
        list.add(component3);
        MultivariateMixture.Component component4 = new MultivariateMixture.Component();
        component4.priori = component.priori / 2.0d;
        double[] dArr2 = new double[mean.length];
        for (int i2 = 0; i2 < mean.length; i2++) {
            dArr2[i2] = mean[i2] - (Math.sqrt(cov[i2][i2]) / 2.0d);
        }
        component4.distribution = new MultivariateGaussianDistribution(dArr2, cov);
        list.add(component4);
        list.remove(component);
    }
}
