package ai.idylnlp.models.deeplearning.training;

import ai.idylnlp.model.nlp.annotation.AnnotationTypes;
import ai.idylnlp.model.nlp.subjects.CoNLL2003SubjectOfTrainingOrEvaluation;
import ai.idylnlp.model.nlp.subjects.IdylNLPSubjectOfTrainingOrEvaluation;
import ai.idylnlp.model.nlp.subjects.OpenNLPSubjectOfTrainingOrEvaluation;
import ai.idylnlp.model.nlp.subjects.SubjectOfTrainingOrEvaluation;
import ai.idylnlp.models.ObjectStreamUtils;
import ai.idylnlp.models.deeplearning.training.model.DeepLearningTrainingDefinition;
import ai.idylnlp.models.deeplearning.training.model.HyperParameters;
import ai.idylnlp.nlp.recognizer.deep.NameSampleDataSetIterator;
import com.google.gson.Gson;
import com.google.gson.GsonBuilder;
import java.io.File;
import java.io.IOException;
import java.util.ArrayList;
import java.util.UUID;
import java.util.concurrent.TimeUnit;
import org.apache.commons.lang3.StringUtils;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.deeplearning4j.earlystopping.EarlyStoppingConfiguration;
import org.deeplearning4j.earlystopping.EarlyStoppingResult;
import org.deeplearning4j.earlystopping.saver.LocalFileModelSaver;
import org.deeplearning4j.earlystopping.scorecalc.DataSetLossCalculator;
import org.deeplearning4j.earlystopping.termination.EpochTerminationCondition;
import org.deeplearning4j.earlystopping.termination.IterationTerminationCondition;
import org.deeplearning4j.earlystopping.termination.MaxTimeIterationTerminationCondition;
import org.deeplearning4j.earlystopping.termination.ScoreImprovementEpochTerminationCondition;
import org.deeplearning4j.earlystopping.trainer.EarlyStoppingTrainer;
import org.deeplearning4j.eval.Evaluation;
import org.deeplearning4j.models.embeddings.loader.WordVectorSerializer;
import org.deeplearning4j.models.embeddings.wordvectors.WordVectors;
import org.deeplearning4j.nn.conf.LearningRatePolicy;
import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.layers.GravesLSTM;
import org.deeplearning4j.nn.conf.layers.RnnOutputLayer;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.deeplearning4j.optimize.listeners.ScoreIterationListener;
import org.deeplearning4j.parallelism.ParallelWrapper;
import org.deeplearning4j.ui.stats.StatsListener;
import org.deeplearning4j.ui.storage.FileStatsStorage;
import org.deeplearning4j.util.ModelSerializer;
import org.nd4j.linalg.activations.Activation;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.dataset.DataSet;
import org.nd4j.linalg.lossfunctions.LossFunctions;

/* loaded from: input_file:ai/idylnlp/models/deeplearning/training/DeepLearningEntityModelOperations.class */
public class DeepLearningEntityModelOperations {
    private static final Logger LOGGER = LogManager.getLogger(DeepLearningEntityModelOperations.class);
    private Gson gson;

    public DeepLearningEntityModelOperations() {
        GsonBuilder gsonBuilder = new GsonBuilder();
        gsonBuilder.serializeSpecialFloatingPointValues();
        this.gson = gsonBuilder.setPrettyPrinting().create();
    }

