package org.apache.mahout.math;

import java.util.Iterator;
import org.apache.mahout.math.Vector;
import org.apache.mahout.math.function.DoubleDoubleFunction;
import org.apache.mahout.math.set.OpenIntHashSet;

/* loaded from: input_file:BOOT-INF/lib/libarx-3.8.0.jar:org/apache/mahout/math/VectorBinaryAggregate.class */
public abstract class VectorBinaryAggregate {
    public static final VectorBinaryAggregate[] OPERATIONS = {new AggregateNonzerosIterateThisLookupThat(), new AggregateNonzerosIterateThatLookupThis(), new AggregateIterateIntersection(), new AggregateIterateUnionSequential(), new AggregateIterateUnionRandom(), new AggregateAllIterateSequential(), new AggregateAllIterateThisLookupThat(), new AggregateAllIterateThatLookupThis(), new AggregateAllLoop()};

    /* loaded from: input_file:BOOT-INF/lib/libarx-3.8.0.jar:org/apache/mahout/math/VectorBinaryAggregate$AggregateAllIterateSequential.class */
    public static class AggregateAllIterateSequential extends VectorBinaryAggregate {
        @Override // org.apache.mahout.math.VectorBinaryAggregate
        public boolean isValid(Vector vector, Vector vector2, DoubleDoubleFunction doubleDoubleFunction, DoubleDoubleFunction doubleDoubleFunction2) {
            return vector.isSequentialAccess() && vector2.isSequentialAccess() && !vector.isDense() && !vector2.isDense();
        }

        @Override // org.apache.mahout.math.VectorBinaryAggregate
        public double estimateCost(Vector vector, Vector vector2, DoubleDoubleFunction doubleDoubleFunction, DoubleDoubleFunction doubleDoubleFunction2) {
            return Math.max(vector.size() * vector.getIteratorAdvanceCost(), vector2.size() * vector2.getIteratorAdvanceCost());
        }

        @Override // org.apache.mahout.math.VectorBinaryAggregate
        public double aggregate(Vector vector, Vector vector2, DoubleDoubleFunction doubleDoubleFunction, DoubleDoubleFunction doubleDoubleFunction2) {
            Iterator<Vector.Element> it = vector.all().iterator();
            Iterator<Vector.Element> it2 = vector2.all().iterator();
            boolean z = false;
            double d = 0.0d;
            while (it.hasNext() && it2.hasNext()) {
                double apply = doubleDoubleFunction2.apply(it.next().get(), it2.next().get());
                if (z) {
                    d = doubleDoubleFunction.apply(d, apply);
                } else {
                    d = apply;
                    z = true;
                }
            }
            return d;
        }
    }

    /* loaded from: input_file:BOOT-INF/lib/libarx-3.8.0.jar:org/apache/mahout/math/VectorBinaryAggregate$AggregateAllIterateThatLookupThis.class */
    public static class AggregateAllIterateThatLookupThis extends VectorBinaryAggregate {
        @Override // org.apache.mahout.math.VectorBinaryAggregate
        public boolean isValid(Vector vector, Vector vector2, DoubleDoubleFunction doubleDoubleFunction, DoubleDoubleFunction doubleDoubleFunction2) {
            return (doubleDoubleFunction.isAssociativeAndCommutative() || vector2.isSequentialAccess()) && !vector2.isDense();
        }

        @Override // org.apache.mahout.math.VectorBinaryAggregate
        public double estimateCost(Vector vector, Vector vector2, DoubleDoubleFunction doubleDoubleFunction, DoubleDoubleFunction doubleDoubleFunction2) {
            return vector2.size() * vector2.getIteratorAdvanceCost() * vector.getLookupCost();
        }

        @Override // org.apache.mahout.math.VectorBinaryAggregate
        public double aggregate(Vector vector, Vector vector2, DoubleDoubleFunction doubleDoubleFunction, DoubleDoubleFunction doubleDoubleFunction2) {
            boolean z = false;
            double d = 0.0d;
            for (Vector.Element element : vector2.all()) {
                double apply = doubleDoubleFunction2.apply(vector.getQuick(element.index()), element.get());
                if (z) {
                    d = doubleDoubleFunction.apply(d, apply);
                } else {
                    d = apply;
                    z = true;
                }
            }
            return d;
        }
    }

