package ch.resear.thiriot.knime.bayesiannetworks.lib.sampling;

import cern.colt.function.DoubleFunction;
import ch.resear.thiriot.knime.bayesiannetworks.lib.ILogger;
import ch.resear.thiriot.knime.bayesiannetworks.lib.bn.CategoricalBayesianNetwork;
import ch.resear.thiriot.knime.bayesiannetworks.lib.bn.NodeCategorical;
import ch.resear.thiriot.knime.bayesiannetworks.lib.inference.AbstractInferenceEngine;
import java.util.Arrays;
import java.util.Collections;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.stream.Collectors;
import org.knime.core.node.CanceledExecutionException;
import org.knime.core.node.ExecutionMonitor;

/* loaded from: input_file:readbnfromxmlbif.jar:ch/resear/thiriot/knime/bayesiannetworks/lib/sampling/RecursiveSamplingIterator.class */
public abstract class RecursiveSamplingIterator<R extends DoubleFunction> implements Iterator<EntitiesAndCount> {
    private final boolean isRoot;
    protected final String name;
    protected final R rng;
    protected final ExecutionMonitor exec;
    protected final ILogger logger;
    protected final boolean debug;
    private final NodeCategorical node;
    private final Iterator<Map.Entry<String, Integer>> itDomainAndCount;
    private RecursiveSamplingIterator<R> itSub;
    private final Map<NodeCategorical, String> evidence;
    private final List<NodeCategorical> remaining;
    private final Map<NodeCategorical, Map<String, Double>> alreadyComputedNow;
    protected final AbstractInferenceEngine engine;

    public RecursiveSamplingIterator(int i, CategoricalBayesianNetwork categoricalBayesianNetwork, R r, AbstractInferenceEngine abstractInferenceEngine, ExecutionMonitor executionMonitor, ILogger iLogger) {
        this(i, categoricalBayesianNetwork.enumerateNodes(), r, abstractInferenceEngine, executionMonitor, iLogger);
    }

    protected RecursiveSamplingIterator(int i, List<NodeCategorical> list, R r, AbstractInferenceEngine abstractInferenceEngine, ExecutionMonitor executionMonitor, ILogger iLogger) {
        this(i, list.get(0), list.subList(1, list.size()), Collections.emptyMap(), Collections.emptyMap(), r, abstractInferenceEngine, executionMonitor, iLogger, "");
    }

    protected abstract int[] getCounts(int i, double[] dArr);

    /* JADX INFO: Access modifiers changed from: protected */
    public RecursiveSamplingIterator(int i, NodeCategorical nodeCategorical, List<NodeCategorical> list, Map<NodeCategorical, String> map, Map<NodeCategorical, Map<String, Double>> map2, R r, AbstractInferenceEngine abstractInferenceEngine, ExecutionMonitor executionMonitor, ILogger iLogger, String str) {
        this.itSub = null;
        this.isRoot = map.isEmpty();
        this.exec = executionMonitor;
        this.node = nodeCategorical;
        this.name = String.valueOf(str) + " -- " + this.node.name;
        this.rng = r;
        this.evidence = map;
        this.remaining = list;
        this.logger = iLogger;
        this.engine = abstractInferenceEngine;
        this.debug = false;
        if (this.debug) {
            this.logger.debug("iterator " + this.name + (list.isEmpty() ? " -| " : "") + " (generate " + i + ")");
            this.logger.debug("\tcomputing p(" + nodeCategorical.name + "|" + ((String) map.entrySet().stream().map(entry -> {
                return String.valueOf(((NodeCategorical) entry.getKey()).name) + "=" + ((String) entry.getValue());
            }).collect(Collectors.joining(","))) + ")");
        }
        abstractInferenceEngine.clearEvidence();
        abstractInferenceEngine.addEvidence(map);
        try {
            double[] array = this.node.getDomain().stream().mapToDouble(str2 -> {
                return abstractInferenceEngine.getConditionalProbability(this.node, str2);
            }).toArray();
            if (this.debug) {
                this.logger.debug("\tprobabilities: " + Arrays.toString(array));
            }
            if (Double.isNaN(array[0])) {
                throw new RuntimeException("unable to compute p(" + nodeCategorical.name + "|" + ((String) map.entrySet().stream().map(entry2 -> {
                    return String.valueOf(((NodeCategorical) entry2.getKey()).name) + "=" + ((String) entry2.getValue());
                }).collect(Collectors.joining(","))) + ")");
            }
            this.alreadyComputedNow = new HashMap(map2);
            HashMap hashMap = new HashMap();
            for (int i2 = 0; i2 < array.length; i2++) {
                hashMap.put(nodeCategorical.getDomain(i2), Double.valueOf(array[i2]));
            }
            this.alreadyComputedNow.put(nodeCategorical, hashMap);
            try {
                int[] counts = getCounts(i, array);
                if (this.debug) {
                    this.logger.warn("\tcounts:");
                    this.logger.warn("\t" + Arrays.toString(counts));
                }
                HashMap hashMap2 = new HashMap();
                for (int i3 = 0; i3 < array.length; i3++) {
                    if (counts[i3] != 0) {
                        hashMap2.put(nodeCategorical.getDomain(i3), Integer.valueOf(counts[i3]));
                    }
                }
                this.itDomainAndCount = hashMap2.entrySet().iterator();
            } catch (RuntimeException e) {
                e.printStackTrace();
                throw new RuntimeException("error computing the counts to generate for variable " + this.node + ": " + e.getMessage(), e);
            }
        } catch (RuntimeException e2) {
            e2.printStackTrace();
            throw new RuntimeException("error when computing conditional probabilities for variable " + this.node + ": " + e2.getMessage(), e2);
        }
    }

    @Override // java.util.Iterator
    public boolean hasNext() {
        try {
            this.exec.checkCanceled();
            if (this.itSub != null && this.itSub.hasNext()) {
                return true;
            }
            this.itSub = null;
            return this.itDomainAndCount.hasNext();
        } catch (CanceledExecutionException unused) {
            return false;
        }
    }

    /* JADX WARN: Can't rename method to resolve collision */
    @Override // java.util.Iterator
    public EntitiesAndCount next() {
        try {
            if (this.itSub != null) {
                return this.itSub.next();
            }
            Map.Entry<String, Integer> next = this.itDomainAndCount.next();
            String key = next.getKey();
            Integer value = next.getValue();
            HashMap hashMap = new HashMap(this.evidence);
            hashMap.put(this.node, key);
            if (this.remaining.isEmpty()) {
                return new EntitiesAndCount(hashMap, value);
            }
            this.itSub = createSubIterator(value.intValue(), this.remaining.get(0), this.remaining.subList(1, this.remaining.size()), hashMap, this.alreadyComputedNow);
            return this.itSub.next();
        } catch (RuntimeException e) {
            e.printStackTrace();
            throw new RuntimeException("Error computing next recursive iterator " + e.getMessage(), e);
        }
    }

    protected abstract RecursiveSamplingIterator<R> createSubIterator(int i, NodeCategorical nodeCategorical, List<NodeCategorical> list, Map<NodeCategorical, String> map, Map<NodeCategorical, Map<String, Double>> map2);
}
