package meka.classifiers.multilabel;

import java.util.ArrayList;
import java.util.Enumeration;
import java.util.List;
import java.util.Vector;
import meka.classifiers.multitarget.CR;
import meka.core.OptionUtils;
import org.kramerlab.autoencoder.math.matrix.Mat;
import org.kramerlab.autoencoder.neuralnet.autoencoder.Autoencoder;
import org.kramerlab.autoencoder.package$;
import weka.classifiers.Classifier;
import weka.classifiers.functions.LinearRegression;
import weka.core.Instance;
import weka.core.Instances;
import weka.core.TechnicalInformation;
import weka.core.TechnicalInformationHandler;
import weka.filters.Filter;
import weka.filters.unsupervised.instance.SparseToNonSparse;

/* loaded from: input_file:meka/classifiers/multilabel/Maniac.class */
public class Maniac extends LabelTransformationClassifier implements TechnicalInformationHandler {
    protected static final long serialVersionUID = 585507197229071545L;
    private Autoencoder ae;
    private Instances compressedTemplateInst;
    protected boolean optimizeAE = getDefaultOptimizeAE();
    protected double compression = getDefaultCompression();
    protected int numberAutoencoders = getDefaultNumberAutoencoders();

    protected void setAE(Autoencoder autoencoder) {
        this.ae = autoencoder;
    }

    private Autoencoder getAE() {
        return this.ae;
    }

    protected int getDefaultNumberAutoencoders() {
        return 4;
    }

    public final int getNumberAutoencoders() {
        return this.numberAutoencoders;
    }

    public final void setNumberAutoencoders(int i) {
        this.numberAutoencoders = i;
    }

    public String numberAutoencodersToolTip() {
        return "Number of autoencoders, i.e. number of hidden layers +1. Note that this can be also used as the number of autoencoders to use in the optimization search, autoencoders will be added until this number is reached  and then the best configuration in terms of number of layers is selects.";
    }

    public String numberAutoencodersTipText() {
        return numberAutoencodersToolTip();
    }

    public final double getCompression() {
        return this.compression;
    }

    public final void setCompression(double d) {
        this.compression = d;
    }

    protected double getDefaultCompression() {
        return 0.85d;
    }

    public String compressionToolTip() {
        return "Compression factor of the autoencoders, each level of autoencoders will compress the labels to factor times previous layer size.";
    }

    public String compressionTipText() {
        return compressionToolTip();
    }

    public String optimizeAETipText() {
        return optimizeAEToolTip();
    }

    public final boolean isOptimizeAE() {
        return this.optimizeAE;
    }

    public final void setOptimizeAE(boolean z) {
        this.optimizeAE = z;
    }

    protected boolean getDefaultOptimizeAE() {
        return false;
    }

    public String optimizeAEToolTip() {
        return "Optimize the number of layers of autoencoders. If set to true the number of layers will internally be optimized using a validation set.";
    }

    public String globalInfo() {
        return "Maniac - Multi-lAbel classificatioN using AutoenCoders.Transforms the labels using layers of autoencoders.For more information see:\n" + getTechnicalInformation();
    }

    public Enumeration listOptions() {
        Vector vector = new Vector();
        OptionUtils.addOption(vector, compressionTipText(), "" + getDefaultCompression(), "compression");
        OptionUtils.addOption(vector, numberAutoencodersTipText(), "" + getDefaultNumberAutoencoders(), "numberAutoencoders");
        OptionUtils.addOption(vector, optimizeAETipText(), "" + getDefaultOptimizeAE(), "optimizeAE");
        OptionUtils.add(vector, super.listOptions());
        return OptionUtils.toEnumeration(vector);
    }

    @Override // meka.classifiers.multilabel.LabelTransformationClassifier
    protected Classifier getDefaultClassifier() {
        CR cr = new CR();
        cr.setClassifier(new LinearRegression());
        return cr;
    }

    public String[] getOptions() {
        ArrayList arrayList = new ArrayList();
        OptionUtils.add(arrayList, "compression", getCompression());
        OptionUtils.add(arrayList, "optimizeAE", isOptimizeAE());
        OptionUtils.add((List<String>) arrayList, "numberAutoencoders", getNumberAutoencoders());
        OptionUtils.add(arrayList, super.getOptions());
        return OptionUtils.toArray(arrayList);
    }

    public void setOptions(String[] strArr) throws Exception {
        setCompression(OptionUtils.parse(strArr, "compression", getDefaultCompression()));
        setNumberAutoencoders(OptionUtils.parse(strArr, "numberAutoencoders", getDefaultNumberAutoencoders()));
        setOptimizeAE(OptionUtils.parse(strArr, "optimizeAE", getDefaultOptimizeAE()));
        super.setOptions(strArr);
    }

    public TechnicalInformation getTechnicalInformation() {
        TechnicalInformation technicalInformation = new TechnicalInformation(TechnicalInformation.Type.INPROCEEDINGS);
        technicalInformation.setValue(TechnicalInformation.Field.AUTHOR, "J\"org Wicker, Andrey Tyukin, Stefan Kramer");
        technicalInformation.setValue(TechnicalInformation.Field.TITLE, "A Nonlinear Label Compression and Transformation Method for Multi-Label Classification using Autoencoders");
        technicalInformation.setValue(TechnicalInformation.Field.BOOKTITLE, "The 20th Pacific Asia Conference on Knowledge Discovery and Data Mining (PAKDD)");
        technicalInformation.setValue(TechnicalInformation.Field.YEAR, "2016");
        technicalInformation.setValue(TechnicalInformation.Field.PAGES, "328-340");
        return technicalInformation;
    }

