package ai.libs.jaicore.ml.tsc.classifier.shapelets;

import ai.libs.jaicore.basic.TimeOut;
import ai.libs.jaicore.basic.algorithm.IRandomAlgorithmConfig;
import ai.libs.jaicore.basic.algorithm.events.AlgorithmEvent;
import ai.libs.jaicore.basic.algorithm.exceptions.AlgorithmException;
import ai.libs.jaicore.ml.core.exception.TrainingException;
import ai.libs.jaicore.ml.tsc.classifier.ASimplifiedTSCLearningAlgorithm;
import ai.libs.jaicore.ml.tsc.dataset.TimeSeriesDataset;
import ai.libs.jaicore.ml.tsc.util.MathUtil;
import ai.libs.jaicore.ml.tsc.util.TimeSeriesUtil;
import ai.libs.jaicore.ml.tsc.util.WekaUtil;
import com.google.common.collect.Iterables;
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import java.util.Random;
import java.util.concurrent.TimeUnit;
import java.util.stream.Collectors;
import java.util.stream.IntStream;
import org.aeonbits.owner.Config;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import weka.clusterers.SimpleKMeans;
import weka.core.Instances;

/* loaded from: input_file:ai/libs/jaicore/ml/tsc/classifier/shapelets/LearnShapeletsLearningAlgorithm.class */
public class LearnShapeletsLearningAlgorithm extends ASimplifiedTSCLearningAlgorithm<Integer, LearnShapeletsClassifier> {
    private static final Logger LOGGER = LoggerFactory.getLogger(LearnShapeletsLearningAlgorithm.class);
    private int numInstances;
    private int q;
    private int numClasses;
    public static final boolean USE_BIAS_CORRECTION = false;
    public static final double ALPHA = -30.0d;
    private static final double EPS = 1.0E-21d;
    private TimeOut timeout;
    private boolean useInstanceReordering;

    /* loaded from: input_file:ai/libs/jaicore/ml/tsc/classifier/shapelets/LearnShapeletsLearningAlgorithm$ILearnShapeletsLearningAlgorithmConfig.class */
    public interface ILearnShapeletsLearningAlgorithmConfig extends IRandomAlgorithmConfig {
        public static final String K_NUMSHAPELETS = "numshapelets";
        public static final String K_LEARNINGRATE = "learningrate";
        public static final String K_REGULARIZATION = "regularization";
        public static final String K_SHAPELETLENGTH_MIN = "minshapeletlength";
        public static final String K_SHAPELETLENGTH_RELMIN = "relativeminshapeletlength";
        public static final String K_SCALER = "scaler";
        public static final String K_MAXITER = "maxiter";
        public static final String K_GAMMA = "gamma";
        public static final String K_ESTIMATEK = "estimatek";

        @Config.Key("numshapelets")
        int numShapelets();

        @Config.Key("learningrate")
        double learningRate();

        @Config.Key("regularization")
        double regularization();

        @Config.Key("minshapeletlength")
        int minShapeletLength();

        @Config.Key("relativeminshapeletlength")
        double minShapeLengthPercentage();

        @Config.Key("scaler")
        int scaleR();

        @Config.Key("maxiter")
        int maxIterations();

        @Config.DefaultValue("0.5")
        @Config.Key("gamma")
        double gamma();

        @Config.DefaultValue("false")
        @Config.Key("estimatek")
        boolean estimateK();
    }

    public LearnShapeletsLearningAlgorithm(ILearnShapeletsLearningAlgorithmConfig iLearnShapeletsLearningAlgorithmConfig, LearnShapeletsClassifier learnShapeletsClassifier, TimeSeriesDataset timeSeriesDataset) {
        super(iLearnShapeletsLearningAlgorithmConfig, learnShapeletsClassifier, timeSeriesDataset);
        this.timeout = new TimeOut(2147483647L, TimeUnit.SECONDS);
        this.useInstanceReordering = true;
    }