    /* loaded from: input_file:BOOT-INF/lib/libarx-3.8.0.jar:org/apache/mahout/math/VectorBinaryAggregate$AggregateAllIterateThisLookupThat.class */
    public static class AggregateAllIterateThisLookupThat extends VectorBinaryAggregate {
        @Override // org.apache.mahout.math.VectorBinaryAggregate
        public boolean isValid(Vector vector, Vector vector2, DoubleDoubleFunction doubleDoubleFunction, DoubleDoubleFunction doubleDoubleFunction2) {
            return (doubleDoubleFunction.isAssociativeAndCommutative() || vector.isSequentialAccess()) && !vector.isDense();
        }

        @Override // org.apache.mahout.math.VectorBinaryAggregate
        public double estimateCost(Vector vector, Vector vector2, DoubleDoubleFunction doubleDoubleFunction, DoubleDoubleFunction doubleDoubleFunction2) {
            return vector.size() * vector.getIteratorAdvanceCost() * vector2.getLookupCost();
        }

        @Override // org.apache.mahout.math.VectorBinaryAggregate
        public double aggregate(Vector vector, Vector vector2, DoubleDoubleFunction doubleDoubleFunction, DoubleDoubleFunction doubleDoubleFunction2) {
            boolean z = false;
            double d = 0.0d;
            for (Vector.Element element : vector.all()) {
                double apply = doubleDoubleFunction2.apply(element.get(), vector2.getQuick(element.index()));
                if (z) {
                    d = doubleDoubleFunction.apply(d, apply);
                } else {
                    d = apply;
                    z = true;
                }
            }
            return d;
        }
    }

    /* loaded from: input_file:BOOT-INF/lib/libarx-3.8.0.jar:org/apache/mahout/math/VectorBinaryAggregate$AggregateAllLoop.class */
    public static class AggregateAllLoop extends VectorBinaryAggregate {
        @Override // org.apache.mahout.math.VectorBinaryAggregate
        public boolean isValid(Vector vector, Vector vector2, DoubleDoubleFunction doubleDoubleFunction, DoubleDoubleFunction doubleDoubleFunction2) {
            return true;
        }

        @Override // org.apache.mahout.math.VectorBinaryAggregate
        public double estimateCost(Vector vector, Vector vector2, DoubleDoubleFunction doubleDoubleFunction, DoubleDoubleFunction doubleDoubleFunction2) {
            return vector.size() * vector.getLookupCost() * vector2.getLookupCost();
        }

        @Override // org.apache.mahout.math.VectorBinaryAggregate
        public double aggregate(Vector vector, Vector vector2, DoubleDoubleFunction doubleDoubleFunction, DoubleDoubleFunction doubleDoubleFunction2) {
            double apply = doubleDoubleFunction2.apply(vector.getQuick(0), vector2.getQuick(0));
            int size = vector.size();
            for (int i = 1; i < size; i++) {
                apply = doubleDoubleFunction.apply(apply, doubleDoubleFunction2.apply(vector.getQuick(i), vector2.getQuick(i)));
            }
            return apply;
        }
    }

    /* loaded from: input_file:BOOT-INF/lib/libarx-3.8.0.jar:org/apache/mahout/math/VectorBinaryAggregate$AggregateIterateIntersection.class */
    public static class AggregateIterateIntersection extends VectorBinaryAggregate {
        @Override // org.apache.mahout.math.VectorBinaryAggregate
        public boolean isValid(Vector vector, Vector vector2, DoubleDoubleFunction doubleDoubleFunction, DoubleDoubleFunction doubleDoubleFunction2) {
            return doubleDoubleFunction.isLikeRightPlus() && doubleDoubleFunction2.isLikeMult() && vector.isSequentialAccess() && vector2.isSequentialAccess();
        }

        @Override // org.apache.mahout.math.VectorBinaryAggregate
        public double estimateCost(Vector vector, Vector vector2, DoubleDoubleFunction doubleDoubleFunction, DoubleDoubleFunction doubleDoubleFunction2) {
            return Math.min(vector.getNumNondefaultElements() * vector.getIteratorAdvanceCost(), vector2.getNumNondefaultElements() * vector2.getIteratorAdvanceCost());
        }

