package ml.dmlc.xgboost4j.java;

import hex.tree.xgboost.XGBoost;
import hex.tree.xgboost.XGBoostExtension;
import hex.tree.xgboost.XGBoostModel;
import hex.tree.xgboost.XGBoostOutput;
import hex.tree.xgboost.XGBoostUtils;
import java.io.ByteArrayInputStream;
import java.io.Closeable;
import java.io.File;
import java.io.FileOutputStream;
import java.io.IOException;
import java.util.HashMap;
import java.util.Map;
import water.ExtensionManager;
import water.H2O;
import water.MRTask;
import water.util.FileUtils;
import water.util.IcedHashMapGeneric;
import water.util.Log;

/* loaded from: input_file:ml/dmlc/xgboost4j/java/XGBoostUpdateTask.class */
public class XGBoostUpdateTask extends MRTask<XGBoostUpdateTask> {
    private final XGBoostModelInfo _sharedModel;
    private final XGBoostOutput _output;
    private transient Booster _booster;
    private byte[] _rawBooster;
    private final XGBoostModel.XGBoostParameters _parms;
    private final int _tid;
    private IcedHashMapGeneric.IcedHashMapStringString rabitEnv = new IcedHashMapGeneric.IcedHashMapStringString();
    private String[] _featureMap;

    public XGBoostUpdateTask(Booster booster, XGBoostModelInfo xGBoostModelInfo, XGBoostOutput xGBoostOutput, XGBoostModel.XGBoostParameters xGBoostParameters, int i, Map<String, String> map, String[] strArr) {
        this._sharedModel = xGBoostModelInfo;
        this._output = xGBoostOutput;
        this._parms = xGBoostParameters;
        this._tid = i;
        this._featureMap = strArr;
        this._rawBooster = XGBoost.getRawArray(booster);
        this.rabitEnv.putAll(map);
    }

    protected void setupLocal() {
        if (H2O.ARGS.client) {
            return;
        }
        if (!ExtensionManager.getInstance().isCoreExtensionEnabled(XGBoostExtension.NAME)) {
            throw new IllegalStateException("XGBoost is not available on the node " + H2O.SELF);
        }
        try {
            update();
        } catch (XGBoostError e) {
            try {
                Rabit.shutdown();
            } catch (XGBoostError e2) {
                e2.printStackTrace();
            }
            e.printStackTrace();
            throw new IllegalStateException("Failed XGBoost training.", e);
        }
    }

    private void update() throws XGBoostError {
        HashMap<String, Object> createParams = XGBoostModel.createParams(this._parms, this._output);
        this.rabitEnv.put("DMLC_TASK_ID", String.valueOf(H2O.SELF.index()));
        DMatrix convertFrameToDMatrix = XGBoostUtils.convertFrameToDMatrix(this._sharedModel._dataInfoKey, this._fr, true, this._parms._response_column, this._parms._weights_column, this._parms._fold_column, this._featureMap, this._output._sparse);
        if (null == convertFrameToDMatrix) {
            return;
        }
        try {
            Rabit.init(this.rabitEnv);
            if (this._rawBooster == null) {
                this._booster = XGBoost.train(convertFrameToDMatrix, createParams, 0, new HashMap(), (IObjective) null, (IEvaluation) null);
            } else {
                try {
                    this._booster = Booster.loadModel(new ByteArrayInputStream(this._rawBooster));
                    this._booster.setParams(createParams);
                    this._booster.update(convertFrameToDMatrix, this._tid);
                } catch (IOException e) {
                    e.printStackTrace();
                    throw new IllegalStateException("Failed to load the booster.", e);
                }
            }
            this._rawBooster = this._booster.toByteArray();
            try {
                Rabit.shutdown();
            } catch (XGBoostError e2) {
                Log.debug(new Object[]{"Rabit shutdown during update failed", e2});
            }
        } catch (Throwable th) {
            try {
                Rabit.shutdown();
            } catch (XGBoostError e3) {
                Log.debug(new Object[]{"Rabit shutdown during update failed", e3});
            }
            throw th;
        }
    }

    public void reduce(XGBoostUpdateTask xGBoostUpdateTask) {
        if (null == this._rawBooster) {
            this._rawBooster = xGBoostUpdateTask._rawBooster;
            this._featureMap = xGBoostUpdateTask._featureMap;
        }
    }

    private void updateFeatureMapFile(File file) {
        FileOutputStream fileOutputStream = null;
        try {
            try {
                fileOutputStream = new FileOutputStream(file);
                fileOutputStream.write(this._featureMap[0].getBytes());
                fileOutputStream.close();
                FileUtils.close(new Closeable[]{fileOutputStream});
            } catch (IOException e) {
                H2O.fail("Cannot generate " + file, e);
                FileUtils.close(new Closeable[]{fileOutputStream});
            }
        } catch (Throwable th) {
            FileUtils.close(new Closeable[]{fileOutputStream});
            throw th;
        }
    }

    public Booster getBooster() {
        return getBooster(null);
    }

    public Booster getBooster(File file) {
        if (null == this._booster) {
            try {
                this._booster = Booster.loadModel(new ByteArrayInputStream(this._rawBooster));
            } catch (XGBoostError | IOException e) {
                throw new IllegalStateException("Failed to load the booster.", e);
            }
        }
        if (file != null) {
            updateFeatureMapFile(file);
        }
        return this._booster;
    }
}
