/*
 * Decompiled with CFR 0.152.
 */
package smile.feature.imputation;

import java.lang.reflect.Array;
import java.util.HashMap;
import java.util.Map;
import java.util.stream.Collectors;
import java.util.stream.IntStream;
import smile.data.AbstractTuple;
import smile.data.DataFrame;
import smile.data.Tuple;
import smile.data.measure.NominalScale;
import smile.data.transform.Transform;
import smile.data.type.DataType;
import smile.data.type.DataTypes;
import smile.data.type.StructField;
import smile.data.type.StructType;
import smile.data.vector.BaseVector;
import smile.data.vector.BooleanVector;
import smile.data.vector.ByteVector;
import smile.data.vector.CharVector;
import smile.data.vector.DoubleVector;
import smile.data.vector.FloatVector;
import smile.data.vector.IntVector;
import smile.data.vector.LongVector;
import smile.data.vector.ShortVector;
import smile.data.vector.Vector;
import smile.math.MathEx;
import smile.sort.IQAgent;

public class SimpleImputer
implements Transform {
    private final Map<String, Object> values;

    public SimpleImputer(Map<String, Object> values) {
        this.values = values;
    }

    static boolean isMissing(Object x) {
        if (x == null) {
            return true;
        }
        if (x instanceof Number) {
            Number n = (Number)x;
            return Double.isNaN(n.doubleValue());
        }
        return false;
    }

    public static boolean hasMissing(Tuple x) {
        int n = x.length();
        for (int i = 0; i < n; ++i) {
            if (!SimpleImputer.isMissing(x.get(i))) continue;
            return true;
        }
        return false;
    }

    public Tuple apply(final Tuple x) {
        final StructType schema = x.schema();
        return new AbstractTuple(){

            public Object get(int i) {
                Object xi = x.get(i);
                return SimpleImputer.isMissing(xi) ? SimpleImputer.this.values.get(schema.field((int)i).name) : xi;
            }

            public StructType schema() {
                return schema;
            }
        };
    }

    public DataFrame apply(DataFrame data) {
        int n = data.nrow();
        StructType schema = data.schema();
        BaseVector[] vectors = new BaseVector[schema.length()];
        IntStream.range(0, schema.length()).parallel().forEach(j -> {
            StructField field = schema.field(j);
            Object value = this.values.get(field.name);
            if (value != null) {
                if (field.type.id() == DataType.ID.Double) {
                    double x = ((Number)value).doubleValue();
                    double[] column = data.doubleVector(j).array();
                    double[] vector = new double[n];
                    for (int i = 0; i < n; ++i) {
                        vector[i] = Double.isNaN(column[i]) ? x : column[i];
                    }
                    vectors[j] = DoubleVector.of((StructField)field, (double[])vector);
                } else if (field.type.id() == DataType.ID.Float) {
                    float x = ((Number)value).floatValue();
                    float[] column = data.floatVector(j).array();
                    float[] vector = new float[n];
                    for (int i = 0; i < n; ++i) {
                        vector[i] = Float.isNaN(column[i]) ? x : column[i];
                    }
                    vectors[j] = FloatVector.of((StructField)field, (float[])vector);
                } else if (field.type.isObject()) {
                    if (field.type == DataTypes.BooleanObjectType) {
                        boolean x = (Boolean)value;
                        boolean[] vector = new boolean[n];
                        for (int i = 0; i < n; ++i) {
                            Boolean cell = (Boolean)data.get(i, j);
                            vector[i] = cell == null ? x : cell;
                        }
                        vectors[j] = BooleanVector.of((StructField)field, (boolean[])vector);
                    } else if (field.type == DataTypes.ByteObjectType) {
                        byte x = ((Number)value).byteValue();
                        byte[] vector = new byte[n];
                        for (int i = 0; i < n; ++i) {
                            Byte cell = (Byte)data.get(i, j);
                            vector[i] = cell == null ? x : cell;
                        }
                        vectors[j] = ByteVector.of((StructField)field, (byte[])vector);
                    } else if (field.type == DataTypes.CharObjectType) {
                        char x = ((Character)value).charValue();
                        char[] vector = new char[n];
                        for (int i = 0; i < n; ++i) {
                            Character cell = (Character)data.get(i, j);
                            vector[i] = cell == null ? x : cell.charValue();
                        }
                        vectors[j] = CharVector.of((StructField)field, (char[])vector);
                    } else if (field.type == DataTypes.DoubleObjectType) {
                        double x = ((Number)value).doubleValue();
                        double[] vector = new double[n];
                        for (int i = 0; i < n; ++i) {
                            Double cell = (Double)data.get(i, j);
                            vector[i] = cell == null || cell.isNaN() ? x : cell;
                        }
                        vectors[j] = DoubleVector.of((StructField)field, (double[])vector);
                    } else if (field.type == DataTypes.FloatObjectType) {
                        float x = ((Number)value).floatValue();
                        float[] vector = new float[n];
                        for (int i = 0; i < n; ++i) {
                            Float cell = (Float)data.get(i, j);
                            vector[i] = cell == null || cell.isNaN() ? x : cell.floatValue();
                        }
                        vectors[j] = FloatVector.of((StructField)field, (float[])vector);
                    } else if (field.type == DataTypes.IntegerObjectType) {
                        int x = ((Number)value).intValue();
                        int[] vector = new int[n];
                        for (int i = 0; i < n; ++i) {
                            Integer cell = (Integer)data.get(i, j);
                            vector[i] = cell == null ? x : cell;
                        }
                        vectors[j] = IntVector.of((StructField)field, (int[])vector);
                    } else if (field.type == DataTypes.LongObjectType) {
                        long x = ((Number)value).longValue();
                        long[] vector = new long[n];
                        for (int i = 0; i < n; ++i) {
                            Long cell = (Long)data.get(i, j);
                            vector[i] = cell == null ? x : cell;
                        }
                        vectors[j] = LongVector.of((StructField)field, (long[])vector);
                    } else if (field.type == DataTypes.ShortObjectType) {
                        short x = ((Number)value).shortValue();
                        short[] vector = new short[n];
                        for (int i = 0; i < n; ++i) {
                            Short cell = (Short)data.get(i, j);
                            vector[i] = cell == null ? x : cell;
                        }
                        vectors[j] = ShortVector.of((StructField)field, (short[])vector);
                    } else {
                        Object[] vector = (Object[])Array.newInstance(value.getClass(), n);
                        for (int i = 0; i < n; ++i) {
                            Object cell = data.get(i, j);
                            vector[i] = cell == null ? value : cell;
                        }
                        vectors[j] = Vector.of((StructField)field, (Object[])vector);
                    }
                }
            }
            if (vectors[j] == null) {
                vectors[j] = data.column(j);
            }
        });
        return DataFrame.of((BaseVector[])vectors);
    }

    public String toString() {
        return this.values.keySet().stream().map(key -> key + " -> " + String.valueOf(this.values.get(key))).collect(Collectors.joining(",\n  ", "SimpleImputer(\n  ", "\n)"));
    }

    public static SimpleImputer fit(DataFrame data, String ... columns) {
        return SimpleImputer.fit(data, 0.5, 0.5, columns);
    }

    public static SimpleImputer fit(DataFrame data, double lower, double upper, String ... columns) {
        if (data.isEmpty()) {
            throw new IllegalArgumentException("Empty data frame");
        }
        if (lower < 0.0) {
            throw new IllegalArgumentException("Invalid lower: " + lower);
        }
        if (upper > 1.0) {
            throw new IllegalArgumentException("Invalid upper: " + upper);
        }
        if (lower > upper) {
            throw new IllegalArgumentException(String.format("Invalid lower=%f > upper=%f", lower, upper));
        }
        StructType schema = data.schema();
        if (columns.length == 0) {
            columns = data.names();
        }
        HashMap<String, Object> values = new HashMap<String, Object>();
        for (String column : columns) {
            Object[] vector;
            StructField field = schema.field(column);
            if (field.type.isString()) {
                values.put(field.name, "");
                continue;
            }
            if (field.type.isBoolean()) {
                vector = MathEx.omit((int[])data.column(column).toIntArray(), (int)Integer.MIN_VALUE);
                int mode = MathEx.mode((int[])vector);
                values.put(field.name, mode != 0);
                continue;
            }
            if (field.type.isChar()) {
                vector = MathEx.omit((int[])data.column(column).toIntArray(), (int)Integer.MIN_VALUE);
                int mode = MathEx.mode((int[])vector);
                values.put(field.name, Character.valueOf((char)mode));
                continue;
            }
            if (field.measure instanceof NominalScale) {
                vector = MathEx.omit((int[])data.column(column).toIntArray(), (int)Integer.MIN_VALUE);
                int mode = MathEx.mode((int[])vector);
                values.put(field.name, mode);
                continue;
            }
            if (!field.type.isNumeric()) continue;
            vector = MathEx.omitNaN((double[])data.column(column).toDoubleArray());
            IQAgent agent = new IQAgent();
            for (int xi : vector) {
                agent.add((double)xi);
            }
            if (lower == upper) {
                values.put(field.name, agent.quantile(lower));
                continue;
            }
            double d = agent.quantile(lower);
            double hi = agent.quantile(upper);
            int n = 0;
            double sum = 0.0;
            for (int xi : vector) {
                if (!(xi >= d) || !(xi <= hi)) continue;
                ++n;
                sum += xi;
            }
            values.put(field.name, sum / (double)n);
        }
        return new SimpleImputer(values);
    }

    public static double[][] impute(double[][] data) {
        double[][] full;
        int i;
        int d = data[0].length;
        int[] count = new int[d];
        for (i = 0; i < data.length; ++i) {
            int missing = 0;
            for (int j = 0; j < d; ++j) {
                if (!Double.isNaN(data[i][j])) continue;
                ++missing;
                int n = j;
                count[n] = count[n] + 1;
            }
            if (missing != d) continue;
            throw new IllegalArgumentException("The whole row " + i + " is missing");
        }
        for (i = 0; i < d; ++i) {
            if (count[i] != data.length) continue;
            throw new IllegalArgumentException("The whole column " + i + " is missing");
        }
        double[] mean = new double[d];
        int[] n = new int[d];
        for (double[] x : data) {
            for (int j = 0; j < d; ++j) {
                if (Double.isNaN(x[j])) continue;
                int n2 = j;
                n[n2] = n[n2] + 1;
                int n3 = j;
                mean[n3] = mean[n3] + x[j];
            }
        }
        for (int j = 0; j < d; ++j) {
            if (n[j] == 0) continue;
            int n4 = j;
            mean[n4] = mean[n4] / (double)n[j];
        }
        for (double[] x : full = MathEx.clone((double[][])data)) {
            for (int j = 0; j < d; ++j) {
                if (!Double.isNaN(x[j])) continue;
                x[j] = mean[j];
            }
        }
        return full;
    }
}

