package org.apache.mahout.classifier.sgd;

import com.google.common.collect.Lists;
import java.io.DataInput;
import java.io.DataOutput;
import java.io.IOException;
import java.util.Iterator;
import java.util.List;
import org.apache.hadoop.io.Writable;
import org.apache.mahout.classifier.AbstractVectorClassifier;
import org.apache.mahout.classifier.OnlineLearner;
import org.apache.mahout.math.DenseVector;
import org.apache.mahout.math.Vector;
import org.apache.mahout.math.function.DoubleDoubleFunction;
import org.apache.mahout.math.function.Functions;
import org.apache.mahout.math.stats.GlobalOnlineAuc;
import org.apache.mahout.math.stats.OnlineAuc;

/* loaded from: input_file:BOOT-INF/lib/libarx-3.8.0.jar:org/apache/mahout/classifier/sgd/CrossFoldLearner.class */
public class CrossFoldLearner extends AbstractVectorClassifier implements OnlineLearner, Writable {
    private int record;
    private static final double MIN_SCORE = 1.0E-50d;
    private OnlineAuc auc;
    private double logLikelihood;
    private final List<OnlineLogisticRegression> models;
    private double[] parameters;
    private int numFeatures;
    private PriorFunction prior;
    private double percentCorrect;
    private int windowSize;

    public CrossFoldLearner() {
        this.auc = new GlobalOnlineAuc();
        this.models = Lists.newArrayList();
        this.parameters = new double[4];
        this.windowSize = Integer.MAX_VALUE;
    }

    public CrossFoldLearner(int i, int i2, int i3, PriorFunction priorFunction) {
        this.auc = new GlobalOnlineAuc();
        this.models = Lists.newArrayList();
        this.parameters = new double[4];
        this.windowSize = Integer.MAX_VALUE;
        this.numFeatures = i3;
        this.prior = priorFunction;
        for (int i4 = 0; i4 < i; i4++) {
            OnlineLogisticRegression onlineLogisticRegression = new OnlineLogisticRegression(i2, i3, priorFunction);
            onlineLogisticRegression.alpha(1.0d).stepOffset(0).decayExponent(0.0d);
            this.models.add(onlineLogisticRegression);
        }
    }

    public CrossFoldLearner lambda(double d) {
        Iterator<OnlineLogisticRegression> it = this.models.iterator();
        while (it.hasNext()) {
            it.next().lambda(d);
        }
        return this;
    }

    public CrossFoldLearner learningRate(double d) {
        Iterator<OnlineLogisticRegression> it = this.models.iterator();
        while (it.hasNext()) {
            it.next().learningRate(d);
        }
        return this;
    }

    public CrossFoldLearner stepOffset(int i) {
        Iterator<OnlineLogisticRegression> it = this.models.iterator();
        while (it.hasNext()) {
            it.next().stepOffset(i);
        }
        return this;
    }

    public CrossFoldLearner decayExponent(double d) {
        Iterator<OnlineLogisticRegression> it = this.models.iterator();
        while (it.hasNext()) {
            it.next().decayExponent(d);
        }
        return this;
    }

    public CrossFoldLearner alpha(double d) {
        Iterator<OnlineLogisticRegression> it = this.models.iterator();
        while (it.hasNext()) {
            it.next().alpha(d);
        }
        return this;
    }

    @Override // org.apache.mahout.classifier.OnlineLearner
    public void train(int i, Vector vector) {
        train(this.record, null, i, vector);
    }

    @Override // org.apache.mahout.classifier.OnlineLearner
    public void train(long j, int i, Vector vector) {
        train(j, null, i, vector);
    }

    @Override // org.apache.mahout.classifier.OnlineLearner
    public void train(long j, String str, int i, Vector vector) {
        this.record++;
        int i2 = 0;
        for (OnlineLogisticRegression onlineLogisticRegression : this.models) {
            if (i2 == mod(j, this.models.size())) {
                Vector classifyFull = onlineLogisticRegression.classifyFull(vector);
                this.logLikelihood += (Math.log(Math.max(classifyFull.get(i), MIN_SCORE)) - this.logLikelihood) / Math.min(this.record, this.windowSize);
                this.percentCorrect += ((classifyFull.maxValueIndex() == i ? 1 : 0) - this.percentCorrect) / Math.min(this.record, this.windowSize);
                if (numCategories() == 2) {
                    this.auc.addSample(i, str, classifyFull.get(1));
                }
            } else {
                onlineLogisticRegression.train(j, str, i, vector);
            }
            i2++;
        }
    }

    private static long mod(long j, int i) {
        long j2 = j % i;
        return j2 < 0 ? j2 + i : j2;
    }

    @Override // org.apache.mahout.classifier.OnlineLearner, java.io.Closeable, java.lang.AutoCloseable
    public void close() {
        Iterator<OnlineLogisticRegression> it = this.models.iterator();
        while (it.hasNext()) {
            it.next().close();
        }
    }

    public void resetLineCounter() {
        this.record = 0;
    }

    public boolean validModel() {
        boolean z = true;
        Iterator<OnlineLogisticRegression> it = this.models.iterator();
        while (it.hasNext()) {
            z &= it.next().validModel();
        }
        return z;
    }

