package ch.resear.thiriot.knime.bayesiannetworks.lib.inference;

import cern.jet.random.AbstractContinousDistribution;
import cern.jet.random.Uniform;
import cern.jet.random.engine.RandomEngine;
import ch.resear.thiriot.knime.bayesiannetworks.lib.ILogger;
import ch.resear.thiriot.knime.bayesiannetworks.lib.bn.CategoricalBayesianNetwork;
import ch.resear.thiriot.knime.bayesiannetworks.lib.bn.NodeCategorical;
import java.util.Arrays;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.stream.Collectors;

/* loaded from: input_file:readbnfromxmlbif.jar:ch/resear/thiriot/knime/bayesiannetworks/lib/inference/AbstractInferenceEngine.class */
public abstract class AbstractInferenceEngine {
    protected final ILogger logger;
    protected final AbstractContinousDistribution rng;
    protected final CategoricalBayesianNetwork bn;
    protected Map<NodeCategorical, String> evidenceVariable2value = new HashMap();
    protected boolean dirty = true;
    private Map<NodeCategorical, double[]> node2cumulatedProbability = new HashMap();

    public AbstractInferenceEngine(ILogger iLogger, RandomEngine randomEngine, CategoricalBayesianNetwork categoricalBayesianNetwork) {
        this.bn = categoricalBayesianNetwork;
        this.logger = iLogger;
        this.rng = new Uniform(randomEngine);
    }

    public CategoricalBayesianNetwork getBN() {
        return this.bn;
    }

    public void addEvidence(NodeCategorical nodeCategorical, String str) {
        if (!this.bn.containsNode(nodeCategorical)) {
            throw new IllegalArgumentException("this node is not in the bn: " + nodeCategorical);
        }
        if (!nodeCategorical.contains(str)) {
            throw new IllegalArgumentException("value \"" + str + "\" unknown in node " + nodeCategorical + " (it contains " + nodeCategorical.getDomain() + ")");
        }
        this.dirty = str != this.evidenceVariable2value.put(nodeCategorical, str) || this.dirty;
    }

    public void addEvidence(String str, String str2) {
        NodeCategorical variable = this.bn.getVariable(str);
        if (variable == null) {
            throw new IllegalArgumentException("unknown node " + str);
        }
        addEvidence(variable, str2);
    }

    public void removeEvidence(NodeCategorical nodeCategorical) {
        this.dirty = this.evidenceVariable2value.remove(nodeCategorical) != null || this.dirty;
    }

    public void clearEvidence() {
        if (this.evidenceVariable2value.isEmpty()) {
            return;
        }
        this.dirty = true;
        this.evidenceVariable2value = new HashMap();
    }

    public void compute() {
        this.dirty = false;
    }