    public String train(DeepLearningTrainingDefinition deepLearningTrainingDefinition) throws IOException {
        LOGGER.info("Starting training.");
        GsonBuilder gsonBuilder = new GsonBuilder();
        gsonBuilder.serializeSpecialFloatingPointValues();
        LOGGER.debug(gsonBuilder.setPrettyPrinting().create().toJson(deepLearningTrainingDefinition, DeepLearningTrainingDefinition.class));
        WordVectors loadStaticModel = WordVectorSerializer.loadStaticModel(new File(deepLearningTrainingDefinition.getTrainingData().getWordVectorsFile()));
        int length = loadStaticModel.getWordVector(loadStaticModel.vocab().wordAtIndex(0)).length;
        String[] labels = getLabels(deepLearningTrainingDefinition.getEntityType());
        LOGGER.debug("Using vector size: {}", Integer.valueOf(length));
        NameSampleDataSetIterator nameSampleDataSetIterator = new NameSampleDataSetIterator(ObjectStreamUtils.getObjectStream(getSubjectOfTraining(deepLearningTrainingDefinition)), loadStaticModel, length, deepLearningTrainingDefinition.getHyperParameters().getWindowSize(), labels, deepLearningTrainingDefinition.getHyperParameters().getBatchSize());
        NameSampleDataSetIterator nameSampleDataSetIterator2 = new NameSampleDataSetIterator(ObjectStreamUtils.getObjectStream(getSubjectOfEvaluation(deepLearningTrainingDefinition)), loadStaticModel, length, deepLearningTrainingDefinition.getHyperParameters().getWindowSize(), labels, deepLearningTrainingDefinition.getHyperParameters().getBatchSize());
        MultiLayerNetwork buildNetwork = buildNetwork(buildNetworkConfiguration(deepLearningTrainingDefinition.getHyperParameters(), length), deepLearningTrainingDefinition);
        if (deepLearningTrainingDefinition.getEarlyTermination() != null) {
            LOGGER.info("Enabling early-termination training.");
            EarlyStoppingConfiguration.Builder builder = new EarlyStoppingConfiguration.Builder();
            if (deepLearningTrainingDefinition.getEarlyTermination().getMaxEpochs() != null) {
                builder.epochTerminationConditions(new EpochTerminationCondition[]{new ScoreImprovementEpochTerminationCondition(deepLearningTrainingDefinition.getEarlyTermination().getMaxEpochs().intValue())});
            }
            if (deepLearningTrainingDefinition.getEarlyTermination().getMaxMinutes() != null) {
                builder.iterationTerminationConditions(new IterationTerminationCondition[]{new MaxTimeIterationTerminationCondition(deepLearningTrainingDefinition.getEarlyTermination().getMaxMinutes().intValue(), TimeUnit.MINUTES)});
            }
            builder.scoreCalculator(new DataSetLossCalculator(nameSampleDataSetIterator2, true));
            builder.evaluateEveryNEpochs(1);
            builder.modelSaver(new LocalFileModelSaver(System.getProperty("java.io.tmpdir")));
            EarlyStoppingResult fit = new EarlyStoppingTrainer(builder.build(), buildNetwork, nameSampleDataSetIterator).fit();
            buildNetwork = fit.getBestModel();
            LOGGER.info("Termination reason: " + fit.getTerminationReason());
            LOGGER.info("Termination details: " + fit.getTerminationDetails());
            LOGGER.info("Total epochs: " + fit.getTotalEpochs());
            LOGGER.info("Best epoch number: " + fit.getBestModelEpoch());
            LOGGER.info("Score at best epoch: " + fit.getBestModelScore());
        } else if (deepLearningTrainingDefinition.getParallelTraining() != null) {
            LOGGER.info("Doing parallel training.");
            ParallelWrapper build = new ParallelWrapper.Builder(buildNetwork).prefetchBuffer(deepLearningTrainingDefinition.getParallelTraining().getPrefetchBuffer()).workers(deepLearningTrainingDefinition.getParallelTraining().getWorkers()).reportScoreAfterAveraging(deepLearningTrainingDefinition.getParallelTraining().isReportScoreAfterAveraging()).averagingFrequency(deepLearningTrainingDefinition.getParallelTraining().getAveragingFrequency()).useLegacyAveraging(deepLearningTrainingDefinition.getParallelTraining().isLegacyAveraging()).build();
            for (int i = 1; i <= deepLearningTrainingDefinition.getHyperParameters().getEpochs(); i++) {
                build.fit(nameSampleDataSetIterator);
                nameSampleDataSetIterator.reset();
                LOGGER.info("Finished epoch {}", Integer.valueOf(i));
                Evaluation evaluation = new Evaluation();
                while (nameSampleDataSetIterator2.hasNext()) {
                    DataSet dataSet = (DataSet) nameSampleDataSetIterator2.next();
                    INDArray featureMatrix = dataSet.getFeatureMatrix();
                    INDArray labels2 = dataSet.getLabels();
                    INDArray featuresMaskArray = dataSet.getFeaturesMaskArray();
                    INDArray labelsMaskArray = dataSet.getLabelsMaskArray();
                    evaluation.evalTimeSeries(labels2, buildNetwork.output(featureMatrix, false, featuresMaskArray, labelsMaskArray), labelsMaskArray);
                }
                nameSampleDataSetIterator2.reset();
                LOGGER.info("Evaluation statistics:\n{}", evaluation.stats());
            }
        } else {
            LOGGER.info("Doing single node training.");
            for (int i2 = 1; i2 <= deepLearningTrainingDefinition.getHyperParameters().getEpochs(); i2++) {
                buildNetwork.fit(nameSampleDataSetIterator);
                nameSampleDataSetIterator.reset();
                LOGGER.info("Finished epoch {}", Integer.valueOf(i2));
                Evaluation evaluation2 = new Evaluation();
                while (nameSampleDataSetIterator2.hasNext()) {
                    DataSet dataSet2 = (DataSet) nameSampleDataSetIterator2.next();
                    INDArray featureMatrix2 = dataSet2.getFeatureMatrix();
                    INDArray labels3 = dataSet2.getLabels();
                    INDArray featuresMaskArray2 = dataSet2.getFeaturesMaskArray();
                    INDArray labelsMaskArray2 = dataSet2.getLabelsMaskArray();
                    evaluation2.evalTimeSeries(labels3, buildNetwork.output(featureMatrix2, false, featuresMaskArray2, labelsMaskArray2), labelsMaskArray2);
                }
                nameSampleDataSetIterator2.reset();
                LOGGER.info("Evaluation statistics:\n{}", evaluation2.stats());
            }
        }
        File file = new File(deepLearningTrainingDefinition.getOutput().getOutputFile());
        ModelSerializer.writeModel(buildNetwork, file, false);
        LOGGER.info("Model serialized to {}", file.getAbsolutePath());
        return UUID.randomUUID().toString();
    }

