package org.deidentifier.arx.aggregates.classification;

import org.apache.mahout.math.Matrix;
import org.apache.mahout.math.OrderedIntDoubleMapping;
import org.apache.mahout.math.Vector;
import org.apache.mahout.math.function.DoubleDoubleFunction;
import org.apache.mahout.math.function.DoubleFunction;
import org.apache.mahout.vectorizer.encoders.ConstantValueEncoder;
import org.apache.mahout.vectorizer.encoders.StaticWordValueEncoder;
import org.deidentifier.arx.DataHandleInternal;
import org.deidentifier.arx.aggregates.ClassificationConfigurationNaiveBayes;
import org.deidentifier.arx.common.WrappedBoolean;
import smile.classification.NaiveBayes;
import smile.classification.TrainingInterrupt;

/* loaded from: input_file:BOOT-INF/lib/libarx-3.8.0.jar:org/deidentifier/arx/aggregates/classification/MultiClassNaiveBayes.class */
public class MultiClassNaiveBayes extends ClassificationMethod {
    private final ClassificationConfigurationNaiveBayes config;
    private final ConstantValueEncoder interceptEncoder;
    private final NaiveBayes nb;
    private final ClassificationDataSpecification specification;
    private final StaticWordValueEncoder wordEncoder;
    private final DataHandleInternal inputHandle;

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:BOOT-INF/lib/libarx-3.8.0.jar:org/deidentifier/arx/aggregates/classification/MultiClassNaiveBayes$NBVector.class */
    public static class NBVector implements Vector {
        private final double[] array;

        private NBVector(int i) {
            this.array = new double[i];
        }

        @Override // org.apache.mahout.math.Vector
        public double aggregate(DoubleDoubleFunction doubleDoubleFunction, DoubleFunction doubleFunction) {
            throw new UnsupportedOperationException();
        }

        @Override // org.apache.mahout.math.Vector
        public double aggregate(Vector vector, DoubleDoubleFunction doubleDoubleFunction, DoubleDoubleFunction doubleDoubleFunction2) {
            throw new UnsupportedOperationException();
        }

        @Override // org.apache.mahout.math.Vector
        public Iterable<Vector.Element> all() {
            throw new UnsupportedOperationException();
        }

        @Override // org.apache.mahout.math.Vector
        public String asFormatString() {
            throw new UnsupportedOperationException();
        }

        @Override // org.apache.mahout.math.Vector
        public Vector assign(double d) {
            throw new UnsupportedOperationException();
        }

        @Override // org.apache.mahout.math.Vector
        public Vector assign(double[] dArr) {
            throw new UnsupportedOperationException();
        }

        @Override // org.apache.mahout.math.Vector
        public Vector assign(Vector vector) {
            throw new UnsupportedOperationException();
        }

        @Override // org.apache.mahout.math.Vector
        public Vector assign(DoubleFunction doubleFunction) {
            throw new UnsupportedOperationException();
        }

        @Override // org.apache.mahout.math.Vector
        public Vector assign(Vector vector, DoubleDoubleFunction doubleDoubleFunction) {
            throw new UnsupportedOperationException();
        }

        @Override // org.apache.mahout.math.Vector
        public Vector assign(DoubleDoubleFunction doubleDoubleFunction, double d) {
            throw new UnsupportedOperationException();
        }

        @Override // org.apache.mahout.math.Vector
        public Matrix cross(Vector vector) {
            throw new UnsupportedOperationException();
        }

        @Override // org.apache.mahout.math.Vector
        public Vector divide(double d) {
            throw new UnsupportedOperationException();
        }

        @Override // org.apache.mahout.math.Vector
        public double dot(Vector vector) {
            throw new UnsupportedOperationException();
        }

        @Override // org.apache.mahout.math.Vector
        public double get(int i) {
            return this.array[i];
        }

        @Override // org.apache.mahout.math.Vector
        public double getDistanceSquared(Vector vector) {
            throw new UnsupportedOperationException();
        }