    /* JADX WARN: Type inference failed for: r0v11, types: [double[][], double[][][]] */
    public double[][][] initializeS(double[][] dArr) throws TrainingException {
        LOGGER.debug("Initializing S...");
        int scaleR = m83getConfig().scaleR();
        int seed = m83getConfig().seed();
        int minShapeletLength = m83getConfig().minShapeletLength();
        ?? r0 = new double[scaleR];
        for (int i = 0; i < scaleR; i++) {
            int numberOfSegments = getNumberOfSegments(this.q, minShapeletLength, i);
            if (numberOfSegments < 1) {
                throw new TrainingException("The number of segments is lower than 1. Can not train the LearnShapelets model.");
            }
            int i2 = (i + 1) * minShapeletLength;
            double[][] dArr2 = new double[dArr.length * numberOfSegments][i2];
            for (int i3 = 0; i3 < dArr.length; i3++) {
                for (int i4 = 0; i4 < numberOfSegments; i4++) {
                    for (int i5 = 0; i5 < i2; i5++) {
                        dArr2[(i3 * numberOfSegments) + i4][i5] = dArr[i3][i4 + i5];
                    }
                    dArr2[(i3 * numberOfSegments) + i4] = TimeSeriesUtil.zNormalize(dArr2[(i3 * numberOfSegments) + i4], false);
                }
            }
            Instances matrixToWekaInstances = WekaUtil.matrixToWekaInstances(dArr2);
            SimpleKMeans simpleKMeans = new SimpleKMeans();
            try {
                simpleKMeans.setNumClusters(m83getConfig().numShapelets());
                simpleKMeans.setSeed(seed);
                simpleKMeans.setMaxIterations(100);
                simpleKMeans.buildClusterer(matrixToWekaInstances);
                Instances clusterCentroids = simpleKMeans.getClusterCentroids();
                double[][] dArr3 = new double[clusterCentroids.numInstances()][clusterCentroids.numAttributes()];
                for (int i6 = 0; i6 < dArr3.length; i6++) {
                    double[] doubleArray = clusterCentroids.get(i6).toDoubleArray();
                    for (int i7 = 0; i7 < dArr3[i6].length; i7++) {
                        dArr3[i6][i7] = doubleArray[i7];
                    }
                }
                r0[i] = dArr3;
            } catch (Exception e) {
                LOGGER.warn("Could not initialize matrix S using kMeans clustering for r={} due to the following problem: {}. Using zero matrix instead (possibly leading to a poor training performance).", Integer.valueOf(i), e.getMessage());
                r0[i] = new double[m83getConfig().numShapelets()][i * minShapeletLength];
            }
        }
        LOGGER.debug("Initialized S.");
        return r0;
    }

    /* JADX WARN: Multi-variable type inference failed */
    /* JADX WARN: Type inference failed for: r0v43, types: [double[][], double[][][]] */
    /* renamed from: call, reason: merged with bridge method [inline-methods] */
    public LearnShapeletsClassifier m84call() throws AlgorithmException {
        long currentTimeMillis = System.currentTimeMillis();
        TimeSeriesDataset timeSeriesDataset = (TimeSeriesDataset) getInput();
        if (timeSeriesDataset.isMultivariate()) {
            throw new UnsupportedOperationException("Multivariate datasets are not supported.");
        }
        if (timeSeriesDataset.isEmpty()) {
            throw new IllegalArgumentException("The training dataset must not be null!");
        }
        double[][] valuesOrNull = timeSeriesDataset.getValuesOrNull(0);
        if (valuesOrNull == null) {
            throw new IllegalArgumentException("Timestamp matrix must be a valid 2D matrix containing the time series values for all instances!");
        }
        int[] targets = timeSeriesDataset.getTargets();
        List<Integer> classesInDataset = TimeSeriesUtil.getClassesInDataset(timeSeriesDataset);
        this.numInstances = timeSeriesDataset.getNumberOfInstances();
        this.q = valuesOrNull[0].length;
        this.numClasses = classesInDataset.size();
        m83getConfig().setProperty("minshapeletlength", "" + (m83getConfig().minShapeLengthPercentage() * this.q));
        int minShapeletLength = m83getConfig().minShapeletLength();
        int scaleR = m83getConfig().scaleR();
        int[][] iArr = new int[this.numInstances][this.numClasses];
        for (int i = 0; i < this.numInstances; i++) {
            iArr[i][classesInDataset.indexOf(Integer.valueOf(targets[i]))] = 1;
        }
        if (m83getConfig().estimateK()) {
            int i2 = 0;
            for (int i3 = 0; i3 < scaleR; i3++) {
                i2 += getNumberOfSegments(this.q, minShapeletLength, i3) * this.numInstances;
            }
            int log = (int) (Math.log(i2) * (this.numClasses - 1));
            m83getConfig().setProperty("numshapelets", "" + (log >= 0 ? log : 1));
        }
        int numShapelets = m83getConfig().numShapelets();
        LOGGER.info("Parameters: k={}, learningRate={}, reg={}, r={}, minShapeLength={}, maxIter={}, Q={}, C={}", new Object[]{Integer.valueOf(numShapelets), Double.valueOf(m83getConfig().learningRate()), Double.valueOf(m83getConfig().regularization()), Integer.valueOf(scaleR), Integer.valueOf(m83getConfig().minShapeletLength()), Integer.valueOf(m83getConfig().maxIterations()), Integer.valueOf(this.q), Integer.valueOf(this.numClasses)});
        try {
            double[][][] initializeS = initializeS(valuesOrNull);
            ?? r0 = new double[scaleR];
            for (int i4 = 0; i4 < scaleR; i4++) {
                r0[i4] = new double[initializeS[i4].length][initializeS[i4][0].length];
            }
            double[][][] dArr = new double[this.numClasses][scaleR][numShapelets];
            double[][][] dArr2 = new double[this.numClasses][scaleR][numShapelets];
            double[] dArr3 = new double[this.numClasses];
            double[] dArr4 = new double[this.numClasses];
            initializeWeights(dArr, dArr3);
            LOGGER.debug("Starting training for {} iterations...", Integer.valueOf(m83getConfig().maxIterations()));
            performSGD(dArr, dArr2, dArr3, dArr4, initializeS, r0, valuesOrNull, iArr, currentTimeMillis, targets);
            LOGGER.debug("Finished training.");
            LearnShapeletsClassifier learnShapeletsClassifier = (LearnShapeletsClassifier) getClassifier();
            learnShapeletsClassifier.setS(initializeS);
            learnShapeletsClassifier.setW(dArr);
            learnShapeletsClassifier.setW0(dArr3);
            learnShapeletsClassifier.setC(this.numClasses);
            return learnShapeletsClassifier;
        } catch (TrainingException e) {
            throw new AlgorithmException(e, "Can not train LearnShapelets model due to error during initialization of S.");
        }
    }