        @Override // org.apache.mahout.math.VectorBinaryAggregate
        public double aggregate(Vector vector, Vector vector2, DoubleDoubleFunction doubleDoubleFunction, DoubleDoubleFunction doubleDoubleFunction2) {
            Iterator<Vector.Element> it = vector.nonZeroes().iterator();
            Iterator<Vector.Element> it2 = vector2.nonZeroes().iterator();
            Vector.Element element = null;
            Vector.Element element2 = null;
            boolean z = true;
            boolean z2 = true;
            boolean z3 = false;
            double d = 0.0d;
            while (true) {
                if (z) {
                    if (!it.hasNext()) {
                        break;
                    }
                    element = it.next();
                }
                if (z2) {
                    if (!it2.hasNext()) {
                        break;
                    }
                    element2 = it2.next();
                }
                if (element.index() == element2.index()) {
                    double apply = doubleDoubleFunction2.apply(element.get(), element2.get());
                    if (z3) {
                        d = doubleDoubleFunction.apply(d, apply);
                    } else {
                        d = apply;
                        z3 = true;
                    }
                    z = true;
                    z2 = true;
                } else if (element.index() < element2.index()) {
                    z = true;
                    z2 = false;
                } else {
                    z = false;
                    z2 = true;
                }
            }
            return d;
        }
    }

    /* loaded from: input_file:BOOT-INF/lib/libarx-3.8.0.jar:org/apache/mahout/math/VectorBinaryAggregate$AggregateIterateUnionRandom.class */
    public static class AggregateIterateUnionRandom extends VectorBinaryAggregate {
        @Override // org.apache.mahout.math.VectorBinaryAggregate
        public boolean isValid(Vector vector, Vector vector2, DoubleDoubleFunction doubleDoubleFunction, DoubleDoubleFunction doubleDoubleFunction2) {
            return doubleDoubleFunction.isLikeRightPlus() && !doubleDoubleFunction2.isDensifying() && (doubleDoubleFunction.isAssociativeAndCommutative() || (vector.isSequentialAccess() && vector2.isSequentialAccess()));
        }

        @Override // org.apache.mahout.math.VectorBinaryAggregate
        public double estimateCost(Vector vector, Vector vector2, DoubleDoubleFunction doubleDoubleFunction, DoubleDoubleFunction doubleDoubleFunction2) {
            return Math.max(vector.getNumNondefaultElements() * vector.getIteratorAdvanceCost() * vector2.getLookupCost(), vector2.getNumNondefaultElements() * vector2.getIteratorAdvanceCost() * vector.getLookupCost());
        }

        @Override // org.apache.mahout.math.VectorBinaryAggregate
        public double aggregate(Vector vector, Vector vector2, DoubleDoubleFunction doubleDoubleFunction, DoubleDoubleFunction doubleDoubleFunction2) {
            OpenIntHashSet openIntHashSet = new OpenIntHashSet();
            boolean z = false;
            double d = 0.0d;
            for (Vector.Element element : vector.nonZeroes()) {
                double apply = doubleDoubleFunction2.apply(element.get(), vector2.getQuick(element.index()));
                if (z) {
                    d = doubleDoubleFunction.apply(d, apply);
                } else {
                    d = apply;
                    z = true;
                }
                openIntHashSet.add(element.index());
            }
            for (Vector.Element element2 : vector2.nonZeroes()) {
                if (!openIntHashSet.contains(element2.index())) {
                    double apply2 = doubleDoubleFunction2.apply(vector.getQuick(element2.index()), element2.get());
                    if (z) {
                        d = doubleDoubleFunction.apply(d, apply2);
                    } else {
                        d = apply2;
                        z = true;
                    }
                }
            }
            return d;
        }
    }

