/*
 * Decompiled with CFR 0.152.
 */
package org.apache.spark.mllib.optimization;

import breeze.linalg.DenseVector;
import breeze.linalg.DenseVector$;
import breeze.storage.Zero;
import org.apache.spark.Logging;
import org.apache.spark.annotation.DeveloperApi;
import org.apache.spark.broadcast.Broadcast;
import org.apache.spark.mllib.linalg.Vector;
import org.apache.spark.mllib.linalg.Vectors$;
import org.apache.spark.mllib.optimization.Gradient;
import org.apache.spark.mllib.optimization.GradientDescent$;
import org.apache.spark.mllib.optimization.Updater;
import org.apache.spark.mllib.rdd.RDDFunctions;
import org.apache.spark.mllib.rdd.RDDFunctions$;
import org.apache.spark.rdd.RDD;
import org.slf4j.Logger;
import scala.Function0;
import scala.Function1;
import scala.MatchError;
import scala.Predef$;
import scala.Serializable;
import scala.Tuple2;
import scala.collection.Seq;
import scala.collection.TraversableOnce;
import scala.collection.immutable.StringOps;
import scala.collection.mutable.ArrayBuffer;
import scala.reflect.ClassTag$;
import scala.runtime.BoxesRunTime;
import scala.runtime.DoubleRef;
import scala.runtime.ObjectRef;
import scala.runtime.RichInt$;

