package weka.classifiers.functions;

import java.util.Collections;
import java.util.Enumeration;
import no.uib.cipr.matrix.DenseCholesky;
import no.uib.cipr.matrix.DenseVector;
import no.uib.cipr.matrix.Matrices;
import no.uib.cipr.matrix.Matrix;
import no.uib.cipr.matrix.UpperSPDDenseMatrix;
import no.uib.cipr.matrix.Vector;
import weka.classifiers.ConditionalDensityEstimator;
import weka.classifiers.IntervalEstimator;
import weka.classifiers.RandomizableClassifier;
import weka.classifiers.functions.supportVector.CachedKernel;
import weka.classifiers.functions.supportVector.Kernel;
import weka.classifiers.functions.supportVector.PolyKernel;
import weka.classifiers.lazy.kstar.KStarConstants;
import weka.core.Capabilities;
import weka.core.Instance;
import weka.core.Instances;
import weka.core.Option;
import weka.core.SelectedTag;
import weka.core.Statistics;
import weka.core.Tag;
import weka.core.TechnicalInformation;
import weka.core.TechnicalInformationHandler;
import weka.core.TestInstances;
import weka.core.Utils;
import weka.core.WeightedInstancesHandler;
import weka.core.json.JSONInstances;
import weka.filters.Filter;
import weka.filters.unsupervised.attribute.NominalToBinary;
import weka.filters.unsupervised.attribute.Normalize;
import weka.filters.unsupervised.attribute.ReplaceMissingValues;
import weka.filters.unsupervised.attribute.Standardize;
import weka.gui.knowledgeflow.KnowledgeFlowApp;

/* loaded from: input_file:weka/classifiers/functions/GaussianProcesses.class */
public class GaussianProcesses extends RandomizableClassifier implements IntervalEstimator, ConditionalDensityEstimator, TechnicalInformationHandler, WeightedInstancesHandler {
    static final long serialVersionUID = -8620066949967678545L;
    protected NominalToBinary m_NominalToBinary;
    public static final int FILTER_NORMALIZE = 0;
    public static final int FILTER_STANDARDIZE = 1;
    public static final int FILTER_NONE = 2;
    public static final Tag[] TAGS_FILTER = {new Tag(0, "Normalize training data"), new Tag(1, "Standardize training data"), new Tag(2, "No normalization/standardization")};
    protected ReplaceMissingValues m_Missing;
    protected double m_Alin;
    protected double m_Blin;
    protected Kernel m_actualKernel;
    protected double m_avg_target;
    public Matrix m_L;
    protected Vector m_t;
    protected double[] m_weights;
    protected Filter m_Filter = null;
    protected int m_filterType = 0;
    protected boolean m_checksTurnedOff = false;
    protected double m_delta = 1.0d;
    protected double m_deltaSquared = 1.0d;
    protected Kernel m_kernel = new PolyKernel();
    protected int m_NumTrain = 0;

    public String globalInfo() {
        return " Implements Gaussian processes for regression without hyperparameter-tuning. To make choosing an appropriate noise level easier, this implementation applies normalization/standardization to the target attribute as well as the other attributes (if  normalization/standardizaton is turned on). Missing values are replaced by the global mean/mode. Nominal attributes are converted to binary ones. Note that kernel caching is turned off if the kernel used implements CachedKernel.";
    }

    @Override // weka.core.TechnicalInformationHandler
    public TechnicalInformation getTechnicalInformation() {
        TechnicalInformation technicalInformation = new TechnicalInformation(TechnicalInformation.Type.MISC);
        technicalInformation.setValue(TechnicalInformation.Field.AUTHOR, "David J.C. Mackay");
        technicalInformation.setValue(TechnicalInformation.Field.YEAR, "1998");
        technicalInformation.setValue(TechnicalInformation.Field.TITLE, "Introduction to Gaussian Processes");
        technicalInformation.setValue(TechnicalInformation.Field.ADDRESS, "Dept. of Physics, Cambridge University, UK");
        technicalInformation.setValue(TechnicalInformation.Field.PS, "http://wol.ra.phy.cam.ac.uk/mackay/gpB.ps.gz");
        return technicalInformation;
    }

