package hex.tree.xgboost;

import biz.k11i.xgboost.Predictor;
import hex.genmodel.algos.tree.SharedTreeGraph;
import hex.genmodel.algos.tree.SharedTreeNode;
import hex.genmodel.algos.tree.SharedTreeSubgraph;
import hex.tree.xgboost.XGBoostModel;
import java.io.ByteArrayInputStream;
import org.junit.Assert;
import org.junit.Assume;
import org.junit.BeforeClass;
import org.junit.Test;
import water.DKV;
import water.ExtensionManager;
import water.Key;
import water.Scope;
import water.TestUtil;
import water.fvec.Frame;

/* loaded from: input_file:hex/tree/xgboost/XGBoostTreeConverterTest.class */
public class XGBoostTreeConverterTest extends TestUtil {
    @BeforeClass
    public static void stall() {
        stall_till_cloudsize(1);
        Assume.assumeTrue("XGBoost was not loaded!\nH2O XGBoost needs binary compatible environment;Make sure that you have correct libraries installedand correctly configured LD_LIBRARY_PATH, especiallymake sure that CUDA libraries are available if you are running on GPU!", ExtensionManager.getInstance().isCoreExtensionsEnabled(XGBoostExtension.NAME));
    }

    @Test
    public void convertXGBoostTree_weather() throws Exception {
        Frame frame = null;
        XGBoostModel xGBoostModel = null;
        Scope.enter();
        try {
            frame = parse_test_file("./smalldata/junit/weather.csv");
            Scope.track(frame.replace(frame.find("PressureChange"), frame.vecs()[frame.find("PressureChange")].toCategoricalVec()));
            DKV.put(frame);
            XGBoostModel.XGBoostParameters xGBoostParameters = new XGBoostModel.XGBoostParameters();
            xGBoostParameters._ntrees = 1;
            xGBoostParameters._max_depth = 3;
            xGBoostParameters._train = frame._key;
            xGBoostParameters._response_column = "PressureChange";
            xGBoostModel = (XGBoostModel) new XGBoost(xGBoostParameters).trainModel().get();
            new Predictor(new ByteArrayInputStream(xGBoostModel.model_info()._boosterBytes)).getBooster().getGroupedTrees()[0][0].getNodes();
            SharedTreeGraph convert = xGBoostModel.convert(0, "down");
            XGBoostUtils.makeFeatureMap(frame, xGBoostModel.model_info()._dataInfoKey.get());
            Assert.assertNotNull(convert);
            Assert.assertEquals(xGBoostParameters._ntrees, convert.subgraphArray.size());
            SharedTreeSubgraph sharedTreeSubgraph = (SharedTreeSubgraph) convert.subgraphArray.get(0);
            Assert.assertEquals(xGBoostParameters._max_depth, ((SharedTreeNode) sharedTreeSubgraph.nodesArray.get(sharedTreeSubgraph.nodesArray.size() - 1)).getDepth());
            Scope.exit(new Key[0]);
            if (frame != null) {
                frame.remove();
            }
            if (xGBoostModel != null) {
                xGBoostModel.delete();
            }
        } catch (Throwable th) {
            Scope.exit(new Key[0]);
            if (frame != null) {
                frame.remove();
            }
            if (xGBoostModel != null) {
                xGBoostModel.delete();
            }
            throw th;
        }
    }

    @Test
    public void convertXGBoostTree_airlines() throws Exception {
        Frame frame = null;
        XGBoostModel xGBoostModel = null;
        Scope.enter();
        try {
            frame = parse_test_file("./smalldata/testng/airlines_train.csv");
            Scope.track(frame.replace(frame.find("IsDepDelayed"), frame.vecs()[frame.find("IsDepDelayed")].toCategoricalVec()));
            DKV.put(frame);
            XGBoostModel.XGBoostParameters xGBoostParameters = new XGBoostModel.XGBoostParameters();
            xGBoostParameters._ntrees = 1;
            xGBoostParameters._max_depth = 5;
            xGBoostParameters._ignored_columns = new String[]{"fYear", "fMonth", "fDayofMonth", "fDayOfWeek", "UniqueCarrier", "Dest"};
            xGBoostParameters._train = frame._key;
            xGBoostParameters._response_column = "IsDepDelayed";
            xGBoostModel = (XGBoostModel) new XGBoost(xGBoostParameters).trainModel().get();
            new Predictor(new ByteArrayInputStream(xGBoostModel.model_info()._boosterBytes)).getBooster().getGroupedTrees()[0][0].getNodes();
            SharedTreeGraph convert = xGBoostModel.convert(0, "NO");
            Assert.assertNotNull(convert);
            Assert.assertEquals(xGBoostParameters._ntrees, convert.subgraphArray.size());
            SharedTreeSubgraph sharedTreeSubgraph = (SharedTreeSubgraph) convert.subgraphArray.get(0);
            Assert.assertEquals(xGBoostParameters._max_depth, ((SharedTreeNode) sharedTreeSubgraph.nodesArray.get(sharedTreeSubgraph.nodesArray.size() - 1)).getDepth());
            Scope.exit(new Key[0]);
            if (frame != null) {
                frame.remove();
            }
            if (xGBoostModel != null) {
                xGBoostModel.delete();
            }
        } catch (Throwable th) {
            Scope.exit(new Key[0]);
            if (frame != null) {
                frame.remove();
            }
            if (xGBoostModel != null) {
                xGBoostModel.delete();
            }
            throw th;
        }
    }
}
