package smile.classification;

import java.io.Serializable;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Iterator;
import java.util.List;
import java.util.concurrent.Callable;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import smile.classification.DecisionTree;
import smile.data.Attribute;
import smile.data.NumericAttribute;
import smile.math.Math;
import smile.util.SmileUtils;
import smile.validation.Accuracy;
import smile.validation.ClassificationMeasure;

/* loaded from: input_file:BOOT-INF/lib/libarx-3.8.0.jar:smile/classification/RandomForest.class */
public class RandomForest extends SoftClassifier<double[]> implements Serializable {
    private static final long serialVersionUID = 1;
    private static final Logger logger = LoggerFactory.getLogger((Class<?>) RandomForest.class);
    private List<Tree> trees;
    private int k;
    private double error;
    private double[] importance;

    /* loaded from: input_file:BOOT-INF/lib/libarx-3.8.0.jar:smile/classification/RandomForest$Trainer.class */
    public static class Trainer extends ClassifierTrainer<double[]> {
        private int ntrees;
        private DecisionTree.SplitRule rule;
        private int mtry;
        private int nodeSize;
        private int maxNodes;
        private double subsample;

        public Trainer(Attribute[] attributeArr, int i) {
            super(attributeArr);
            this.ntrees = 500;
            this.rule = DecisionTree.SplitRule.GINI;
            this.mtry = -1;
            this.nodeSize = 1;
            this.maxNodes = 100;
            this.subsample = 1.0d;
            if (i < 1) {
                throw new IllegalArgumentException("Invalid number of trees: " + i);
            }
            this.ntrees = i;
        }

        public Trainer(int i, TrainingInterrupt trainingInterrupt) {
            super(trainingInterrupt);
            this.ntrees = 500;
            this.rule = DecisionTree.SplitRule.GINI;
            this.mtry = -1;
            this.nodeSize = 1;
            this.maxNodes = 100;
            this.subsample = 1.0d;
            if (i < 1) {
                throw new IllegalArgumentException("Invalid number of trees: " + i);
            }
            this.ntrees = i;
        }

        public Trainer(TrainingInterrupt trainingInterrupt) {
            super(trainingInterrupt);
            this.ntrees = 500;
            this.rule = DecisionTree.SplitRule.GINI;
            this.mtry = -1;
            this.nodeSize = 1;
            this.maxNodes = 100;
            this.subsample = 1.0d;
        }

        public Trainer setMaxNodes(int i) {
            if (i < 2) {
                throw new IllegalArgumentException("Invalid minimum size of leaf nodes: " + i);
            }
            this.maxNodes = i;
            return this;
        }

        public Trainer setNodeSize(int i) {
            if (i < 1) {
                throw new IllegalArgumentException("Invalid minimum size of leaf nodes: " + i);
            }
            this.nodeSize = i;
            return this;
        }

        public Trainer setNumRandomFeatures(int i) {
            if (i < 1) {
                throw new IllegalArgumentException("Invalid number of random selected features for splitting: " + i);
            }
            this.mtry = i;
            return this;
        }

        public Trainer setNumTrees(int i) {
            if (i < 1) {
                throw new IllegalArgumentException("Invalid number of trees: " + i);
            }
            this.ntrees = i;
            return this;
        }

        public Trainer setSamplingRates(double d) {
            if (d <= 0.0d || d > 1.0d) {
                throw new IllegalArgumentException("Invalid sampling rating: " + d);
            }
            this.subsample = d;
            return this;
        }

        public Trainer setSplitRule(DecisionTree.SplitRule splitRule) {
            this.rule = splitRule;
            return this;
        }

        @Override // smile.classification.ClassifierTrainer
        public RandomForest train(double[][] dArr, int[] iArr) {
            return new RandomForest(this.attributes, dArr, iArr, this.ntrees, this.maxNodes, this.nodeSize, this.mtry, this.subsample, this.rule, null);
        }
    }

    /* loaded from: input_file:BOOT-INF/lib/libarx-3.8.0.jar:smile/classification/RandomForest$TrainingTask.class */
    static class TrainingTask implements Callable<Tree> {
        Attribute[] attributes;
        double[][] x;
        int[] y;
        int mtry;
        int nodeSize;
        int maxNodes;
        double subsample;
        DecisionTree.SplitRule rule;
        int[] classWeight;
        int[][] order;
        int[][] prediction;
        TrainingInterrupt interrupt;

