package ch.resear.thiriot.knime.bayesiannetworks.lib.inference;

import cern.jet.random.engine.RandomEngine;
import ch.resear.thiriot.knime.bayesiannetworks.lib.ILogger;
import ch.resear.thiriot.knime.bayesiannetworks.lib.bn.CategoricalBayesianNetwork;
import ch.resear.thiriot.knime.bayesiannetworks.lib.bn.IteratorCategoricalVariables;
import ch.resear.thiriot.knime.bayesiannetworks.lib.bn.NodeCategorical;
import java.util.Arrays;
import java.util.Collections;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Iterator;
import java.util.Map;
import java.util.Set;
import org.apache.commons.collections4.map.LRUMap;

/* loaded from: input_file:readbnfromxmlbif.jar:ch/resear/thiriot/knime/bayesiannetworks/lib/inference/SimpleConditionningInferenceEngine.class */
public class SimpleConditionningInferenceEngine extends AbstractInferenceEngine {
    private Map<NodeCategorical, double[]> computed;
    private static int CACHE_MAXITEMS = 100;
    private static int CACHE_MAXITEMS2 = 100;
    private static int CACHE_EVIDENCE = 5000;
    private LRUMap<Map<NodeCategorical, String>, Map<Set<NodeCategorical>, Double>> known2nuisance2value;
    private LRUMap<Map<NodeCategorical, String>, Double> evidence2proba;
    final boolean debug;

    public SimpleConditionningInferenceEngine(ILogger iLogger, RandomEngine randomEngine, CategoricalBayesianNetwork categoricalBayesianNetwork) {
        super(iLogger, randomEngine, categoricalBayesianNetwork);
        this.computed = new HashMap();
        this.known2nuisance2value = null;
        this.evidence2proba = null;
        this.debug = iLogger.isDebugEnabled();
    }

    /* JADX INFO: Access modifiers changed from: protected */
    @Override // ch.resear.thiriot.knime.bayesiannetworks.lib.inference.AbstractInferenceEngine
    public double retrieveConditionalProbability(NodeCategorical nodeCategorical, String str) {
        double computePosteriorConditionalProbability;
        if (this.debug) {
            this.logger.debug("p(" + nodeCategorical.name + "=" + str + "|" + this.evidenceVariable2value);
        }
        String str2 = this.evidenceVariable2value.get(nodeCategorical);
        if (str2 != null) {
            return str2.equals(str) ? 1.0d : 0.0d;
        }
        double[] dArr = this.computed.get(nodeCategorical);
        if (dArr == null) {
            dArr = new double[nodeCategorical.getDomainSize()];
            Arrays.fill(dArr, -1.0d);
            this.computed.put(nodeCategorical, dArr);
        } else {
            double d = dArr[nodeCategorical.getDomainIndex(str)];
            if (d > -1.0d) {
                return d;
            }
        }
        int i = 0;
        double d2 = 0.0d;
        for (int i2 = 0; i2 < dArr.length; i2++) {
            d2 += dArr[i2];
            if (d2 >= 1.0d) {
                break;
            }
            if (dArr[i2] > -1.0d) {
                i++;
            }
        }
        if (d2 >= 1.0d) {
            if (this.debug) {
                this.logger.debug("we can save one computation here by doing p(X=x)=0 because sum(p(X=^x))>=1");
            }
            computePosteriorConditionalProbability = 0.0d;
        } else if (i == dArr.length - 1) {
            if (this.debug) {
                this.logger.debug("we can save one computation here by doing p(X=x)=1 - sum(p(X=^x))");
            }
            double d3 = 1.0d;
            for (double d4 : dArr) {
                if (d4 >= 0.0d) {
                    d3 -= d4;
                }
            }
            computePosteriorConditionalProbability = d3;
            if (computePosteriorConditionalProbability < 0.0d) {
                computePosteriorConditionalProbability = 0.0d;
            }
        } else {
            if (this.debug) {
                this.logger.debug("no value computed for p(" + nodeCategorical.name + "=" + str + "|" + this.evidenceVariable2value + "), starting computation...");
            }
            computePosteriorConditionalProbability = computePosteriorConditionalProbability(nodeCategorical, str, this.evidenceVariable2value);
        }
        if (computePosteriorConditionalProbability < 0.0d) {
            throw new RuntimeException("negative probability computed for p(" + nodeCategorical.name + "=" + str + "|" + this.evidenceVariable2value + ")");
        }
        dArr[nodeCategorical.getDomainIndex(str)] = computePosteriorConditionalProbability;
        if (this.debug) {
            this.logger.debug("returning p(" + nodeCategorical.name + "=" + str + "|" + this.evidenceVariable2value + ")=" + computePosteriorConditionalProbability);
        }
        return computePosteriorConditionalProbability;
    }

