/*
 * Decompiled with CFR 0.152.
 */
package smile.validation;

import java.util.Arrays;
import java.util.function.BiFunction;
import java.util.function.Function;
import smile.classification.Classifier;
import smile.classification.DataFrameClassifier;
import smile.data.DataFrame;
import smile.data.Tuple;
import smile.math.MathEx;
import smile.regression.DataFrameRegression;
import smile.regression.Regression;
import smile.sort.QuickSort;

public class GroupKFold {
    public final int k;
    public final int[][] train;
    public final int[][] test;

    public GroupKFold(int n, int k, int[] groups) {
        if (n < 0) {
            throw new IllegalArgumentException("Invalid sample size: " + n);
        }
        if (k < 0) {
            throw new IllegalArgumentException("Invalid number of folds: " + k);
        }
        if (groups.length != n) {
            throw new IllegalArgumentException("Groups array must be of size n, but length is " + groups.length);
        }
        int[] uniqueGroups = MathEx.unique((int[])groups);
        int numGroups = uniqueGroups.length;
        if (k > numGroups) {
            throw new IllegalArgumentException("Number of splits mustn't be greater than number of groups");
        }
        Arrays.sort(uniqueGroups);
        for (int i = 0; i < numGroups; ++i) {
            if (uniqueGroups[i] == i) continue;
            throw new IllegalArgumentException("Invalid encoding of groups, all group indices between [0, numGroups) have to exist");
        }
        this.k = k;
        this.train = new int[k][];
        this.test = new int[k][];
        TestFolds testFolds = this.calculateTestFolds(groups, numGroups);
        for (int i = 0; i < k; ++i) {
            this.train[i] = new int[n - testFolds.numTestSamplesPerFold[i]];
            this.test[i] = new int[testFolds.numTestSamplesPerFold[i]];
            int trainIndex = 0;
            int testIndex = 0;
            for (int j = 0; j < n; ++j) {
                if (testFolds.groupToTestFoldIndex[groups[j]] == i) {
                    this.test[i][testIndex++] = j;
                    continue;
                }
                this.train[i][trainIndex++] = j;
            }
        }
    }

    private TestFolds calculateTestFolds(int[] groups, int numGroups) {
        int[] numSamplesPerGroup = new int[numGroups];
        int[] nArray = groups;
        int n = nArray.length;
        for (int i = 0; i < n; ++i) {
            int g;
            int n2 = g = nArray[i];
            numSamplesPerGroup[n2] = numSamplesPerGroup[n2] + 1;
        }
        int[] toOriginalGroupIndex = QuickSort.sort((int[])numSamplesPerGroup);
        int[] numTestSamplesPerFold = new int[this.k];
        int[] groupToTestFoldIndex = new int[numGroups];
        for (int i = numGroups - 1; i >= 0; --i) {
            int smallestFoldIndex;
            int n3 = smallestFoldIndex = MathEx.whichMin((int[])numTestSamplesPerFold);
            numTestSamplesPerFold[n3] = numTestSamplesPerFold[n3] + numSamplesPerGroup[i];
            groupToTestFoldIndex[toOriginalGroupIndex[i]] = smallestFoldIndex;
        }
        return new TestFolds(numTestSamplesPerFold, groupToTestFoldIndex);
    }

    public <T> int[] classification(T[] x, int[] y, BiFunction<T[], int[], Classifier<T>> trainer) {
        int[] prediction = new int[x.length];
        for (int i = 0; i < this.k; ++i) {
            Object[] trainx = MathEx.slice((Object[])x, (int[])this.train[i]);
            int[] trainy = MathEx.slice((int[])y, (int[])this.train[i]);
            Classifier<T> model = trainer.apply((Object[][])trainx, trainy);
            for (int j : this.test[i]) {
                prediction[j] = model.predict(x[j]);
            }
        }
        return prediction;
    }

    public int[] classification(DataFrame data, Function<DataFrame, DataFrameClassifier> trainer) {
        int[] prediction = new int[data.size()];
        for (int i = 0; i < this.k; ++i) {
            DataFrameClassifier model = trainer.apply(data.of(this.train[i]));
            for (int j : this.test[i]) {
                prediction[j] = model.predict((Tuple)data.get(j));
            }
        }
        return prediction;
    }

    public <T> double[] regression(T[] x, double[] y, BiFunction<T[], double[], Regression<T>> trainer) {
        double[] prediction = new double[x.length];
        for (int i = 0; i < this.k; ++i) {
            Object[] trainx = MathEx.slice((Object[])x, (int[])this.train[i]);
            double[] trainy = MathEx.slice((double[])y, (int[])this.train[i]);
            Regression<T> model = trainer.apply((Object[][])trainx, trainy);
            for (int j : this.test[i]) {
                prediction[j] = model.predict(x[j]);
            }
        }
        return prediction;
    }

    public double[] regression(DataFrame data, Function<DataFrame, DataFrameRegression> trainer) {
        double[] prediction = new double[data.size()];
        for (int i = 0; i < this.k; ++i) {
            DataFrameRegression model = trainer.apply(data.of(this.train[i]));
            for (int j : this.test[i]) {
                prediction[j] = model.predict((Tuple)data.get(j));
            }
        }
        return prediction;
    }

    private class TestFolds {
        private final int[] numTestSamplesPerFold;
        private final int[] groupToTestFoldIndex;

        private TestFolds(int[] numTestSamplesPerFold, int[] groupToTestFoldIndex) {
            this.numTestSamplesPerFold = numTestSamplesPerFold;
            this.groupToTestFoldIndex = groupToTestFoldIndex;
        }
    }
}

