package ai.idylnlp.nlp.documents.dl4j.utils;

import java.util.Iterator;
import java.util.List;
import java.util.concurrent.atomic.AtomicInteger;
import org.deeplearning4j.models.embeddings.inmemory.InMemoryLookupTable;
import org.deeplearning4j.models.word2vec.VocabWord;
import org.deeplearning4j.models.word2vec.wordstore.VocabCache;
import org.deeplearning4j.text.documentiterator.LabelledDocument;
import org.deeplearning4j.text.tokenization.tokenizerfactory.TokenizerFactory;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;

/* loaded from: input_file:ai/idylnlp/nlp/documents/dl4j/utils/MeansBuilder.class */
public class MeansBuilder {
    private VocabCache<VocabWord> vocabCache;
    private InMemoryLookupTable<VocabWord> lookupTable;
    private TokenizerFactory tokenizerFactory;

    public MeansBuilder(InMemoryLookupTable<VocabWord> inMemoryLookupTable, TokenizerFactory tokenizerFactory) {
        this.lookupTable = inMemoryLookupTable;
        this.vocabCache = inMemoryLookupTable.getVocab();
        this.tokenizerFactory = tokenizerFactory;
    }

    public INDArray documentAsVector(LabelledDocument labelledDocument) {
        List<String> tokens = this.tokenizerFactory.create(labelledDocument.getContent()).getTokens();
        AtomicInteger atomicInteger = new AtomicInteger(0);
        Iterator it = tokens.iterator();
        while (it.hasNext()) {
            if (this.vocabCache.containsWord((String) it.next())) {
                atomicInteger.incrementAndGet();
            }
        }
        INDArray create = Nd4j.create(atomicInteger.get(), this.lookupTable.layerSize());
        atomicInteger.set(0);
        for (String str : tokens) {
            if (this.vocabCache.containsWord(str)) {
                create.putRow(atomicInteger.getAndIncrement(), this.lookupTable.vector(str));
            }
        }
        return create.mean(new int[]{0});
    }
}