    public Gson getGson() {
        return this.gson;
    }

    public DeepLearningTrainingDefinition deserializeDefinition(String str) throws IOException {
        return (DeepLearningTrainingDefinition) this.gson.fromJson(str, DeepLearningTrainingDefinition.class);
    }

    private MultiLayerNetwork buildNetwork(MultiLayerConfiguration multiLayerConfiguration, DeepLearningTrainingDefinition deepLearningTrainingDefinition) {
        MultiLayerNetwork multiLayerNetwork = new MultiLayerNetwork(multiLayerConfiguration);
        multiLayerNetwork.init();
        ArrayList arrayList = new ArrayList();
        if (StringUtils.isNotEmpty(deepLearningTrainingDefinition.getOutput().getStatsFile())) {
            arrayList.add(new StatsListener(new FileStatsStorage(new File(deepLearningTrainingDefinition.getOutput().getStatsFile()))));
        }
        arrayList.add(new ScoreIterationListener(deepLearningTrainingDefinition.getMonitoring().getScoreIteration()));
        multiLayerNetwork.setListeners(arrayList);
        return multiLayerNetwork;
    }

    private MultiLayerConfiguration buildNetworkConfiguration(HyperParameters hyperParameters, int i) {
        NeuralNetConfiguration.Builder builder = new NeuralNetConfiguration.Builder();
        builder.seed(hyperParameters.getSeed());
        builder.biasInit(hyperParameters.getNetworkConfigurationParameters().getBiasInit());
        builder.convolutionMode(hyperParameters.getConvolutionModeParam());
        builder.dropOut(hyperParameters.getNetworkConfigurationParameters().getDropOut());
        builder.iterations(hyperParameters.getNetworkConfigurationParameters().getIterations());
        builder.regularization(hyperParameters.getNetworkConfigurationParameters().getRegularizationParameters().getRegularization());
        builder.l1(hyperParameters.getNetworkConfigurationParameters().getRegularizationParameters().getL1().doubleValue());
        builder.l1Bias(hyperParameters.getNetworkConfigurationParameters().getRegularizationParameters().getL1Bias().doubleValue());
        builder.l2(hyperParameters.getNetworkConfigurationParameters().getRegularizationParameters().getL2().doubleValue());
        builder.l2Bias(hyperParameters.getNetworkConfigurationParameters().getRegularizationParameters().getL2Bias().doubleValue());
        builder.updater(hyperParameters.getNetworkConfigurationParameters().getUpdaterParameters().getUpdaterParam());
        builder.useDropConnect(hyperParameters.getNetworkConfigurationParameters().isUseDropConnect().booleanValue());
        builder.optimizationAlgo(hyperParameters.getNetworkConfigurationParameters().getOptimizationAlgorithmParam());
        builder.gradientNormalization(hyperParameters.getNetworkConfigurationParameters().getGradientNormalizationParam());
        builder.gradientNormalizationThreshold(hyperParameters.getNetworkConfigurationParameters().getGradientNormalizationThreshold());
        builder.weightInit(hyperParameters.getNetworkConfigurationParameters().getWeightInitParam());
        return builder.list().layer(0, new GravesLSTM.Builder().nIn(i).nOut(256).activation(Activation.TANH).learningRateDecayPolicy(LearningRatePolicy.Schedule).learningRateSchedule(hyperParameters.getNetworkConfigurationParameters().getLayers().getLayer1().getLearningRateScheduleParam()).biasLearningRate(hyperParameters.getNetworkConfigurationParameters().getLayers().getLayer1().getBiasLearningRate()).build()).layer(1, new RnnOutputLayer.Builder().nIn(256).nOut(3).activation(Activation.SOFTMAX).lossFunction(LossFunctions.LossFunction.MCXENT).learningRateDecayPolicy(LearningRatePolicy.Schedule).learningRateSchedule(hyperParameters.getNetworkConfigurationParameters().getLayers().getLayer2().getLearningRateScheduleParam()).biasLearningRate(hyperParameters.getNetworkConfigurationParameters().getLayers().getLayer2().getBiasLearningRate()).build()).pretrain(hyperParameters.getNetworkConfigurationParameters().isPretrain()).backprop(hyperParameters.getNetworkConfigurationParameters().isBackprop()).build();
    }

