package hex.genmodel;

import hex.ModelCategory;
import java.awt.Color;
import java.awt.Graphics2D;
import java.awt.image.BufferedImage;
import java.awt.image.ImageObserver;
import java.io.IOException;
import java.io.Serializable;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.EnumSet;
import java.util.Iterator;
import java.util.Random;
import water.genmodel.IGeneratedModel;

/* loaded from: input_file:hex/genmodel/GenModel.class */
public abstract class GenModel implements IGenModel, IGeneratedModel, Serializable {
    public final String[] _names;
    public final String[][] _domains;
    public String _offsetColumn = null;
    static final /* synthetic */ boolean $assertionsDisabled;

    public GenModel(String[] strArr, String[][] strArr2) {
        this._names = strArr;
        this._domains = strArr2;
    }

    @Override // hex.genmodel.IGenModel
    public boolean isSupervised() {
        return false;
    }

    @Override // hex.genmodel.IGenModel
    public int nfeatures() {
        return this._names.length;
    }

    @Override // hex.genmodel.IGenModel
    public int nclasses() {
        return 0;
    }

    @Override // hex.genmodel.IGenModel
    public abstract ModelCategory getModelCategory();

    @Override // hex.genmodel.IGenModel
    public EnumSet<ModelCategory> getModelCategories() {
        return EnumSet.of(getModelCategory());
    }

    @Override // water.genmodel.IGeneratedModel
    public abstract String getUUID();

    @Override // water.genmodel.IGeneratedModel
    public int getNumCols() {
        return nfeatures();
    }

    @Override // water.genmodel.IGeneratedModel
    public String[] getNames() {
        return this._names;
    }

    @Override // water.genmodel.IGeneratedModel
    public String getResponseName() {
        return this._names[getResponseIdx()];
    }

    @Override // water.genmodel.IGeneratedModel
    public int getResponseIdx() {
        if (isSupervised()) {
            return this._domains.length - 1;
        }
        throw new UnsupportedOperationException("Cannot provide response index for unsupervised models.");
    }

    @Override // water.genmodel.IGeneratedModel
    public int getNumClasses(int i) {
        String[] domainValues = getDomainValues(i);
        if (domainValues != null) {
            return domainValues.length;
        }
        return -1;
    }

    @Override // water.genmodel.IGeneratedModel
    public int getNumResponseClasses() {
        if (isClassifier()) {
            return nclasses();
        }
        throw new UnsupportedOperationException("Cannot provide number of response classes for non-classifiers.");
    }

    @Override // water.genmodel.IGeneratedModel
    public boolean isClassifier() {
        ModelCategory modelCategory = getModelCategory();
        return modelCategory == ModelCategory.Binomial || modelCategory == ModelCategory.Multinomial;
    }

    @Override // water.genmodel.IGeneratedModel
    public boolean isAutoEncoder() {
        return getModelCategory() == ModelCategory.AutoEncoder;
    }

    @Override // water.genmodel.IGeneratedModel
    public String[] getDomainValues(String str) {
        int colIdx = getColIdx(str);
        if (colIdx != -1) {
            return getDomainValues(colIdx);
        }
        return null;
    }

    @Override // water.genmodel.IGeneratedModel
    public String[] getDomainValues(int i) {
        return getDomainValues()[i];
    }

    @Override // water.genmodel.IGeneratedModel
    public String[][] getDomainValues() {
        return this._domains;
    }

    @Override // water.genmodel.IGeneratedModel
    public int getColIdx(String str) {
        String[] names = getNames();
        for (int i = 0; i < names.length; i++) {
            if (names[i].equals(str)) {
                return i;
            }
        }
        return -1;
    }

    @Override // water.genmodel.IGeneratedModel
    public int mapEnum(int i, String str) {
        String[] domainValues = getDomainValues(i);
        if (domainValues == null) {
            return -1;
        }
        for (int i2 = 0; i2 < domainValues.length; i2++) {
            if (str.equals(domainValues[i2])) {
                return i2;
            }
        }
        return -1;
    }

    @Override // water.genmodel.IGeneratedModel
    public int getPredsSize() {
        if (isClassifier()) {
            return 1 + getNumResponseClasses();
        }
        return 2;
    }

