package ch.resear.thiriot.knime.bayesiannetworks.lib.inference;

import cern.jet.random.engine.RandomEngine;
import ch.resear.thiriot.knime.bayesiannetworks.lib.ILogger;
import ch.resear.thiriot.knime.bayesiannetworks.lib.bn.CategoricalBayesianNetwork;
import ch.resear.thiriot.knime.bayesiannetworks.lib.bn.NodeCategorical;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.stream.Collectors;

/* loaded from: input_file:readbnfromxmlbif.jar:ch/resear/thiriot/knime/bayesiannetworks/lib/inference/RecursiveConditionningEngine.class */
public class RecursiveConditionningEngine extends AbstractInferenceEngine {
    protected DNode dtreeWithoutEvidence;
    private List<NodeCategorical> eliminationOrder;
    protected DNode dtreeWithEvidence;
    private List<NodeCategorical> eliminationOrderWithEvidence;
    protected Double norm;

    public RecursiveConditionningEngine(ILogger iLogger, RandomEngine randomEngine, CategoricalBayesianNetwork categoricalBayesianNetwork) {
        super(iLogger, randomEngine, categoricalBayesianNetwork);
        this.dtreeWithoutEvidence = null;
        this.eliminationOrder = null;
        this.dtreeWithEvidence = null;
        this.eliminationOrderWithEvidence = null;
        this.norm = null;
    }

    public void internalizeEvidence() {
    }

    @Override // ch.resear.thiriot.knime.bayesiannetworks.lib.inference.AbstractInferenceEngine
    public void compute() {
        this.dtreeWithEvidence = null;
        this.dtreeWithoutEvidence = null;
        this.norm = null;
        if (this.evidenceVariable2value.isEmpty()) {
            this.norm = Double.valueOf(1.0d);
        } else {
            this.norm = Double.valueOf(getDtreeWithoutEvidence().recursiveConditionning(this.evidenceVariable2value));
        }
        if (this.logger.isDebugEnabled()) {
            this.logger.debug("dtree is:\n " + this.dtreeWithoutEvidence);
            this.logger.debug("probability for evidence  p(" + this.evidenceVariable2value + ")=" + this.norm);
        }
        super.compute();
    }

    /* JADX INFO: Access modifiers changed from: protected */
    @Override // ch.resear.thiriot.knime.bayesiannetworks.lib.inference.AbstractInferenceEngine
    public double retrieveConditionalProbability(NodeCategorical nodeCategorical, String str) {
        String str2 = this.evidenceVariable2value.get(nodeCategorical);
        if (str2 != null) {
            return str2.equals(str) ? 1.0d : 0.0d;
        }
        HashMap hashMap = new HashMap(this.evidenceVariable2value);
        hashMap.put(nodeCategorical, str);
        return getDtreeWithEvidence().recursiveConditionning(hashMap);
    }

    /* JADX INFO: Access modifiers changed from: protected */
    @Override // ch.resear.thiriot.knime.bayesiannetworks.lib.inference.AbstractInferenceEngine
    public double[] retrieveConditionalProbability(NodeCategorical nodeCategorical) {
        double recursiveConditionning;
        double[] dArr = new double[nodeCategorical.getDomainSize()];
        double d = 0.0d;
        for (int i = 0; i < nodeCategorical.getDomainSize() - 1; i++) {
            String valueIndexed = nodeCategorical.getValueIndexed(i);
            if (this.evidenceVariable2value.containsKey(nodeCategorical)) {
                recursiveConditionning = this.evidenceVariable2value.get(nodeCategorical).equals(valueIndexed) ? 1.0d : 0.0d;
            } else {
                HashMap hashMap = new HashMap(this.evidenceVariable2value);
                hashMap.put(nodeCategorical, valueIndexed);
                try {
                    recursiveConditionning = getDtreeWithEvidence().recursiveConditionning(hashMap) / getProbabilityEvidence();
                } catch (ArithmeticException e) {
                    throw new RuntimeException("error during the recursive conditioning of p(" + nodeCategorical.name + "=" + valueIndexed + ")", e);
                }
            }
            dArr[i] = recursiveConditionning;
            d += recursiveConditionning;
        }
        dArr[nodeCategorical.getDomainSize() - 1] = 1.0d - d;
        return dArr;
    }