    public void initializeWeights(double[][][] dArr, double[] dArr2) {
        Random random = new Random(m83getConfig().seed());
        int scaleR = m83getConfig().scaleR();
        int numShapelets = m83getConfig().numShapelets();
        for (int i = 0; i < this.numClasses; i++) {
            dArr2[i] = EPS * random.nextDouble() * Math.pow(-1.0d, random.nextInt(2));
            for (int i2 = 0; i2 < scaleR; i2++) {
                for (int i3 = 0; i3 < numShapelets; i3++) {
                    dArr[i][i2][i3] = EPS * random.nextDouble() * Math.pow(-1.0d, random.nextInt(2));
                }
            }
        }
    }

    /* JADX WARN: Multi-variable type inference failed */
    public void performSGD(double[][][] dArr, double[][][] dArr2, double[] dArr3, double[] dArr4, double[][][] dArr5, double[][][] dArr6, double[][] dArr7, int[][] iArr, long j, int[] iArr2) {
        int scaleR = m83getConfig().scaleR();
        int minShapeletLength = m83getConfig().minShapeletLength();
        int maxIterations = m83getConfig().maxIterations();
        long seed = m83getConfig().seed();
        int numShapelets = m83getConfig().numShapelets();
        double learningRate = m83getConfig().learningRate();
        double regularization = m83getConfig().regularization();
        double gamma = m83getConfig().gamma();
        double[][][] dArr8 = new double[scaleR][];
        double[][][] dArr9 = new double[scaleR][];
        double[][][] dArr10 = new double[scaleR][];
        int[] iArr3 = new int[scaleR];
        for (int i = 0; i < scaleR; i++) {
            iArr3[i] = getNumberOfSegments(this.q, minShapeletLength, i);
            dArr8[i] = new double[this.numInstances][numShapelets][iArr3[i]];
            dArr9[i] = new double[this.numInstances][numShapelets][iArr3[i]];
            dArr10[i] = new double[this.numInstances][numShapelets][iArr3[i]];
        }
        double[][][] dArr11 = new double[scaleR][this.numInstances][numShapelets];
        double[][][] dArr12 = new double[scaleR][this.numInstances][numShapelets];
        double[][] dArr13 = new double[this.numInstances][this.numClasses];
        List<Integer> list = (List) IntStream.range(0, this.numInstances).boxed().collect(Collectors.toList());
        LOGGER.debug("Starting training for {} iterations...", Integer.valueOf(maxIterations));
        double[][][] dArr14 = new double[dArr.length][dArr[0].length][dArr[0][0].length];
        double[] dArr15 = new double[dArr3.length];
        double[][] dArr16 = new double[dArr5.length];
        for (int i2 = 0; i2 < dArr5.length; i2++) {
            dArr16[i2] = new double[dArr5[i2].length];
            for (int i3 = 0; i3 < dArr5[i2].length; i3++) {
                dArr16[i2][i3] = new double[dArr5[i2][i3].length];
            }
        }
        for (int i4 = 0; i4 < maxIterations; i4++) {
            if (this.useInstanceReordering) {
                list = shuffleAccordingToAlternatingClassScheme(list, iArr2, new Random(seed + i4));
            } else {
                Collections.shuffle(list, new Random(seed + i4));
            }
            for (int i5 = 0; i5 < this.numInstances; i5++) {
                int intValue = list.get(i5).intValue();
                for (int i6 = 0; i6 < scaleR; i6++) {
                    long length = dArr5[i6].length;
                    for (int i7 = 0; i7 < length; i7++) {
                        int i8 = iArr3[i6];
                        for (int i9 = 0; i9 < i8; i9++) {
                            double calculateD = calculateD(dArr5, minShapeletLength, i6, dArr7[intValue], i7, i9);
                            dArr8[i6][intValue][i7][i9] = calculateD;
                            dArr9[i6][intValue][i7][i9] = Math.exp((-30.0d) * calculateD);
                        }
                        double d = 0.0d;
                        double d2 = 0.0d;
                        for (int i10 = 0; i10 < i8; i10++) {
                            d += dArr9[i6][intValue][i7][i10];
                            d2 += dArr8[i6][intValue][i7][i10] * dArr9[i6][intValue][i7][i10];
                        }
                        dArr11[i6][intValue][i7] = d;
                        dArr12[i6][intValue][i7] = d2 / dArr11[i6][intValue][i7];
                    }
                }
                for (int i11 = 0; i11 < this.numClasses; i11++) {
                    double d3 = 0.0d;
                    for (int i12 = 0; i12 < scaleR; i12++) {
                        for (int i13 = 0; i13 < numShapelets; i13++) {
                            d3 += dArr12[i12][intValue][i13] * dArr[i11][i12][i13];
                        }
                    }
                    dArr13[intValue][i11] = iArr[intValue][i11] - MathUtil.sigmoid(d3);
                }
                for (int i14 = 0; i14 < this.numClasses; i14++) {
                    double d4 = dArr13[intValue][i14];
                    for (int i15 = 0; i15 < scaleR; i15++) {
                        for (int i16 = 0; i16 < dArr5[i15].length; i16++) {
                            double d5 = ((-1.0d) * dArr13[intValue][i14] * dArr12[i15][intValue][i16]) + (((2.0d * regularization) / this.numInstances) * dArr[i14][i15][i16]);
                            dArr14[i14][i15][i16] = (gamma * dArr14[i14][i15][i16]) + (learningRate * d5);
                            double[] dArr17 = dArr2[i14][i15];
                            int i17 = i16;
                            dArr17[i17] = dArr17[i17] + (d5 * d5);
                            double[] dArr18 = dArr[i14][i15];
                            int i18 = i16;
                            dArr18[i18] = dArr18[i18] - (dArr14[i14][i15][i16] / Math.sqrt(dArr2[i14][i15][i16] + EPS));
                            int i19 = iArr3[i15];
                            double d6 = 1.0d / (((i15 + 1.0d) * minShapeletLength) * dArr11[i15][intValue][i16]);
                            double[] dArr19 = new double[i19];
                            for (int i20 = 0; i20 < i19; i20++) {
                                dArr19[i20] = dArr9[i15][intValue][i16][i20] * (1.0d + ((-30.0d) * (dArr8[i15][intValue][i16][i20] - dArr12[i15][intValue][i16])));
                            }
                            for (int i21 = 0; i21 < (i15 + 1) * minShapeletLength; i21++) {
                                double d7 = 0.0d;
                                for (int i22 = 0; i22 < i19; i22++) {
                                    d7 += dArr19[i22] * (dArr5[i15][i16][i21] - dArr7[intValue][i22 + i21]);
                                }
                                double d8 = (-1.0d) * d4 * d7 * dArr[i14][i15][i16] * d6;
                                dArr16[i15][i16][i21] = (gamma * dArr16[i15][i16][i21]) + (learningRate * d8);
                                double[] dArr20 = dArr6[i15][i16];
                                int i23 = i21;
                                dArr20[i23] = dArr20[i23] + (d8 * d8);
                                double[] dArr21 = dArr5[i15][i16];
                                int i24 = i21;
                                dArr21[i24] = dArr21[i24] - (dArr16[i15][i16][i21] / Math.sqrt(dArr6[i15][i16][i21] + EPS));
                            }
                        }
                    }
                    dArr15[i14] = (gamma * dArr15[i14]) + (learningRate * d4);
                    int i25 = i14;
                    dArr4[i25] = dArr4[i25] + (d4 * d4);
                    int i26 = i14;
                    dArr3[i26] = dArr3[i26] + (dArr15[i14] / Math.sqrt(dArr4[i14] + EPS));
                }
            }
            if (i4 % 10 == 0) {
                LOGGER.debug("Iteration {}/{}", Integer.valueOf(i4), Integer.valueOf(maxIterations));
                if (System.currentTimeMillis() - j > this.timeout.milliseconds()) {
                    LOGGER.debug("Stopping training due to timeout.");
                    return;
                }
            }
        }
    }

