package ai.preferred.regression.pe;

import ai.preferred.regression.io.CSVInputData;
import ai.preferred.regression.io.CSVUtils;
import ai.preferred.regression.pe.data.Vocabulary;
import com.google.common.collect.HashMultiset;
import com.google.common.collect.ImmutableMultiset;
import com.google.common.collect.Multiset;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.Comparator;
import java.util.Iterator;
import java.util.regex.Pattern;
import org.apache.commons.csv.CSVPrinter;
import org.kohsuke.args4j.Option;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:ai/preferred/regression/pe/EncodeTextAsFrequency.class */
public class EncodeTextAsFrequency extends ProcessingElement {
    private static final Logger LOGGER = LoggerFactory.getLogger(EncodeTextAsFrequency.class);

    @Option(name = "-c", aliases = {"--column"}, usage = "the index of the input column", required = true)
    private int column;

    @Option(name = "-s", aliases = {"--separator"}, usage = "specifies regular expression for splitting text into words")
    private String separator = "\\W+";

    @Option(name = "-n", aliases = {"--number-of-words"}, usage = "the maximum number of words to keep")
    private int numberOfWords = 1000;

    @Option(name = "-p", aliases = {"--prefix"}, usage = "the prefix of the new columns")
    private String prefix = "WORD:";

    private static <T> Comparator<Multiset.Entry<T>> getDecreasingCountComparator() {
        return (entry, entry2) -> {
            return Integer.compare(entry2.getCount(), entry.getCount());
        };
    }

    private static String[] toLowerCase(String[] strArr) {
        String[] strArr2 = new String[strArr.length];
        for (int i = 0; i < strArr.length; i++) {
            strArr2[i] = strArr[i].toLowerCase();
        }
        return strArr2;
    }

    private static String[] trimEmpty(String[] strArr) {
        ArrayList arrayList = new ArrayList();
        for (String str : strArr) {
            if (!str.trim().isEmpty()) {
                arrayList.add(str);
            }
        }
        return (String[]) arrayList.toArray(new String[0]);
    }

    private static Multiset<String> toBagOfWords(String str, String str2) {
        return ImmutableMultiset.copyOf(toLowerCase(toLowerCase(trimEmpty(Pattern.compile(str2).split(str)))));
    }

    private Vocabulary buildVocabulary(CSVInputData cSVInputData, int i) {
        HashMultiset create = HashMultiset.create();
        Iterator<ArrayList<String>> it = cSVInputData.iterator();
        while (it.hasNext()) {
            create.addAll(toBagOfWords(it.next().get(this.column), this.separator));
        }
        ArrayList arrayList = new ArrayList(create.entrySet());
        arrayList.sort(getDecreasingCountComparator());
        ArrayList arrayList2 = new ArrayList(i);
        Iterator it2 = arrayList.subList(0, Math.min(arrayList.size(), i)).iterator();
        while (it2.hasNext()) {
            arrayList2.add((String) ((Multiset.Entry) it2.next()).getElement());
        }
        return new Vocabulary(arrayList2);
    }

    @Override // ai.preferred.regression.pe.ProcessingElement
    protected void process(CSVInputData cSVInputData, CSVPrinter cSVPrinter) throws IOException {
        Vocabulary buildVocabulary = buildVocabulary(cSVInputData, this.numberOfWords);
        if (cSVInputData.hasHeader()) {
            ArrayList<String> header = cSVInputData.getHeader();
            header.remove(this.column);
            Iterator<String> it = buildVocabulary.getVocabularyList().iterator();
            while (it.hasNext()) {
                header.add(this.prefix + it.next());
            }
            cSVPrinter.printRecord(header);
        }
        Iterator<ArrayList<String>> it2 = cSVInputData.iterator();
        while (it2.hasNext()) {
            ArrayList<String> next = it2.next();
            Multiset<String> bagOfWords = toBagOfWords(next.get(this.column), this.separator);
            Integer[] numArr = new Integer[buildVocabulary.size()];
            Arrays.fill((Object[]) numArr, (Object) 0);
            for (Multiset.Entry entry : bagOfWords.entrySet()) {
                int index = buildVocabulary.getIndex((String) entry.getElement());
                if (index != -1) {
                    numArr[index] = Integer.valueOf(entry.getCount());
                }
            }
            next.remove(this.column);
            Collections.addAll(next, CSVUtils.toStringArray(numArr));
            cSVPrinter.printRecord(next);
        }
    }

    public static void main(String[] strArr) {
        parseArgsAndRun(EncodeTextAsFrequency.class, strArr);
    }
}
