/*
 * Decompiled with CFR 0.152.
 */
package smile.feature.selection;

import java.util.Arrays;
import java.util.stream.IntStream;
import smile.classification.ClassLabels;
import smile.data.DataFrame;
import smile.data.type.StructField;
import smile.data.type.StructType;
import smile.data.vector.BaseVector;
import smile.math.MathEx;

public record SignalNoiseRatio(String feature, double ratio) implements Comparable<SignalNoiseRatio>
{
    @Override
    public int compareTo(SignalNoiseRatio other) {
        return Double.compare(this.ratio, other.ratio);
    }

    @Override
    public String toString() {
        return String.format("SignalNoiseRatio(%s, %.4f)", this.feature, this.ratio);
    }

    public static SignalNoiseRatio[] fit(DataFrame data, String clazz) {
        BaseVector y = data.column(clazz);
        ClassLabels codec = ClassLabels.fit(y);
        if (codec.k != 2) {
            throw new UnsupportedOperationException("Signal Noise Ratio is applicable only to binary classification");
        }
        int n = data.nrow();
        int n1 = 0;
        for (int yi : codec.y) {
            if (yi != 0) continue;
            ++n1;
        }
        int n2 = n - n1;
        double[] x1 = new double[n1];
        double[] x2 = new double[n2];
        StructType schema = data.schema();
        return (SignalNoiseRatio[])IntStream.range(0, schema.length()).mapToObj(i -> {
            StructField field = schema.field(i);
            if (field.isNumeric()) {
                Arrays.fill(x1, 0.0);
                Arrays.fill(x2, 0.0);
                BaseVector xi = data.column(i);
                int j = 0;
                int k = 0;
                for (int l = 0; l < n; ++l) {
                    if (codec.y[l] == 0) {
                        x1[j++] = xi.getDouble(l);
                        continue;
                    }
                    x2[k++] = xi.getDouble(l);
                }
                double mu1 = MathEx.mean((double[])x1);
                double mu2 = MathEx.mean((double[])x2);
                double sd1 = MathEx.sd((double[])x1);
                double sd2 = MathEx.sd((double[])x2);
                double s2n = Math.abs(mu1 - mu2) / (sd1 + sd2);
                return new SignalNoiseRatio(field.name, s2n);
            }
            return null;
        }).filter(s2n -> s2n != null && !s2n.feature.equals(clazz)).toArray(SignalNoiseRatio[]::new);
    }
}