    public List<Integer> shuffleAccordingToAlternatingClassScheme(List<Integer> list, int[] iArr, Random random) {
        if (list.size() != iArr.length) {
            throw new IllegalArgumentException("The number of instances must be equal to the number of available target values!");
        }
        HashMap hashMap = new HashMap();
        for (int i = 0; i < list.size(); i++) {
            int i2 = iArr[i];
            if (!hashMap.containsKey(Integer.valueOf(i2))) {
                hashMap.put(Integer.valueOf(i2), new ArrayList());
            }
            ((List) hashMap.get(Integer.valueOf(i2))).add(Integer.valueOf(i));
        }
        ArrayList arrayList = new ArrayList();
        for (List list2 : hashMap.values()) {
            Collections.shuffle(list2, random);
            arrayList.add(list2.iterator());
        }
        ArrayList arrayList2 = new ArrayList();
        Iterator it = Iterables.cycle(arrayList).iterator();
        for (int i3 = 0; i3 < list.size(); i3++) {
            int i4 = 0;
            while (true) {
                if (it.hasNext() && i4 < this.numClasses) {
                    Iterator it2 = (Iterator) it.next();
                    if (it2.hasNext()) {
                        arrayList2.add(it2.next());
                        break;
                    }
                    i4++;
                }
            }
        }
        return arrayList2;
    }