    /* JADX INFO: Access modifiers changed from: protected */
    @Override // ch.resear.thiriot.knime.bayesiannetworks.lib.inference.AbstractInferenceEngine
    public double[] retrieveConditionalProbability(NodeCategorical nodeCategorical) {
        String str = this.evidenceVariable2value.get(nodeCategorical);
        if (str == null) {
            double[] dArr = this.computed.get(nodeCategorical);
            if (dArr != null) {
                return dArr;
            }
            if (dArr == null) {
                dArr = computePosteriorConditionalProbability(nodeCategorical, this.evidenceVariable2value);
                this.computed.put(nodeCategorical, dArr);
            }
            if (this.debug) {
                this.logger.debug("returning p(" + nodeCategorical.name + "=*|" + this.evidenceVariable2value + ") : " + dArr);
            }
            return dArr;
        }
        double[] dArr2 = new double[nodeCategorical.getDomainSize()];
        int i = 0;
        while (true) {
            if (i >= nodeCategorical.getDomainSize()) {
                break;
            }
            if (str.equals(nodeCategorical.getValueIndexed(i))) {
                dArr2[i] = 1.0d;
                break;
            }
            dArr2[i] = 0.0d;
            i++;
        }
        return dArr2;
    }

    protected Set<NodeCategorical> getLeaf(Set<NodeCategorical> set) {
        if (this.debug) {
            this.logger.debug("searching for the leafs of " + set);
        }
        HashSet hashSet = new HashSet(set);
        Iterator<NodeCategorical> it = set.iterator();
        while (it.hasNext()) {
            hashSet.removeAll(it.next().getParents());
        }
        if (this.debug) {
            this.logger.debug("leafs of " + set + " are " + hashSet);
        }
        return hashSet;
    }

    private Double getCached(Map<NodeCategorical, String> map, Set<NodeCategorical> set) {
        if (this.known2nuisance2value == null) {
            this.known2nuisance2value = new LRUMap<>(CACHE_MAXITEMS);
        }
        Map<Set<NodeCategorical>, Double> map2 = this.known2nuisance2value.get(map);
        if (map2 == null) {
            InferencePerformanceUtils.singleton.incCacheMiss();
            return null;
        }
        InferencePerformanceUtils.singleton.incCacheHit();
        return map2.get(set);
    }

    private void storeCache(Map<NodeCategorical, String> map, Set<NodeCategorical> set, Double d) {
        Map<Set<NodeCategorical>, Double> map2 = this.known2nuisance2value.get(map);
        if (map2 == null) {
            map2 = new LRUMap(CACHE_MAXITEMS2);
            this.known2nuisance2value.put(map, map2);
        }
        map2.put(set, d);
    }

    protected double sumProbabilities(Map<NodeCategorical, String> map, Set<NodeCategorical> set) {
        HashSet hashSet = new HashSet(set);
        hashSet.removeAll(map.keySet());
        if (set.isEmpty() && map.isEmpty()) {
            return 1.0d;
        }
        Double cached = getCached(map, hashSet);
        if (cached != null) {
            return cached.doubleValue();
        }
        Double valueOf = Double.valueOf(0.0d);
        if (this.debug) {
            this.logger.debug("summing probabilities for nuisance " + map + ", and known " + hashSet);
        }
        IteratorCategoricalVariables iterateDomains = this.bn.iterateDomains(hashSet);
        while (true) {
            if (!iterateDomains.hasNext()) {
                break;
            }
            Map<NodeCategorical, String> next = iterateDomains.next();
            next.putAll(map);
            double jointProbability = this.bn.jointProbability(next, Collections.emptyMap());
            if (this.debug) {
                this.logger.debug("p(" + next + ")=" + jointProbability);
            }
            valueOf = Double.valueOf(valueOf.doubleValue() + jointProbability);
            InferencePerformanceUtils.singleton.incAdditions();
            if (valueOf.doubleValue() >= 1.0d) {
                valueOf = Double.valueOf(1.0d);
                break;
            }
        }
        storeCache(map, hashSet, valueOf);
        if (this.debug) {
            this.logger.debug("total " + valueOf);
        }
        return valueOf.doubleValue();
    }

