package ch.resear.thiriot.knime.bayesiannetworks.lib.sampling;

import cern.colt.function.DoubleFunction;
import ch.resear.thiriot.knime.bayesiannetworks.lib.ILogger;
import ch.resear.thiriot.knime.bayesiannetworks.lib.bn.CategoricalBayesianNetwork;
import ch.resear.thiriot.knime.bayesiannetworks.lib.bn.NodeCategorical;
import ch.resear.thiriot.knime.bayesiannetworks.lib.inference.AbstractInferenceEngine;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Spliterator;
import java.util.function.Consumer;
import java.util.stream.Collectors;
import org.knime.core.node.CanceledExecutionException;
import org.knime.core.node.ExecutionMonitor;

/* loaded from: input_file:readbnfromxmlbif.jar:ch/resear/thiriot/knime/bayesiannetworks/lib/sampling/RecursiveSamplingSpliterator.class */
public abstract class RecursiveSamplingSpliterator<R extends DoubleFunction> implements Spliterator<EntitiesAndCount> {
    protected final String name;
    protected final R rng;
    protected ExecutionMonitor exec;
    protected final ILogger logger;
    protected final boolean debug;
    private final NodeCategorical node;
    private final Map<String, Integer> value2count;
    private List<String> domainToExplore;
    private RecursiveSamplingSpliterator<R> itSub;
    private final Map<NodeCategorical, String> evidence;
    private final List<NodeCategorical> remaining;
    private final Map<NodeCategorical, Map<String, Double>> alreadyComputedNow;
    protected final AbstractInferenceEngine engine;
    private int count;

    public RecursiveSamplingSpliterator(int i, CategoricalBayesianNetwork categoricalBayesianNetwork, R r, AbstractInferenceEngine abstractInferenceEngine, ExecutionMonitor executionMonitor, ILogger iLogger) {
        this(i, categoricalBayesianNetwork.enumerateNodes(), r, abstractInferenceEngine, executionMonitor, iLogger);
    }

    protected RecursiveSamplingSpliterator(int i, List<NodeCategorical> list, R r, AbstractInferenceEngine abstractInferenceEngine, ExecutionMonitor executionMonitor, ILogger iLogger) {
        this(i, list.get(0), list.subList(1, list.size()), Collections.emptyMap(), Collections.emptyMap(), r, abstractInferenceEngine, executionMonitor, iLogger, "", list.get(0).getDomain());
    }

    protected abstract int[] getCounts(int i, double[] dArr);