    private String[] getLabels(String str) {
        return new String[]{str + "-start", str + "-cont", "other"};
    }

    private SubjectOfTrainingOrEvaluation getSubjectOfTraining(DeepLearningTrainingDefinition deepLearningTrainingDefinition) {
        String inputFile = deepLearningTrainingDefinition.getTrainingData().getInputFile();
        if (deepLearningTrainingDefinition.getTrainingData().getFormat().equalsIgnoreCase(AnnotationTypes.IDYLNLP.getName())) {
            return new IdylNLPSubjectOfTrainingOrEvaluation(inputFile, deepLearningTrainingDefinition.getTrainingData().getAnnotationsFile());
        }
        if (deepLearningTrainingDefinition.getTrainingData().getFormat().equalsIgnoreCase(AnnotationTypes.CONLL2003.getName())) {
            return new CoNLL2003SubjectOfTrainingOrEvaluation(inputFile);
        }
        LOGGER.info("Defaulting to OpenNLP subject of training.");
        return new OpenNLPSubjectOfTrainingOrEvaluation(inputFile);
    }

    private SubjectOfTrainingOrEvaluation getSubjectOfEvaluation(DeepLearningTrainingDefinition deepLearningTrainingDefinition) {
        String inputFile = deepLearningTrainingDefinition.getEvaluationData().getInputFile();
        if (deepLearningTrainingDefinition.getEvaluationData().getFormat().equalsIgnoreCase(AnnotationTypes.IDYLNLP.getName())) {
            return new IdylNLPSubjectOfTrainingOrEvaluation(inputFile, deepLearningTrainingDefinition.getTrainingData().getAnnotationsFile());
        }
        if (deepLearningTrainingDefinition.getEvaluationData().getFormat().equalsIgnoreCase(AnnotationTypes.CONLL2003.getName())) {
            return new CoNLL2003SubjectOfTrainingOrEvaluation(inputFile);
        }
        LOGGER.info("Defaulting to OpenNLP subject of training.");
        return new OpenNLPSubjectOfTrainingOrEvaluation(inputFile);
    }
}
