package ch.resear.thiriot.knime.bayesiannetworks.lib.inference;

import cern.jet.random.engine.RandomEngine;
import ch.resear.thiriot.knime.bayesiannetworks.lib.ILogger;
import ch.resear.thiriot.knime.bayesiannetworks.lib.bn.CategoricalBayesianNetwork;
import ch.resear.thiriot.knime.bayesiannetworks.lib.bn.NodeCategorical;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.Comparator;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Iterator;
import java.util.LinkedHashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.stream.Collectors;
import org.apache.commons.collections4.map.LRUMap;

/* loaded from: input_file:readbnfromxmlbif.jar:ch/resear/thiriot/knime/bayesiannetworks/lib/inference/EliminationInferenceEngine.class */
public class EliminationInferenceEngine extends AbstractInferenceEngine {
    private List<NodeCategorical> eliminationOrderForEvidence;
    private Map<NodeCategorical, Factor> factorsForEvidence;
    private LRUMap<Set<NodeCategorical>, Factor> cacheNode2factorForEvidence;

    public EliminationInferenceEngine(ILogger iLogger, RandomEngine randomEngine, CategoricalBayesianNetwork categoricalBayesianNetwork) {
        super(iLogger, randomEngine, categoricalBayesianNetwork);
        this.eliminationOrderForEvidence = null;
        this.factorsForEvidence = null;
        this.cacheNode2factorForEvidence = null;
        this.cacheNode2factorForEvidence = new LRUMap<>(categoricalBayesianNetwork.getNodes().size() * 100);
    }

    protected Set<Factor> getFactorsWithVariables(Set<NodeCategorical> set) {
        HashSet hashSet = new HashSet(set.size());
        Iterator<NodeCategorical> it = set.iterator();
        while (it.hasNext()) {
            hashSet.add(this.bn.getFactor(it.next()));
        }
        return hashSet;
    }

    protected List<NodeCategorical> getEliminationOrderOptimalForZero(Set<NodeCategorical> set) {
        ArrayList arrayList = new ArrayList(set);
        if (set.isEmpty()) {
            return arrayList;
        }
        Collections.sort(arrayList, new Comparator<NodeCategorical>() { // from class: ch.resear.thiriot.knime.bayesiannetworks.lib.inference.EliminationInferenceEngine.1
            @Override // java.util.Comparator
            public int compare(NodeCategorical nodeCategorical, NodeCategorical nodeCategorical2) {
                return nodeCategorical2.getCardinality() - nodeCategorical.getCardinality();
            }
        });
        return arrayList;
    }

    @Override // ch.resear.thiriot.knime.bayesiannetworks.lib.inference.AbstractInferenceEngine
    public Factor computeFactorPosteriorMarginals(Set<NodeCategorical> set) {
        Factor factor = (Factor) this.cacheNode2factorForEvidence.get(set);
        if (factor != null) {
            InferencePerformanceUtils.singleton.incCacheHit();
            return factor;
        }
        InferencePerformanceUtils.singleton.incCacheMiss();
        LinkedHashSet linkedHashSet = new LinkedHashSet(this.bn.enumerateNodes());
        linkedHashSet.retainAll(selectRelevantVariables(set, this.evidenceVariable2value, this.bn.getNodes()));
        Set<NodeCategorical> hashSet = new HashSet<>(linkedHashSet);
        linkedHashSet.removeAll(set);
        Factor computeFactorPosteriorMarginals = computeFactorPosteriorMarginals(hashSet, getEliminationOrderOptimalForZero(linkedHashSet));
        this.cacheNode2factorForEvidence.put(set, computeFactorPosteriorMarginals);
        return computeFactorPosteriorMarginals;
    }

