/*
 * Decompiled with CFR 0.152.
 */
package org.apache.spark.ml.regression;

import java.io.Serializable;
import org.apache.spark.ml.linalg.DenseMatrix;
import org.apache.spark.ml.linalg.DenseVector;
import org.apache.spark.ml.linalg.Matrix;
import org.apache.spark.ml.linalg.Vector;
import org.apache.spark.ml.linalg.Vectors$;
import org.apache.spark.ml.regression.AdamWUpdater;
import org.apache.spark.ml.regression.BaseFactorizationMachinesGradient;
import org.apache.spark.ml.regression.LogisticFactorizationMachinesGradient;
import org.apache.spark.ml.regression.MSEFactorizationMachinesGradient;
import org.apache.spark.mllib.optimization.SquaredL2Updater;
import org.apache.spark.mllib.optimization.Updater;
import scala.Array$;
import scala.Function0;
import scala.Function1;
import scala.Function2;
import scala.MatchError;
import scala.Predef$;
import scala.Tuple2;
import scala.Tuple3;
import scala.collection.GenTraversableOnce;
import scala.collection.Seq;
import scala.collection.immutable.Nil$;
import scala.collection.mutable.ArrayOps;
import scala.reflect.ClassTag$;
import scala.runtime.BoxedUnit;
import scala.runtime.BoxesRunTime;
import scala.runtime.DoubleRef;
import scala.runtime.RichInt$;
import scala.runtime.java8.JFunction1;
import scala.runtime.java8.JFunction2;