    /* JADX INFO: Access modifiers changed from: protected */
    @Override // ch.resear.thiriot.knime.bayesiannetworks.lib.inference.AbstractInferenceEngine
    public double computeProbabilityEvidence() {
        if (this.norm == null) {
            this.norm = Double.valueOf(getDtreeWithoutEvidence().recursiveConditionning(this.evidenceVariable2value));
        }
        return this.norm.doubleValue();
    }

    private DNode getDtreeWithoutEvidence() {
        if (this.dtreeWithoutEvidence == null) {
            this.eliminationOrder = EliminationOrderBestFirstSearch.computeEliminationOrder(this.logger, this.bn);
            this.logger.debug("building the generic dtree without evidence...");
            this.dtreeWithoutEvidence = DNode.eliminationOrder2DTree(this.bn, this.eliminationOrder);
            this.logger.info("created dtree:\n" + this.dtreeWithoutEvidence);
        }
        return this.dtreeWithoutEvidence;
    }

    private DNode getDtreeWithEvidence() {
        if (this.dtreeWithEvidence == null) {
            this.logger.debug("creating dtree for evidence " + this.evidenceVariable2value);
            this.eliminationOrderWithEvidence = EliminationOrderBestFirstSearch.computeEliminationOrder(this.logger, this.bn, (Set) this.bn.getNodes().stream().filter(nodeCategorical -> {
                return !this.evidenceVariable2value.containsKey(nodeCategorical);
            }).collect(Collectors.toSet()));
            this.dtreeWithEvidence = DNode.eliminationOrder2DTree(this.bn, this.eliminationOrderWithEvidence);
            this.dtreeWithEvidence.reduce(this.evidenceVariable2value);
            this.logger.info("generated dtree with evidence " + this.evidenceVariable2value + ":\n" + this.dtreeWithEvidence);
        }
        return this.dtreeWithEvidence;
    }

    public Map<NodeCategorical, String> sampleOneTODO() {
        if (this.dirty) {
            compute();
        }
        this.logger.debug("sampling for evidence" + this.evidenceVariable2value);
        DNode dtreeWithEvidence = getDtreeWithEvidence();
        HashMap hashMap = new HashMap(this.bn.getNodes().size());
        hashMap.putAll(this.evidenceVariable2value);
        this.logger.debug("from evidence, we know: " + hashMap);
        double d = 1.0d;
        HashMap hashMap2 = new HashMap(this.bn.getNodes().size());
        for (NodeCategorical nodeCategorical : this.bn.enumerateNodes()) {
            if (!this.evidenceVariable2value.containsKey(nodeCategorical)) {
                if (this.logger.isDebugEnabled()) {
                    this.logger.debug("selecting a value for " + nodeCategorical + "=?");
                }
                String str = null;
                double nextDouble = this.rng.nextDouble() * d;
                double d2 = 0.0d;
                Iterator<String> it = nodeCategorical.getDomain().iterator();
                while (it.hasNext()) {
                    hashMap2.put(nodeCategorical, it.next());
                    d2 += dtreeWithEvidence.recursiveConditionning(hashMap2);
                }
                if (Math.abs(d2 - d) > 1.0E-8d) {
                    throw new RuntimeException("not summing to norm evidence " + d + " but to " + d2 + "...");
                }
                double d3 = 0.0d;
                Iterator<String> it2 = nodeCategorical.getDomain().iterator();
                while (true) {
                    if (!it2.hasNext()) {
                        break;
                    }
                    String next = it2.next();
                    hashMap2.put(nodeCategorical, next);
                    dtreeWithEvidence.resetCache();
                    dtreeWithEvidence.resetCacheChildren();
                    double recursiveConditionning = dtreeWithEvidence.recursiveConditionning(hashMap2);
                    if (this.logger.isDebugEnabled()) {
                        this.logger.debug("p(" + nodeCategorical + "=" + next + "|" + hashMap2 + ")=" + recursiveConditionning);
                    }
                    d3 += recursiveConditionning;
                    if (d3 >= nextDouble) {
                        str = next;
                        break;
                    }
                }
                if (this.logger.isDebugEnabled()) {
                    this.logger.debug("selected " + nodeCategorical + "=" + str);
                }
                if (str == null) {
                    throw new RuntimeException("oops, should have picked a value based on postererior probabilities, but they sum to " + d3);
                }
                hashMap.put(nodeCategorical, str);
                if (this.logger.isDebugEnabled()) {
                    this.logger.debug("asserting evidence " + hashMap2);
                }
                d = dtreeWithEvidence.recursiveConditionning(hashMap2);
            }
        }
        return hashMap;
    }
}