    protected Factor computeFactorPosteriorMarginals(Set<NodeCategorical> set, List<NodeCategorical> list) {
        int i = 0;
        if (this.logger.isDebugEnabled()) {
            this.logger.debug("elimination on " + set + " with order " + list);
        }
        HashSet<NodeCategorical> hashSet = new HashSet(list);
        hashSet.addAll(set);
        Map map = (Map) hashSet.stream().collect(Collectors.toMap(nodeCategorical -> {
            return nodeCategorical;
        }, nodeCategorical2 -> {
            return this.bn.getFactor(nodeCategorical2).reduction(this.evidenceVariable2value);
        }));
        if (this.logger.isDebugEnabled()) {
            this.logger.debug("reduced: " + ((String) map.entrySet().stream().map(entry -> {
                return String.valueOf(((NodeCategorical) entry.getKey()).name) + ":" + ((Factor) entry.getValue()).toStringLong();
            }).collect(Collectors.joining(","))));
        }
        for (NodeCategorical nodeCategorical3 : list) {
            if (this.logger.isDebugEnabled()) {
                this.logger.debug("processing " + nodeCategorical3);
            }
            Factor factor = null;
            if (this.logger.isDebugEnabled()) {
                this.logger.debug("product for " + ((Object) null));
            }
            for (NodeCategorical nodeCategorical4 : hashSet) {
                if (map.containsKey(nodeCategorical4)) {
                    Factor factor2 = (Factor) map.get(nodeCategorical4);
                    if (factor2.contains(nodeCategorical3)) {
                        if (factor == null) {
                            factor = factor2;
                        } else {
                            if (this.logger.isDebugEnabled()) {
                                this.logger.debug("mult " + factor.toStringLong() + " X " + factor2.toStringLong());
                            }
                            factor = factor.multiply(factor2);
                            if (this.logger.isDebugEnabled()) {
                                this.logger.debug("=" + factor.toStringLong());
                            }
                        }
                        map.remove(nodeCategorical4);
                    }
                }
            }
            if (factor != null) {
                i = Math.max(i, factor.size());
                if (this.logger.isDebugEnabled()) {
                    this.logger.debug("sum " + nodeCategorical3.name + " for " + factor.toStringLong());
                }
                Factor sumOut = factor.sumOut(nodeCategorical3);
                if (this.logger.isDebugEnabled()) {
                    this.logger.debug("=" + sumOut.toStringLong());
                }
                map.put(nodeCategorical3, sumOut);
            }
        }
        Factor factor3 = new Factor(this.bn, Collections.emptySet());
        for (Factor factor4 : map.values()) {
            if (factor3 == null) {
                factor3 = factor4;
            } else {
                if (this.logger.isDebugEnabled()) {
                    this.logger.debug("mult " + factor3 + " X " + factor4.toStringLong());
                }
                factor3 = factor3.multiply(factor4);
                if (this.logger.isDebugEnabled()) {
                    this.logger.debug("= " + factor3.toStringLong());
                }
            }
        }
        if (this.logger.isDebugEnabled()) {
            this.logger.debug("perf: biggest CPT was " + i + " with order " + list);
        }
        return factor3;
    }

    protected Factor getFactorByEliminationFor(NodeCategorical nodeCategorical) {
        Factor factor = null;
        List list = (List) nodeCategorical.getAllAncestors().stream().map(nodeCategorical2 -> {
            return this.bn.getFactor(nodeCategorical2);
        }).collect(Collectors.toList());
        LinkedHashSet linkedHashSet = new LinkedHashSet((Collection) this.bn.enumerateNodes().stream().map(nodeCategorical3 -> {
            return this.bn.getFactor(nodeCategorical3);
        }).collect(Collectors.toList()));
        linkedHashSet.retainAll(list);
        HashSet<NodeCategorical> hashSet = new HashSet(this.bn.getNodes().size());
        Iterator it = linkedHashSet.iterator();
        while (it.hasNext()) {
            Factor factor2 = (Factor) it.next();
            if (this.logger.isDebugEnabled()) {
                this.logger.debug("eliminating variable " + factor2);
            }
            if (factor == null) {
                factor = factor2;
            } else {
                if (this.logger.isDebugEnabled()) {
                    this.logger.debug("multiply " + factor + " by " + factor2);
                }
                hashSet.addAll(factor.getVariables());
                factor = factor.multiply(factor2);
                for (NodeCategorical nodeCategorical4 : hashSet) {
                    if (this.logger.isDebugEnabled()) {
                        this.logger.debug("summing " + nodeCategorical4.name + " in " + factor);
                    }
                    factor = factor.sumOut(nodeCategorical4);
                }
                hashSet.clear();
            }
        }
        hashSet.addAll(factor.getVariables());
        hashSet.remove(nodeCategorical);
        for (NodeCategorical nodeCategorical5 : hashSet) {
            if (this.logger.isDebugEnabled()) {
                this.logger.debug("summing " + nodeCategorical5.name + " in " + factor);
            }
            factor = factor.sumOut(nodeCategorical5);
        }
        if (this.logger.isDebugEnabled()) {
            this.logger.debug("so factor is " + factor);
        }
        return factor;
    }

    /* JADX INFO: Access modifiers changed from: protected */
    @Override // ch.resear.thiriot.knime.bayesiannetworks.lib.inference.AbstractInferenceEngine
    public double retrieveConditionalProbability(NodeCategorical nodeCategorical, String str) {
        if (this.evidenceVariable2value.isEmpty()) {
            return nodeCategorical.getConditionalProbabilityPosterior(str);
        }
        try {
            if (this.logger.isDebugEnabled()) {
                this.logger.debug("computing conditional probability p(" + nodeCategorical.name + "=" + str + "|evidence)");
            }
            String str2 = this.evidenceVariable2value.get(nodeCategorical);
            if (str2 != null) {
                return str2.equals(str) ? 1.0d : 0.0d;
            }
            HashSet hashSet = new HashSet(1);
            hashSet.add(nodeCategorical);
            Factor computeFactorPosteriorMarginals = computeFactorPosteriorMarginals(hashSet);
            computeFactorPosteriorMarginals.normalize();
            if (computeFactorPosteriorMarginals.getVariables().isEmpty()) {
                this.logger.warn("going to fail here for " + nodeCategorical + "=" + str + "?");
            }
            return computeFactorPosteriorMarginals.get(nodeCategorical.name, str);
        } catch (IllegalArgumentException e) {
            e.printStackTrace();
            this.logger.warn("got exception while computing, checking if Pr(evidence)=0?");
            if (getProbabilityEvidence() == 0.0d) {
                throw new IllegalArgumentException("Pr(evidence)=0 with evidence=" + this.evidenceVariable2value + "; impossible to compute posterior probabilities");
            }
            throw new RuntimeException(e);
        }
    }

