package ai.h2o.targetencoding;

import ai.h2o.targetencoding.BroadcastJoinForTargetEncoder;
import com.pholser.junit.quickcheck.Mode;
import com.pholser.junit.quickcheck.Property;
import com.pholser.junit.quickcheck.generator.InRange;
import com.pholser.junit.quickcheck.runner.JUnitQuickcheck;
import java.util.Random;
import org.junit.After;
import org.junit.Assert;
import org.junit.Assume;
import org.junit.BeforeClass;
import org.junit.Test;
import org.junit.runner.RunWith;
import water.Key;
import water.Scope;
import water.TestUtil;
import water.fvec.CategoricalWrappedVec;
import water.fvec.Frame;
import water.fvec.TestFrameBuilder;
import water.fvec.Vec;
import water.rapids.Merge;
import water.util.IcedHashMapGeneric;

@RunWith(JUnitQuickcheck.class)
/* loaded from: input_file:ai/h2o/targetencoding/BroadcastJoinTest.class */
public class BroadcastJoinTest extends TestUtil {
    private Frame fr = null;

    @BeforeClass
    public static void setup() {
        stall_till_cloudsize(1);
    }

    @Property(trials = 2, mode = Mode.EXHAUSTIVE)
    public void joinPerformsWithoutLoosingOriginalOrderTest(boolean z) {
        Scope.enter();
        try {
            Frame build = new TestFrameBuilder().withName("testFrame").withColNames(new String[]{"ColA", "fold"}).withVecTypes(new byte[]{4, 3}).withDataForCol(0, ar(new String[]{"a", "c", "b"})).withDataForCol(1, z ? new long[]{1, 0, 1} : new long[]{2, 1, 2}).withChunkLayout(new long[]{1, 1, 1}).build();
            Frame build2 = new TestFrameBuilder().withName("testFrame2").withColNames(new String[]{"ColA", "fold", TargetEncoder.NUMERATOR_COL_NAME, TargetEncoder.DENOMINATOR_COL_NAME}).withVecTypes(new byte[]{4, 3, 3, 3}).withDataForCol(0, ar(new String[]{"a", "b", "c"})).withDataForCol(1, z ? new long[]{1, 0, 0} : new long[]{2, 1, 1}).withDataForCol(2, ar(new long[]{22, 33, 42})).withDataForCol(3, ar(new long[]{44, 66, 84})).withChunkLayout(new long[]{1, 1, 1}).build();
            Vec makeCon = build.anyVec().makeCon(0.0d);
            build.add(TargetEncoder.NUMERATOR_COL_NAME, makeCon);
            Vec makeCon2 = build.anyVec().makeCon(0.0d);
            build.add(TargetEncoder.DENOMINATOR_COL_NAME, makeCon2);
            Scope.track(makeCon);
            Scope.track(makeCon2);
            Frame join = BroadcastJoinForTargetEncoder.join(build, new int[]{0}, 1, build2, new int[]{0}, 1, 2);
            assertStringVecEquals(cvec(new String[]{"a", "c", "b"}), join.vec("ColA"));
            Assert.assertEquals(22.0d, join.vec(TargetEncoder.NUMERATOR_COL_NAME).at(0L), 1.0E-5d);
            Assert.assertEquals(44.0d, join.vec(TargetEncoder.DENOMINATOR_COL_NAME).at(0L), 1.0E-5d);
            Assert.assertEquals(42.0d, join.vec(TargetEncoder.NUMERATOR_COL_NAME).at(1L), 1.0E-5d);
            Assert.assertEquals(84.0d, join.vec(TargetEncoder.DENOMINATOR_COL_NAME).at(1L), 1.0E-5d);
            Assert.assertTrue(join.vec(TargetEncoder.NUMERATOR_COL_NAME).isNA(2L));
            Assert.assertTrue(join.vec(TargetEncoder.DENOMINATOR_COL_NAME).isNA(2L));
            Scope.exit(new Key[0]);
        } catch (Throwable th) {
            Scope.exit(new Key[0]);
            throw th;
        }
    }