    @Override // org.apache.mahout.classifier.AbstractVectorClassifier
    public Vector classify(Vector vector) {
        DenseVector denseVector = new DenseVector(numCategories() - 1);
        DoubleDoubleFunction plusMult = Functions.plusMult(1.0d / this.models.size());
        Iterator<OnlineLogisticRegression> it = this.models.iterator();
        while (it.hasNext()) {
            denseVector.assign(it.next().classify(vector), plusMult);
        }
        return denseVector;
    }

    @Override // org.apache.mahout.classifier.AbstractVectorClassifier
    public Vector classifyNoLink(Vector vector) {
        DenseVector denseVector = new DenseVector(numCategories() - 1);
        DoubleDoubleFunction plusMult = Functions.plusMult(1.0d / this.models.size());
        Iterator<OnlineLogisticRegression> it = this.models.iterator();
        while (it.hasNext()) {
            denseVector.assign(it.next().classifyNoLink(vector), plusMult);
        }
        return denseVector;
    }

    @Override // org.apache.mahout.classifier.AbstractVectorClassifier
    public double classifyScalar(Vector vector) {
        double d = 0.0d;
        int i = 0;
        Iterator<OnlineLogisticRegression> it = this.models.iterator();
        while (it.hasNext()) {
            i++;
            d += it.next().classifyScalar(vector);
        }
        return d / i;
    }

    @Override // org.apache.mahout.classifier.AbstractVectorClassifier
    public int numCategories() {
        return this.models.get(0).numCategories();
    }

    public double auc() {
        return this.auc.auc();
    }

    public double logLikelihood() {
        return this.logLikelihood;
    }

    public double percentCorrect() {
        return this.percentCorrect;
    }

    public CrossFoldLearner copy() {
        CrossFoldLearner crossFoldLearner = new CrossFoldLearner(this.models.size(), numCategories(), this.numFeatures, this.prior);
        crossFoldLearner.models.clear();
        for (OnlineLogisticRegression onlineLogisticRegression : this.models) {
            onlineLogisticRegression.close();
            OnlineLogisticRegression onlineLogisticRegression2 = new OnlineLogisticRegression(onlineLogisticRegression.numCategories(), onlineLogisticRegression.numFeatures(), onlineLogisticRegression.prior);
            onlineLogisticRegression2.copyFrom(onlineLogisticRegression);
            crossFoldLearner.models.add(onlineLogisticRegression2);
        }
        return crossFoldLearner;
    }

    public int getRecord() {
        return this.record;
    }

    public void setRecord(int i) {
        this.record = i;
    }

    public OnlineAuc getAucEvaluator() {
        return this.auc;
    }

    public void setAucEvaluator(OnlineAuc onlineAuc) {
        this.auc = onlineAuc;
    }

    public double getLogLikelihood() {
        return this.logLikelihood;
    }

    public void setLogLikelihood(double d) {
        this.logLikelihood = d;
    }

    public List<OnlineLogisticRegression> getModels() {
        return this.models;
    }

    public void addModel(OnlineLogisticRegression onlineLogisticRegression) {
        this.models.add(onlineLogisticRegression);
    }

    public double[] getParameters() {
        return this.parameters;
    }

    public void setParameters(double[] dArr) {
        this.parameters = dArr;
    }

    public int getNumFeatures() {
        return this.numFeatures;
    }

    public void setNumFeatures(int i) {
        this.numFeatures = i;
    }

    public void setWindowSize(int i) {
        this.windowSize = i;
        this.auc.setWindowSize(i);
    }

    public PriorFunction getPrior() {
        return this.prior;
    }

    public void setPrior(PriorFunction priorFunction) {
        this.prior = priorFunction;
    }

    @Override // org.apache.hadoop.io.Writable
    public void write(DataOutput dataOutput) throws IOException {
        dataOutput.writeInt(this.record);
        PolymorphicWritable.write(dataOutput, this.auc);
        dataOutput.writeDouble(this.logLikelihood);
        dataOutput.writeInt(this.models.size());
        Iterator<OnlineLogisticRegression> it = this.models.iterator();
        while (it.hasNext()) {
            it.next().write(dataOutput);
        }
        for (double d : this.parameters) {
            dataOutput.writeDouble(d);
        }
        dataOutput.writeInt(this.numFeatures);
        PolymorphicWritable.write(dataOutput, this.prior);
        dataOutput.writeDouble(this.percentCorrect);
        dataOutput.writeInt(this.windowSize);
    }

    @Override // org.apache.hadoop.io.Writable
    public void readFields(DataInput dataInput) throws IOException {
        this.record = dataInput.readInt();
        this.auc = (OnlineAuc) PolymorphicWritable.read(dataInput, OnlineAuc.class);
        this.logLikelihood = dataInput.readDouble();
        int readInt = dataInput.readInt();
        for (int i = 0; i < readInt; i++) {
            OnlineLogisticRegression onlineLogisticRegression = new OnlineLogisticRegression();
            onlineLogisticRegression.readFields(dataInput);
            this.models.add(onlineLogisticRegression);
        }
        this.parameters = new double[4];
        for (int i2 = 0; i2 < 4; i2++) {
            this.parameters[i2] = dataInput.readDouble();
        }
        this.numFeatures = dataInput.readInt();
        this.prior = (PriorFunction) PolymorphicWritable.read(dataInput, PriorFunction.class);
        this.percentCorrect = dataInput.readDouble();
        this.windowSize = dataInput.readInt();
    }
}