    public void computeAll() {
        for (NodeCategorical nodeCategorical : this.bn.getNodes()) {
            if (this.logger.isInfoEnabled()) {
                this.logger.info("computing probability for " + nodeCategorical + " (" + nodeCategorical.getDomainSize() + " values: " + nodeCategorical.getDomain() + ")");
            }
            retrieveConditionalProbability(nodeCategorical);
            InferencePerformanceUtils.singleton.display(this.logger);
        }
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public abstract double retrieveConditionalProbability(NodeCategorical nodeCategorical, String str);

    /* JADX INFO: Access modifiers changed from: protected */
    public abstract double[] retrieveConditionalProbability(NodeCategorical nodeCategorical);

    public final double getConditionalProbability(NodeCategorical nodeCategorical, String str) {
        if (!nodeCategorical.contains(str)) {
            throw new IllegalArgumentException("there is no value " + str + " in the domain of variable " + nodeCategorical + " (use one of " + nodeCategorical.getDomain() + ")");
        }
        if (this.evidenceVariable2value.isEmpty() && !nodeCategorical.hasParents()) {
            return nodeCategorical.getProbability(str, new Object[0]);
        }
        String str2 = this.evidenceVariable2value.get(nodeCategorical);
        if (str2 != null) {
            return str2.equals(str) ? 1.0d : 0.0d;
        }
        if (this.dirty) {
            compute();
        }
        return retrieveConditionalProbability(nodeCategorical, str);
    }

    public final double getConditionalProbability(String str, String str2) {
        NodeCategorical variable = this.bn.getVariable(str);
        if (variable == null) {
            throw new IllegalArgumentException("this Bayesian network does not contains a variable named " + str);
        }
        return getConditionalProbability(variable, str2);
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public Set<NodeCategorical> selectRelevantVariables(NodeCategorical nodeCategorical, Map<NodeCategorical, String> map, Set<NodeCategorical> set) {
        HashSet hashSet = new HashSet(set.size());
        if (nodeCategorical != null) {
            hashSet.addAll(nodeCategorical.getAllAncestors());
        }
        Iterator<NodeCategorical> it = map.keySet().iterator();
        while (it.hasNext()) {
            hashSet.addAll(it.next().getAllAncestors());
        }
        return hashSet;
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public Set<NodeCategorical> selectRelevantVariables(Set<NodeCategorical> set, Map<NodeCategorical, String> map, Set<NodeCategorical> set2) {
        if (this.logger.isDebugEnabled()) {
            this.logger.debug("select relevant variables to compute " + ((String) set.stream().map((v0) -> {
                return v0.getName();
            }).collect(Collectors.joining(","))) + " for evidence " + ((String) map.entrySet().stream().map(entry -> {
                return String.valueOf(((NodeCategorical) entry.getKey()).name) + "=" + ((String) entry.getValue());
            }).collect(Collectors.joining(","))) + " among " + ((String) set2.stream().map((v0) -> {
                return v0.getName();
            }).collect(Collectors.joining(","))));
        }
        HashSet hashSet = new HashSet(set2.size());
        Iterator<NodeCategorical> it = set.iterator();
        while (it.hasNext()) {
            hashSet.addAll(it.next().getAllAncestors());
        }
        Iterator<NodeCategorical> it2 = map.keySet().iterator();
        while (it2.hasNext()) {
            hashSet.addAll(it2.next().getAllAncestors());
        }
        if (this.logger.isDebugEnabled()) {
            this.logger.debug("selected relevant variables " + ((String) hashSet.stream().map((v0) -> {
                return v0.getName();
            }).collect(Collectors.joining(","))));
        }
        return hashSet;
    }

    public Factor computeFactorPriorMarginalsFromString(Set<String> set) {
        return computeFactorPosteriorMarginals((Set) set.stream().map(str -> {
            return this.bn.getVariable(str);
        }).collect(Collectors.toSet()));
    }

    public Factor computeFactorPosteriorMarginals(Set<NodeCategorical> set) {
        throw new UnsupportedOperationException("this inference engine does not computes prior marginals as factors");
    }

    public void addEvidence(Map<NodeCategorical, String> map) {
        for (Map.Entry<NodeCategorical, String> entry : map.entrySet()) {
            addEvidence(entry.getKey(), entry.getValue());
        }
    }

    private double[] getCumulatedProbabilities(NodeCategorical nodeCategorical) {
        double[] dArr = this.node2cumulatedProbability.get(nodeCategorical);
        if (dArr == null) {
            dArr = new double[nodeCategorical.getDomainSize()];
            List<String> domain = nodeCategorical.getDomain();
            double d = 0.0d;
            for (int i = 0; i < domain.size(); i++) {
                d += getConditionalProbability(nodeCategorical, domain.get(i));
                dArr[i] = d;
            }
            this.node2cumulatedProbability.put(nodeCategorical, dArr);
        }
        return dArr;
    }

    private String sampleValueWithProbaCache(NodeCategorical nodeCategorical, double d) {
        int binarySearch = Arrays.binarySearch(getCumulatedProbabilities(nodeCategorical), d);
        return binarySearch > 0 ? nodeCategorical.getDomain(binarySearch) : nodeCategorical.getDomain((-binarySearch) - 1);
    }

    public Map<NodeCategorical, String> sampleOne() {
        if (getProbabilityEvidence() == 0.0d) {
            throw new IllegalArgumentException("cannot generate if the probability of evidence is 0 - evidence is not possible");
        }
        HashMap hashMap = new HashMap(this.evidenceVariable2value);
        HashMap hashMap2 = new HashMap();
        for (NodeCategorical nodeCategorical : this.bn.enumerateNodes()) {
            String str = null;
            if (hashMap.containsKey(nodeCategorical)) {
                str = hashMap.get(nodeCategorical);
            } else {
                double nextDouble = this.rng.nextDouble();
                if (nodeCategorical.hasParents()) {
                    double d = 0.0d;
                    Iterator<String> it = nodeCategorical.getDomain().iterator();
                    while (true) {
                        if (!it.hasNext()) {
                            break;
                        }
                        String next = it.next();
                        d += getConditionalProbability(nodeCategorical, next);
                        if (d >= nextDouble) {
                            str = next;
                            break;
                        }
                    }
                    if (str == null) {
                        throw new RuntimeException("oops, should have picked a value based on postererior probabilities for variable " + nodeCategorical + " knowing " + this.evidenceVariable2value);
                    }
                } else {
                    str = sampleValueWithProbaCache(nodeCategorical, nextDouble);
                }
            }
            hashMap2.put(nodeCategorical, str);
            addEvidence(nodeCategorical, str);
        }
        clearEvidence();
        addEvidence(hashMap);
        return hashMap2;
    }

    public final double getProbabilityEvidence() {
        if (this.evidenceVariable2value.isEmpty()) {
            return 1.0d;
        }
        if (this.dirty) {
            compute();
        }
        return computeProbabilityEvidence();
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public abstract double computeProbabilityEvidence();

    /* JADX INFO: Access modifiers changed from: protected */
    public double[] getEvidenceAsDoubleArray(NodeCategorical nodeCategorical) {
        String str = this.evidenceVariable2value.get(nodeCategorical);
        if (str == null) {
            throw new IllegalArgumentException("cannot return evidence as a double array if there is not evidence for the current node");
        }
        double[] dArr = new double[nodeCategorical.getDomainSize()];
        dArr[nodeCategorical.getDomainIndex(str)] = 1.0d;
        return dArr;
    }
}