    @Override // weka.classifiers.AbstractClassifier, weka.classifiers.Classifier, weka.core.CapabilitiesHandler
    public Capabilities getCapabilities() {
        Capabilities capabilities = getKernel().getCapabilities();
        capabilities.setOwner(this);
        capabilities.enableAllAttributeDependencies();
        if (capabilities.handles(Capabilities.Capability.NUMERIC_ATTRIBUTES)) {
            capabilities.enable(Capabilities.Capability.NOMINAL_ATTRIBUTES);
        }
        capabilities.enable(Capabilities.Capability.MISSING_VALUES);
        capabilities.disableAllClasses();
        capabilities.disableAllClassDependencies();
        capabilities.disable(Capabilities.Capability.NO_CLASS);
        capabilities.enable(Capabilities.Capability.NUMERIC_CLASS);
        capabilities.enable(Capabilities.Capability.DATE_CLASS);
        capabilities.enable(Capabilities.Capability.MISSING_CLASS_VALUES);
        return capabilities;
    }

    @Override // weka.classifiers.Classifier
    public void buildClassifier(Instances instances) throws Exception {
        if (this.m_checksTurnedOff) {
            this.m_Missing = null;
        } else {
            getCapabilities().testWithFail(instances);
            Instances instances2 = new Instances(instances);
            instances2.deleteWithMissingClass();
            this.m_Missing = new ReplaceMissingValues();
            this.m_Missing.setInputFormat(instances2);
            instances = Filter.useFilter(instances2, this.m_Missing);
        }
        if (getCapabilities().handles(Capabilities.Capability.NUMERIC_ATTRIBUTES)) {
            boolean z = true;
            if (!this.m_checksTurnedOff) {
                int i = 0;
                while (true) {
                    if (i >= instances.numAttributes()) {
                        break;
                    }
                    if (i != instances.classIndex() && !instances.attribute(i).isNumeric()) {
                        z = false;
                        break;
                    }
                    i++;
                }
            }
            if (z) {
                this.m_NominalToBinary = null;
            } else {
                this.m_NominalToBinary = new NominalToBinary();
                this.m_NominalToBinary.setInputFormat(instances);
                instances = Filter.useFilter(instances, this.m_NominalToBinary);
            }
        } else {
            this.m_NominalToBinary = null;
        }
        if (this.m_filterType == 1) {
            this.m_Filter = new Standardize();
            ((Standardize) this.m_Filter).setIgnoreClass(true);
            this.m_Filter.setInputFormat(instances);
            instances = Filter.useFilter(instances, this.m_Filter);
        } else if (this.m_filterType == 0) {
            this.m_Filter = new Normalize();
            ((Normalize) this.m_Filter).setIgnoreClass(true);
            this.m_Filter.setInputFormat(instances);
            instances = Filter.useFilter(instances, this.m_Filter);
        } else {
            this.m_Filter = null;
        }
        this.m_NumTrain = instances.numInstances();
        if (this.m_Filter != null) {
            Instance instance = (Instance) instances.instance(0).copy();
            instance.setValue(instances.classIndex(), KStarConstants.FLOOR);
            this.m_Filter.input(instance);
            this.m_Filter.batchFinished();
            this.m_Blin = this.m_Filter.output().value(instances.classIndex());
            instance.setValue(instances.classIndex(), 1.0d);
            this.m_Filter.input(instance);
            this.m_Filter.batchFinished();
            this.m_Alin = this.m_Filter.output().value(instances.classIndex()) - this.m_Blin;
        } else {
            this.m_Alin = 1.0d;
            this.m_Blin = KStarConstants.FLOOR;
        }
        this.m_actualKernel = Kernel.makeCopy(this.m_kernel);
        if (this.m_kernel instanceof CachedKernel) {
            ((CachedKernel) this.m_actualKernel).setCacheSize(-1);
        }
        this.m_actualKernel.buildKernel(instances);
        double d = 0.0d;
        for (int i2 = 0; i2 < instances.numInstances(); i2++) {
            d += instances.instance(i2).weight() * instances.instance(i2).classValue();
        }
        this.m_avg_target = d / instances.sumOfWeights();
        this.m_deltaSquared = this.m_delta * this.m_delta;
        this.m_weights = new double[instances.numInstances()];
        for (int i3 = 0; i3 < instances.numInstances(); i3++) {
            this.m_weights[i3] = Math.sqrt(instances.instance(i3).weight());
        }
        int numInstances = instances.numInstances();
        this.m_L = new UpperSPDDenseMatrix(numInstances);
        for (int i4 = 0; i4 < numInstances; i4++) {
            for (int i5 = i4 + 1; i5 < numInstances; i5++) {
                this.m_L.set(i4, i5, this.m_weights[i4] * this.m_weights[i5] * this.m_actualKernel.eval(i4, i5, instances.instance(i4)));
            }
            this.m_L.set(i4, i4, (this.m_weights[i4] * this.m_weights[i4] * this.m_actualKernel.eval(i4, i4, instances.instance(i4))) + this.m_deltaSquared);
        }
        this.m_L = new DenseCholesky(numInstances, true).factor(this.m_L).solve(Matrices.identity(numInstances));
        this.m_L = new UpperSPDDenseMatrix(this.m_L);
        DenseVector denseVector = new DenseVector(numInstances);
        for (int i6 = 0; i6 < numInstances; i6++) {
            denseVector.set(i6, this.m_weights[i6] * (instances.instance(i6).classValue() - this.m_avg_target));
        }
        this.m_t = this.m_L.mult(denseVector, new DenseVector(instances.numInstances()));
    }