        TrainingTask(Attribute[] attributeArr, double[][] dArr, int[] iArr, int i, int i2, int i3, double d, DecisionTree.SplitRule splitRule, int[] iArr2, int[][] iArr3, int[][] iArr4, TrainingInterrupt trainingInterrupt) {
            this.nodeSize = 5;
            this.maxNodes = 100;
            this.subsample = 1.0d;
            this.attributes = attributeArr;
            this.x = dArr;
            this.y = iArr;
            this.mtry = i3;
            this.nodeSize = i2;
            this.maxNodes = i;
            this.subsample = d;
            this.rule = splitRule;
            this.classWeight = iArr2;
            this.order = iArr3;
            this.prediction = iArr4;
            this.interrupt = trainingInterrupt;
        }

        protected void interrupt() throws TrainingInterruptedException {
            if (this.interrupt != null && this.interrupt.isInterrupted()) {
                throw new TrainingInterruptedException();
            }
        }

        /* JADX WARN: Can't rename method to resolve collision */
        @Override // java.util.concurrent.Callable
        public Tree call() {
            int length = this.x.length;
            int max = Math.max(this.y) + 1;
            int[] iArr = new int[length];
            if (this.subsample == 1.0d) {
                for (int i = 0; i < max; i++) {
                    int i2 = 0;
                    ArrayList arrayList = new ArrayList();
                    for (int i3 = 0; i3 < length; i3++) {
                        interrupt();
                        if (this.y[i3] == i) {
                            arrayList.add(Integer.valueOf(i3));
                            i2++;
                        }
                    }
                    int i4 = i2 / this.classWeight[i];
                    for (int i5 = 0; i5 < i4; i5++) {
                        interrupt();
                        int intValue = ((Integer) arrayList.get(Math.randomInt(i2))).intValue();
                        iArr[intValue] = iArr[intValue] + 1;
                    }
                }
            } else {
                int[] iArr2 = new int[length];
                for (int i6 = 0; i6 < length; i6++) {
                    iArr2[i6] = i6;
                }
                Math.permutate(iArr2);
                int[] iArr3 = new int[max];
                for (int i7 = 0; i7 < length; i7++) {
                    interrupt();
                    int i8 = this.y[i7];
                    iArr3[i8] = iArr3[i8] + 1;
                }
                for (int i9 = 0; i9 < max; i9++) {
                    int round = (int) Math.round((iArr3[i9] * this.subsample) / this.classWeight[i9]);
                    int i10 = 0;
                    for (int i11 = 0; i11 < length && i10 < round; i11++) {
                        interrupt();
                        int i12 = iArr2[i11];
                        if (this.y[i12] == i9) {
                            iArr[i12] = iArr[i12] + 1;
                            i10++;
                        }
                    }
                }
            }
            DecisionTree decisionTree = new DecisionTree(this.attributes, this.x, this.y, this.maxNodes, this.nodeSize, this.mtry, this.rule, iArr, this.order, this.interrupt);
            int i13 = 0;
            int i14 = 0;
            for (int i15 = 0; i15 < length; i15++) {
                interrupt();
                if (iArr[i15] == 0) {
                    i13++;
                    int predict = decisionTree.predict(this.x[i15]);
                    if (predict == this.y[i15]) {
                        i14++;
                    }
                    synchronized (this.prediction[i15]) {
                        int[] iArr4 = this.prediction[i15];
                        iArr4[predict] = iArr4[predict] + 1;
                    }
                }
            }
            double d = 1.0d;
            if (i13 != 0) {
                d = i14 / i13;
                RandomForest.logger.info("Random forest tree OOB size: {}, accuracy: {}", Integer.valueOf(i13), String.format("%.2f%%", Double.valueOf(100.0d * d)));
            } else {
                RandomForest.logger.error("Random forest has a tree trained without OOB samples.");
            }
            return new Tree(decisionTree, d);
        }
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    /* loaded from: input_file:BOOT-INF/lib/libarx-3.8.0.jar:smile/classification/RandomForest$Tree.class */
    public static class Tree implements Serializable {
        private static final long serialVersionUID = 7167971471207545655L;
        DecisionTree tree;
        double weight;

        Tree(DecisionTree decisionTree, double d) {
            this.tree = decisionTree;
            this.weight = d;
        }
    }