        @Override // org.apache.mahout.math.Vector
        public Vector.Element getElement(int i) {
            throw new UnsupportedOperationException();
        }

        @Override // org.apache.mahout.math.Vector
        public double getIteratorAdvanceCost() {
            throw new UnsupportedOperationException();
        }

        @Override // org.apache.mahout.math.Vector, org.apache.mahout.math.LengthCachingVector
        public double getLengthSquared() {
            throw new UnsupportedOperationException();
        }

        @Override // org.apache.mahout.math.Vector
        public double getLookupCost() {
            throw new UnsupportedOperationException();
        }

        @Override // org.apache.mahout.math.Vector
        public int getNumNonZeroElements() {
            throw new UnsupportedOperationException();
        }

        @Override // org.apache.mahout.math.Vector
        public int getNumNondefaultElements() {
            throw new UnsupportedOperationException();
        }

        @Override // org.apache.mahout.math.Vector
        public double getQuick(int i) {
            throw new UnsupportedOperationException();
        }

        @Override // org.apache.mahout.math.Vector
        public void incrementQuick(int i, double d) {
            throw new UnsupportedOperationException();
        }

        @Override // org.apache.mahout.math.Vector
        public boolean isAddConstantTime() {
            throw new UnsupportedOperationException();
        }

        @Override // org.apache.mahout.math.Vector
        public boolean isDense() {
            throw new UnsupportedOperationException();
        }

        @Override // org.apache.mahout.math.Vector
        public boolean isSequentialAccess() {
            throw new UnsupportedOperationException();
        }

        @Override // org.apache.mahout.math.Vector
        public Vector like() {
            throw new UnsupportedOperationException();
        }

        @Override // org.apache.mahout.math.Vector
        public Vector like(int i) {
            throw new UnsupportedOperationException();
        }

        @Override // org.apache.mahout.math.Vector
        public Vector logNormalize() {
            throw new UnsupportedOperationException();
        }

        @Override // org.apache.mahout.math.Vector
        public Vector logNormalize(double d) {
            throw new UnsupportedOperationException();
        }

        @Override // org.apache.mahout.math.Vector
        public double maxValue() {
            throw new UnsupportedOperationException();
        }

        @Override // org.apache.mahout.math.Vector
        public int maxValueIndex() {
            throw new UnsupportedOperationException();
        }

        @Override // org.apache.mahout.math.Vector
        public void mergeUpdates(OrderedIntDoubleMapping orderedIntDoubleMapping) {
            throw new UnsupportedOperationException();
        }

        @Override // org.apache.mahout.math.Vector
        public double minValue() {
            throw new UnsupportedOperationException();
        }

        @Override // org.apache.mahout.math.Vector
        public int minValueIndex() {
            throw new UnsupportedOperationException();
        }

        @Override // org.apache.mahout.math.Vector
        public Vector minus(Vector vector) {
            throw new UnsupportedOperationException();
        }

        @Override // org.apache.mahout.math.Vector
        public Iterable<Vector.Element> nonZeroes() {
            throw new UnsupportedOperationException();
        }

        @Override // org.apache.mahout.math.Vector
        public double norm(double d) {
            throw new UnsupportedOperationException();
        }

        @Override // org.apache.mahout.math.Vector
        public Vector normalize() {
            throw new UnsupportedOperationException();
        }

        @Override // org.apache.mahout.math.Vector
        public Vector normalize(double d) {
            throw new UnsupportedOperationException();
        }

        @Override // org.apache.mahout.math.Vector
        public Vector plus(double d) {
            throw new UnsupportedOperationException();
        }

        @Override // org.apache.mahout.math.Vector
        public Vector plus(Vector vector) {
            throw new UnsupportedOperationException();
        }

        @Override // org.apache.mahout.math.Vector
        public void set(int i, double d) {
            this.array[i] = d;
        }

        @Override // org.apache.mahout.math.Vector
        public void setQuick(int i, double d) {
            throw new UnsupportedOperationException();
        }

        @Override // org.apache.mahout.math.Vector
        public int size() {
            return this.array.length;
        }