    @Override // weka.classifiers.AbstractClassifier, weka.classifiers.Classifier
    public double classifyInstance(Instance instance) throws Exception {
        Instance filterInstance = filterInstance(instance);
        DenseVector denseVector = new DenseVector(this.m_NumTrain);
        for (int i = 0; i < this.m_NumTrain; i++) {
            denseVector.set(i, this.m_weights[i] * this.m_actualKernel.eval(-1, i, filterInstance));
        }
        return ((denseVector.dot(this.m_t) + this.m_avg_target) - this.m_Blin) / this.m_Alin;
    }

    protected Instance filterInstance(Instance instance) throws Exception {
        if (!this.m_checksTurnedOff) {
            this.m_Missing.input(instance);
            this.m_Missing.batchFinished();
            instance = this.m_Missing.output();
        }
        if (this.m_NominalToBinary != null) {
            this.m_NominalToBinary.input(instance);
            this.m_NominalToBinary.batchFinished();
            instance = this.m_NominalToBinary.output();
        }
        if (this.m_Filter != null) {
            this.m_Filter.input(instance);
            this.m_Filter.batchFinished();
            instance = this.m_Filter.output();
        }
        return instance;
    }

    protected double computeStdDev(Instance instance, Vector vector) throws Exception {
        double eval = this.m_actualKernel.eval(-1, -1, instance) + this.m_deltaSquared;
        double dot = this.m_L.mult(vector, new DenseVector(vector.size())).dot(vector);
        double d = this.m_delta;
        if (eval > dot) {
            d = Math.sqrt(eval - dot);
        }
        return d;
    }

    @Override // weka.classifiers.IntervalEstimator
    public double[][] predictIntervals(Instance instance, double d) throws Exception {
        Instance filterInstance = filterInstance(instance);
        DenseVector denseVector = new DenseVector(this.m_NumTrain);
        for (int i = 0; i < this.m_NumTrain; i++) {
            denseVector.set(i, this.m_weights[i] * this.m_actualKernel.eval(-1, i, filterInstance));
        }
        double dot = denseVector.dot(this.m_t) + this.m_avg_target;
        double computeStdDev = computeStdDev(filterInstance, denseVector);
        double normalInverse = Statistics.normalInverse(1.0d - ((1.0d - d) / 2.0d));
        double[][] dArr = new double[1][2];
        dArr[0][0] = dot - (normalInverse * computeStdDev);
        dArr[0][1] = dot + (normalInverse * computeStdDev);
        dArr[0][0] = (dArr[0][0] - this.m_Blin) / this.m_Alin;
        dArr[0][1] = (dArr[0][1] - this.m_Blin) / this.m_Alin;
        return dArr;
    }