    public RandomForest(Attribute[] attributeArr, double[][] dArr, int[] iArr, int i, int i2, int i3, int i4, double d, DecisionTree.SplitRule splitRule, int[] iArr2, TrainingInterrupt trainingInterrupt) {
        super(trainingInterrupt);
        this.k = 2;
        if (dArr.length != iArr.length) {
            throw new IllegalArgumentException(String.format("The sizes of X and Y don't match: %d != %d", Integer.valueOf(dArr.length), Integer.valueOf(iArr.length)));
        }
        if (i < 1) {
            throw new IllegalArgumentException("Invalid number of trees: " + i);
        }
        if (i4 < 1 || i4 > dArr[0].length) {
            throw new IllegalArgumentException("Invalid number of variables to split on at a node of the tree: " + i4);
        }
        if (i3 < 1) {
            throw new IllegalArgumentException("Invalid minimum size of leaves: " + i3);
        }
        if (i2 < 2) {
            throw new IllegalArgumentException("Invalid maximum number of leaves: " + i2);
        }
        if (d <= 0.0d || d > 1.0d) {
            throw new IllegalArgumentException("Invalid sampling rating: " + d);
        }
        int[] unique = Math.unique(iArr);
        Arrays.sort(unique);
        for (int i5 = 0; i5 < unique.length; i5++) {
            if (unique[i5] < 0) {
                throw new IllegalArgumentException("Negative class label: " + unique[i5]);
            }
            if (i5 > 0 && unique[i5] - unique[i5 - 1] > 1) {
                throw new IllegalArgumentException("Missing class: " + unique[i5] + 1);
            }
        }
        this.k = unique.length;
        if (this.k < 2) {
            throw new IllegalArgumentException("Only one class.");
        }
        if (attributeArr == null) {
            int length = dArr[0].length;
            attributeArr = new Attribute[length];
            for (int i6 = 0; i6 < length; i6++) {
                attributeArr[i6] = new NumericAttribute("V" + (i6 + 1));
            }
        }
        if (iArr2 == null) {
            iArr2 = new int[this.k];
            for (int i7 = 0; i7 < this.k; i7++) {
                iArr2[i7] = 1;
            }
        }
        int length2 = dArr.length;
        int[][] iArr3 = new int[length2][this.k];
        int[][] sort = SmileUtils.sort(attributeArr, dArr);
        ArrayList arrayList = new ArrayList();
        for (int i8 = 0; i8 < i; i8++) {
            arrayList.add(new TrainingTask(attributeArr, dArr, iArr, i2, i3, i4, d, splitRule, iArr2, sort, iArr3, trainingInterrupt));
        }
        this.trees = new ArrayList(i);
        for (int i9 = 0; i9 < i; i9++) {
            this.trees.add(((TrainingTask) arrayList.get(i9)).call());
        }
        int i10 = 0;
        for (int i11 = 0; i11 < length2; i11++) {
            interrupt();
            int whichMax = Math.whichMax(iArr3[i11]);
            if (iArr3[i11][whichMax] > 0) {
                i10++;
                if (whichMax != iArr[i11]) {
                    this.error += 1.0d;
                }
            }
        }
        if (i10 > 0) {
            this.error /= i10;
        }
        this.importance = new double[attributeArr.length];
        for (Tree tree : this.trees) {
            interrupt();
            double[] importance = tree.tree.importance();
            for (int i12 = 0; i12 < importance.length; i12++) {
                double[] dArr2 = this.importance;
                int i13 = i12;
                dArr2[i13] = dArr2[i13] + importance[i12];
            }
        }
    }

    public RandomForest(Attribute[] attributeArr, double[][] dArr, int[] iArr, int i, int i2, int i3, int i4, double d, DecisionTree.SplitRule splitRule, TrainingInterrupt trainingInterrupt) {
        this(attributeArr, dArr, iArr, i, i2, i3, i4, d, splitRule, null, trainingInterrupt);
    }

    public RandomForest(Attribute[] attributeArr, double[][] dArr, int[] iArr, int i, int i2, int i3, int i4, double d, TrainingInterrupt trainingInterrupt) {
        this(attributeArr, dArr, iArr, i, i2, i3, i4, d, DecisionTree.SplitRule.GINI, trainingInterrupt);
    }

    public RandomForest(Attribute[] attributeArr, double[][] dArr, int[] iArr, int i, int i2, TrainingInterrupt trainingInterrupt) {
        this(attributeArr, dArr, iArr, i, 100, 5, i2, 1.0d, trainingInterrupt);
    }

