package ai.h2o.targetencoding.strategy;

import ai.h2o.targetencoding.TargetEncoderModel;
import hex.Model;
import hex.grid.Grid;
import hex.grid.GridSearch;
import hex.grid.HyperSpaceSearchCriteria;
import hex.grid.HyperSpaceWalker;
import java.util.Arrays;
import java.util.HashMap;
import java.util.Objects;
import org.junit.Assert;
import org.junit.BeforeClass;
import org.junit.Test;
import org.junit.experimental.runners.Enclosed;
import org.junit.runner.RunWith;
import org.junit.runners.Parameterized;
import water.Job;
import water.Key;
import water.Scope;
import water.TestUtil;
import water.api.GridSearchHandler;
import water.fvec.Frame;

@RunWith(Enclosed.class)
/* loaded from: input_file:ai/h2o/targetencoding/strategy/TargetEncoderRGSTest.class */
public class TargetEncoderRGSTest {

    /* loaded from: input_file:ai/h2o/targetencoding/strategy/TargetEncoderRGSTest$TargetEncoderRGSNonParametrizedTest.class */
    public static class TargetEncoderRGSNonParametrizedTest extends TestUtil {
        @BeforeClass
        public static void setup() {
            stall_till_cloudsize(1);
        }

        @Test
        public void getTargetEncodingMapByTrainingTEBuilder() {
            Scope.enter();
            try {
                HashMap hashMap = new HashMap();
                hashMap.put("blending", new Boolean[]{true, false});
                hashMap.put("noise_level", new Double[]{Double.valueOf(0.0d), Double.valueOf(0.01d), Double.valueOf(0.1d)});
                hashMap.put("k", new Double[]{Double.valueOf(1.0d), Double.valueOf(2.0d), Double.valueOf(3.0d)});
                hashMap.put("f", new Double[]{Double.valueOf(5.0d), Double.valueOf(10.0d), Double.valueOf(20.0d)});
                HyperSpaceWalker.HyperSpaceIterator it = new HyperSpaceWalker.RandomDiscreteValueWalker(new TargetEncoderModel.TargetEncoderParameters(), hashMap, new GridSearchHandler.DefaultModelParametersBuilderFactory(), new HyperSpaceSearchCriteria.RandomDiscreteValueSearchCriteria()).iterator();
                int i = 0;
                while (it.hasNext((Model) null)) {
                    TargetEncoderModel.TargetEncoderParameters nextModelParameters = it.nextModelParameters((Model) null);
                    System.out.println(nextModelParameters._blending + ":" + nextModelParameters._noise_level + ":" + nextModelParameters._k + ":" + nextModelParameters._f);
                    i++;
                }
                Assert.assertEquals("Unexpected number of grid items", 54L, i);
                Scope.exit(new Key[0]);
            } catch (Throwable th) {
                Scope.exit(new Key[0]);
                throw th;
            }
        }
    }

    @RunWith(Parameterized.class)
    /* loaded from: input_file:ai/h2o/targetencoding/strategy/TargetEncoderRGSTest$TargetEncoderRGSParametrizedTest.class */
    public static class TargetEncoderRGSParametrizedTest extends TestUtil {

        @Parameterized.Parameter
        public int parallelism;

        @BeforeClass
        public static void setup() {
            stall_till_cloudsize(1);
        }

        @Parameterized.Parameters(name = "RGS over TE parameters: parallelism = {0}")
        public static Object[] parallelism() {
            return new Integer[]{1, 2, 4};
        }

        @Test
        public void regularGSOverTEParameters_parallel() {
            Scope.enter();
            try {
                Frame parse_test_file = parse_test_file("./smalldata/gbm_test/titanic.csv");
                Scope.track(new Frame[]{parse_test_file});
                asFactor(parse_test_file, "survived");
                HashMap hashMap = new HashMap();
                hashMap.put("blending", new Boolean[]{true, false});
                hashMap.put("noise_level", new Double[]{Double.valueOf(0.0d), Double.valueOf(0.01d), Double.valueOf(0.1d)});
                hashMap.put("k", new Double[]{Double.valueOf(1.0d), Double.valueOf(2.0d), Double.valueOf(3.0d)});
                hashMap.put("f", new Double[]{Double.valueOf(5.0d), Double.valueOf(10.0d), Double.valueOf(20.0d)});
                TargetEncoderModel.TargetEncoderParameters targetEncoderParameters = new TargetEncoderModel.TargetEncoderParameters();
                targetEncoderParameters._train = parse_test_file._key;
                targetEncoderParameters._response_column = "survived";
                targetEncoderParameters._ignored_columns = ignoredColumns(parse_test_file, new String[]{"home.dest", "embarked", targetEncoderParameters._response_column});
                Job startGridSearch = GridSearch.startGridSearch(Key.make(), new HyperSpaceWalker.RandomDiscreteValueWalker(targetEncoderParameters, hashMap, new GridSearchHandler.DefaultModelParametersBuilderFactory(), new HyperSpaceSearchCriteria.RandomDiscreteValueSearchCriteria()), this.parallelism);
                Scope.track_generic(startGridSearch);
                Grid grid = startGridSearch.get();
                Scope.track_generic(grid);
                Assert.assertEquals(54L, grid.getModelCount());
                Assert.assertTrue(Arrays.stream(grid.getModels()).allMatch((v0) -> {
                    return Objects.nonNull(v0);
                }));
                Scope.exit(new Key[0]);
            } catch (Throwable th) {
                Scope.exit(new Key[0]);
                throw th;
            }
        }
    }
}