    public int getPredsSize(ModelCategory modelCategory) {
        return modelCategory == ModelCategory.DimReduction ? nclasses() : modelCategory == ModelCategory.AutoEncoder ? nfeatures() : getPredsSize();
    }

    public static String createAuxKey(String str) {
        return str + ".aux";
    }

    public abstract double[] score0(double[] dArr, double[] dArr2);

    public double[] score0(double[] dArr, double d, double[] dArr2) {
        throw new UnsupportedOperationException("`offset` column is not supported");
    }

    public static double[] correctProbabilities(double[] dArr, double[] dArr2, double[] dArr3) {
        double d = 0.0d;
        for (int i = 1; i < dArr.length; i++) {
            double d2 = dArr2[i - 1];
            double d3 = dArr3[i - 1];
            if (!$assertionsDisabled && Double.isNaN(dArr[i])) {
                throw new AssertionError("Predicted NaN class probability");
            }
            if (d2 != 0.0d && d3 != 0.0d) {
                int i2 = i;
                dArr[i2] = dArr[i2] * (d2 / d3);
            }
            d += dArr[i];
        }
        if (d > 0.0d) {
            for (int i3 = 1; i3 < dArr.length; i3++) {
                int i4 = i3;
                dArr[i4] = dArr[i4] / d;
            }
        }
        return dArr;
    }

    public static int getPrediction(double[] dArr, double[] dArr2, double[] dArr3, double d) {
        if (dArr.length == 3) {
            return dArr[2] >= d ? 1 : 0;
        }
        ArrayList<Integer> arrayList = new ArrayList();
        arrayList.add(0);
        int i = 1;
        int i2 = 0;
        for (int i3 = 2; i3 < dArr.length; i3++) {
            if (dArr[i] < dArr[i3]) {
                i = i3;
                i2 = 0;
            } else if (dArr[i] == dArr[i3]) {
                i2++;
                arrayList.add(Integer.valueOf(i3 - 1));
            }
        }
        if (i2 == 0) {
            return i - 1;
        }
        long j = 0;
        if (dArr3 != null) {
            for (double d2 : dArr3) {
                j ^= Double.doubleToRawLongBits(d2) >> 6;
            }
        }
        if (dArr2 != null) {
            if (!$assertionsDisabled && dArr.length != dArr2.length + 1) {
                throw new AssertionError();
            }
            double d3 = 0.0d;
            Iterator it = arrayList.iterator();
            while (it.hasNext()) {
                d3 += dArr2[((Integer) it.next()).intValue()];
            }
            double nextDouble = new Random(j).nextDouble();
            double d4 = 0.0d;
            for (Integer num : arrayList) {
                d4 += dArr2[num.intValue()] / d3;
                if (nextDouble <= d4) {
                    return num.intValue();
                }
            }
        }
        double d5 = dArr[i];
        int i4 = ((int) j) % (i2 + 1);
        for (int i5 = 1; i5 < dArr.length; i5++) {
            if (d5 == dArr[i5]) {
                i4--;
                if (i4 < 0) {
                    return i5 - 1;
                }
            }
        }
        throw new RuntimeException("Should Not Reach Here");
    }

    public static boolean bitSetContains(byte[] bArr, int i, int i2, double d) {
        if (!$assertionsDisabled && Double.isNaN(d)) {
            throw new AssertionError();
        }
        int i3 = ((int) d) - i2;
        if ($assertionsDisabled || (i3 >= 0 && i3 < i)) {
            return (bArr[i3 >> 3] & (1 << (i3 & 7))) != 0;
        }
        throw new AssertionError("Must have " + i2 + " <= idx <= " + ((i2 + i) - 1) + ": " + i3);
    }

    public static boolean bitSetIsInRange(int i, int i2, double d) {
        if (!$assertionsDisabled && Double.isNaN(d)) {
            throw new AssertionError();
        }
        int i3 = ((int) d) - i2;
        return i3 >= 0 && i3 < i;
    }

