/*
 * Decompiled with CFR 0.152.
 */
package smile.feature.extraction;

import java.util.Arrays;
import java.util.function.Function;
import smile.data.DataFrame;
import smile.data.Tuple;
import smile.data.measure.CategoricalMeasure;
import smile.data.measure.Measure;
import smile.data.type.StructField;
import smile.data.type.StructType;
import smile.util.SparseArray;

public class SparseEncoder
implements Function<Tuple, SparseArray> {
    private final StructType schema;
    private final String[] columns;
    private final int[] base;

    public SparseEncoder(StructType schema, String ... columns) {
        this.schema = schema;
        if (columns == null || columns.length == 0) {
            columns = (String[])Arrays.stream(schema.fields()).filter(field -> field.isNumeric() || field.measure instanceof CategoricalMeasure).map(field -> field.name).toArray(String[]::new);
        }
        this.columns = columns;
        this.base = new int[columns.length];
        for (int i = 0; i < columns.length; ++i) {
            StructField field2 = schema.field(columns[i]);
            if (field2.isNumeric()) {
                if (i >= this.base.length - 1) continue;
                this.base[i + 1] = this.base[i] + 1;
                continue;
            }
            Measure measure = field2.measure;
            if (measure instanceof CategoricalMeasure) {
                CategoricalMeasure cat = (CategoricalMeasure)measure;
                if (i >= this.base.length - 1) continue;
                this.base[i + 1] = this.base[i] + cat.size();
                continue;
            }
            throw new IllegalArgumentException(String.format("Column '%s' is neither numeric or categorical", field2.name));
        }
    }

    @Override
    public SparseArray apply(Tuple x) {
        SparseArray features = new SparseArray();
        for (int i = 0; i < this.columns.length; ++i) {
            StructField field = this.schema.field(this.columns[i]);
            if (field.isNumeric()) {
                features.append(this.base[i], x.getDouble(this.columns[i]));
                continue;
            }
            if (field.measure instanceof CategoricalMeasure) {
                features.append(x.getInt(this.columns[i]) + this.base[i], 1.0);
                continue;
            }
            throw new IllegalArgumentException(String.format("Column '%s' is neither numeric or categorical", field.name));
        }
        return features;
    }

    @Override
    public SparseArray[] apply(DataFrame data) {
        return (SparseArray[])data.stream().map(this).toArray(SparseArray[]::new);
    }
}