    public RandomForest(Attribute[] attributeArr, double[][] dArr, int[] iArr, int i, TrainingInterrupt trainingInterrupt) {
        this(attributeArr, dArr, iArr, i, (int) Math.floor(Math.sqrt(dArr[0].length)), trainingInterrupt);
    }

    public RandomForest(double[][] dArr, int[] iArr, int i, int i2, TrainingInterrupt trainingInterrupt) {
        this(null, dArr, iArr, i, i2, trainingInterrupt);
    }

    public RandomForest(double[][] dArr, int[] iArr, int i, TrainingInterrupt trainingInterrupt) {
        this((Attribute[]) null, dArr, iArr, i, trainingInterrupt);
    }

    public double error() {
        return this.error;
    }

    public DecisionTree[] getTrees() {
        DecisionTree[] decisionTreeArr = new DecisionTree[this.trees.size()];
        for (int i = 0; i < decisionTreeArr.length; i++) {
            decisionTreeArr[i] = this.trees.get(i).tree;
        }
        return decisionTreeArr;
    }

    public double[] importance() {
        return this.importance;
    }

    @Override // smile.classification.Classifier
    public int predict(double[] dArr) {
        int[] iArr = new int[this.k];
        Iterator<Tree> it = this.trees.iterator();
        while (it.hasNext()) {
            int predict = it.next().tree.predict(dArr);
            iArr[predict] = iArr[predict] + 1;
        }
        return Math.whichMax(iArr);
    }

    @Override // smile.classification.SoftClassifier
    public int predict(double[] dArr, double[] dArr2) {
        if (dArr2.length != this.k) {
            throw new IllegalArgumentException(String.format("Invalid posteriori vector size: %d, expected: %d", Integer.valueOf(dArr2.length), Integer.valueOf(this.k)));
        }
        Arrays.fill(dArr2, 0.0d);
        int[] iArr = new int[this.k];
        double[] dArr3 = new double[this.k];
        for (Tree tree : this.trees) {
            int predict = tree.tree.predict(dArr, dArr3);
            iArr[predict] = iArr[predict] + 1;
            for (int i = 0; i < this.k; i++) {
                int i2 = i;
                dArr2[i2] = dArr2[i2] + (tree.weight * dArr3[i]);
            }
        }
        Math.unitize1(dArr2);
        return Math.whichMax(iArr);
    }

    public int size() {
        return this.trees.size();
    }

    public double[] test(double[][] dArr, int[] iArr) {
        int size = this.trees.size();
        double[] dArr2 = new double[size];
        int length = dArr.length;
        int[] iArr2 = new int[length];
        int[][] iArr3 = new int[length][this.k];
        Accuracy accuracy = new Accuracy();
        for (int i = 0; i < size; i++) {
            for (int i2 = 0; i2 < length; i2++) {
                int[] iArr4 = iArr3[i2];
                int predict = this.trees.get(i).tree.predict(dArr[i2]);
                iArr4[predict] = iArr4[predict] + 1;
                iArr2[i2] = Math.whichMax(iArr3[i2]);
            }
            dArr2[i] = accuracy.measure(iArr, iArr2);
        }
        return dArr2;
    }

    public double[][] test(double[][] dArr, int[] iArr, ClassificationMeasure[] classificationMeasureArr) {
        int size = this.trees.size();
        int length = classificationMeasureArr.length;
        double[][] dArr2 = new double[size][length];
        int length2 = dArr.length;
        int[] iArr2 = new int[length2];
        double[][] dArr3 = new double[length2][this.k];
        for (int i = 0; i < size; i++) {
            for (int i2 = 0; i2 < length2; i2++) {
                double[] dArr4 = dArr3[i2];
                int predict = this.trees.get(i).tree.predict(dArr[i2]);
                dArr4[predict] = dArr4[predict] + 1.0d;
                iArr2[i2] = Math.whichMax(dArr3[i2]);
            }
            for (int i3 = 0; i3 < length; i3++) {
                dArr2[i][i3] = classificationMeasureArr[i3].measure(iArr, iArr2);
            }
        }
        return dArr2;
    }

    public void trim(int i) {
        if (i > this.trees.size()) {
            throw new IllegalArgumentException("The new model size is larger than the current size.");
        }
        if (i <= 0) {
            throw new IllegalArgumentException("Invalid new model size: " + i);
        }
        ArrayList arrayList = new ArrayList(i);
        for (int i2 = 0; i2 < i; i2++) {
            arrayList.add(this.trees.get(i2));
        }
        this.trees = arrayList;
    }
}
