package ai.libs.jaicore.ml.core.dataset.sampling.inmemory.stratified.sampling;

import ai.libs.jaicore.ml.clustering.GMeans;
import ai.libs.jaicore.ml.core.dataset.IDataset;
import ai.libs.jaicore.ml.core.dataset.INumericArrayInstance;
import org.apache.commons.math3.ml.clustering.KMeansPlusPlusClusterer;
import org.apache.commons.math3.ml.distance.DistanceMeasure;
import org.apache.commons.math3.ml.distance.ManhattanDistance;
import org.apache.commons.math3.random.JDKRandomGenerator;

/* loaded from: input_file:ai/libs/jaicore/ml/core/dataset/sampling/inmemory/stratified/sampling/GMeansStratiAmountSelectorAndAssigner.class */
public class GMeansStratiAmountSelectorAndAssigner<I extends INumericArrayInstance, D extends IDataset<I>> extends ClusterStratiAssigner<I, D> implements IStratiAmountSelector<D> {
    private GMeans<I> clusterer;

    public GMeansStratiAmountSelectorAndAssigner(int i) {
        this.randomSeed = i;
        this.distanceMeasure = new ManhattanDistance();
    }

    public GMeansStratiAmountSelectorAndAssigner(DistanceMeasure distanceMeasure, int i) {
        this.randomSeed = i;
        this.distanceMeasure = distanceMeasure;
    }

    @Override // ai.libs.jaicore.ml.core.dataset.sampling.inmemory.stratified.sampling.IStratiAmountSelector
    public int selectStratiAmount(D d) {
        this.clusterer = new GMeans<>(d, this.distanceMeasure, this.randomSeed);
        this.clusters = this.clusterer.cluster();
        return this.clusters.size();
    }

    @Override // ai.libs.jaicore.ml.core.dataset.sampling.inmemory.stratified.sampling.IStratiAssigner
    public void init(D d, int i) {
        if (this.clusterer == null || this.clusters == null) {
            JDKRandomGenerator jDKRandomGenerator = new JDKRandomGenerator();
            jDKRandomGenerator.setSeed(this.randomSeed);
            this.clusters = new KMeansPlusPlusClusterer(i, -1, this.distanceMeasure, jDKRandomGenerator).cluster(d);
        }
    }
}
