package hex.genmodel.algos.tree;

import ai.h2o.a.a.a;
import ai.h2o.a.a.b;
import java.io.Serializable;

/* loaded from: input_file:hex/genmodel/algos/tree/TreeSHAP.class */
public class TreeSHAP<R, N extends a<R>, S extends b> implements TreeSHAPPredictor<R> {

    /* renamed from: a, reason: collision with root package name */
    private final int f1039a = 0;

    /* renamed from: b, reason: collision with root package name */
    private final N[] f1040b;

    /* renamed from: c, reason: collision with root package name */
    private final S[] f1041c;
    private final float d;

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:hex/genmodel/algos/tree/TreeSHAP$PathElement.class */
    public static class PathElement implements Serializable {

        /* renamed from: a, reason: collision with root package name */
        int f1042a;

        /* renamed from: b, reason: collision with root package name */
        float f1043b;

        /* renamed from: c, reason: collision with root package name */
        float f1044c;
        float d;

        private PathElement() {
        }

        /* synthetic */ PathElement(byte b2) {
            this();
        }
    }

    /* loaded from: input_file:hex/genmodel/algos/tree/TreeSHAP$PathPointer.class */
    public static class PathPointer {

        /* renamed from: a, reason: collision with root package name */
        PathElement[] f1045a;

        /* renamed from: b, reason: collision with root package name */
        int f1046b;

        PathPointer(PathElement[] pathElementArr) {
            this.f1045a = pathElementArr;
        }

        PathPointer(PathElement[] pathElementArr, int i) {
            this.f1045a = pathElementArr;
            this.f1046b = i;
        }

        final PathElement a(int i) {
            return this.f1045a[this.f1046b + i];
        }
    }

    public TreeSHAP(N[] nArr, S[] sArr, int i) {
        this.f1040b = nArr;
        this.f1041c = sArr;
        this.d = a(this.f1040b, this.f1041c, 0);
    }

    private static void a(PathPointer pathPointer, int i, int i2) {
        float f = pathPointer.a(i2).f1044c;
        float f2 = pathPointer.a(i2).f1043b;
        float f3 = pathPointer.a(i).d;
        for (int i3 = i - 1; i3 >= 0; i3--) {
            if (f != 0.0f) {
                float f4 = pathPointer.a(i3).d;
                pathPointer.a(i3).d = (f3 * (i + 1)) / ((i3 + 1) * f);
                f3 = f4 - (((pathPointer.a(i3).d * f2) * (i - i3)) / (i + 1));
            } else {
                pathPointer.a(i3).d = (pathPointer.a(i3).d * (i + 1)) / (f2 * (i - i3));
            }
        }
        for (int i4 = i2; i4 < i; i4++) {
            pathPointer.a(i4).f1042a = pathPointer.a(i4 + 1).f1042a;
            pathPointer.a(i4).f1043b = pathPointer.a(i4 + 1).f1043b;
            pathPointer.a(i4).f1044c = pathPointer.a(i4 + 1).f1044c;
        }
    }

