package org.apache.mahout.common.distance;

import com.google.common.base.Preconditions;
import com.google.common.collect.Lists;
import com.google.common.io.Closeables;
import java.io.FileNotFoundException;
import java.io.IOException;
import java.util.Collection;
import java.util.List;
import org.apache.hadoop.conf.Configuration;
import org.apache.hadoop.fs.FSDataInputStream;
import org.apache.hadoop.fs.FileSystem;
import org.apache.hadoop.fs.Path;
import org.apache.mahout.common.ClassUtils;
import org.apache.mahout.common.parameters.ClassParameter;
import org.apache.mahout.common.parameters.Parameter;
import org.apache.mahout.common.parameters.Parametered;
import org.apache.mahout.common.parameters.PathParameter;
import org.apache.mahout.math.Algebra;
import org.apache.mahout.math.CardinalityException;
import org.apache.mahout.math.DenseMatrix;
import org.apache.mahout.math.DenseVector;
import org.apache.mahout.math.Matrix;
import org.apache.mahout.math.MatrixWritable;
import org.apache.mahout.math.SingularValueDecomposition;
import org.apache.mahout.math.Vector;
import org.apache.mahout.math.VectorWritable;

/* loaded from: input_file:BOOT-INF/lib/libarx-3.8.0.jar:org/apache/mahout/common/distance/MahalanobisDistanceMeasure.class */
public class MahalanobisDistanceMeasure implements DistanceMeasure {
    private Matrix inverseCovarianceMatrix;
    private Vector meanVector;
    private ClassParameter vectorClass;
    private ClassParameter matrixClass;
    private List<Parameter<?>> parameters;
    private Parameter<Path> inverseCovarianceFile;
    private Parameter<Path> meanVectorFile;

    @Override // org.apache.mahout.common.parameters.Parametered
    public void configure(Configuration configuration) {
        FSDataInputStream open;
        if (this.parameters == null) {
            Parametered.ParameteredGeneralizations.configureParameters(this, configuration);
        }
        try {
            if (this.inverseCovarianceFile.get() != null) {
                FileSystem fileSystem = FileSystem.get(this.inverseCovarianceFile.get().toUri(), configuration);
                MatrixWritable matrixWritable = (MatrixWritable) ClassUtils.instantiateAs(this.matrixClass.get(), MatrixWritable.class);
                if (!fileSystem.exists(this.inverseCovarianceFile.get())) {
                    throw new FileNotFoundException(this.inverseCovarianceFile.get().toString());
                }
                open = fileSystem.open(this.inverseCovarianceFile.get());
                try {
                    matrixWritable.readFields(open);
                    Closeables.close(open, true);
                    this.inverseCovarianceMatrix = matrixWritable.get();
                    Preconditions.checkArgument(this.inverseCovarianceMatrix != null, "inverseCovarianceMatrix not initialized");
                } finally {
                }
            }
            if (this.meanVectorFile.get() != null) {
                FileSystem fileSystem2 = FileSystem.get(this.meanVectorFile.get().toUri(), configuration);
                VectorWritable vectorWritable = (VectorWritable) ClassUtils.instantiateAs(this.vectorClass.get(), VectorWritable.class);
                if (!fileSystem2.exists(this.meanVectorFile.get())) {
                    throw new FileNotFoundException(this.meanVectorFile.get().toString());
                }
                open = fileSystem2.open(this.meanVectorFile.get());
                try {
                    vectorWritable.readFields(open);
                    Closeables.close(open, true);
                    this.meanVector = vectorWritable.get();
                    Preconditions.checkArgument(this.meanVector != null, "meanVector not initialized");
                } finally {
                }
            }
        } catch (IOException e) {
            throw new IllegalStateException(e);
        }
    }

    @Override // org.apache.mahout.common.parameters.Parametered
    public Collection<Parameter<?>> getParameters() {
        return this.parameters;
    }

    @Override // org.apache.mahout.common.parameters.Parametered
    public void createParameters(String str, Configuration configuration) {
        this.parameters = Lists.newArrayList();
        this.inverseCovarianceFile = new PathParameter(str, "inverseCovarianceFile", configuration, null, "Path on DFS to a file containing the inverse covariance matrix.");
        this.parameters.add(this.inverseCovarianceFile);
        this.matrixClass = new ClassParameter(str, "maxtrixClass", configuration, DenseMatrix.class, "Class<Matix> file specified in parameter inverseCovarianceFile has been serialized with.");
        this.parameters.add(this.matrixClass);
        this.meanVectorFile = new PathParameter(str, "meanVectorFile", configuration, null, "Path on DFS to a file containing the mean Vector.");
        this.parameters.add(this.meanVectorFile);
        this.vectorClass = new ClassParameter(str, "vectorClass", configuration, DenseVector.class, "Class file specified in parameter meanVectorFile has been serialized with.");
        this.parameters.add(this.vectorClass);
    }

    public double distance(Vector vector) {
        return Math.sqrt(vector.minus(this.meanVector).dot(Algebra.mult(this.inverseCovarianceMatrix, vector.minus(this.meanVector))));
    }

    @Override // org.apache.mahout.common.distance.DistanceMeasure
    public double distance(Vector vector, Vector vector2) {
        if (vector.size() != vector2.size()) {
            throw new CardinalityException(vector.size(), vector2.size());
        }
        return Math.sqrt(vector.minus(vector2).dot(Algebra.mult(this.inverseCovarianceMatrix, vector.minus(vector2))));
    }

    @Override // org.apache.mahout.common.distance.DistanceMeasure
    public double distance(double d, Vector vector, Vector vector2) {
        return distance(vector, vector2);
    }

    public void setInverseCovarianceMatrix(Matrix matrix) {
        Preconditions.checkArgument(matrix != null, "inverseCovarianceMatrix not initialized");
        this.inverseCovarianceMatrix = matrix;
    }

    public void setCovarianceMatrix(Matrix matrix) {
        if (matrix.numRows() != matrix.numCols()) {
            throw new CardinalityException(matrix.numRows(), matrix.numCols());
        }
        SingularValueDecomposition singularValueDecomposition = new SingularValueDecomposition(matrix);
        Matrix s = singularValueDecomposition.getS();
        for (int i = 0; i < s.numRows(); i++) {
            double d = s.get(i, i);
            if (d <= 0.0d) {
                throw new IllegalStateException("Eigen Value equals to 0 found.");
            }
            s.set(i, i, 1.0d / d);
        }
        this.inverseCovarianceMatrix = singularValueDecomposition.getU().times(s.times(singularValueDecomposition.getU().transpose()));
        Preconditions.checkArgument(this.inverseCovarianceMatrix != null, "inverseCovarianceMatrix not initialized");
    }

    public Matrix getInverseCovarianceMatrix() {
        return this.inverseCovarianceMatrix;
    }

    public void setMeanVector(Vector vector) {
        Preconditions.checkArgument(vector != null, "meanVector not initialized");
        this.meanVector = vector;
    }

    public Vector getMeanVector() {
        return this.meanVector;
    }
}
