/*
 * Decompiled with CFR 0.152.
 */
package org.apache.spark.ml.optim.aggregator;

import java.io.Serializable;
import java.util.Arrays;
import org.apache.spark.broadcast.Broadcast;
import org.apache.spark.ml.feature.InstanceBlock;
import org.apache.spark.ml.linalg.BLAS$;
import org.apache.spark.ml.linalg.DenseVector;
import org.apache.spark.ml.linalg.DenseVector$;
import org.apache.spark.ml.linalg.Vector;
import org.apache.spark.ml.optim.aggregator.DifferentiableLossAggregator;
import scala.Array$;
import scala.Function0;
import scala.Function1;
import scala.Option;
import scala.Predef$;
import scala.collection.mutable.ArrayOps;
import scala.math.package$;
import scala.reflect.ClassTag$;
import scala.reflect.ScalaSignature;
import scala.runtime.java8.JFunction1;

@ScalaSignature(bytes="\u0006\u0001\t4Q!\u0004\b\u0001%iA\u0001\u0002\f\u0001\u0003\u0002\u0003\u0006IA\f\u0005\tu\u0001\u0011\t\u0011)A\u0005w!Aa\b\u0001B\u0001B\u0003%q\bC\u0003G\u0001\u0011\u0005q\tC\u0004M\u0001\t\u0007I\u0011K'\t\rE\u0003\u0001\u0015!\u0003O\u0011\u001d\u0011\u0006A1A\u0005\n5Caa\u0015\u0001!\u0002\u0013q\u0005\u0002\u0003+\u0001\u0011\u000b\u0007I\u0011B+\t\u000fi\u0003!\u0019!C\u00057\"1A\f\u0001Q\u0001\n]BQ!\u0018\u0001\u0005\u0002y\u0013!#\u0011$U\u00052|7m[!hOJ,w-\u0019;pe*\u0011q\u0002E\u0001\u000bC\u001e<'/Z4bi>\u0014(BA\t\u0013\u0003\u0015y\u0007\u000f^5n\u0015\t\u0019B#\u0001\u0002nY*\u0011QCF\u0001\u0006gB\f'o\u001b\u0006\u0003/a\ta!\u00199bG\",'\"A\r\u0002\u0007=\u0014xmE\u0002\u00017\u0005\u0002\"\u0001H\u0010\u000e\u0003uQ\u0011AH\u0001\u0006g\u000e\fG.Y\u0005\u0003Au\u0011a!\u00118z%\u00164\u0007\u0003\u0002\u0012$K-j\u0011AD\u0005\u0003I9\u0011A\u0004R5gM\u0016\u0014XM\u001c;jC\ndW\rT8tg\u0006;wM]3hCR|'\u000f\u0005\u0002'S5\tqE\u0003\u0002)%\u00059a-Z1ukJ,\u0017B\u0001\u0016(\u00055Ien\u001d;b]\u000e,'\t\\8dWB\u0011!\u0005A\u0001\rE\u000e\u001c6-\u00197fI6+\u0017M\\\u0002\u0001!\ry#\u0007N\u0007\u0002a)\u0011\u0011\u0007F\u0001\nEJ|\u0017\rZ2bgRL!a\r\u0019\u0003\u0013\t\u0013x.\u00193dCN$\bc\u0001\u000f6o%\u0011a'\b\u0002\u0006\u0003J\u0014\u0018-\u001f\t\u00039aJ!!O\u000f\u0003\r\u0011{WO\u00197f\u000311\u0017\u000e^%oi\u0016\u00148-\u001a9u!\taB(\u0003\u0002>;\t9!i\\8mK\u0006t\u0017A\u00042d\u0007>,gMZ5dS\u0016tGo\u001d\t\u0004_I\u0002\u0005CA!E\u001b\u0005\u0011%BA\"\u0013\u0003\u0019a\u0017N\\1mO&\u0011QI\u0011\u0002\u0007-\u0016\u001cGo\u001c:\u0002\rqJg.\u001b;?)\rA%j\u0013\u000b\u0003W%CQA\u0010\u0003A\u0002}BQ\u0001\f\u0003A\u00029BQA\u000f\u0003A\u0002m\n1\u0001Z5n+\u0005q\u0005C\u0001\u000fP\u0013\t\u0001VDA\u0002J]R\fA\u0001Z5nA\u0005Ya.^7GK\u0006$XO]3t\u00031qW/\u001c$fCR,(/Z:!\u0003E\u0019w.\u001a4gS\u000eLWM\u001c;t\u0003J\u0014\u0018-_\u000b\u0002i!\u0012\u0011b\u0016\t\u00039aK!!W\u000f\u0003\u0013Q\u0014\u0018M\\:jK:$\u0018\u0001D7be\u001eLgn\u00144gg\u0016$X#A\u001c\u0002\u001b5\f'oZ5o\u001f\u001a47/\u001a;!\u0003\r\tG\r\u001a\u000b\u0003?\u0002l\u0011\u0001\u0001\u0005\u0006C2\u0001\r!J\u0001\u0006E2|7m\u001b")
public class AFTBlockAggregator
implements DifferentiableLossAggregator<InstanceBlock, AFTBlockAggregator> {
    private transient double[] coefficientsArray;
    private final Broadcast<double[]> bcScaledMean;
    private final boolean fitIntercept;
    private final Broadcast<Vector> bcCoefficients;
    private final int dim;
    private final int numFeatures;
    private final double marginOffset;
    private double weightSum;
    private double lossSum;
    private double[] gradientSumArray;
    private volatile transient boolean bitmap$trans$0;
    private volatile boolean bitmap$0;

    @Override
    public DifferentiableLossAggregator merge(DifferentiableLossAggregator other) {
        return DifferentiableLossAggregator.merge$(this, other);
    }

    @Override
    public Vector gradient() {
        return DifferentiableLossAggregator.gradient$(this);
    }

    @Override
    public double weight() {
        return DifferentiableLossAggregator.weight$(this);
    }

    @Override
    public double loss() {
        return DifferentiableLossAggregator.loss$(this);
    }

    @Override
    public double weightSum() {
        return this.weightSum;
    }

    @Override
    public void weightSum_$eq(double x$1) {
        this.weightSum = x$1;
    }

    @Override
    public double lossSum() {
        return this.lossSum;
    }

    @Override
    public void lossSum_$eq(double x$1) {
        this.lossSum = x$1;
    }

    private double[] gradientSumArray$lzycompute() {
        AFTBlockAggregator aFTBlockAggregator = this;
        synchronized (aFTBlockAggregator) {
            if (!this.bitmap$0) {
                this.gradientSumArray = DifferentiableLossAggregator.gradientSumArray$(this);
                this.bitmap$0 = true;
            }
        }
        return this.gradientSumArray;
    }

    @Override
    public double[] gradientSumArray() {
        return !this.bitmap$0 ? this.gradientSumArray$lzycompute() : this.gradientSumArray;
    }

    @Override
    public int dim() {
        return this.dim;
    }

    private int numFeatures() {
        return this.numFeatures;
    }

    private double[] coefficientsArray$lzycompute() {
        AFTBlockAggregator aFTBlockAggregator = this;
        synchronized (aFTBlockAggregator) {
            if (!this.bitmap$trans$0) {
                double[] values;
                DenseVector denseVector;
                Option option;
                Vector vector = (Vector)this.bcCoefficients.value();
                if (!(vector instanceof DenseVector) || (option = DenseVector$.MODULE$.unapply(denseVector = (DenseVector)vector)).isEmpty()) {
                    throw new IllegalArgumentException(new StringBuilder(54).append("coefficients only supports dense vector").append(" but got type ").append(this.bcCoefficients.value().getClass()).append(".").toString());
                }
                double[] dArray = values = (double[])option.get();
                this.coefficientsArray = dArray;
                this.bitmap$trans$0 = true;
            }
        }
        return this.coefficientsArray;
    }

    private double[] coefficientsArray() {
        return !this.bitmap$trans$0 ? this.coefficientsArray$lzycompute() : this.coefficientsArray;
    }

    private double marginOffset() {
        return this.marginOffset;
    }

    @Override
    public AFTBlockAggregator add(InstanceBlock block) {
        Predef$.MODULE$.require(block.matrix().isTransposed());
        Predef$.MODULE$.require(this.numFeatures() == block.numFeatures(), (Function0 & Serializable & scala.Serializable)() -> new StringBuilder(66).append("Dimensions mismatch when adding new ").append("instance. Expecting ").append(this.numFeatures()).append(" but got ").append(block.numFeatures()).append(".").toString());
        Predef$.MODULE$.require(new ArrayOps.ofDouble(Predef$.MODULE$.doubleArrayOps(block.labels())).forall((Function1)(JFunction1.mcZD.sp & Serializable & scala.Serializable)x$1 -> x$1 > 0.0), (Function0 & Serializable & scala.Serializable)() -> "The lifetime or label should be greater than 0.");
        int size = block.size();
        double sigma = package$.MODULE$.exp(this.coefficientsArray()[this.dim() - 1]);
        double[] arr = (double[])Array$.MODULE$.ofDim(size, ClassTag$.MODULE$.Double());
        if (this.fitIntercept) {
            Arrays.fill(arr, this.marginOffset());
        }
        BLAS$.MODULE$.gemv(1.0, block.matrix(), this.coefficientsArray(), 1.0, arr);
        double localLossSum = 0.0;
        double sigmaGradSum = 0.0;
        double multiplierSum = 0.0;
        for (int i = 0; i < size; ++i) {
            double multiplier;
            double ti = block.getLabel(i);
            double delta = block.getWeight().apply$mcDI$sp(i);
            double margin = arr[i];
            double epsilon = (package$.MODULE$.log(ti) - margin) / sigma;
            double expEpsilon = package$.MODULE$.exp(epsilon);
            localLossSum += delta * package$.MODULE$.log(sigma) - delta * epsilon + expEpsilon;
            arr[i] = multiplier = (delta - expEpsilon) / sigma;
            multiplierSum += multiplier;
            sigmaGradSum += delta + multiplier * sigma * epsilon;
        }
        this.lossSum_$eq(this.lossSum() + localLossSum);
        this.weightSum_$eq(this.weightSum() + (double)size);
        BLAS$.MODULE$.gemv(1.0, block.matrix().transpose(), arr, 1.0, this.gradientSumArray());
        if (this.fitIntercept) {
            BLAS$.MODULE$.javaBLAS().daxpy(this.numFeatures(), -multiplierSum, (double[])this.bcScaledMean.value(), 1, this.gradientSumArray(), 1);
            int n = this.dim() - 2;
            this.gradientSumArray()[n] = this.gradientSumArray()[n] + multiplierSum;
        }
        int n = this.dim() - 1;
        this.gradientSumArray()[n] = this.gradientSumArray()[n] + sigmaGradSum;
        return this;
    }

    public AFTBlockAggregator(Broadcast<double[]> bcScaledMean, boolean fitIntercept, Broadcast<Vector> bcCoefficients) {
        this.bcScaledMean = bcScaledMean;
        this.fitIntercept = fitIntercept;
        this.bcCoefficients = bcCoefficients;
        DifferentiableLossAggregator.$init$(this);
        this.dim = ((Vector)bcCoefficients.value()).size();
        this.numFeatures = this.dim() - 2;
        this.marginOffset = fitIntercept ? this.coefficientsArray()[this.dim() - 2] - BLAS$.MODULE$.getBLAS(this.numFeatures()).ddot(this.numFeatures(), this.coefficientsArray(), 1, (double[])bcScaledMean.value(), 1) : Double.NaN;
    }
}

