package ai.idylnlp.nlp.documents.dl4j;

import ai.idylnlp.model.manifest.DocumentModelManifest;
import ai.idylnlp.model.nlp.documents.DeepLearningDocumentClassificationRequest;
import ai.idylnlp.model.nlp.documents.DocumentClassificationEvaluationRequest;
import ai.idylnlp.model.nlp.documents.DocumentClassificationEvaluationResponse;
import ai.idylnlp.model.nlp.documents.DocumentClassificationResponse;
import ai.idylnlp.model.nlp.documents.DocumentClassificationScores;
import ai.idylnlp.model.nlp.documents.DocumentClassifier;
import ai.idylnlp.model.nlp.documents.DocumentClassifierException;
import ai.idylnlp.nlp.documents.dl4j.model.DeepLearningDocumentClassifierConfiguration;
import ai.idylnlp.nlp.documents.dl4j.utils.LabelSeeker;
import ai.idylnlp.nlp.documents.dl4j.utils.MeansBuilder;
import com.neovisionaries.i18n.LanguageCode;
import java.io.File;
import java.io.IOException;
import java.util.HashMap;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import java.util.concurrent.atomic.AtomicInteger;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.deeplearning4j.berkeley.Pair;
import org.deeplearning4j.models.embeddings.inmemory.InMemoryLookupTable;
import org.deeplearning4j.models.embeddings.loader.WordVectorSerializer;
import org.deeplearning4j.models.paragraphvectors.ParagraphVectors;
import org.deeplearning4j.text.documentiterator.FileLabelAwareIterator;
import org.deeplearning4j.text.documentiterator.LabelledDocument;
import org.deeplearning4j.text.tokenization.tokenizer.preprocessor.CommonPreprocessor;
import org.deeplearning4j.text.tokenization.tokenizerfactory.DefaultTokenizerFactory;

/* loaded from: input_file:ai/idylnlp/nlp/documents/dl4j/DeepLearningDocumentClassifier.class */
public class DeepLearningDocumentClassifier implements DocumentClassifier<DeepLearningDocumentClassifierConfiguration, DeepLearningDocumentClassificationRequest> {
    private static final Logger LOGGER = LogManager.getLogger(DeepLearningDocumentClassifier.class);
    private DeepLearningDocumentClassifierConfiguration configuration;
    private Map<LanguageCode, ParagraphVectors> models = new HashMap();

    public DeepLearningDocumentClassifier(DeepLearningDocumentClassifierConfiguration deepLearningDocumentClassifierConfiguration) throws DocumentClassifierException {
        this.configuration = deepLearningDocumentClassifierConfiguration;
        for (DocumentModelManifest documentModelManifest : deepLearningDocumentClassifierConfiguration.getModels()) {
            File file = new File(documentModelManifest.getModelFileName());
            try {
                LOGGER.info("Loading model {}", file.getAbsolutePath());
                this.models.put(documentModelManifest.getLanguageCode(), WordVectorSerializer.readParagraphVectors(file));
            } catch (IOException e) {
                LOGGER.error("Unable to load document classification model {}. Verify the file exists.", e, documentModelManifest.getModelFileName());
            }
        }
    }

    public DocumentClassificationResponse classify(DeepLearningDocumentClassificationRequest deepLearningDocumentClassificationRequest) throws DocumentClassifierException {
        try {
            ParagraphVectors paragraphVectors = this.models.get(deepLearningDocumentClassificationRequest.getLanguageCode());
            if (paragraphVectors == null) {
                throw new DocumentClassifierException("No model for language " + deepLearningDocumentClassificationRequest.getLanguageCode().getAlpha3().toString() + ".");
            }
            DocumentModelManifest documentModelManifest = this.configuration.getModels().stream().filter(documentModelManifest2 -> {
                return documentModelManifest2.getLanguageCode().equals(deepLearningDocumentClassificationRequest.getLanguageCode());
            }).findFirst().get();
            DefaultTokenizerFactory defaultTokenizerFactory = new DefaultTokenizerFactory();
            defaultTokenizerFactory.setTokenPreProcessor(new CommonPreprocessor());
            InMemoryLookupTable lookupTable = paragraphVectors.getLookupTable();
            MeansBuilder meansBuilder = new MeansBuilder(lookupTable, defaultTokenizerFactory);
            LabelSeeker labelSeeker = new LabelSeeker(documentModelManifest.getLabels(), lookupTable);
            LabelledDocument labelledDocument = new LabelledDocument();
            labelledDocument.setContent(deepLearningDocumentClassificationRequest.getText());
            List<Pair<String, Double>> scores = labelSeeker.getScores(meansBuilder.documentAsVector(labelledDocument));
            HashMap hashMap = new HashMap();
            for (Pair<String, Double> pair : scores) {
                hashMap.put(pair.getFirst(), pair.getSecond());
            }
            return new DocumentClassificationResponse(new DocumentClassificationScores(hashMap));
        } catch (Exception e) {
            throw new DocumentClassifierException("Unable to classify document.", e);
        }
    }

    public DocumentClassificationEvaluationResponse evaluate(DocumentClassificationEvaluationRequest documentClassificationEvaluationRequest) throws DocumentClassifierException {
        LinkedHashMap linkedHashMap = new LinkedHashMap();
        FileLabelAwareIterator build = new FileLabelAwareIterator.Builder().addSourceFolder(new File(documentClassificationEvaluationRequest.getDirectory())).build();
        LOGGER.info("Beginning model evaluation using directory {}", documentClassificationEvaluationRequest.getDirectory());
        while (build.hasNext()) {
            LabelledDocument nextDocument = build.nextDocument();
            DocumentClassificationResponse classify = classify(new DeepLearningDocumentClassificationRequest(nextDocument.getContent(), documentClassificationEvaluationRequest.getLanguageCode()));
            String str = (String) nextDocument.getLabels().get(0);
            String str2 = (String) classify.getScores().getPredictedCategory().getLeft();
            linkedHashMap.putIfAbsent(str, new HashMap());
            ((Map) linkedHashMap.get(str)).putIfAbsent(str2, new AtomicInteger(0));
            ((AtomicInteger) ((Map) linkedHashMap.get(str)).get(str2)).incrementAndGet();
        }
        return new DocumentClassificationEvaluationResponse(linkedHashMap);
    }
}