    public double getStandardDeviation(Instance instance) throws Exception {
        Instance filterInstance = filterInstance(instance);
        DenseVector denseVector = new DenseVector(this.m_NumTrain);
        for (int i = 0; i < this.m_NumTrain; i++) {
            denseVector.set(i, this.m_weights[i] * this.m_actualKernel.eval(-1, i, filterInstance));
        }
        return computeStdDev(filterInstance, denseVector) / this.m_Alin;
    }

    @Override // weka.classifiers.ConditionalDensityEstimator
    public double logDensity(Instance instance, double d) throws Exception {
        Instance filterInstance = filterInstance(instance);
        DenseVector denseVector = new DenseVector(this.m_NumTrain);
        for (int i = 0; i < this.m_NumTrain; i++) {
            denseVector.set(i, this.m_weights[i] * this.m_actualKernel.eval(-1, i, filterInstance));
        }
        double dot = denseVector.dot(this.m_t) + this.m_avg_target;
        double computeStdDev = computeStdDev(filterInstance, denseVector);
        double d2 = ((d * this.m_Alin) + this.m_Blin) - dot;
        return ((-Math.log(computeStdDev * Math.sqrt(6.283185307179586d))) - ((d2 * d2) / ((2.0d * computeStdDev) * computeStdDev))) + Math.log(this.m_Alin);
    }

    @Override // weka.classifiers.RandomizableClassifier, weka.classifiers.AbstractClassifier, weka.core.OptionHandler
    public Enumeration<Option> listOptions() {
        java.util.Vector vector = new java.util.Vector();
        vector.addElement(new Option("\tLevel of Gaussian Noise wrt transformed target. (default 1)", "L", 1, "-L <double>"));
        vector.addElement(new Option("\tWhether to 0=normalize/1=standardize/2=neither. (default 0=normalize)", "N", 1, "-N"));
        vector.addElement(new Option("\tThe Kernel to use.\n\t(default: weka.classifiers.functions.supportVector.PolyKernel)", "K", 1, "-K <classname and parameters>"));
        vector.addAll(Collections.list(super.listOptions()));
        vector.addElement(new Option(KnowledgeFlowApp.KnowledgeFlowGeneralDefaults.LAF, KnowledgeFlowApp.KnowledgeFlowGeneralDefaults.LAF, 0, "\nOptions specific to kernel " + getKernel().getClass().getName() + JSONInstances.SPARSE_SEPARATOR));
        vector.addAll(Collections.list(getKernel().listOptions()));
        return vector.elements();
    }

    @Override // weka.classifiers.RandomizableClassifier, weka.classifiers.AbstractClassifier, weka.core.OptionHandler
    public void setOptions(String[] strArr) throws Exception {
        String option = Utils.getOption('L', strArr);
        if (option.length() != 0) {
            setNoise(Double.parseDouble(option));
        } else {
            setNoise(1.0d);
        }
        String option2 = Utils.getOption('N', strArr);
        if (option2.length() != 0) {
            setFilterType(new SelectedTag(Integer.parseInt(option2), TAGS_FILTER));
        } else {
            setFilterType(new SelectedTag(0, TAGS_FILTER));
        }
        String[] splitOptions = Utils.splitOptions(Utils.getOption('K', strArr));
        if (splitOptions.length != 0) {
            String str = splitOptions[0];
            splitOptions[0] = KnowledgeFlowApp.KnowledgeFlowGeneralDefaults.LAF;
            setKernel(Kernel.forName(str, splitOptions));
        }
        super.setOptions(strArr);
    }

