package ai.libs.jaicore.ml.core.dataset.sampling.inmemory.stratified.sampling;

import ai.libs.jaicore.ml.core.dataset.IDataset;
import ai.libs.jaicore.ml.core.dataset.IInstance;
import ai.libs.jaicore.ml.core.dataset.sampling.inmemory.stratified.sampling.DiscretizationHelper;
import com.google.common.collect.Lists;
import com.google.common.collect.Sets;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.Collections;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.Future;
import org.apache.commons.collections4.keyvalue.MultiKey;
import org.apache.commons.collections4.map.MultiKeyMap;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:ai/libs/jaicore/ml/core/dataset/sampling/inmemory/stratified/sampling/AttributeBasedStratiAmountSelectorAndAssigner.class */
public class AttributeBasedStratiAmountSelectorAndAssigner<I extends IInstance> implements IStratiAmountSelector<I>, IStratiAssigner<I> {
    private static final Logger LOG = LoggerFactory.getLogger(AttributeBasedStratiAmountSelectorAndAssigner.class);
    private static final DiscretizationHelper.DiscretizationStrategy DEFAULT_DISCRETIZATION_STRATEGY = DiscretizationHelper.DiscretizationStrategy.EQUAL_SIZE;
    private static final int DEFAULT_DISCRETIZATION_CATEGORY_AMOUNT = 5;
    private List<Integer> attributeIndices;
    private MultiKeyMap<Object, Integer> stratumAssignments;
    private int numCPUs;
    private IDataset<I> dataset;
    private Map<Integer, AttributeDiscretizationPolicy> discretizationPolicies;
    private Map<Integer, Set<Object>> attributeValues;
    private DiscretizationHelper.DiscretizationStrategy discretizationStrategy;
    private int numberOfCategories;

    public AttributeBasedStratiAmountSelectorAndAssigner() {
        this.numCPUs = 1;
        this.discretizationStrategy = DEFAULT_DISCRETIZATION_STRATEGY;
        this.numberOfCategories = DEFAULT_DISCRETIZATION_CATEGORY_AMOUNT;
    }

    public AttributeBasedStratiAmountSelectorAndAssigner(List<Integer> list) {
        this(list, null);
        this.discretizationStrategy = DEFAULT_DISCRETIZATION_STRATEGY;
        this.numberOfCategories = DEFAULT_DISCRETIZATION_CATEGORY_AMOUNT;
    }

    public AttributeBasedStratiAmountSelectorAndAssigner(List<Integer> list, DiscretizationHelper.DiscretizationStrategy discretizationStrategy, int i) {
        this(list, null);
        this.discretizationStrategy = discretizationStrategy;
        this.numberOfCategories = i;
    }

    public AttributeBasedStratiAmountSelectorAndAssigner(List<Integer> list, Map<Integer, AttributeDiscretizationPolicy> map) {
        this.numCPUs = 1;
        if (list == null || list.isEmpty()) {
            throw new IllegalArgumentException("No attribute indices are provided!");
        }
        this.attributeIndices = list;
        this.discretizationPolicies = map;
    }

    @Override // ai.libs.jaicore.ml.core.dataset.sampling.inmemory.stratified.sampling.IStratiAmountSelector
    public int selectStratiAmount(IDataset<I> iDataset) {
        this.dataset = iDataset;
        computeAttributeValues();
        int i = 1;
        Iterator<Set<Object>> it = this.attributeValues.values().iterator();
        while (it.hasNext()) {
            i *= it.next().size();
        }
        if (LOG.isInfoEnabled()) {
            LOG.info(String.format("%d strati are needed", Integer.valueOf(i)));
        }
        return i;
    }

    private void computeAttributeValues() {
        LOG.info("computeAttributeValues(): enter");
        if (this.attributeIndices == null || this.attributeIndices.isEmpty()) {
            int numberOfAttributes = this.dataset.getNumberOfAttributes();
            if (LOG.isInfoEnabled()) {
                LOG.info(String.format("No attribute indices provided. Working with target attribute only (index: %d", Integer.valueOf(numberOfAttributes)));
            }
            this.attributeIndices = Collections.singletonList(Integer.valueOf(numberOfAttributes));
        }
        if (LOG.isDebugEnabled()) {
            LOG.debug("Computing attribute values for attribute indices {}", this.attributeIndices);
        }
        Iterator<Integer> it = this.attributeIndices.iterator();
        while (it.hasNext()) {
            int intValue = it.next().intValue();
            if (intValue > this.dataset.getNumberOfAttributes()) {
                throw new IndexOutOfBoundsException(String.format("Attribute index %d is out of bounds for the delivered data set!", Integer.valueOf(intValue)));
            }
        }
        this.attributeValues = new HashMap();
        Iterator<Integer> it2 = this.attributeIndices.iterator();
        while (it2.hasNext()) {
            this.attributeValues.put(Integer.valueOf(it2.next().intValue()), new HashSet());
        }
        ExecutorService newFixedThreadPool = Executors.newFixedThreadPool(this.numCPUs);
        ArrayList arrayList = new ArrayList();
        if (LOG.isInfoEnabled()) {
            LOG.info(String.format("Starting %d threads for computation..", Integer.valueOf(this.numCPUs)));
        }
        Iterator it3 = Lists.partition(this.dataset, this.dataset.size() / this.numCPUs).iterator();
        while (it3.hasNext()) {
            arrayList.add(newFixedThreadPool.submit(new ListProcessor((List) it3.next(), new HashSet(this.attributeIndices), this.dataset)));
        }
        Iterator it4 = arrayList.iterator();
        while (it4.hasNext()) {
            try {
                Map map = (Map) ((Future) it4.next()).get();
                for (Map.Entry<Integer, Set<Object>> entry : this.attributeValues.entrySet()) {
                    this.attributeValues.get(entry.getKey()).addAll((Collection) map.get(entry.getKey()));
                }
            } catch (InterruptedException e) {
                LOG.error("Thread has been interrupted");
                Thread.currentThread().interrupt();
            } catch (ExecutionException e2) {
                LOG.error("Exception while waiting for future to complete..", e2);
            }
        }
        newFixedThreadPool.shutdown();
        DiscretizationHelper discretizationHelper = new DiscretizationHelper();
        if (this.discretizationPolicies == null) {
            LOG.info("No discretization policies provided. Computing defaults..");
            this.discretizationPolicies = discretizationHelper.createDefaultDiscretizationPolicies(this.dataset, this.attributeIndices, this.attributeValues, this.discretizationStrategy, this.numberOfCategories);
        }
        if (!this.discretizationPolicies.isEmpty()) {
            if (LOG.isInfoEnabled()) {
                LOG.info("Discretizing numeric attributes using policies: {}", this.discretizationPolicies);
            }
            discretizationHelper.discretizeAttributeValues(this.discretizationPolicies, this.attributeValues);
        }
        LOG.info("computeAttributeValues(): leave");
    }