@DeveloperApi
public final class GradientDescent$
implements Logging,
Serializable {
    public static final GradientDescent$ MODULE$;
    private transient Logger org$apache$spark$Logging$$log_;

    static {
        new GradientDescent$();
    }

    public Logger org$apache$spark$Logging$$log_() {
        return this.org$apache$spark$Logging$$log_;
    }

    public void org$apache$spark$Logging$$log__$eq(Logger x$1) {
        this.org$apache$spark$Logging$$log_ = x$1;
    }

    public String logName() {
        return Logging.class.logName((Logging)this);
    }

    public Logger log() {
        return Logging.class.log((Logging)this);
    }

    public void logInfo(Function0<String> msg) {
        Logging.class.logInfo((Logging)this, msg);
    }

    public void logDebug(Function0<String> msg) {
        Logging.class.logDebug((Logging)this, msg);
    }

    public void logTrace(Function0<String> msg) {
        Logging.class.logTrace((Logging)this, msg);
    }

    public void logWarning(Function0<String> msg) {
        Logging.class.logWarning((Logging)this, msg);
    }

    public void logError(Function0<String> msg) {
        Logging.class.logError((Logging)this, msg);
    }

    public void logInfo(Function0<String> msg, Throwable throwable) {
        Logging.class.logInfo((Logging)this, msg, (Throwable)throwable);
    }

    public void logDebug(Function0<String> msg, Throwable throwable) {
        Logging.class.logDebug((Logging)this, msg, (Throwable)throwable);
    }

    public void logTrace(Function0<String> msg, Throwable throwable) {
        Logging.class.logTrace((Logging)this, msg, (Throwable)throwable);
    }

    public void logWarning(Function0<String> msg, Throwable throwable) {
        Logging.class.logWarning((Logging)this, msg, (Throwable)throwable);
    }

    public void logError(Function0<String> msg, Throwable throwable) {
        Logging.class.logError((Logging)this, msg, (Throwable)throwable);
    }

    public boolean isTraceEnabled() {
        return Logging.class.isTraceEnabled((Logging)this);
    }

    public Tuple2<Vector, double[]> runMiniBatchSGD(RDD<Tuple2<Object, Vector>> data, Gradient gradient, Updater updater, double stepSize, int numIterations, double regParam, double miniBatchFraction, Vector initialWeights) {
        ArrayBuffer stochasticLossHistory = new ArrayBuffer(numIterations);
        long numExamples = data.count();
        double miniBatchSize = (double)numExamples * miniBatchFraction;
        if (numExamples == 0L) {
            this.logInfo((Function0<String>)new Serializable(){
                public static final long serialVersionUID = 0L;

                public final String apply() {
                    return "GradientDescent.runMiniBatchSGD returning initial weights, no data found";
                }
            });
            return new Tuple2((Object)initialWeights, stochasticLossHistory.toArray(ClassTag$.MODULE$.Double()));
        }
        ObjectRef weights = new ObjectRef((Object)Vectors$.MODULE$.dense(initialWeights.toArray()));
        int n = ((Vector)weights.elem).size();
        DoubleRef regVal = new DoubleRef(updater.compute((Vector)weights.elem, Vectors$.MODULE$.dense(new double[((Vector)weights.elem).size()]), 0.0, 1, regParam)._2$mcD$sp());
        RichInt$.MODULE$.to$extension0(Predef$.MODULE$.intWrapper(1), numIterations).foreach$mVc$sp((Function1)new Serializable(data, gradient, updater, stepSize, regParam, miniBatchFraction, stochasticLossHistory, miniBatchSize, weights, n, regVal){
            public static final long serialVersionUID = 0L;
            private final RDD data$1;
            public final Gradient gradient$1;
            private final Updater updater$1;
            private final double stepSize$1;
            private final double regParam$1;
            private final double miniBatchFraction$1;
            private final ArrayBuffer stochasticLossHistory$1;
            private final double miniBatchSize$1;
            private final ObjectRef weights$1;
            private final int n$1;
            private final DoubleRef regVal$1;

            public final void apply(int i) {
                this.apply$mcVI$sp(i);
            }

            public void apply$mcVI$sp(int i) {
                int x$5;
                Serializable x$4;
                Serializable x$3;
                Tuple2 x$2;
                Broadcast bcWeights = this.data$1.context().broadcast((Object)((Vector)this.weights$1.elem), ClassTag$.MODULE$.apply(Vector.class));
                RDDFunctions<T> qual$1 = RDDFunctions$.MODULE$.fromRDD(this.data$1.sample(false, this.miniBatchFraction$1, (long)(42 + i)), ClassTag$.MODULE$.apply(Tuple2.class));
                Tuple2 tuple2 = qual$1.treeAggregate(x$2 = new Tuple2((Object)DenseVector$.MODULE$.zeros$mDc$sp(this.n$1, ClassTag$.MODULE$.Double(), (Zero)Zero.DoubleZero$.MODULE$), (Object)BoxesRunTime.boxToDouble((double)0.0)), x$3 = new Serializable(this, bcWeights){
                    public static final long serialVersionUID = 0L;
                    private final /* synthetic */ anonfun.runMiniBatchSGD.1 $outer;
                    private final Broadcast bcWeights$1;

                    public final Tuple2<DenseVector<Object>, Object> apply(Tuple2<DenseVector<Object>, Object> c, Tuple2<Object, Vector> v) {
                        Tuple2 tuple2 = new Tuple2(c, v);
                        if (tuple2 != null) {
                            Tuple2 tuple22 = (Tuple2)tuple2._1();
                            Tuple2 tuple23 = (Tuple2)tuple2._2();
                            if (tuple22 != null) {
                                DenseVector grad = (DenseVector)tuple22._1();
                                double loss = tuple22._2$mcD$sp();
                                if (tuple23 != null) {
                                    double label = tuple23._1$mcD$sp();
                                    Vector features = (Vector)tuple23._2();
                                    double l = this.$outer.gradient$1.compute(features, label, (Vector)this.bcWeights$1.value(), Vectors$.MODULE$.fromBreeze((breeze.linalg.Vector<Object>)grad));
                                    Tuple2 tuple24 = new Tuple2((Object)grad, (Object)BoxesRunTime.boxToDouble((double)(loss + l)));
                                    return tuple24;
                                }
                            }
                        }
                        throw new MatchError((Object)tuple2);
                    }
                    {
                        if ($outer == null) {
                            throw new NullPointerException();
                        }
                        this.$outer = $outer;
                        this.bcWeights$1 = bcWeights$1;
                    }
                }, x$4 = new Serializable(this){
                    public static final long serialVersionUID = 0L;

                    public final Tuple2<DenseVector<Object>, Object> apply(Tuple2<DenseVector<Object>, Object> c1, Tuple2<DenseVector<Object>, Object> c2) {
                        Tuple2 tuple2 = new Tuple2(c1, c2);
                        if (tuple2 != null) {
                            Tuple2 tuple22 = (Tuple2)tuple2._1();
                            Tuple2 tuple23 = (Tuple2)tuple2._2();
                            if (tuple22 != null) {
                                DenseVector grad1 = (DenseVector)tuple22._1();
                                double loss1 = tuple22._2$mcD$sp();
                                if (tuple23 != null) {
                                    DenseVector grad2 = (DenseVector)tuple23._1();
                                    double loss2 = tuple23._2$mcD$sp();
                                    Tuple2 tuple24 = new Tuple2(grad1.$plus$eq((Object)grad2, DenseVector$.MODULE$.canAddIntoD()), (Object)BoxesRunTime.boxToDouble((double)(loss1 + loss2)));
                                    return tuple24;
                                }
                            }
                        }
                        throw new MatchError((Object)tuple2);
                    }
                }, x$5 = qual$1.treeAggregate$default$4(x$2), ClassTag$.MODULE$.apply(Tuple2.class));
                if (tuple2 != null) {
                    Tuple2 tuple22;
                    DenseVector gradientSum = (DenseVector)tuple2._1();
                    double lossSum = tuple2._2$mcD$sp();
                    Tuple2 tuple23 = tuple22 = new Tuple2((Object)gradientSum, (Object)BoxesRunTime.boxToDouble((double)lossSum));
                    DenseVector gradientSum2 = (DenseVector)tuple23._1();
                    double lossSum2 = tuple23._2$mcD$sp();
                    this.stochasticLossHistory$1.append((Seq)Predef$.MODULE$.wrapDoubleArray(new double[]{lossSum2 / this.miniBatchSize$1 + this.regVal$1.elem}));
                    Tuple2<Vector, Object> update = this.updater$1.compute((Vector)this.weights$1.elem, Vectors$.MODULE$.fromBreeze((breeze.linalg.Vector<Object>)((breeze.linalg.Vector)gradientSum2.$div((Object)BoxesRunTime.boxToDouble((double)this.miniBatchSize$1), DenseVector$.MODULE$.dv_s_Op_Double_OpDiv()))), this.stepSize$1, i, this.regParam$1);
                    this.weights$1.elem = (Vector)update._1();
                    this.regVal$1.elem = update._2$mcD$sp();
                    return;
                }
                throw new MatchError((Object)tuple2);
            }
            {
                this.data$1 = data$1;
                this.gradient$1 = gradient$1;
                this.updater$1 = updater$1;
                this.stepSize$1 = stepSize$1;
                this.regParam$1 = regParam$1;
                this.miniBatchFraction$1 = miniBatchFraction$1;
                this.stochasticLossHistory$1 = stochasticLossHistory$1;
                this.miniBatchSize$1 = miniBatchSize$1;
                this.weights$1 = weights$1;
                this.n$1 = n$1;
                this.regVal$1 = regVal$1;
            }
        });
        this.logInfo((Function0<String>)new Serializable(stochasticLossHistory){
            public static final long serialVersionUID = 0L;
            private final ArrayBuffer stochasticLossHistory$1;

            public final String apply() {
                return new StringOps(Predef$.MODULE$.augmentString("GradientDescent.runMiniBatchSGD finished. Last 10 stochastic losses %s")).format((Seq)Predef$.MODULE$.genericWrapArray((Object)new Object[]{((TraversableOnce)this.stochasticLossHistory$1.takeRight(10)).mkString(", ")}));
            }
            {
                this.stochasticLossHistory$1 = stochasticLossHistory$1;
            }
        });
        return new Tuple2((Object)((Vector)weights.elem), stochasticLossHistory.toArray(ClassTag$.MODULE$.Double()));
    }

    private Object readResolve() {
        return MODULE$;
    }

    private GradientDescent$() {
        MODULE$ = this;
        Logging.class.$init$((Logging)this);
    }
}