    @Override // weka.classifiers.RandomizableClassifier, weka.classifiers.AbstractClassifier, weka.core.OptionHandler
    public String[] getOptions() {
        java.util.Vector vector = new java.util.Vector();
        vector.addElement("-L");
        vector.addElement(KnowledgeFlowApp.KnowledgeFlowGeneralDefaults.LAF + getNoise());
        vector.addElement("-N");
        vector.addElement(KnowledgeFlowApp.KnowledgeFlowGeneralDefaults.LAF + this.m_filterType);
        vector.addElement("-K");
        vector.addElement(KnowledgeFlowApp.KnowledgeFlowGeneralDefaults.LAF + this.m_kernel.getClass().getName() + TestInstances.DEFAULT_SEPARATORS + Utils.joinOptions(this.m_kernel.getOptions()));
        Collections.addAll(vector, super.getOptions());
        return (String[]) vector.toArray(new String[vector.size()]);
    }

    public String kernelTipText() {
        return "The kernel to use.";
    }

    public Kernel getKernel() {
        return this.m_kernel;
    }

    public void setKernel(Kernel kernel) {
        this.m_kernel = kernel;
    }

    public String filterTypeTipText() {
        return "Determines how/if the data will be transformed.";
    }

    public SelectedTag getFilterType() {
        return new SelectedTag(this.m_filterType, TAGS_FILTER);
    }

    public void setFilterType(SelectedTag selectedTag) {
        if (selectedTag.getTags() == TAGS_FILTER) {
            this.m_filterType = selectedTag.getSelectedTag().getID();
        }
    }

    public String noiseTipText() {
        return "The level of Gaussian Noise (added to the diagonal of the Covariance Matrix), after the target has been normalized/standardized/left unchanged).";
    }

    public double getNoise() {
        return this.m_delta;
    }

    public void setNoise(double d) {
        this.m_delta = d;
    }

    public String toString() {
        StringBuffer stringBuffer = new StringBuffer();
        if (this.m_t == null) {
            return "Gaussian Processes: No model built yet.";
        }
        try {
            stringBuffer.append("Gaussian Processes\n\n");
            stringBuffer.append("Kernel used:\n  " + this.m_kernel.toString() + "\n\n");
            stringBuffer.append("All values shown based on: " + TAGS_FILTER[this.m_filterType].getReadable() + "\n\n");
            stringBuffer.append("Average Target Value : " + this.m_avg_target + "\n");
            stringBuffer.append("Inverted Covariance Matrix:\n");
            double d = this.m_L.get(0, 0);
            double d2 = this.m_L.get(0, 0);
            for (int i = 0; i < this.m_NumTrain; i++) {
                for (int i2 = 0; i2 <= i; i2++) {
                    if (this.m_L.get(i, i2) < d) {
                        d = this.m_L.get(i, i2);
                    } else if (this.m_L.get(i, i2) > d2) {
                        d2 = this.m_L.get(i, i2);
                    }
                }
            }
            stringBuffer.append("    Lowest Value = " + d + "\n");
            stringBuffer.append("    Highest Value = " + d2 + "\n");
            stringBuffer.append("Inverted Covariance Matrix * Target-value Vector:\n");
            double d3 = this.m_t.get(0);
            double d4 = this.m_t.get(0);
            for (int i3 = 0; i3 < this.m_NumTrain; i3++) {
                if (this.m_t.get(i3) < d3) {
                    d3 = this.m_t.get(i3);
                } else if (this.m_t.get(i3) > d4) {
                    d4 = this.m_t.get(i3);
                }
            }
            stringBuffer.append("    Lowest Value = " + d3 + "\n");
            stringBuffer.append("    Highest Value = " + d4 + "\n \n");
            return stringBuffer.toString();
        } catch (Exception e) {
            return "Can't print the classifier.";
        }
    }

    public static void main(String[] strArr) {
        runClassifier(new GaussianProcesses(), strArr);
    }
}