    /* JADX INFO: Access modifiers changed from: protected */
    @Override // ch.resear.thiriot.knime.bayesiannetworks.lib.inference.AbstractInferenceEngine
    public double[] retrieveConditionalProbability(NodeCategorical nodeCategorical) {
        if (this.evidenceVariable2value.containsKey(nodeCategorical)) {
            return getEvidenceAsDoubleArray(nodeCategorical);
        }
        HashSet hashSet = new HashSet(1);
        hashSet.add(nodeCategorical);
        Factor computeFactorPosteriorMarginals = computeFactorPosteriorMarginals(hashSet);
        computeFactorPosteriorMarginals.normalize();
        double[] dArr = new double[nodeCategorical.getDomainSize()];
        for (int i = 0; i < nodeCategorical.getDomainSize(); i++) {
            dArr[i] = computeFactorPosteriorMarginals.get(nodeCategorical.name, nodeCategorical.getValueIndexed(i));
        }
        return dArr;
    }

    @Override // ch.resear.thiriot.knime.bayesiannetworks.lib.inference.AbstractInferenceEngine
    public void compute() {
        this.eliminationOrderForEvidence = null;
        this.factorsForEvidence = null;
        this.cacheNode2factorForEvidence.clear();
        super.compute();
    }

    /* JADX INFO: Access modifiers changed from: protected */
    @Override // ch.resear.thiriot.knime.bayesiannetworks.lib.inference.AbstractInferenceEngine
    public double computeProbabilityEvidence() {
        if (this.evidenceVariable2value.isEmpty()) {
            return 1.0d;
        }
        try {
            return computeFactorPosteriorMarginals(Collections.emptySet()).getUniqueValue();
        } catch (RuntimeException unused) {
            return 1.0d;
        }
    }

    private List<NodeCategorical> getEliminationOrderForEvidence() {
        if (this.eliminationOrderForEvidence == null) {
            this.eliminationOrderForEvidence = (List) this.bn.enumerateNodes().stream().filter(nodeCategorical -> {
                return !this.evidenceVariable2value.containsKey(nodeCategorical);
            }).collect(Collectors.toList());
        }
        return this.eliminationOrderForEvidence;
    }

    private Map<NodeCategorical, Factor> getFactorsForEvidence() {
        if (this.factorsForEvidence == null) {
            this.factorsForEvidence = (Map) getEliminationOrderForEvidence().stream().collect(Collectors.toMap(nodeCategorical -> {
                return nodeCategorical;
            }, nodeCategorical2 -> {
                return this.bn.getFactor(nodeCategorical2).reduction(this.evidenceVariable2value);
            }));
        }
        return this.factorsForEvidence;
    }

    public final Map<NodeCategorical, String> sampleOneTODO() {
        HashMap hashMap = new HashMap(this.bn.getNodes().size());
        hashMap.putAll(this.evidenceVariable2value);
        List<NodeCategorical> eliminationOrderForEvidence = getEliminationOrderForEvidence();
        HashMap hashMap2 = new HashMap(getFactorsForEvidence());
        for (NodeCategorical nodeCategorical : eliminationOrderForEvidence) {
            Factor factor = (Factor) hashMap2.get(nodeCategorical);
            if (factor.getVariables().size() > 1) {
                throw new RuntimeException("wrong iteration order... factor " + factor + " for variable " + nodeCategorical + " should not have more than one variable anymore !");
            }
            String str = null;
            double nextDouble = this.rng.nextDouble();
            double d = 0.0d;
            Iterator<String> it = nodeCategorical.getDomain().iterator();
            while (true) {
                if (!it.hasNext()) {
                    break;
                }
                String next = it.next();
                d += factor.get(nodeCategorical.name, next);
                if (d >= nextDouble) {
                    str = next;
                    break;
                }
            }
            if (str == null) {
                throw new RuntimeException("oops, should have picked a value based on postererior probabilities, but they sum to " + d);
            }
            hashMap.put(nodeCategorical, str);
            for (int indexOf = eliminationOrderForEvidence.indexOf(nodeCategorical) + 1; indexOf < eliminationOrderForEvidence.size(); indexOf++) {
                NodeCategorical nodeCategorical2 = eliminationOrderForEvidence.get(indexOf);
                Factor factor2 = (Factor) hashMap2.get(nodeCategorical2);
                if (this.logger.isDebugEnabled()) {
                    this.logger.debug("reducing factor " + factor + " knowing " + nodeCategorical + "=" + str);
                }
                hashMap2.put(nodeCategorical2, factor2.reduction(nodeCategorical, str));
            }
        }
        return hashMap;
    }
}