    public static int KMeans_closest(double[][] dArr, double[] dArr2, String[][] strArr, double[] dArr3, double[] dArr4) {
        int i = -1;
        double d = Double.MAX_VALUE;
        for (int i2 = 0; i2 < dArr.length; i2++) {
            double KMeans_distance = KMeans_distance(dArr[i2], dArr2, strArr, dArr3, dArr4);
            if (KMeans_distance < d) {
                i = i2;
                d = KMeans_distance;
            }
        }
        return i;
    }

    public static double[] KMeans_simplex(double[][] dArr, double[] dArr2, String[][] strArr, double[] dArr3, double[] dArr4) {
        double[] dArr5 = new double[dArr.length];
        double d = 0.0d;
        double d2 = 0.0d;
        for (int i = 0; i < dArr.length; i++) {
            dArr5[i] = KMeans_distance(dArr[i], dArr2, strArr, dArr3, dArr4);
            d += dArr5[i];
            d2 += 1.0d / dArr5[i];
        }
        double[] dArr6 = new double[dArr.length];
        if (d == 0.0d) {
            dArr6[new Random().nextInt(dArr.length)] = 1.0d;
        } else {
            int i2 = -1;
            int i3 = 0;
            while (true) {
                if (i3 >= dArr.length) {
                    break;
                }
                if (dArr5[i3] == 0.0d) {
                    i2 = i3;
                    break;
                }
                i3++;
            }
            if (i2 == -1) {
                for (int i4 = 0; i4 < dArr.length; i4++) {
                    dArr6[i4] = 1.0d / (dArr5[i4] * d2);
                }
            } else {
                dArr6[i2] = 1.0d;
            }
        }
        return dArr6;
    }

    public static double KMeans_distance(double[] dArr, float[] fArr, String[][] strArr, double[] dArr2, double[] dArr3, double[] dArr4, double[] dArr5) {
        double d = 0.0d;
        int length = fArr.length;
        for (int i = 0; i < dArr.length; i++) {
            float f = fArr[i];
            if (Float.isNaN(f)) {
                length--;
            } else {
                if (strArr[i] == null) {
                    if (dArr3 != null) {
                        f = (float) (((float) (f - dArr2[i])) * dArr3[i]);
                    }
                    double d2 = f - dArr[i];
                    d += d2 * d2;
                } else if (f != dArr[i]) {
                    d += 1.0d;
                }
                int i2 = i;
                dArr4[i2] = dArr4[i2] + f;
                int i3 = i;
                dArr5[i3] = dArr5[i3] + (f * f);
            }
        }
        if (0 < length && length < fArr.length) {
            d *= fArr.length / length;
        }
        return d;
    }

    public static double KMeans_distance(double[] dArr, double[] dArr2, String[][] strArr, double[] dArr3, double[] dArr4) {
        double d = 0.0d;
        int length = dArr2.length;
        for (int i = 0; i < dArr.length; i++) {
            double d2 = dArr2[i];
            if (Double.isNaN(d2)) {
                length--;
            } else if (strArr[i] == null) {
                if (dArr4 != null) {
                    d2 = (d2 - dArr3[i]) * dArr4[i];
                }
                double d3 = d2 - dArr[i];
                d += d3 * d3;
            } else if (d2 != dArr[i]) {
                d += 1.0d;
            }
        }
        if (0 < length && length < dArr2.length) {
            d *= dArr2.length / length;
        }
        return d;
    }

    public static double log_rescale(double[] dArr) {
        double d = Double.NEGATIVE_INFINITY;
        for (int i = 1; i < dArr.length; i++) {
            d = Math.max(d, dArr[i]);
        }
        if (!$assertionsDisabled && Double.isInfinite(d)) {
            throw new AssertionError("Something is wrong with GBM trees since returned prediction is " + Arrays.toString(dArr));
        }
        double d2 = 0.0d;
        for (int i2 = 1; i2 < dArr.length; i2++) {
            double exp = Math.exp(dArr[i2] - d);
            dArr[i2] = exp;
            d2 += exp;
        }
        return d2;
    }

    public static void GBM_rescale(double[] dArr) {
        double log_rescale = log_rescale(dArr);
        for (int i = 1; i < dArr.length; i++) {
            int i2 = i;
            dArr[i2] = dArr[i2] / log_rescale;
        }
    }