    public static double calculateMHat(double[][][] dArr, int i, int i2, double[] dArr2, int i3, int i4, double d) {
        double d2 = 0.0d;
        double d3 = 0.0d;
        for (int i5 = 0; i5 < getNumberOfSegments(i4, i, i2); i5++) {
            double calculateD = calculateD(dArr, i, i2, dArr2, i3, i5);
            double exp = Math.exp(d * calculateD);
            d2 += calculateD * exp;
            d3 += exp;
        }
        return d2 / (d3 == 0.0d ? EPS : d3);
    }

    public static double calculateD(double[][][] dArr, int i, int i2, double[] dArr2, int i3, int i4) {
        double d = 0.0d;
        for (int i5 = 0; i5 < (i2 + 1) * i; i5++) {
            d += Math.pow(dArr2[i4 + i5] - dArr[i2][i3][i5], 2.0d);
        }
        return d / ((i2 + 1) * i);
    }

    @Override // ai.libs.jaicore.ml.tsc.classifier.ASimplifiedTSCLearningAlgorithm
    public AlgorithmEvent nextWithException() {
        throw new UnsupportedOperationException("The operation to be performed is not supported.");
    }

    public static int getNumberOfSegments(int i, int i2, int i3) {
        return i - ((i3 + 1) * i2);
    }

    /* renamed from: getConfig, reason: merged with bridge method [inline-methods] */
    public ILearnShapeletsLearningAlgorithmConfig m83getConfig() {
        return super.getConfig();
    }

    public boolean isUseInstanceReordering() {
        return this.useInstanceReordering;
    }

    public void setUseInstanceReordering(boolean z) {
        this.useInstanceReordering = z;
    }

    public int getC() {
        return this.numClasses;
    }

    public void setC(int i) {
        this.numClasses = i;
    }
}
