package water.rapids.ast.prims.advmath;

import hex.quantile.QuantileModel;
import java.util.Arrays;
import java.util.HashSet;
import java.util.Set;
import water.Freezable;
import water.H2O;
import water.MRTask;
import water.fvec.Chunk;
import water.fvec.Frame;
import water.fvec.Vec;
import water.rapids.Env;
import water.rapids.Val;
import water.rapids.ast.AstFrame;
import water.rapids.ast.AstPrimitive;
import water.rapids.ast.AstRoot;
import water.rapids.ast.params.AstNum;
import water.rapids.ast.params.AstNumList;
import water.rapids.ast.params.AstStr;
import water.rapids.ast.params.AstStrList;
import water.rapids.ast.prims.mungers.AstGroup;
import water.rapids.ast.prims.reducers.AstMean;
import water.rapids.ast.prims.reducers.AstMedian;
import water.rapids.vals.ValFrame;
import water.rapids.vals.ValNums;
import water.util.ArrayUtils;
import water.util.IcedDouble;
import water.util.IcedHashMap;

/* loaded from: input_file:water/rapids/ast/prims/advmath/AstImpute.class */
public class AstImpute extends AstPrimitive {

    /* loaded from: input_file:water/rapids/ast/prims/advmath/AstImpute$Gather.class */
    private static class Gather extends MRTask<Gather> {
        private final int _imputedCol;
        private final int _ncol;
        private final int[] _byCols0;
        private final int[] _byCols;
        private IcedHashMap<AstGroup.G, Freezable[]> _group_impute_map;
        private transient Set<Integer> _localbyColzSet;

        Gather(int[] iArr, int[] iArr2, int i, int i2) {
            this._byCols = iArr2;
            this._byCols0 = iArr;
            this._ncol = i;
            this._imputedCol = i2;
        }

        @Override // water.MRTask
        public void setupLocal() {
            this._localbyColzSet = new HashSet();
            for (int i : this._byCols0) {
                this._localbyColzSet.add(Integer.valueOf(i));
            }
        }

        @Override // water.MRTask
        public void map(Chunk[] chunkArr) {
            this._group_impute_map = new IcedHashMap<>();
            for (int i = 0; i < chunkArr[0]._len; i++) {
                IcedDouble[] icedDoubleArr = new IcedDouble[this._ncol];
                int i2 = 0;
                int length = this._byCols.length;
                while (i2 < icedDoubleArr.length) {
                    if (this._imputedCol != -1) {
                        icedDoubleArr[i2] = i2 == this._imputedCol ? new IcedDouble(chunkArr[chunkArr.length - 1].atd(i)) : new IcedDouble(Double.NaN);
                    } else {
                        icedDoubleArr[i2] = this._localbyColzSet.contains(Integer.valueOf(i2)) ? new IcedDouble(Double.NaN) : new IcedDouble(chunkArr[length].atd(i));
                    }
                    i2++;
                    length++;
                }
                this._group_impute_map.put(new AstGroup.G(this._byCols.length, null).fill(i, chunkArr, this._byCols), icedDoubleArr);
            }
        }

        @Override // water.MRTask
        public void reduce(Gather gather) {
            this._group_impute_map.putAll(gather._group_impute_map);
        }
    }

    @Override // water.rapids.ast.AstPrimitive
    public String[] args() {
        return new String[]{"ary", "col", "method", "combineMethod", "groupByCols", "groupByFrame", "values"};
    }

    @Override // water.rapids.ast.AstRoot
    public String str() {
        return "h2o.impute";
    }

    @Override // water.rapids.ast.AstRoot
    public int nargs() {
        return 8;
    }