    @Property(trials = 5)
    public void joinWorksWithoutLoosingOriginalOrderTest(@InRange(minInt = 2, maxInt = 10000) int i, @InRange(minInt = 1, maxInt = 1000) int i2) {
        Scope.enter();
        IcedHashMapGeneric icedHashMapGeneric = null;
        try {
            Frame build = new TestFrameBuilder().withName("testFrame").withColNames(new String[]{"ColA", "response"}).withVecTypes(new byte[]{4, 4}).withDataForCol(0, randomArrOfStrings(i)).withRandomBinaryDataForCol(1, i, 1234L).withChunkLayout(new long[]{i / 2, i - (i / 2)}).build();
            Assume.assumeTrue(build.vec("response").cardinality() == 2);
            TargetEncoderFrameHelper.addKFoldColumn(build, "fold", i2, 1234L);
            Assume.assumeTrue(build.vec("fold").clone().toCategoricalVec().cardinality() == i2);
            icedHashMapGeneric = new TargetEncoder(new String[]{"ColA"}).prepareEncodingMap(build, "response", "fold");
            Frame frame = (Frame) icedHashMapGeneric.get("ColA");
            Vec makeCon = build.anyVec().makeCon(0.0d);
            build.add(TargetEncoder.NUMERATOR_COL_NAME, makeCon);
            Vec makeCon2 = build.anyVec().makeCon(0.0d);
            build.add(TargetEncoder.DENOMINATOR_COL_NAME, makeCon2);
            Scope.track(makeCon);
            Scope.track(makeCon2);
            Frame join = BroadcastJoinForTargetEncoder.join(build, new int[]{0}, build.find("fold"), frame, new int[]{0}, frame.find("fold"), i2);
            Scope.track(new Frame[]{join});
            assertStringVecEquals(build.vec("ColA"), join.vec("ColA"));
            int nextInt = new Random(1234L).nextInt(i);
            double at = build.vec("ColA").at(nextInt);
            double at2 = build.vec("fold").at(nextInt);
            Frame filterByValue = TargetEncoderFrameHelper.filterByValue(build, 0, at);
            Frame filterByValue2 = TargetEncoderFrameHelper.filterByValue(filterByValue, filterByValue.find("fold"), at2);
            Frame filterByValue3 = TargetEncoderFrameHelper.filterByValue(frame, 0, at);
            Frame filterByValue4 = TargetEncoderFrameHelper.filterByValue(filterByValue3, filterByValue3.find("fold"), at2);
            Scope.track(new Frame[]{filterByValue, filterByValue2, filterByValue3, filterByValue4});
            Assert.assertEquals(filterByValue2.vec(TargetEncoder.NUMERATOR_COL_NAME).at(0L), filterByValue4.vec(TargetEncoder.NUMERATOR_COL_NAME).at(0L), 1.0E-5d);
            Assert.assertEquals(filterByValue2.vec(TargetEncoder.DENOMINATOR_COL_NAME).at(0L), filterByValue4.vec(TargetEncoder.DENOMINATOR_COL_NAME).at(0L), 1.0E-5d);
            if (icedHashMapGeneric != null) {
                TargetEncoderFrameHelper.encodingMapCleanUp(icedHashMapGeneric);
            }
            Scope.exit(new Key[0]);
        } catch (Throwable th) {
            if (icedHashMapGeneric != null) {
                TargetEncoderFrameHelper.encodingMapCleanUp(icedHashMapGeneric);
            }
            Scope.exit(new Key[0]);
            throw th;
        }
    }

    /* JADX WARN: Type inference failed for: r0v21, types: [int[], int[][]] */
    @Test(expected = AssertionError.class)
    public void mergeWillUseRightFramesOrderAndGroupByValues() {
        Scope.enter();
        Frame frame = null;
        try {
            Frame build = new TestFrameBuilder().withName("leftFrame").withColNames(new String[]{"ColA", "ColB"}).withVecTypes(new byte[]{4, 3}).withDataForCol(0, ar(new String[]{"a", "b", "c", "e", "a"})).withDataForCol(1, ard(new double[]{-1.0d, 2.0d, 3.0d, 4.0d, 7.0d})).build();
            Frame build2 = new TestFrameBuilder().withName("holdoutEncodingMap").withColNames(new String[]{"ColB", "ColC"}).withVecTypes(new byte[]{4, 3}).withDataForCol(0, ar(new String[]{"c", "a", "e", "b"})).withDataForCol(1, ard(new double[]{2.0d, 3.0d, 4.0d, 5.0d})).build();
            frame = Merge.merge(build2, build, new int[]{0}, new int[]{0}, false, (int[][]) new int[]{CategoricalWrappedVec.computeMap(build2.vec(0).domain(), build.vec(0).domain())});
            assertStringVecEquals(cvec(new String[]{"a", "b", "c", "e", "a"}), frame.vec("ColB"));
            frame.delete();
            Scope.exit(new Key[0]);
        } catch (Throwable th) {
            frame.delete();
            Scope.exit(new Key[0]);
            throw th;
        }
    }