    public static double GLM_identityInv(double d) {
        return d;
    }

    public static double GLM_logitInv(double d) {
        return 1.0d / (Math.exp(-d) + 1.0d);
    }

    public static double GLM_logInv(double d) {
        return Math.exp(d);
    }

    public static double GLM_inverseInv(double d) {
        return 1.0d / (d < 0.0d ? Math.min(-1.0E-5d, d) : Math.max(1.0E-5d, d));
    }

    public static double GLM_tweedieInv(double d, double d2) {
        return d2 == 0.0d ? Math.max(2.0E-16d, Math.exp(d)) : Math.pow(d, 1.0d / d2);
    }

    public String getHeader() {
        return null;
    }

    public static void setInput(double[] dArr, float[] fArr, int i, int i2, int[] iArr, double[] dArr2, double[] dArr3, boolean z) {
        float[] fArr2 = new float[i];
        int[] iArr2 = new int[i2];
        for (int i3 = 0; i3 < i2; i3++) {
            if (Double.isNaN(dArr[i3])) {
                iArr2[i3] = iArr[i3 + 1] - 1;
            } else {
                int i4 = (int) dArr[i3];
                if (z) {
                    iArr2[i3] = i4 + iArr[i3];
                } else if (i4 != 0) {
                    iArr2[i3] = (i4 + iArr[i3]) - 1;
                }
                if (iArr2[i3] >= iArr[i3 + 1]) {
                    iArr2[i3] = iArr[i3 + 1] - 1;
                }
            }
        }
        for (int i5 = i2; i5 < i2 + i; i5++) {
            double d = dArr[i5];
            if (dArr2 != null) {
                d = (d - dArr3[i5 - i2]) * dArr2[i5 - i2];
            }
            fArr2[i5 - i2] = (float) d;
        }
        if (!$assertionsDisabled && fArr.length != i + iArr[i2]) {
            throw new AssertionError();
        }
        Arrays.fill(fArr, 0.0f);
        for (int i6 = 0; i6 < i2; i6++) {
            fArr[iArr2[i6]] = 1.0f;
        }
        for (int i7 = 0; i7 < i; i7++) {
            fArr[iArr[i2] + i7] = Double.isNaN((double) fArr2[i7]) ? 0.0f : fArr2[i7];
        }
    }

    public static void img2pixels(BufferedImage bufferedImage, int i, int i2, int i3, float[] fArr, int i4, float[] fArr2) throws IOException {
        BufferedImage bufferedImage2 = new BufferedImage(i, i2, bufferedImage.getType());
        Graphics2D createGraphics = bufferedImage2.createGraphics();
        createGraphics.drawImage(bufferedImage, 0, 0, i, i2, (ImageObserver) null);
        createGraphics.dispose();
        int i5 = i4;
        int i6 = i5 + (i * i2);
        int i7 = i6 + (i * i2);
        for (int i8 = 0; i8 < i2; i8++) {
            for (int i9 = 0; i9 < i; i9++) {
                Color color = new Color(bufferedImage2.getRGB(i9, i8));
                int red = color.getRed();
                int green = color.getGreen();
                int blue = color.getBlue();
                if (i3 == 1) {
                    fArr[i5] = ((red + green) + blue) / 3;
                    if (fArr2 != null) {
                        int i10 = i5;
                        fArr[i10] = fArr[i10] - fArr2[i5];
                    }
                } else {
                    fArr[i5] = red;
                    fArr[i6] = green;
                    fArr[i7] = blue;
                    if (fArr2 != null) {
                        int i11 = i5;
                        fArr[i11] = fArr[i11] - fArr2[i5 - i4];
                        int i12 = i6;
                        fArr[i12] = fArr[i12] - fArr2[i6 - i4];
                        int i13 = i7;
                        fArr[i13] = fArr[i13] - fArr2[i7 - i4];
                    }
                }
                i5++;
                i6++;
                i7++;
            }
        }
    }

    static {
        $assertionsDisabled = !GenModel.class.desiredAssertionStatus();
    }
}