    @Override // meka.classifiers.multilabel.LabelTransformationClassifier
    public Instance transformInstance(Instance instance) throws Exception {
        Instances instances = new Instances(instance.dataset());
        instances.delete();
        instances.add(instance);
        Instances extractPart = extractPart(instances, false);
        Instances instances2 = new Instances(this.compressedTemplateInst);
        Instance instance2 = instances2.instance(0);
        instances2.delete();
        instances2.add(instance2);
        for (int i = 0; i < instances2.classIndex(); i++) {
            if (Thread.currentThread().isInterrupted()) {
                throw new InterruptedException("Thread has been interrupted.");
            }
            instances2.instance(0).setMissing(i);
        }
        Instances mergeInstances = Instances.mergeInstances(instances2, extractPart);
        mergeInstances.setClassIndex(instances2.numAttributes());
        return mergeInstances.instance(0);
    }

    @Override // meka.classifiers.multilabel.LabelTransformationClassifier
    public Instances transformLabels(Instances instances) throws Exception {
        package$ package_ = package$.MODULE$;
        org.kramerlab.autoencoder.wekacompatibility.package$ package_2 = org.kramerlab.autoencoder.wekacompatibility.package$.MODULE$;
        org.kramerlab.autoencoder.experiments.package$ package_3 = org.kramerlab.autoencoder.experiments.package$.MODULE$;
        int i = -1;
        if (isOptimizeAE()) {
            Instances trainCV = instances.trainCV(3, 1);
            Instances testCV = instances.testCV(3, 1);
            Instances extractPart = extractPart(trainCV, true);
            SparseToNonSparse sparseToNonSparse = new SparseToNonSparse();
            sparseToNonSparse.setInputFormat(extractPart);
            double d = Double.NEGATIVE_INFINITY;
            int i2 = 0;
            i = 0;
            for (Autoencoder autoencoder : package_.deepAutoencoderStream_java(package_.Sigmoid(), getNumberAutoencoders(), getCompression(), package_2.instancesToMat(Filter.useFilter(extractPart, sparseToNonSparse)), true, package_.HintonsMiraculousStrategy(), true, package_.NoObservers())) {
                if (Thread.currentThread().isInterrupted()) {
                    throw new InterruptedException("Thread has been interrupted.");
                }
                i2++;
                Maniac maniac = new Maniac();
                maniac.setOptimizeAE(false);
                maniac.setNumberAutoencoders(getNumberAutoencoders());
                maniac.setCompression(getCompression());
                maniac.setClassifier(getClassifier());
                maniac.setAE(autoencoder);
                double doubleValue = ((Double) Evaluation.evaluateModel(maniac, trainCV, testCV).getValue("Accuracy")).doubleValue();
                if (d < doubleValue) {
                    d = doubleValue;
                    i = i2;
                }
            }
        }
        Instances extractPart2 = extractPart(instances, false);
        Instances extractPart3 = extractPart(instances, true);
        SparseToNonSparse sparseToNonSparse2 = new SparseToNonSparse();
        sparseToNonSparse2.setInputFormat(extractPart3);
        Mat instancesToMat = package_2.instancesToMat(Filter.useFilter(extractPart3, sparseToNonSparse2));
        if (getAE() == null) {
            int i3 = 0;
            for (Autoencoder autoencoder2 : package_.deepAutoencoderStream_java(package_.Sigmoid(), getNumberAutoencoders(), getCompression(), instancesToMat, true, package_.HintonsMiraculousStrategy(), true, package_.NoObservers())) {
                if (Thread.currentThread().isInterrupted()) {
                    throw new InterruptedException("Thread has been interrupted.");
                }
                i3++;
                if ((i > 0 && i3 == i) || i3 == getNumberAutoencoders()) {
                    setAE(autoencoder2);
                    break;
                }
            }
        }
        Instances matToInstances = package_2.matToInstances(getAE().compress(instancesToMat));
        this.compressedTemplateInst = new Instances(matToInstances);
        Instances mergeInstances = Instances.mergeInstances(matToInstances, extractPart2);
        mergeInstances.setClassIndex(matToInstances.numAttributes());
        return mergeInstances;
    }

    @Override // meka.classifiers.multilabel.LabelTransformationClassifier
    public double[] transformPredictionsBack(double[] dArr) {
        Mat mat = new Mat(1, dArr.length / 2);
        for (int i = 0; i < dArr.length / 2; i++) {
            mat.update(0, i, dArr[(dArr.length / 2) + i]);
        }
        Mat decompress = getAE().decompress(mat);
        double[] dArr2 = new double[decompress.toArray()[0].length];
        for (int i2 = 0; i2 < dArr2.length; i2++) {
            dArr2[i2] = decompress.apply(0, i2);
        }
        return dArr2;
    }

    @Override // meka.classifiers.MultiXClassifier
    public String getModel() {
        return "";
    }

    public String toString() {
        return getModel();
    }

    public static void main(String[] strArr) throws Exception {
        AbstractMultiLabelClassifier.evaluation(new Maniac(), strArr);
    }
}
