package org.deidentifier.arx.aggregates.classification;

import com.carrotsearch.hppc.IntArrayList;
import com.carrotsearch.hppc.IntIntOpenHashMap;
import java.util.ArrayList;
import java.util.List;
import org.deidentifier.arx.DataHandleInternal;
import org.deidentifier.arx.aggregates.ClassificationConfigurationRandomForest;
import org.deidentifier.arx.common.WrappedBoolean;
import smile.classification.DecisionTree;
import smile.classification.RandomForest;
import smile.classification.TrainingInterrupt;
import smile.data.Attribute;

/* loaded from: input_file:BOOT-INF/lib/libarx-3.8.0.jar:org/deidentifier/arx/aggregates/classification/MultiClassRandomForest.class */
public class MultiClassRandomForest extends ClassificationMethod {
    private final ClassificationConfigurationRandomForest config;
    private RandomForest rm;
    private final ClassificationDataSpecification specification;
    private List<double[]> features;
    private IntArrayList classes;
    private final int numberOfVariablesToSplit;
    private final DataHandleInternal inputHandle;
    private IntIntOpenHashMap mapping;

    public MultiClassRandomForest(WrappedBoolean wrappedBoolean, ClassificationDataSpecification classificationDataSpecification, ClassificationConfigurationRandomForest classificationConfigurationRandomForest, DataHandleInternal dataHandleInternal) {
        super(wrappedBoolean);
        this.features = new ArrayList();
        this.classes = new IntArrayList();
        this.config = classificationConfigurationRandomForest;
        this.specification = classificationDataSpecification;
        this.inputHandle = dataHandleInternal;
        if (classificationConfigurationRandomForest.getNumberOfVariablesToSplit() == 0) {
            this.numberOfVariablesToSplit = (int) Math.floor(Math.sqrt(this.specification.featureIndices.length));
        } else {
            this.numberOfVariablesToSplit = classificationConfigurationRandomForest.getNumberOfVariablesToSplit();
        }
    }

    @Override // org.deidentifier.arx.aggregates.classification.ClassificationMethod
    public ClassificationResult classify(DataHandleInternal dataHandleInternal, int i) {
        double[] dArr = new double[this.mapping.size()];
        int i2 = this.mapping.get(this.rm.predict(encodeFeatures(dataHandleInternal, i, true), dArr));
        double[] dArr2 = new double[this.specification.classMap.size()];
        for (int i3 = 0; i3 < dArr.length; i3++) {
            dArr2[this.mapping.get(i3)] = dArr[i3];
        }
        return new MultiClassRandomForestClassificationResult(i2, dArr2, this.specification.classMap);
    }

    @Override // org.deidentifier.arx.aggregates.classification.ClassificationMethod
    public void close() {
        DecisionTree.SplitRule splitRule;
        switch (this.config.getSplitRule()) {
            case CLASSIFICATION_ERROR:
                splitRule = DecisionTree.SplitRule.CLASSIFICATION_ERROR;
                break;
            case ENTROPY:
                splitRule = DecisionTree.SplitRule.ENTROPY;
                break;
            case GINI:
                splitRule = DecisionTree.SplitRule.GINI;
                break;
            default:
                throw new IllegalStateException("Unknown split rule");
        }
        this.mapping = new IntIntOpenHashMap();
        IntIntOpenHashMap intIntOpenHashMap = new IntIntOpenHashMap();
        int[] iArr = new int[this.classes.size()];
        for (int i = 0; i < this.classes.size(); i++) {
            int i2 = this.classes.get(i);
            int size = intIntOpenHashMap.size();
            if (intIntOpenHashMap.containsKey(i2)) {
                size = intIntOpenHashMap.lget();
            } else {
                intIntOpenHashMap.put(i2, size);
                this.mapping.put(size, i2);
            }
            iArr[i] = size;
        }
        this.rm = new RandomForest((Attribute[]) null, (double[][]) this.features.toArray((Object[]) new double[this.features.size()]), iArr, this.config.getNumberOfTrees(), this.config.getMaximumNumberOfLeafNodes(), this.config.getMinimumSizeOfLeafNodes(), this.numberOfVariablesToSplit, this.config.getSubsample(), splitRule, new TrainingInterrupt() { // from class: org.deidentifier.arx.aggregates.classification.MultiClassRandomForest.1
            @Override // smile.classification.TrainingInterrupt
            public boolean isInterrupted() {
                return MultiClassRandomForest.this.interrupt.value;
            }
        });
        this.features.clear();
        this.classes.clear();
        this.features = new ArrayList();
        this.classes = new IntArrayList();
    }

    @Override // org.deidentifier.arx.aggregates.classification.ClassificationMethod
    public void train(DataHandleInternal dataHandleInternal, DataHandleInternal dataHandleInternal2, int i) {
        this.features.add(encodeFeatures(dataHandleInternal, i, false));
        this.classes.add(encodeClass(dataHandleInternal2, i));
    }

    private int encodeClass(DataHandleInternal dataHandleInternal, int i) {
        return this.specification.classMap.get(dataHandleInternal.getValue(i, this.specification.classIndex, true)).intValue();
    }

    private double[] encodeFeatures(DataHandleInternal dataHandleInternal, int i, boolean z) {
        double[] dArr = new double[this.specification.featureIndices.length];
        if (this.specification.featureIndices.length == 0) {
            return dArr;
        }
        int i2 = 0;
        for (int i3 : this.specification.featureIndices) {
            ClassificationFeatureMetadata classificationFeatureMetadata = this.specification.featureMetadata[i2];
            Double valueOf = Double.valueOf(classificationFeatureMetadata.getNumericValue((z && classificationFeatureMetadata.isNumericMicroaggregation()) ? this.inputHandle.getValue(i, i3, true) : dataHandleInternal.getValue(i, i3, true)));
            if (Double.isNaN(valueOf.doubleValue())) {
                dArr[i2] = dataHandleInternal.getValueIdentifier(i3, r17);
            } else {
                dArr[i2] = valueOf.doubleValue();
            }
            i2++;
        }
        return dArr;
    }
}