    @Override // water.rapids.ast.AstRoot
    public Val apply(Env env, Env.StackHelp stackHelp, AstRoot[] astRootArr) {
        AstNumList astNumList;
        Frame frame = stackHelp.track(astRootArr[1].exec(env)).getFrame();
        int num = (int) astRootArr[2].exec(env).getNum();
        if (num >= frame.numCols()) {
            throw new IllegalArgumentException("Column not -1 or in range 0 to " + frame.numCols());
        }
        boolean z = num == -1;
        Vec vec = z ? null : frame.vec(num);
        Cloneable cloneable = null;
        boolean z2 = false;
        boolean z3 = false;
        String upperCase = astRootArr[3].exec(env).getStr().toUpperCase();
        boolean z4 = -1;
        switch (upperCase.hashCode()) {
            case -2024701686:
                if (upperCase.equals("MEDIAN")) {
                    z4 = true;
                    break;
                }
                break;
            case 2362309:
                if (upperCase.equals("MEAN")) {
                    z4 = false;
                    break;
                }
                break;
            case 2372003:
                if (upperCase.equals("MODE")) {
                    z4 = 2;
                    break;
                }
                break;
            case 63110341:
                if (upperCase.equals("BFILL")) {
                    z4 = 4;
                    break;
                }
                break;
            case 66804425:
                if (upperCase.equals("FFILL")) {
                    z4 = 3;
                    break;
                }
                break;
        }
        switch (z4) {
            case false:
                cloneable = new AstMean();
                break;
            case true:
                cloneable = new AstMedian();
                break;
            case true:
                cloneable = new AstMode();
                break;
            case true:
                z2 = true;
                break;
            case true:
                z3 = true;
                break;
            default:
                throw new IllegalArgumentException("Method must be one of mean, median or mode");
        }
        QuantileModel.CombineMethod valueOf = QuantileModel.CombineMethod.valueOf(astRootArr[4].exec(env).getStr().toUpperCase());
        AstRoot astRoot = astRootArr[5];
        if (astRoot instanceof AstNumList) {
            astNumList = (AstNumList) astRoot;
        } else if (astRoot instanceof AstNum) {
            astNumList = new AstNumList(((AstNum) astRoot).getNum());
        } else {
            if (!(astRoot instanceof AstStrList)) {
                throw new IllegalArgumentException("Requires a number-list, but found a " + astRoot.getClass());
            }
            double[] dArr = new double[((AstStrList) astRoot)._strs.length];
            int i = 0;
            int length = ((AstStrList) astRoot)._strs.length;
            for (int i2 = 0; i2 < length; i2++) {
                int i3 = i;
                i++;
                dArr[i3] = frame.find(r0[i2]);
            }
            Arrays.sort(dArr);
            astNumList = new AstNumList(dArr);
        }
        Frame frame2 = astRootArr[6].str().equals("_") ? null : stackHelp.track(astRootArr[6].exec(env)).getFrame();
        AstRoot astRoot2 = astRootArr[7];
        AstNumList astNumList2 = astRoot2 instanceof AstNumList ? (AstNumList) astRoot2 : astRoot2 instanceof AstNum ? new AstNumList(((AstNum) astRoot2).getNum()) : null;
        if (!((astNumList.isEmpty() && frame2 == null) ? false : true)) {
            if (z2 || z3) {
                throw H2O.unimpl("No ffill or bfill imputation supported");
            }
            final double[] expand = astNumList2 == null ? new double[frame.numCols()] : astNumList2.expand();
            if (astNumList2 == null) {
                if (z) {
                    for (int i4 = 0; i4 < expand.length; i4++) {
                        if (frame.vec(i4).isNumeric() || frame.vec(i4).isCategorical()) {
                            expand[i4] = frame.vec(i4).isNumeric() ? frame.vec(i4).mean() : ArrayUtils.maxIndex(frame.vec(i4).bins());
                        }
                    }
                } else {
                    Arrays.fill(expand, Double.NaN);
                    if (cloneable instanceof AstMean) {
                        expand[num] = vec.mean();
                    }
                    if (cloneable instanceof AstMedian) {
                        expand[num] = AstMedian.median(new Frame(vec), valueOf);
                    }
                    if (cloneable instanceof AstMode) {
                        expand[num] = AstMode.mode(vec);
                    }
                }
            }
            new MRTask() { // from class: water.rapids.ast.prims.advmath.AstImpute.1
                @Override // water.MRTask
                public void map(Chunk[] chunkArr) {
                    int i5 = chunkArr[0]._len;
                    for (int i6 = 0; i6 < chunkArr.length; i6++) {
                        if (!Double.isNaN(expand[i6])) {
                            for (int i7 = 0; i7 < i5; i7++) {
                                if (chunkArr[i6].isNA(i7)) {
                                    chunkArr[i6].set(i7, expand[i6]);
                                }
                            }
                        }
                    }
                }
            }.doAll(frame);
            return new ValNums(expand);
        }
        if (num >= frame.numCols()) {
            throw new IllegalArgumentException("Column not -1 or in range 0 to " + frame.numCols());
        }
        Frame frame3 = frame2;
        if (frame3 == null) {
            AstRoot astGroup = new AstGroup();
            if (z) {
                AstRoot[] astRootArr2 = new AstRoot[(int) (3 + (3 * (frame.numCols() - astNumList.cnt())))];
                astRootArr2[0] = astGroup;
                astRootArr2[1] = new AstFrame(frame);
                astRootArr2[2] = astNumList;
                int i5 = 3;
                for (int i6 = 0; i6 < frame.numCols(); i6++) {
                    if (!astNumList.has(i6) && (frame.vec(i6).isCategorical() || frame.vec(i6).isNumeric())) {
                        astRootArr2[i5] = frame.vec(i6).isNumeric() ? new AstMean() : new AstMode();
                        astRootArr2[i5 + 1] = new AstNumList(i6, i6 + 1);
                        astRootArr2[i5 + 2] = new AstStr("rm");
                        i5 += 3;
                    }
                }
                frame3 = astGroup.apply(env, stackHelp, astRootArr2).getFrame();
            } else {
                frame3 = astGroup.apply(env, stackHelp, new AstRoot[]{astGroup, new AstFrame(frame), astNumList, cloneable, new AstNumList(num, num + 1), new AstStr("rm")}).getFrame();
            }
        }
        if (astNumList.isEmpty() && frame3.numCols() > 2) {
            throw new IllegalArgumentException("Ambiguous group-by frame. Supply the `by` columns to proceed.");
        }
        final IcedHashMap icedHashMap = new Gather(astNumList.expand4(), ArrayUtils.seq(0, Math.max((int) astNumList.cnt(), 1)), frame.numCols(), num).doAll(frame3)._group_impute_map;
        if (astNumList.isEmpty()) {
            int[] iArr = new int[frame3.numCols() - 1];
            for (int i7 = 0; i7 < iArr.length; i7++) {
                iArr[i7] = frame.find(frame3.name(i7));
            }
            astNumList = new AstNumList(iArr);
        }
        final int[] expand4 = astNumList.expand4();
        new MRTask() { // from class: water.rapids.ast.prims.advmath.AstImpute.2
            /* JADX WARN: Multi-variable type inference failed */
            @Override // water.MRTask
            public void map(Chunk[] chunkArr) {
                HashSet hashSet = new HashSet();
                for (int i8 : expand4) {
                    hashSet.add(Integer.valueOf(i8));
                }
                AstGroup.G g = new AstGroup.G(expand4.length, null);
                for (int i9 = 0; i9 < chunkArr[0]._len; i9++) {
                    for (int i10 = 0; i10 < chunkArr.length; i10++) {
                        if (!hashSet.contains(Integer.valueOf(i10)) && chunkArr[i10].isNA(i9)) {
                            chunkArr[i10].set(i9, ((IcedDouble) ((Freezable[]) icedHashMap.get(g.fill(i9, chunkArr, expand4)))[i10])._val);
                        }
                    }
                }
            }
        }.doAll(frame);
        return new ValFrame(frame3);
    }
}