    @Override // ai.libs.jaicore.ml.core.dataset.sampling.inmemory.stratified.sampling.IStratiAmountSelector, ai.libs.jaicore.ml.core.dataset.sampling.inmemory.stratified.sampling.IStratiAssigner
    public void setNumCPUs(int i) {
        if (i < 1) {
            throw new IllegalArgumentException("Number of CPU cores must be nonnegative");
        }
        this.numCPUs = i;
    }

    @Override // ai.libs.jaicore.ml.core.dataset.sampling.inmemory.stratified.sampling.IStratiAmountSelector, ai.libs.jaicore.ml.core.dataset.sampling.inmemory.stratified.sampling.IStratiAssigner
    public int getNumCPUs() {
        return this.numCPUs;
    }

    @Override // ai.libs.jaicore.ml.core.dataset.sampling.inmemory.stratified.sampling.IStratiAssigner
    public void init(IDataset<I> iDataset, int i) {
        init(iDataset);
    }

    public void init(IDataset<I> iDataset) {
        LOG.debug("init(): enter");
        if (this.dataset == null || !this.dataset.equals(iDataset) || this.attributeValues == null) {
            this.dataset = iDataset;
            computeAttributeValues();
        } else {
            LOG.info("No recomputation of the attribute values needed");
        }
        Set<List> cartesianProduct = Sets.cartesianProduct(new ArrayList(this.attributeValues.values()));
        this.stratumAssignments = new MultiKeyMap<>();
        LOG.info("There are {} elements in the cartesian product of the attribute values", Integer.valueOf(cartesianProduct.size()));
        LOG.info("Assigning stratum numbers to elements in the cartesian product..");
        int i = 0;
        for (List list : cartesianProduct) {
            Object[] objArr = new Object[list.size()];
            list.toArray(objArr);
            MultiKey multiKey = new MultiKey(objArr);
            if (this.stratumAssignments.containsKey(multiKey)) {
                throw new IllegalStateException(String.format("Mulitkey %s occured twice!", multiKey.toString()));
            }
            int i2 = i;
            i++;
            this.stratumAssignments.put(new MultiKey(objArr), Integer.valueOf(i2));
        }
        LOG.debug("init(): leave");
    }

    @Override // ai.libs.jaicore.ml.core.dataset.sampling.inmemory.stratified.sampling.IStratiAssigner
    public int assignToStrati(IInstance iInstance) {
        DiscretizationHelper discretizationHelper = new DiscretizationHelper();
        if (this.stratumAssignments == null || this.stratumAssignments.isEmpty()) {
            throw new IllegalStateException("StratiAssigner has not been initialized!");
        }
        Object[] objArr = new Object[this.attributeIndices.size()];
        for (int i = 0; i < this.attributeIndices.size(); i++) {
            int intValue = this.attributeIndices.get(i).intValue();
            objArr[i] = toBeDiscretized(intValue) ? Integer.valueOf(discretizationHelper.discretize(((Double) (intValue == this.dataset.getNumberOfAttributes() ? iInstance.getTargetValue(Object.class).getValue() : iInstance.getAttributeValue(intValue, Object.class).getValue())).doubleValue(), this.discretizationPolicies.get(Integer.valueOf(intValue)))) : intValue == this.dataset.getNumberOfAttributes() ? iInstance.getTargetValue(Object.class).getValue() : iInstance.getAttributeValue(intValue, Object.class).getValue();
        }
        if (LOG.isDebugEnabled()) {
            LOG.debug(String.format("Attribute values are: %s", Arrays.toString(objArr)));
        }
        MultiKey multiKey = new MultiKey(objArr);
        if (!this.stratumAssignments.containsKey(multiKey)) {
            throw new IllegalStateException(String.format("No assignment available for attribute combination %s", Arrays.toString(objArr)));
        }
        int intValue2 = ((Integer) this.stratumAssignments.get(multiKey)).intValue();
        LOG.debug("Assigned stratum {}", Integer.valueOf(intValue2));
        return intValue2;
    }

    private boolean toBeDiscretized(int i) {
        return this.discretizationPolicies.containsKey(Integer.valueOf(i));
    }
}
