package hex.genmodel.algos.targetencoder;

import hex.genmodel.MojoModel;
import java.util.ArrayList;
import java.util.Collections;
import java.util.Comparator;
import java.util.HashMap;
import java.util.Iterator;
import java.util.LinkedHashMap;
import java.util.Map;

/* loaded from: input_file:hex/genmodel/algos/targetencoder/TargetEncoderMojoModel.class */
public class TargetEncoderMojoModel extends MojoModel {
    public EncodingMaps r;
    public Map<String, Integer> s;
    public Map<String, Integer> t;
    public boolean u;
    public double v;
    public double w;
    public double x;
    private final boolean y = true;

    /* loaded from: input_file:hex/genmodel/algos/targetencoder/TargetEncoderMojoModel$SortByKeyAssociatedIndex.class */
    public static class SortByKeyAssociatedIndex<K extends String, V> implements Comparator<Map.Entry<K, V>> {

        /* renamed from: a, reason: collision with root package name */
        private Map<String, Integer> f1012a;

        public SortByKeyAssociatedIndex(Map<String, Integer> map) {
            this.f1012a = map;
        }

        @Override // java.util.Comparator
        public /* synthetic */ int compare(Object obj, Object obj2) {
            return this.f1012a.get((String) ((Map.Entry) obj).getKey()).compareTo(this.f1012a.get((String) ((Map.Entry) obj2).getKey()));
        }
    }

    public TargetEncoderMojoModel(String[] strArr, String[][] strArr2, String str) {
        super(strArr, strArr2, str);
        this.s = new HashMap();
        this.y = true;
        this.s = new HashMap(strArr.length);
        for (int i = 0; i < strArr.length - 1; i++) {
            this.s.put(strArr[i], Integer.valueOf(i));
        }
    }

    private static double a(int i, double d, double d2) {
        return 1.0d / (1.0d + Math.exp((d - i) / d2));
    }

    private static double a(double d, double d2, double d3) {
        return (d * d2) + ((1.0d - d) * d3);
    }

    @Override // hex.genmodel.GenModel
    public final double[] a(double[] dArr, double[] dArr2) {
        if (this.r == null) {
            throw new IllegalStateException("Encoding map is missing.");
        }
        int i = 0;
        ArrayList arrayList = new ArrayList(this.r.a().entrySet());
        Collections.sort(arrayList, new SortByKeyAssociatedIndex(this.s));
        LinkedHashMap linkedHashMap = new LinkedHashMap();
        Iterator it = arrayList.iterator();
        while (it.hasNext()) {
            Map.Entry entry = (Map.Entry) it.next();
            linkedHashMap.put(entry.getKey(), entry.getValue());
        }
        for (Map.Entry entry2 : linkedHashMap.entrySet()) {
            EncodingMap encodingMap = (EncodingMap) entry2.getValue();
            String str = (String) entry2.getKey();
            double d = dArr[this.s.get(str).intValue()];
            if (!Double.isNaN(d)) {
                a(dArr2, i, encodingMap, (int) d);
            } else if (this.t.get(str).intValue() == 1) {
                a(dArr2, i, encodingMap, encodingMap.f1010a.size() - 1);
            } else {
                dArr2[i] = this.x;
            }
            i++;
        }
        return dArr2;
    }

    private void a(double[] dArr, int i, EncodingMap encodingMap, int i2) {
        int[] iArr = encodingMap.f1010a.get(Integer.valueOf(i2));
        double d = iArr[0] / iArr[1];
        if (!this.u) {
            dArr[i] = d;
            return;
        }
        int i3 = iArr[1];
        double d2 = this.v;
        double d3 = this.w;
        a(i3, d2, d3);
        dArr[i] = a(d3, d, this.x);
    }
}