public final class FactorizationMachines$
implements scala.Serializable {
    public static FactorizationMachines$ MODULE$;
    private final String GD;
    private final String AdamW;
    private final String[] supportedSolvers;
    private final String LogisticLoss;
    private final String SquaredError;
    private final String[] supportedRegressorLosses;
    private final String[] supportedClassifierLosses;
    private final String[] supportedLosses;

    static {
        new FactorizationMachines$();
    }

    public String GD() {
        return this.GD;
    }

    public String AdamW() {
        return this.AdamW;
    }

    public String[] supportedSolvers() {
        return this.supportedSolvers;
    }

    public String LogisticLoss() {
        return this.LogisticLoss;
    }

    public String SquaredError() {
        return this.SquaredError;
    }

    public String[] supportedRegressorLosses() {
        return this.supportedRegressorLosses;
    }

    public String[] supportedClassifierLosses() {
        return this.supportedClassifierLosses;
    }

    public String[] supportedLosses() {
        return this.supportedLosses;
    }

    public Updater parseSolver(String solver, int coefficientsSize) {
        Updater updater;
        String string = solver;
        String string2 = this.GD();
        String string3 = string;
        if (!(string2 != null ? !string2.equals(string3) : string3 != null)) {
            updater = new SquaredL2Updater();
        } else {
            String string4 = this.AdamW();
            String string5 = string;
            if (!(string4 != null ? !string4.equals(string5) : string5 != null)) {
                updater = new AdamWUpdater(coefficientsSize);
            } else {
                throw new MatchError((Object)string);
            }
        }
        return updater;
    }

    public BaseFactorizationMachinesGradient parseLoss(String lossFunc, int factorSize, boolean fitIntercept, boolean fitLinear, int numFeatures) {
        BaseFactorizationMachinesGradient baseFactorizationMachinesGradient;
        String string = lossFunc;
        String string2 = this.LogisticLoss();
        String string3 = string;
        if (!(string2 != null ? !string2.equals(string3) : string3 != null)) {
            baseFactorizationMachinesGradient = new LogisticFactorizationMachinesGradient(factorSize, fitIntercept, fitLinear, numFeatures);
        } else {
            String string4 = this.SquaredError();
            String string5 = string;
            if (!(string4 != null ? !string4.equals(string5) : string5 != null)) {
                baseFactorizationMachinesGradient = new MSEFactorizationMachinesGradient(factorSize, fitIntercept, fitLinear, numFeatures);
            } else {
                throw new IllegalArgumentException(new StringBuilder(35).append("loss function type ").append(lossFunc).append(" is invalidation").toString());
            }
        }
        return baseFactorizationMachinesGradient;
    }

    public Tuple3<Object, Vector, Matrix> splitCoefficients(Vector coefficients, int numFeatures, int factorSize, boolean fitIntercept, boolean fitLinear) {
        int coefficientsSize = numFeatures * factorSize + (fitLinear ? numFeatures : 0) + (fitIntercept ? 1 : 0);
        Predef$.MODULE$.require(coefficientsSize == coefficients.size(), (Function0 & Serializable & scala.Serializable)() -> new StringBuilder(50).append("coefficients.size did not match the excepted size ").append(coefficientsSize).toString());
        double intercept = fitIntercept ? coefficients.apply(coefficients.size() - 1) : 0.0;
        DenseVector linear = fitLinear ? new DenseVector((double[])new ArrayOps.ofDouble(Predef$.MODULE$.doubleArrayOps(coefficients.toArray())).slice(numFeatures * factorSize, numFeatures * factorSize + numFeatures)) : Vectors$.MODULE$.sparse(numFeatures, (Seq)Nil$.MODULE$);
        DenseMatrix factors = new DenseMatrix(numFeatures, factorSize, (double[])new ArrayOps.ofDouble(Predef$.MODULE$.doubleArrayOps(coefficients.toArray())).slice(0, numFeatures * factorSize), true);
        return new Tuple3((Object)BoxesRunTime.boxToDouble((double)intercept), (Object)linear, (Object)factors);
    }

    public Vector combineCoefficients(double intercept, Vector linear, Matrix factors, boolean fitIntercept, boolean fitLinear) {
        double[] coefficients = (double[])new ArrayOps.ofDouble(Predef$.MODULE$.doubleArrayOps((double[])new ArrayOps.ofDouble(Predef$.MODULE$.doubleArrayOps(factors.toDense().values())).$plus$plus((GenTraversableOnce)(fitLinear ? new ArrayOps.ofDouble(Predef$.MODULE$.doubleArrayOps(linear.toArray())) : new ArrayOps.ofDouble(Predef$.MODULE$.doubleArrayOps(Array$.MODULE$.emptyDoubleArray()))), Array$.MODULE$.canBuildFrom(ClassTag$.MODULE$.Double())))).$plus$plus((GenTraversableOnce)(fitIntercept ? new ArrayOps.ofDouble(Predef$.MODULE$.doubleArrayOps(new double[]{intercept})) : new ArrayOps.ofDouble(Predef$.MODULE$.doubleArrayOps(Array$.MODULE$.emptyDoubleArray()))), Array$.MODULE$.canBuildFrom(ClassTag$.MODULE$.Double()));
        return new DenseVector(coefficients);
    }

    public double getRawPrediction(Vector features, double intercept, Vector linear, Matrix factors) {
        DoubleRef rawPrediction = DoubleRef.create((double)(intercept + features.dot(linear)));
        RichInt$.MODULE$.until$extension0(Predef$.MODULE$.intWrapper(0), factors.numCols()).foreach$mVc$sp((Function1)(JFunction1.mcVI.sp & Serializable & scala.Serializable)f -> {
            DoubleRef sumSquare = DoubleRef.create((double)0.0);
            DoubleRef sum = DoubleRef.create((double)0.0);
            features.foreachNonZero((Function2)(JFunction2.mcVID.sp & Serializable & scala.Serializable)(x0$1, x1$1) -> {
                Tuple2.mcID.sp sp2 = new Tuple2.mcID.sp(x0$1, x1$1);
                if (sp2 != null) {
                    int index = sp2._1$mcI$sp();
                    double value = sp2._2$mcD$sp();
                    double vx = factors.apply(index, f) * value;
                    sumSquare$1.elem += vx * vx;
                    sum$1.elem += vx;
                } else {
                    throw new MatchError((Object)sp2);
                }
                BoxedUnit boxedUnit = BoxedUnit.UNIT;
            });
            rawPrediction$1.elem += 0.5 * (sum.elem * sum.elem - sumSquare.elem);
        });
        return rawPrediction.elem;
    }

    private Object readResolve() {
        return MODULE$;
    }

    private FactorizationMachines$() {
        MODULE$ = this;
        this.GD = "gd";
        this.AdamW = "adamW";
        this.supportedSolvers = (String[])((Object[])new String[]{this.GD(), this.AdamW()});
        this.LogisticLoss = "logisticLoss";
        this.SquaredError = "squaredError";
        this.supportedRegressorLosses = (String[])((Object[])new String[]{this.SquaredError()});
        this.supportedClassifierLosses = (String[])((Object[])new String[]{this.LogisticLoss()});
        this.supportedLosses = (String[])new ArrayOps.ofRef(Predef$.MODULE$.refArrayOps((Object[])this.supportedRegressorLosses())).$plus$plus((GenTraversableOnce)new ArrayOps.ofRef(Predef$.MODULE$.refArrayOps((Object[])this.supportedClassifierLosses())), Array$.MODULE$.canBuildFrom(ClassTag$.MODULE$.apply(String.class)));
    }
}

