package hex.deeplearning;

import hex.FrameTask;
import hex.deeplearning.DeepLearningModel;
import hex.deeplearning.Neurons;
import java.util.Arrays;
import java.util.Random;
import water.H2O;
import water.Key;
import water.util.Log;

/* loaded from: input_file:hex/deeplearning/DeepLearningTask.class */
public class DeepLearningTask extends FrameTask<DeepLearningTask> {
    private final boolean _training;
    private DeepLearningModel.DeepLearningModelInfo _input;
    DeepLearningModel.DeepLearningModelInfo _output;
    transient Neurons[] _neurons;
    int _chunk_node_count;
    static long _lastWarn;
    static long _warnCount;
    static final /* synthetic */ boolean $assertionsDisabled;

    public final DeepLearningModel.DeepLearningModelInfo model_info() {
        return this._output;
    }

    public DeepLearningTask(Key key, DeepLearningModel.DeepLearningModelInfo deepLearningModelInfo, float f) {
        this(key, deepLearningModelInfo, f, null);
    }

    private DeepLearningTask(Key key, DeepLearningModel.DeepLearningModelInfo deepLearningModelInfo, float f, H2O.H2OCountedCompleter h2OCountedCompleter) {
        super(key, deepLearningModelInfo.data_info(), h2OCountedCompleter);
        this._chunk_node_count = 1;
        this._training = true;
        this._input = deepLearningModelInfo;
        this._useFraction = f;
        this._shuffle = this._input.get_params()._shuffle_training_data;
        if (!$assertionsDisabled && this._output != null) {
            throw new AssertionError();
        }
    }

    /* JADX INFO: Access modifiers changed from: protected */
    @Override // hex.FrameTask
    public void setupLocal() {
        super.setupLocal();
        this._output = this._input;
        this._input = null;
        this._output.set_processed_local(0L);
    }

    /* JADX INFO: Access modifiers changed from: protected */
    @Override // hex.FrameTask
    public void chunkInit() {
        this._neurons = makeNeuronsForTraining(this._output);
    }

    @Override // hex.FrameTask
    public final void processRow(long j, double[] dArr, int i, int[] iArr, double[] dArr2) {
        long nextLong = model_info().get_params()._reproducible ? j + model_info().get_processed_global() : new Random().nextLong();
        ((Neurons.Input) this._neurons[0]).setInput(nextLong, dArr, i, iArr);
        step(nextLong, this._neurons, this._output, this._training, dArr2);
    }

    @Override // hex.FrameTask
    protected void chunkDone(long j) {
        if (this._training) {
            this._output.add_processed_local(j);
        }
    }

    public void reduce(DeepLearningTask deepLearningTask) {
        if (deepLearningTask._output.get_processed_local() > 0 && deepLearningTask._output != this._output) {
            if (this._output.get_processed_local() == 0) {
                this._output = deepLearningTask._output;
                this._chunk_node_count = deepLearningTask._chunk_node_count;
            } else {
                this._output.add(deepLearningTask._output);
                this._chunk_node_count += deepLearningTask._chunk_node_count;
            }
        }
        if (deepLearningTask._output.unstable()) {
            this._output.set_unstable();
        }
    }

    protected void postGlobal() {
        if (H2O.CLOUD.size() > 1 && !this._output.get_params()._replicate_training_data) {
            long currentTimeMillis = System.currentTimeMillis();
            if (this._chunk_node_count < H2O.CLOUD.size() && currentTimeMillis - _lastWarn > 5000 && _warnCount < 3) {
                Log.warn(new Object[]{(H2O.CLOUD.size() - this._chunk_node_count) + " node(s) (out of " + H2O.CLOUD.size() + ") are not contributing to model updates. Consider setting replicate_training_data to true or using a larger training dataset (or fewer H2O nodes)."});
                _lastWarn = currentTimeMillis;
                _warnCount++;
            }
        }
        if (!this._output.get_params()._replicate_training_data || H2O.CLOUD.size() == 1) {
            this._output.div(this._chunk_node_count);
            this._output.add_processed_global(this._output.get_processed_local());
            this._output.set_processed_local(0L);
        }
        if (!$assertionsDisabled && this._input != null) {
            throw new AssertionError();
        }
    }