    /* loaded from: input_file:BOOT-INF/lib/libarx-3.8.0.jar:org/apache/mahout/math/VectorBinaryAggregate$AggregateIterateUnionSequential.class */
    public static class AggregateIterateUnionSequential extends VectorBinaryAggregate {
        @Override // org.apache.mahout.math.VectorBinaryAggregate
        public boolean isValid(Vector vector, Vector vector2, DoubleDoubleFunction doubleDoubleFunction, DoubleDoubleFunction doubleDoubleFunction2) {
            return doubleDoubleFunction.isLikeRightPlus() && !doubleDoubleFunction2.isDensifying() && vector.isSequentialAccess() && vector2.isSequentialAccess();
        }

        @Override // org.apache.mahout.math.VectorBinaryAggregate
        public double estimateCost(Vector vector, Vector vector2, DoubleDoubleFunction doubleDoubleFunction, DoubleDoubleFunction doubleDoubleFunction2) {
            return Math.max(vector.getNumNondefaultElements() * vector.getIteratorAdvanceCost(), vector2.getNumNondefaultElements() * vector2.getIteratorAdvanceCost());
        }

        @Override // org.apache.mahout.math.VectorBinaryAggregate
        public double aggregate(Vector vector, Vector vector2, DoubleDoubleFunction doubleDoubleFunction, DoubleDoubleFunction doubleDoubleFunction2) {
            double apply;
            Iterator<Vector.Element> it = vector.nonZeroes().iterator();
            Iterator<Vector.Element> it2 = vector2.nonZeroes().iterator();
            Vector.Element element = null;
            Vector.Element element2 = null;
            boolean z = true;
            boolean z2 = true;
            boolean z3 = false;
            double d = 0.0d;
            while (true) {
                if (z) {
                    element = it.hasNext() ? it.next() : null;
                }
                if (z2) {
                    element2 = it2.hasNext() ? it2.next() : null;
                }
                if (element == null || element2 == null) {
                    if (element != null) {
                        apply = doubleDoubleFunction2.apply(element.get(), 0.0d);
                        z = true;
                        z2 = false;
                    } else {
                        if (element2 == null) {
                            return d;
                        }
                        apply = doubleDoubleFunction2.apply(0.0d, element2.get());
                        z = false;
                        z2 = true;
                    }
                } else if (element.index() == element2.index()) {
                    apply = doubleDoubleFunction2.apply(element.get(), element2.get());
                    z = true;
                    z2 = true;
                } else if (element.index() < element2.index()) {
                    apply = doubleDoubleFunction2.apply(element.get(), 0.0d);
                    z = true;
                    z2 = false;
                } else {
                    apply = doubleDoubleFunction2.apply(0.0d, element2.get());
                    z = false;
                    z2 = true;
                }
                if (z3) {
                    d = doubleDoubleFunction.apply(d, apply);
                } else {
                    d = apply;
                    z3 = true;
                }
            }
        }
    }

    /* loaded from: input_file:BOOT-INF/lib/libarx-3.8.0.jar:org/apache/mahout/math/VectorBinaryAggregate$AggregateNonzerosIterateThatLookupThis.class */
    public static class AggregateNonzerosIterateThatLookupThis extends VectorBinaryAggregate {
        @Override // org.apache.mahout.math.VectorBinaryAggregate
        public boolean isValid(Vector vector, Vector vector2, DoubleDoubleFunction doubleDoubleFunction, DoubleDoubleFunction doubleDoubleFunction2) {
            return doubleDoubleFunction.isLikeRightPlus() && (doubleDoubleFunction.isAssociativeAndCommutative() || vector2.isSequentialAccess()) && doubleDoubleFunction2.isLikeRightMult();
        }

        @Override // org.apache.mahout.math.VectorBinaryAggregate
        public double estimateCost(Vector vector, Vector vector2, DoubleDoubleFunction doubleDoubleFunction, DoubleDoubleFunction doubleDoubleFunction2) {
            return vector2.getNumNondefaultElements() * vector2.getIteratorAdvanceCost() * vector.getLookupCost() * vector.getLookupCost();
        }