    /* JADX WARN: Multi-variable type inference failed */
    protected double[] computePosteriorConditionalProbability(NodeCategorical nodeCategorical, Map<NodeCategorical, String> map) {
        int domainSize = nodeCategorical.getDomainSize();
        double[] dArr = new double[domainSize];
        double probabilityEvidence = getProbabilityEvidence();
        for (int i = 0; i < domainSize; i++) {
            String valueIndexed = nodeCategorical.getValueIndexed(i);
            if (this.debug) {
                this.logger.debug("computing p(*=*|" + nodeCategorical.name + "=" + valueIndexed + ")");
            }
            HashMap hashMap = new HashMap(map);
            hashMap.put(nodeCategorical, valueIndexed);
            double sumProbabilities = sumProbabilities(hashMap, selectRelevantVariables(nodeCategorical, map, (Set<NodeCategorical>) this.bn.nodes));
            if (this.debug) {
                this.logger.debug("computed p(" + nodeCategorical.name + "=" + valueIndexed + "|" + hashMap + "," + nodeCategorical.name + "=" + valueIndexed + ")=" + sumProbabilities);
                this.logger.debug("computed p(*=*|" + nodeCategorical.name + "=" + valueIndexed + ")=" + sumProbabilities);
            }
            dArr[i] = sumProbabilities;
        }
        if (this.debug) {
            this.logger.debug("now computing the overall probas");
        }
        for (int i2 = 0; i2 < domainSize; i2++) {
            String valueIndexed2 = nodeCategorical.getValueIndexed(i2);
            double d = dArr[i2];
            double d2 = d / probabilityEvidence;
            dArr[i2] = d2;
            if (this.debug) {
                this.logger.debug("computed p({" + nodeCategorical.name + "=" + valueIndexed2 + "|evidence)= p(" + nodeCategorical.name + "=" + valueIndexed2 + "|evidence)/p(" + nodeCategorical.name + "|evidence)=" + d + "/" + probabilityEvidence + "=" + d2);
            }
        }
        return dArr;
    }

    /* JADX WARN: Multi-variable type inference failed */
    protected double computePosteriorConditionalProbability(NodeCategorical nodeCategorical, String str, Map<NodeCategorical, String> map) {
        double d;
        double sumProbabilities = sumProbabilities(map, selectRelevantVariables((NodeCategorical) null, map, (Set<NodeCategorical>) this.bn.nodes));
        if (this.debug) {
            this.logger.debug("computing p(*=*|" + nodeCategorical.name + "=" + str + ")");
        }
        HashMap hashMap = new HashMap(map);
        hashMap.put(nodeCategorical, str);
        double sumProbabilities2 = sumProbabilities(hashMap, selectRelevantVariables(nodeCategorical, map, (Set<NodeCategorical>) this.bn.nodes));
        if (this.debug) {
            this.logger.debug("computed p(" + nodeCategorical.name + "=" + str + "|" + hashMap + "," + nodeCategorical.name + "=" + str + ")=" + sumProbabilities2);
            this.logger.debug("computed p(*=*|" + nodeCategorical.name + "=" + str + ")=" + sumProbabilities2);
            this.logger.debug("now computing the overall probas");
        }
        try {
            d = sumProbabilities2 / sumProbabilities;
        } catch (ArithmeticException unused) {
            this.logger.error("unable to compute probability p(" + nodeCategorical.name + "=" + str + "|*): pfree=" + sumProbabilities + ", p=" + sumProbabilities2);
            d = 0.0d;
        }
        if (this.debug) {
            this.logger.debug("computed p(" + nodeCategorical.name + "=" + str + "|evidence)= p(" + nodeCategorical.name + "=" + str + "|evidence)/p(" + nodeCategorical.name + "|evidence)=" + sumProbabilities2 + "/" + sumProbabilities + "=" + d);
        }
        return d;
    }

    @Override // ch.resear.thiriot.knime.bayesiannetworks.lib.inference.AbstractInferenceEngine
    public void compute() {
        this.computed.clear();
        super.compute();
    }

    /* JADX INFO: Access modifiers changed from: protected */
    /* JADX WARN: Multi-variable type inference failed */
    @Override // ch.resear.thiriot.knime.bayesiannetworks.lib.inference.AbstractInferenceEngine
    public double computeProbabilityEvidence() {
        if (this.evidence2proba == null) {
            this.evidence2proba = new LRUMap<>(CACHE_EVIDENCE);
        }
        Double d = this.evidence2proba.get(this.evidence2proba);
        if (d == null) {
            InferencePerformanceUtils.singleton.incCacheMiss();
            d = Double.valueOf(sumProbabilities(this.evidenceVariable2value, selectRelevantVariables((NodeCategorical) null, this.evidenceVariable2value, (Set<NodeCategorical>) this.bn.nodes)));
            this.evidence2proba.put(this.evidenceVariable2value, d);
        } else {
            InferencePerformanceUtils.singleton.incCacheHit();
        }
        return d.doubleValue();
    }
}