    @Test(expected = AssertionError.class)
    public void foldValuesThatAreBiggerThanIntegerWillCauseExceptionTest() {
        this.fr = new TestFrameBuilder().withName("testFrame").withColNames(new String[]{"ColA", "fold", TargetEncoder.NUMERATOR_COL_NAME, TargetEncoder.DENOMINATOR_COL_NAME}).withVecTypes(new byte[]{4, 3, 3, 3}).withDataForCol(0, ar(new String[]{"a", "b", "c"})).withDataForCol(1, ar(new long[]{-2147483648L, 33, 42})).withDataForCol(2, ar(new long[]{44, 66, 84})).withDataForCol(3, ar(new long[]{88, 132, 168})).withChunkLayout(new long[]{2, 1}).build();
        new BroadcastJoinForTargetEncoder.FrameWithEncodingDataToArray(0, 1, 2, 3, this.fr.vec("ColA").cardinality(), (int) (-2147483648L)).doAll(this.fr).getEncodingDataArray();
    }

    @Property(trials = 100)
    public void foldValuesThatAreInRangeWouldNotCauseExceptionTest(@InRange(minInt = 1, maxInt = 1000) int i) {
        this.fr = new TestFrameBuilder().withName("testFrame").withColNames(new String[]{"ColA", "fold", TargetEncoder.NUMERATOR_COL_NAME, TargetEncoder.DENOMINATOR_COL_NAME}).withVecTypes(new byte[]{4, 3, 3, 3}).withDataForCol(0, ar(new String[]{"a", "b", "c"})).withDataForCol(1, ar(new long[]{0, 1, 2})).withDataForCol(2, ar(new long[]{i, 66, 84})).withDataForCol(3, ar(new long[]{88, 132, i})).withChunkLayout(new long[]{2, 1}).build();
        new BroadcastJoinForTargetEncoder.FrameWithEncodingDataToArray(0, 1, 2, 3, this.fr.vec("ColA").cardinality(), Math.max(i, 42)).doAll(this.fr).getEncodingDataArray();
    }

    @Test
    public void joinWithoutFoldColumnTest() {
        Frame frame = null;
        try {
            this.fr = new TestFrameBuilder().withName("testFrame").withColNames(new String[]{"ColA"}).withVecTypes(new byte[]{4}).withDataForCol(0, ar(new String[]{"a", "c", "b"})).build();
            frame = new TestFrameBuilder().withName("testFrame2").withColNames(new String[]{"ColA", TargetEncoder.NUMERATOR_COL_NAME, TargetEncoder.DENOMINATOR_COL_NAME}).withVecTypes(new byte[]{4, 3, 3}).withDataForCol(0, ar(new String[]{"a", "b", "c"})).withDataForCol(1, ar(new long[]{22, 33, 42})).withDataForCol(2, ar(new long[]{44, 66, 84})).withChunkLayout(new long[]{1, 1, 1}).build();
            this.fr.add(TargetEncoder.NUMERATOR_COL_NAME, Vec.makeZero(this.fr.numRows()));
            this.fr.add(TargetEncoder.DENOMINATOR_COL_NAME, Vec.makeZero(this.fr.numRows()));
            Frame join = BroadcastJoinForTargetEncoder.join(this.fr, new int[]{0}, -1, frame, new int[]{0}, -1, 0);
            Scope.enter();
            assertStringVecEquals(cvec(new String[]{"a", "c", "b"}), join.vec("ColA"));
            assertVecEquals(vec(new int[]{22, 42, 33}), join.vec(TargetEncoder.NUMERATOR_COL_NAME), 1.0E-5d);
            assertVecEquals(vec(new int[]{44, 84, 66}), join.vec(TargetEncoder.DENOMINATOR_COL_NAME), 1.0E-5d);
            Scope.exit(new Key[0]);
            if (frame != null) {
                frame.delete();
            }
        } catch (Throwable th) {
            if (frame != null) {
                frame.delete();
            }
            throw th;
        }
    }

    private String[] randomArrOfStrings(int i) {
        String[] strArr = new String[i];
        Random random = new Random();
        int i2 = i / 2;
        for (int i3 = 0; i3 < i; i3++) {
            strArr[i3] = Integer.toString(random.nextInt(Math.max(1, i2)));
        }
        return strArr;
    }

    @After
    public void afterEach() {
        if (this.fr != null) {
            this.fr.delete();
        }
    }
}