    /* JADX INFO: Access modifiers changed from: protected */
    public RecursiveSamplingSpliterator(int i, NodeCategorical nodeCategorical, List<NodeCategorical> list, Map<NodeCategorical, String> map, Map<NodeCategorical, Map<String, Double>> map2, R r, AbstractInferenceEngine abstractInferenceEngine, ExecutionMonitor executionMonitor, ILogger iLogger, String str, List<String> list2) {
        this.itSub = null;
        this.exec = executionMonitor;
        this.node = nodeCategorical;
        this.name = String.valueOf(str) + " -- " + this.node.name;
        this.rng = r;
        this.evidence = map;
        this.remaining = list;
        this.logger = iLogger;
        this.engine = abstractInferenceEngine;
        this.debug = iLogger.isDebugEnabled();
        this.count = i;
        this.domainToExplore = new ArrayList(list2);
        if (i < 0) {
            throw new IllegalArgumentException("count should not be negative");
        }
        if (this.debug) {
            this.logger.debug("iterator " + this.name + (list.isEmpty() ? " -| " : "") + " (generate " + i + ")");
            this.logger.debug("\tcomputing p(" + nodeCategorical.name + "|" + ((String) map.entrySet().stream().map(entry -> {
                return String.valueOf(((NodeCategorical) entry.getKey()).name) + "=" + ((String) entry.getValue());
            }).collect(Collectors.joining(","))) + ")");
        }
        abstractInferenceEngine.clearEvidence();
        abstractInferenceEngine.addEvidence(map);
        try {
            double[] array = this.node.getDomain().stream().mapToDouble(str2 -> {
                return abstractInferenceEngine.getConditionalProbability(this.node, str2);
            }).toArray();
            if (this.debug) {
                this.logger.debug("\tprobabilities: " + Arrays.toString(array));
            }
            if (Double.isNaN(array[0])) {
                throw new RuntimeException("unable to compute p(" + nodeCategorical.name + "|" + ((String) map.entrySet().stream().map(entry2 -> {
                    return String.valueOf(((NodeCategorical) entry2.getKey()).name) + "=" + ((String) entry2.getValue());
                }).collect(Collectors.joining(","))) + ")");
            }
            this.alreadyComputedNow = new HashMap(map2);
            HashMap hashMap = new HashMap();
            for (int i2 = 0; i2 < array.length; i2++) {
                hashMap.put(nodeCategorical.getDomain(i2), Double.valueOf(array[i2]));
            }
            this.alreadyComputedNow.put(nodeCategorical, hashMap);
            try {
                int[] counts = getCounts(i, array);
                if (this.debug) {
                    this.logger.warn("\tcounts:");
                    this.logger.warn("\t" + Arrays.toString(counts));
                }
                this.value2count = new HashMap();
                for (int i3 = 0; i3 < array.length; i3++) {
                    if (counts[i3] != 0) {
                        this.value2count.put(nodeCategorical.getDomain(i3), Integer.valueOf(counts[i3]));
                    }
                }
            } catch (RuntimeException e) {
                e.printStackTrace();
                throw new RuntimeException("error computing the counts to generate for variable " + this.node + ": " + e.getMessage(), e);
            }
        } catch (RuntimeException e2) {
            e2.printStackTrace();
            throw new RuntimeException("error when computing conditional probabilities for variable " + this.node + ": " + e2.getMessage(), e2);
        }
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public RecursiveSamplingSpliterator(int i, NodeCategorical nodeCategorical, List<NodeCategorical> list, Map<NodeCategorical, String> map, Map<NodeCategorical, Map<String, Double>> map2, R r, AbstractInferenceEngine abstractInferenceEngine, ExecutionMonitor executionMonitor, ILogger iLogger, String str, List<String> list2, Map<String, Integer> map3) {
        this.itSub = null;
        this.exec = executionMonitor;
        this.node = nodeCategorical;
        this.name = String.valueOf(str) + " -- " + this.node.name;
        this.rng = r;
        this.evidence = map;
        this.remaining = list;
        this.logger = iLogger;
        this.engine = abstractInferenceEngine;
        this.debug = true;
        this.count = i;
        this.domainToExplore = new ArrayList(list2);
        this.alreadyComputedNow = map2;
        this.value2count = map3;
    }

    @Override // java.util.Spliterator
    public boolean tryAdvance(Consumer<? super EntitiesAndCount> consumer) {
        try {
            this.exec.checkCanceled();
            try {
                if (this.itSub != null) {
                    if (this.itSub.tryAdvance(consumer)) {
                        return true;
                    }
                    this.itSub = null;
                }
                if (this.domainToExplore.isEmpty()) {
                    return false;
                }
                String str = null;
                Integer num = null;
                while (num == null && !this.domainToExplore.isEmpty()) {
                    str = this.domainToExplore.remove(0);
                    num = this.value2count.get(str);
                }
                if (num == null && this.domainToExplore.isEmpty()) {
                    return false;
                }
                HashMap hashMap = new HashMap(this.evidence);
                hashMap.put(this.node, str);
                if (!this.remaining.isEmpty()) {
                    this.itSub = createSubIterator(num.intValue(), this.remaining.get(0), this.remaining.subList(1, this.remaining.size()), hashMap, this.alreadyComputedNow, this.remaining.get(0).getDomain());
                    if (this.itSub.tryAdvance(consumer)) {
                        return true;
                    }
                    this.itSub = null;
                } else if (num != null) {
                    consumer.accept(new EntitiesAndCount(hashMap, num));
                }
                return !this.domainToExplore.isEmpty();
            } catch (RuntimeException e) {
                throw e;
            }
        } catch (CanceledExecutionException unused) {
            return false;
        }
    }

    @Override // java.util.Spliterator
    /* renamed from: trySplit, reason: merged with bridge method [inline-methods] */
    public Spliterator<EntitiesAndCount> trySplit2() {
        if (this.domainToExplore.size() < 2) {
            return null;
        }
        int size = this.domainToExplore.size() / 2;
        List<String> subList = this.domainToExplore.subList(size, this.domainToExplore.size());
        this.domainToExplore = this.domainToExplore.subList(0, size);
        if (this.logger.isDebugEnabled()) {
            this.logger.debug("splitting " + this.node.name + " between " + this.domainToExplore + " and " + subList);
        }
        this.count = this.domainToExplore.stream().mapToInt(str -> {
            return this.value2count.get(str).intValue();
        }).sum();
        return createSubIterator(subList.stream().mapToInt(str2 -> {
            return this.value2count.get(str2).intValue();
        }).sum(), this.node, this.remaining, this.evidence, this.alreadyComputedNow, subList, this.value2count);
    }

    protected abstract RecursiveSamplingSpliterator<R> createSubIterator(int i, NodeCategorical nodeCategorical, List<NodeCategorical> list, Map<NodeCategorical, String> map, Map<NodeCategorical, Map<String, Double>> map2, List<String> list2);

    protected abstract RecursiveSamplingSpliterator<R> createSubIterator(int i, NodeCategorical nodeCategorical, List<NodeCategorical> list, Map<NodeCategorical, String> map, Map<NodeCategorical, Map<String, Double>> map2, List<String> list2, Map<String, Integer> map3);

    @Override // java.util.Spliterator
    public long estimateSize() {
        long size = this.domainToExplore.size();
        while (this.remaining.iterator().hasNext()) {
            size *= r0.next().getDomainSize();
        }
        return size;
    }

    @Override // java.util.Spliterator
    public int characteristics() {
        return 16721;
    }
}