    private void a(R r, float[] fArr, N n, S s, int i, PathPointer pathPointer, float f, float f2, int i2, int i3, int i4, float f3) {
        while (f3 != 0.0f) {
            PathPointer pathPointer2 = pathPointer;
            int i5 = i;
            for (int i6 = 0; i6 < i5; i6++) {
                pathPointer2.f1045a[pathPointer2.f1046b + i5 + i6].f1042a = pathPointer2.f1045a[pathPointer2.f1046b + i6].f1042a;
                pathPointer2.f1045a[pathPointer2.f1046b + i5 + i6].f1043b = pathPointer2.f1045a[pathPointer2.f1046b + i6].f1043b;
                pathPointer2.f1045a[pathPointer2.f1046b + i5 + i6].f1044c = pathPointer2.f1045a[pathPointer2.f1046b + i6].f1044c;
                pathPointer2.f1045a[pathPointer2.f1046b + i5 + i6].d = pathPointer2.f1045a[pathPointer2.f1046b + i6].d;
            }
            PathPointer pathPointer3 = new PathPointer(pathPointer2.f1045a, pathPointer2.f1046b + i5);
            if (i3 == 0 || i4 != i2) {
                int i7 = i;
                float f4 = f;
                float f5 = f2;
                pathPointer3.a(i7).f1042a = i2;
                pathPointer3.a(i7).f1043b = f4;
                pathPointer3.a(i7).f1044c = f5;
                pathPointer3.a(i7).d = i7 == 0 ? 1.0f : 0.0f;
                for (int i8 = i7 - 1; i8 >= 0; i8--) {
                    pathPointer3.a(i8 + 1).d += ((f5 * pathPointer3.a(i8).d) * (i8 + 1)) / (i7 + 1);
                    pathPointer3.a(i8).d = ((f4 * pathPointer3.a(i8).d) * (i7 - i8)) / (i7 + 1);
                }
            }
            int c2 = n.c();
            if (n.a()) {
                for (int i9 = 1; i9 <= i; i9++) {
                    int i10 = i;
                    int i11 = i9;
                    float f6 = pathPointer3.a(i11).f1044c;
                    float f7 = pathPointer3.a(i11).f1043b;
                    float f8 = pathPointer3.a(i10).d;
                    float f9 = 0.0f;
                    for (int i12 = i10 - 1; i12 >= 0; i12--) {
                        if (f6 != 0.0f) {
                            float f10 = (f8 * (i10 + 1)) / ((i12 + 1) * f6);
                            f9 += f10;
                            f8 = pathPointer3.a(i12).d - ((f10 * f7) * ((i10 - i12) / (i10 + 1)));
                        } else if (f7 != 0.0f) {
                            f9 += (pathPointer3.a(i12).d / f7) / ((i10 - i12) / (i10 + 1));
                        } else if (pathPointer3.a(i12).d != 0.0f) {
                            throw new IllegalStateException("Unique path " + i12 + " must have zero getWeight");
                        }
                    }
                    PathElement a2 = pathPointer3.a(i9);
                    float[] fArr2 = fArr;
                    int i13 = a2.f1042a;
                    fArr2[i13] = fArr2[i13] + (f9 * (a2.f1044c - a2.f1043b) * n.b() * f3);
                }
                return;
            }
            int a3 = n.a(r);
            int e = a3 == n.d() ? n.e() : n.d();
            float f11 = s.f();
            float f12 = this.f1041c[a3].f() / f11;
            float f13 = this.f1041c[e].f() / f11;
            float f14 = 1.0f;
            float f15 = 1.0f;
            int i14 = 0;
            while (i14 <= i && pathPointer3.a(i14).f1042a != c2) {
                i14++;
            }
            if (i14 != i + 1) {
                f14 = pathPointer3.a(i14).f1043b;
                f15 = pathPointer3.a(i14).f1044c;
                a(pathPointer3, i, i14);
                i--;
            }
            float f16 = f3;
            float f17 = f3;
            if (i3 > 0 && c2 == i4) {
                f17 = 0.0f;
                i--;
            } else if (i3 < 0 && c2 == i4) {
                f16 *= f12;
                f17 *= f13;
                i--;
            }
            this.a(r, fArr, this.f1040b[a3], this.f1041c[a3], i + 1, pathPointer3, f12 * f14, f15, c2, i3, i4, f16);
            N n2 = this.f1040b[e];
            f3 = f17;
            i4 = i4;
            i3 = i3;
            i2 = c2;
            f2 = 0.0f;
            f = f13 * f14;
            pathPointer = pathPointer3;
            i++;
            s = this.f1041c[e];
            n = n2;
            fArr = fArr;
            r = r;
            this = this;
        }
    }

    @Override // hex.genmodel.algos.tree.TreeSHAPPredictor
    public final float[] a(R r, float[] fArr, int i, int i2, Object obj) {
        if (i == 0) {
            int length = fArr.length - 1;
            fArr[length] = fArr[length] + this.d;
        }
        PathPointer pathPointer = (PathPointer) obj;
        PathElement pathElement = pathPointer.f1045a[0];
        pathElement.f1042a = 0;
        pathElement.f1043b = 0.0f;
        pathElement.f1044c = 0.0f;
        pathElement.d = 0.0f;
        a(r, fArr, this.f1040b[this.f1039a], this.f1041c[this.f1039a], 0, pathPointer, 1.0f, 1.0f, -1, i, i2, 1.0f);
        return fArr;
    }

    @Override // hex.genmodel.algos.tree.TreeSHAPPredictor
    public final int a() {
        int c2 = c() + 2;
        return (c2 * (c2 + 1)) / 2;
    }

    private int c() {
        return a(this.f1040b, 0);
    }

    private static <N extends a> int a(N[] nArr, int i) {
        N n = nArr[i];
        if (n.a()) {
            return 1;
        }
        return 1 + Math.max(a(nArr, n.d()), a(nArr, n.e()));
    }

    private static <N extends a, S extends b> float a(N[] nArr, S[] sArr, int i) {
        N n = nArr[i];
        return n.a() ? n.b() : ((sArr[n.d()].f() * a(nArr, sArr, n.d())) + (sArr[n.e()].f() * a(nArr, sArr, n.e()))) / sArr[i].f();
    }

    @Override // hex.genmodel.algos.tree.TreeSHAPPredictor
    public final /* synthetic */ Object b() {
        PathElement[] pathElementArr = new PathElement[a()];
        for (int i = 0; i < pathElementArr.length; i++) {
            pathElementArr[i] = new PathElement((byte) 0);
        }
        return new PathPointer(pathElementArr);
    }
}