        @Override // org.apache.mahout.math.Vector
        public Vector times(double d) {
            throw new UnsupportedOperationException();
        }

        @Override // org.apache.mahout.math.Vector
        public Vector times(Vector vector) {
            throw new UnsupportedOperationException();
        }

        @Override // org.apache.mahout.math.Vector
        public Vector viewPart(int i, int i2) {
            throw new UnsupportedOperationException();
        }

        @Override // org.apache.mahout.math.Vector
        public double zSum() {
            throw new UnsupportedOperationException();
        }

        @Override // org.apache.mahout.math.Vector
        /* renamed from: clone, reason: merged with bridge method [inline-methods] */
        public NBVector m4500clone() {
            throw new UnsupportedOperationException();
        }
    }

    public MultiClassNaiveBayes(WrappedBoolean wrappedBoolean, ClassificationDataSpecification classificationDataSpecification, ClassificationConfigurationNaiveBayes classificationConfigurationNaiveBayes, DataHandleInternal dataHandleInternal) {
        super(wrappedBoolean);
        this.config = classificationConfigurationNaiveBayes;
        this.specification = classificationDataSpecification;
        this.inputHandle = dataHandleInternal;
        this.nb = new NaiveBayes(classificationConfigurationNaiveBayes.getType() == ClassificationConfigurationNaiveBayes.Type.BERNOULLI ? NaiveBayes.Model.BERNOULLI : NaiveBayes.Model.MULTINOMIAL, this.specification.classMap.size(), classificationConfigurationNaiveBayes.getVectorLength(), classificationConfigurationNaiveBayes.getSigma(), (TrainingInterrupt) null);
        this.interceptEncoder = new ConstantValueEncoder("intercept");
        this.wordEncoder = new StaticWordValueEncoder("feature");
    }

    @Override // org.deidentifier.arx.aggregates.classification.ClassificationMethod
    public ClassificationResult classify(DataHandleInternal dataHandleInternal, int i) {
        double[] dArr = new double[this.specification.classMap.size()];
        return new MultiClassNaiveBayesClassificationResult(this.nb.predict(encodeFeatures(dataHandleInternal, i, true), dArr), dArr, this.specification.classMap);
    }

    @Override // org.deidentifier.arx.aggregates.classification.ClassificationMethod
    public void close() {
    }

    @Override // org.deidentifier.arx.aggregates.classification.ClassificationMethod
    public void train(DataHandleInternal dataHandleInternal, DataHandleInternal dataHandleInternal2, int i) {
        this.nb.learn(encodeFeatures(dataHandleInternal, i, false), encodeClass(dataHandleInternal2, i));
    }

    private int encodeClass(DataHandleInternal dataHandleInternal, int i) {
        return this.specification.classMap.get(dataHandleInternal.getValue(i, this.specification.classIndex, true)).intValue();
    }

    private double[] encodeFeatures(DataHandleInternal dataHandleInternal, int i, boolean z) {
        NBVector nBVector = new NBVector(this.config.getVectorLength());
        this.interceptEncoder.addToVector("1", nBVector);
        if (this.specification.featureIndices.length == 0) {
            this.wordEncoder.addToVector("Feature:1", 1.0d, nBVector);
            return nBVector.array;
        }
        int i2 = 0;
        for (int i3 : this.specification.featureIndices) {
            ClassificationFeatureMetadata classificationFeatureMetadata = this.specification.featureMetadata[i2];
            String value = (z && classificationFeatureMetadata.isNumericMicroaggregation()) ? this.inputHandle.getValue(i, i3, true) : dataHandleInternal.getValue(i, i3, true);
            Double valueOf = Double.valueOf(classificationFeatureMetadata.getNumericValue(value));
            if (Double.isNaN(valueOf.doubleValue())) {
                this.wordEncoder.addToVector("Attribute-" + i3 + ":" + value, 1.0d, nBVector);
            } else {
                this.wordEncoder.addToVector("Attribute-" + i3, valueOf.doubleValue(), nBVector);
            }
            i2++;
        }
        return nBVector.array;
    }
}