        @Override // org.apache.mahout.math.VectorBinaryAggregate
        public double aggregate(Vector vector, Vector vector2, DoubleDoubleFunction doubleDoubleFunction, DoubleDoubleFunction doubleDoubleFunction2) {
            Iterator<Vector.Element> it = vector2.nonZeroes().iterator();
            if (!it.hasNext()) {
                return 0.0d;
            }
            Vector.Element next = it.next();
            double apply = doubleDoubleFunction2.apply(vector.getQuick(next.index()), next.get());
            while (true) {
                double d = apply;
                if (!it.hasNext()) {
                    return d;
                }
                Vector.Element next2 = it.next();
                apply = doubleDoubleFunction.apply(d, doubleDoubleFunction2.apply(vector.getQuick(next2.index()), next2.get()));
            }
        }
    }

    /* loaded from: input_file:BOOT-INF/lib/libarx-3.8.0.jar:org/apache/mahout/math/VectorBinaryAggregate$AggregateNonzerosIterateThisLookupThat.class */
    public static class AggregateNonzerosIterateThisLookupThat extends VectorBinaryAggregate {
        @Override // org.apache.mahout.math.VectorBinaryAggregate
        public boolean isValid(Vector vector, Vector vector2, DoubleDoubleFunction doubleDoubleFunction, DoubleDoubleFunction doubleDoubleFunction2) {
            return doubleDoubleFunction.isLikeRightPlus() && (doubleDoubleFunction.isAssociativeAndCommutative() || vector.isSequentialAccess()) && doubleDoubleFunction2.isLikeLeftMult();
        }

        @Override // org.apache.mahout.math.VectorBinaryAggregate
        public double estimateCost(Vector vector, Vector vector2, DoubleDoubleFunction doubleDoubleFunction, DoubleDoubleFunction doubleDoubleFunction2) {
            return vector.getNumNondefaultElements() * vector.getIteratorAdvanceCost() * vector2.getLookupCost();
        }

        @Override // org.apache.mahout.math.VectorBinaryAggregate
        public double aggregate(Vector vector, Vector vector2, DoubleDoubleFunction doubleDoubleFunction, DoubleDoubleFunction doubleDoubleFunction2) {
            Iterator<Vector.Element> it = vector.nonZeroes().iterator();
            if (!it.hasNext()) {
                return 0.0d;
            }
            Vector.Element next = it.next();
            double apply = doubleDoubleFunction2.apply(next.get(), vector2.getQuick(next.index()));
            while (true) {
                double d = apply;
                if (!it.hasNext()) {
                    return d;
                }
                Vector.Element next2 = it.next();
                apply = doubleDoubleFunction.apply(d, doubleDoubleFunction2.apply(next2.get(), vector2.getQuick(next2.index())));
            }
        }
    }

    public abstract boolean isValid(Vector vector, Vector vector2, DoubleDoubleFunction doubleDoubleFunction, DoubleDoubleFunction doubleDoubleFunction2);

    public abstract double estimateCost(Vector vector, Vector vector2, DoubleDoubleFunction doubleDoubleFunction, DoubleDoubleFunction doubleDoubleFunction2);

    public abstract double aggregate(Vector vector, Vector vector2, DoubleDoubleFunction doubleDoubleFunction, DoubleDoubleFunction doubleDoubleFunction2);

    public static VectorBinaryAggregate getBestOperation(Vector vector, Vector vector2, DoubleDoubleFunction doubleDoubleFunction, DoubleDoubleFunction doubleDoubleFunction2) {
        int i = -1;
        double d = Double.POSITIVE_INFINITY;
        for (int i2 = 0; i2 < OPERATIONS.length; i2++) {
            if (OPERATIONS[i2].isValid(vector, vector2, doubleDoubleFunction, doubleDoubleFunction2)) {
                double estimateCost = OPERATIONS[i2].estimateCost(vector, vector2, doubleDoubleFunction, doubleDoubleFunction2);
                if (estimateCost < d) {
                    d = estimateCost;
                    i = i2;
                }
            }
        }
        return OPERATIONS[i];
    }

    public static double aggregateBest(Vector vector, Vector vector2, DoubleDoubleFunction doubleDoubleFunction, DoubleDoubleFunction doubleDoubleFunction2) {
        return getBestOperation(vector, vector2, doubleDoubleFunction, doubleDoubleFunction2).aggregate(vector, vector2, doubleDoubleFunction, doubleDoubleFunction2);
    }
}
