package org.apache.mahout.classifier.naivebayes.training;

import com.google.common.base.Splitter;
import java.io.IOException;
import org.apache.hadoop.conf.Configuration;
import org.apache.hadoop.fs.Path;
import org.apache.hadoop.io.IntWritable;
import org.apache.hadoop.io.Text;
import org.apache.hadoop.mapreduce.Job;
import org.apache.hadoop.mapreduce.lib.input.SequenceFileInputFormat;
import org.apache.hadoop.mapreduce.lib.output.SequenceFileOutputFormat;
import org.apache.hadoop.util.ToolRunner;
import org.apache.mahout.classifier.naivebayes.BayesUtils;
import org.apache.mahout.classifier.naivebayes.NaiveBayesModel;
import org.apache.mahout.common.AbstractJob;
import org.apache.mahout.common.HadoopUtil;
import org.apache.mahout.common.commandline.DefaultOptionCreator;
import org.apache.mahout.common.iterator.sequencefile.PathFilters;
import org.apache.mahout.common.iterator.sequencefile.PathType;
import org.apache.mahout.common.iterator.sequencefile.SequenceFileDirIterable;
import org.apache.mahout.common.mapreduce.VectorSumReducer;
import org.apache.mahout.math.VectorWritable;
import org.apache.pdfbox.pdmodel.common.PDPageLabelRange;

/* loaded from: input_file:BOOT-INF/lib/libarx-3.8.0.jar:org/apache/mahout/classifier/naivebayes/training/TrainNaiveBayesJob.class */
public final class TrainNaiveBayesJob extends AbstractJob {
    private static final String TRAIN_COMPLEMENTARY = "trainComplementary";
    private static final String ALPHA_I = "alphaI";
    private static final String LABEL_INDEX = "labelIndex";
    private static final String EXTRACT_LABELS = "extractLabels";
    private static final String LABELS = "labels";
    public static final String WEIGHTS_PER_FEATURE = "__SPF";
    public static final String WEIGHTS_PER_LABEL = "__SPL";
    public static final String LABEL_THETA_NORMALIZER = "_LTN";
    public static final String SUMMED_OBSERVATIONS = "summedObservations";
    public static final String WEIGHTS = "weights";
    public static final String THETAS = "thetas";

    public static void main(String[] strArr) throws Exception {
        ToolRunner.run(new Configuration(), new TrainNaiveBayesJob(), strArr);
    }

    @Override // org.apache.hadoop.util.Tool
    public int run(String[] strArr) throws Exception {
        addInputOption();
        addOutputOption();
        addOption(LABELS, "l", "comma-separated list of labels to include in training", false);
        addOption(buildOption(EXTRACT_LABELS, "el", "Extract the labels from the input", false, false, ""));
        addOption(ALPHA_I, PDPageLabelRange.STYLE_LETTERS_LOWER, "smoothing parameter", String.valueOf(1.0f));
        addOption(buildOption(TRAIN_COMPLEMENTARY, "c", "train complementary?", false, false, String.valueOf(false)));
        addOption(LABEL_INDEX, "li", "The path to store the label index in", false);
        addOption(DefaultOptionCreator.overwriteOption().create());
        if (parseArguments(strArr) == null) {
            return -1;
        }
        if (hasOption("overwrite")) {
            HadoopUtil.delete(getConf(), getOutputPath());
            HadoopUtil.delete(getConf(), getTempPath());
        }
        String option = getOption(LABEL_INDEX);
        Path path = option != null ? new Path(option) : getTempPath(LABEL_INDEX);
        long createLabelIndex = createLabelIndex(path);
        float parseFloat = Float.parseFloat(getOption(ALPHA_I));
        boolean hasOption = hasOption(TRAIN_COMPLEMENTARY);
        HadoopUtil.setSerializations(getConf());
        HadoopUtil.cacheFiles(path, getConf());
        Job prepareJob = prepareJob(getInputPath(), getTempPath(SUMMED_OBSERVATIONS), SequenceFileInputFormat.class, IndexInstancesMapper.class, IntWritable.class, VectorWritable.class, VectorSumReducer.class, IntWritable.class, VectorWritable.class, SequenceFileOutputFormat.class);
        prepareJob.setCombinerClass(VectorSumReducer.class);
        if (!prepareJob.waitForCompletion(true)) {
            return -1;
        }
        Job prepareJob2 = prepareJob(getTempPath(SUMMED_OBSERVATIONS), getTempPath(WEIGHTS), SequenceFileInputFormat.class, WeightsMapper.class, Text.class, VectorWritable.class, VectorSumReducer.class, Text.class, VectorWritable.class, SequenceFileOutputFormat.class);
        prepareJob2.getConfiguration().set(WeightsMapper.NUM_LABELS, String.valueOf(createLabelIndex));
        prepareJob2.setCombinerClass(VectorSumReducer.class);
        if (!prepareJob2.waitForCompletion(true)) {
            return -1;
        }
        HadoopUtil.cacheFiles(getTempPath(WEIGHTS), getConf());
        Job prepareJob3 = prepareJob(getTempPath(SUMMED_OBSERVATIONS), getTempPath(THETAS), SequenceFileInputFormat.class, ThetaMapper.class, Text.class, VectorWritable.class, VectorSumReducer.class, Text.class, VectorWritable.class, SequenceFileOutputFormat.class);
        prepareJob3.setCombinerClass(VectorSumReducer.class);
        prepareJob3.getConfiguration().setFloat(ThetaMapper.ALPHA_I, parseFloat);
        prepareJob3.getConfiguration().setBoolean(ThetaMapper.TRAIN_COMPLEMENTARY, hasOption);
        getConf().setFloat(ThetaMapper.ALPHA_I, parseFloat);
        NaiveBayesModel readModelFromDir = BayesUtils.readModelFromDir(getTempPath(), getConf());
        readModelFromDir.validate();
        readModelFromDir.serialize(getOutputPath(), getConf());
        return 0;
    }

    private long createLabelIndex(Path path) throws IOException {
        long j = 0;
        if (hasOption(LABELS)) {
            j = BayesUtils.writeLabelIndex(getConf(), Splitter.on(",").split(getOption(LABELS)), path);
        } else if (hasOption(EXTRACT_LABELS)) {
            j = BayesUtils.writeLabelIndex(getConf(), path, new SequenceFileDirIterable(getInputPath(), PathType.LIST, PathFilters.logsCRCFilter(), getConf()));
        }
        return j;
    }
}
