package ai.idylnlp.nlp.features;

import ai.idylnlp.nlp.utils.ngrams.NgramUtils;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Iterator;
import java.util.Map;
import java.util.Set;
import java.util.stream.IntStream;
import org.apache.commons.collections4.Bag;
import org.apache.commons.collections4.bag.HashBag;

/* loaded from: input_file:ai/idylnlp/nlp/features/BagOfWords.class */
public class BagOfWords {
    private Bag<String> bag;

    public BagOfWords(String[] strArr) {
        this.bag = new HashBag();
        for (String str : strArr) {
            this.bag.add(str);
        }
    }

    public BagOfWords(String[] strArr, int i) {
        this.bag = new HashBag();
        for (String str : strArr) {
            this.bag.add(str);
        }
        removeBelowMinimum(i);
    }

    public BagOfWords(String[] strArr, int i, int i2) {
        if (i2 < 2) {
            throw new IllegalArgumentException("Length of n-grams must be at least 2.");
        }
        this.bag = new HashBag();
        for (String str : NgramUtils.getNgrams(strArr, i2)) {
            this.bag.add(str);
        }
        removeBelowMinimum(i);
    }

    public static Map<String, double[]> normalize(Set<BagOfWords> set) {
        HashMap hashMap = new HashMap();
        HashSet<String> hashSet = new HashSet();
        Iterator<BagOfWords> it = set.iterator();
        while (it.hasNext()) {
            hashSet.addAll(it.next().uniqueSet());
        }
        for (String str : hashSet) {
            int[] iArr = new int[set.size()];
            int i = 0;
            Iterator<BagOfWords> it2 = set.iterator();
            while (it2.hasNext()) {
                int i2 = i;
                i++;
                iArr[i2] = it2.next().getCount(str);
            }
            int sum = IntStream.range(0, iArr.length).map(i3 -> {
                return iArr[i3];
            }).sum();
            double[] dArr = new double[set.size()];
            IntStream.range(0, iArr.length).forEach(i4 -> {
                dArr[i4] = iArr[i4] / sum;
            });
            hashMap.put(str, dArr);
        }
        return hashMap;
    }

    public Set<String> uniqueSet() {
        return this.bag.uniqueSet();
    }

    public int getCount(String str) {
        return this.bag.getCount(str);
    }

    public int size() {
        return this.bag.size();
    }

    public Iterator<String> iterator() {
        return this.bag.iterator();
    }

    public boolean isEmpty() {
        return this.bag.isEmpty();
    }

    public boolean contains(String str) {
        return this.bag.contains(str);
    }

    public void clear() {
        this.bag.clear();
    }

    private void removeBelowMinimum(int i) {
        this.bag.removeIf(str -> {
            return this.bag.getCount(str) < i;
        });
    }
}