    public static Neurons[] makeNeuronsForTraining(DeepLearningModel.DeepLearningModelInfo deepLearningModelInfo) {
        return makeNeurons(deepLearningModelInfo, true);
    }

    public static Neurons[] makeNeuronsForTesting(DeepLearningModel.DeepLearningModelInfo deepLearningModelInfo) {
        return makeNeurons(deepLearningModelInfo, false);
    }

    private static Neurons[] makeNeurons(DeepLearningModel.DeepLearningModelInfo deepLearningModelInfo, boolean z) {
        FrameTask.DataInfo data_info = deepLearningModelInfo.data_info();
        DeepLearningModel.DeepLearningParameters deepLearningParameters = deepLearningModelInfo.get_params();
        int[] iArr = deepLearningParameters._hidden;
        Neurons[] neuronsArr = new Neurons[iArr.length + 2];
        neuronsArr[0] = new Neurons.Input(data_info.fullN(), data_info);
        for (int i = 0; i < iArr.length; i++) {
            switch (deepLearningParameters._activation) {
                case Tanh:
                    neuronsArr[i + 1] = new Neurons.Tanh(iArr[i]);
                    break;
                case TanhWithDropout:
                    neuronsArr[i + 1] = new Neurons.TanhDropout(iArr[i]);
                    break;
                case Rectifier:
                    neuronsArr[i + 1] = new Neurons.Rectifier(iArr[i]);
                    break;
                case RectifierWithDropout:
                    neuronsArr[i + 1] = new Neurons.RectifierDropout(iArr[i]);
                    break;
                case Maxout:
                    neuronsArr[i + 1] = new Neurons.Maxout(iArr[i]);
                    break;
                case MaxoutWithDropout:
                    neuronsArr[i + 1] = new Neurons.MaxoutDropout(iArr[i]);
                    break;
            }
        }
        if (deepLearningModelInfo._classification) {
            neuronsArr[neuronsArr.length - 1] = new Neurons.Softmax(deepLearningModelInfo.units[deepLearningModelInfo.units.length - 1]);
        } else {
            neuronsArr[neuronsArr.length - 1] = new Neurons.Linear(1);
        }
        for (int i2 = 0; i2 < neuronsArr.length; i2++) {
            neuronsArr[i2].init(neuronsArr, i2, deepLearningParameters, deepLearningModelInfo, z);
        }
        return neuronsArr;
    }

    public static void step(long j, Neurons[] neuronsArr, DeepLearningModel.DeepLearningModelInfo deepLearningModelInfo, boolean z, double[] dArr) {
        for (int i = 1; i < neuronsArr.length - 1; i++) {
            try {
                neuronsArr[i].fprop(j, z);
            } catch (RuntimeException e) {
                Log.warn(new Object[]{e.getMessage()});
                deepLearningModelInfo.set_unstable();
                throw new RuntimeException("Canceling job due to numerical instability.");
            }
        }
        if (deepLearningModelInfo._classification) {
            ((Neurons.Softmax) neuronsArr[neuronsArr.length - 1]).fprop();
            if (z) {
                for (int i2 = 1; i2 < neuronsArr.length - 1; i2++) {
                    Arrays.fill(neuronsArr[i2]._e.raw(), 0.0f);
                }
                if (!$assertionsDisabled && ((int) dArr[0]) != dArr[0]) {
                    throw new AssertionError();
                }
                ((Neurons.Softmax) neuronsArr[neuronsArr.length - 1]).bprop((int) dArr[0]);
            }
        } else {
            ((Neurons.Linear) neuronsArr[neuronsArr.length - 1]).fprop();
            if (z) {
                for (int i3 = 1; i3 < neuronsArr.length - 1; i3++) {
                    Arrays.fill(neuronsArr[i3]._e.raw(), 0.0f);
                }
                ((Neurons.Linear) neuronsArr[neuronsArr.length - 1]).bprop((float) dArr[0]);
            }
        }
        if (z) {
            for (int length = neuronsArr.length - 2; length > 0; length--) {
                neuronsArr[length].bprop();
            }
        }
    }

    static {
        $assertionsDisabled = !DeepLearningTask.class.desiredAssertionStatus();
    }
}
