package ai.h2o.targetencoding;

import ai.h2o.targetencoding.TargetEncoder;
import ai.h2o.targetencoding.TargetEncoderModel;
import org.apache.commons.lang.ArrayUtils;
import org.junit.Assert;
import org.junit.Before;
import org.junit.Test;
import water.Key;
import water.Scope;
import water.TestUtil;
import water.fvec.Frame;

/* loaded from: input_file:ai/h2o/targetencoding/TargetEncoderModelTest.class */
public class TargetEncoderModelTest extends TestUtil {
    @Before
    public void setUp() {
        TestUtil.stall_till_cloudsize(1);
    }

    @Test
    public void testTargetEncoderModel() {
        try {
            Scope.enter();
            Frame parse_test_file = parse_test_file("./smalldata/testng/airlines_train.csv");
            Scope.track(new Frame[]{parse_test_file});
            Frame parse_test_file2 = parse_test_file("./smalldata/testng/airlines_test.csv");
            Scope.track(new Frame[]{parse_test_file2});
            TargetEncoderModel.TargetEncoderParameters targetEncoderParameters = new TargetEncoderModel.TargetEncoderParameters();
            targetEncoderParameters._data_leakage_handling = TargetEncoder.DataLeakageHandlingStrategy.None;
            targetEncoderParameters._k = 0.3d;
            targetEncoderParameters._f = 0.7d;
            targetEncoderParameters._blending = true;
            targetEncoderParameters._response_column = "IsDepDelayed";
            targetEncoderParameters._ignored_columns = ignoredColumns(parse_test_file, new String[]{"Origin", targetEncoderParameters._response_column});
            targetEncoderParameters._train = parse_test_file._key;
            targetEncoderParameters._seed = 65261L;
            TargetEncoderModel targetEncoderModel = new TargetEncoderBuilder(targetEncoderParameters).trainModel().get();
            Scope.track_generic(targetEncoderModel);
            Assert.assertNotNull(targetEncoderModel);
            Frame score = targetEncoderModel.score(parse_test_file2);
            Scope.track(new Frame[]{score});
            Assert.assertNotNull(score);
            Assert.assertEquals(parse_test_file.numCols() + 1, score.numCols());
            int indexOf = ArrayUtils.indexOf(score.names(), "Origin_te");
            Assert.assertNotEquals(-1L, indexOf);
            Assert.assertTrue(score.vec(indexOf).isNumeric());
            Scope.exit(new Key[0]);
        } catch (Throwable th) {
            Scope.exit(new Key[0]);
            throw th;
        }
    }

    @Test
    public void testTargetEncoderModel_noBlendingParameters() {
        try {
            Scope.enter();
            Frame parse_test_file = parse_test_file("./smalldata/testng/airlines_train.csv");
            Scope.track(new Frame[]{parse_test_file});
            Frame parse_test_file2 = parse_test_file("./smalldata/testng/airlines_test.csv");
            Scope.track(new Frame[]{parse_test_file2});
            TargetEncoderModel.TargetEncoderParameters targetEncoderParameters = new TargetEncoderModel.TargetEncoderParameters();
            targetEncoderParameters._data_leakage_handling = TargetEncoder.DataLeakageHandlingStrategy.None;
            targetEncoderParameters._blending = true;
            targetEncoderParameters._response_column = "IsDepDelayed";
            targetEncoderParameters._ignored_columns = ignoredColumns(parse_test_file, new String[]{"Origin", targetEncoderParameters._response_column});
            targetEncoderParameters._train = parse_test_file._key;
            targetEncoderParameters._seed = 65261L;
            TargetEncoderModel targetEncoderModel = new TargetEncoderBuilder(targetEncoderParameters).trainModel().get();
            Scope.track_generic(targetEncoderModel);
            Assert.assertNotNull(targetEncoderModel);
            Frame score = targetEncoderModel.score(parse_test_file2);
            Scope.track(new Frame[]{score});
            Assert.assertNotNull(score);
            Assert.assertEquals(parse_test_file.numCols() + ((parse_test_file.numCols() - targetEncoderParameters._ignored_columns.length) - 1), score.numCols());
            int indexOf = ArrayUtils.indexOf(score.names(), "Origin_te");
            Assert.assertNotEquals(-1L, indexOf);
            Assert.assertTrue(score.vec(indexOf).isNumeric());
            Scope.exit(new Key[0]);
        } catch (Throwable th) {
            Scope.exit(new Key[0]);
            throw th;
        }
    }

    @Test
    public void testTargetEncoderModel_dropNonCategoricalCols() {
        try {
            Scope.enter();
            Frame parse_test_file = parse_test_file("./smalldata/testng/airlines_train.csv");
            Scope.track(new Frame[]{parse_test_file});
            TargetEncoderModel.TargetEncoderParameters targetEncoderParameters = new TargetEncoderModel.TargetEncoderParameters();
            targetEncoderParameters._data_leakage_handling = TargetEncoder.DataLeakageHandlingStrategy.None;
            targetEncoderParameters._response_column = "IsDepDelayed";
            targetEncoderParameters._ignored_columns = null;
            targetEncoderParameters._train = parse_test_file._key;
            targetEncoderParameters._seed = 65261L;
            TargetEncoderModel targetEncoderModel = new TargetEncoderBuilder(targetEncoderParameters).trainModel().get();
            Scope.track_generic(targetEncoderModel);
            Assert.assertArrayEquals(new String[]{"fYear", "fMonth", "fDayofMonth", "fDayOfWeek", "UniqueCarrier", "Origin", "Dest", "IsDepDelayed"}, targetEncoderModel._output._names);
            Scope.exit(new Key[0]);
        } catch (Throwable th) {
            Scope.exit(new Key[0]);
            throw th;
        }
    }
}
