package ch.resear.thiriot.knime.bayesiannetworks.lib.bn;

import ch.resear.thiriot.knime.bayesiannetworks.lib.ILogger;
import ch.resear.thiriot.knime.bayesiannetworks.lib.LogIntoJavaLogger;
import ch.resear.thiriot.knime.bayesiannetworks.lib.inference.Factor;
import ch.resear.thiriot.knime.bayesiannetworks.lib.inference.InferencePerformanceUtils;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Iterator;
import java.util.LinkedList;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.TreeSet;
import java.util.regex.Pattern;
import java.util.stream.Collectors;
import org.apache.commons.lang3.StringEscapeUtils;

/* loaded from: input_file:readbnfromxmlbif.jar:ch/resear/thiriot/knime/bayesiannetworks/lib/bn/NodeCategorical.class */
public final class NodeCategorical extends FiniteNode<NodeCategorical> {
    private final ILogger logger;
    protected NodeCategorical[] parentsArray;
    protected Map<NodeCategorical, Integer> parent2index;
    private Integer countZeros;
    private double[] content;
    private int[] multipliers;
    protected final CategoricalBayesianNetwork cNetwork;
    private int cachedParentsCardinality;
    private static Pattern patternReplaceNonNumeric = Pattern.compile("[^a-zA-Z0-9]+");
    private static Pattern patternFirstCharNumeric = Pattern.compile("^[a-zA-Z].*");

    public static String getStrRepresentationOfDomain(List<String> list) {
        if (list.size() <= 20) {
            return list.toString();
        }
        ArrayList arrayList = new ArrayList(list.subList(0, 19));
        arrayList.add("...");
        return arrayList.toString();
    }

    public NodeCategorical(CategoricalBayesianNetwork categoricalBayesianNetwork, String str) {
        super(categoricalBayesianNetwork, str);
        this.parentsArray = new NodeCategorical[0];
        this.parent2index = new HashMap(50);
        this.countZeros = null;
        this.cachedParentsCardinality = -1;
        if (categoricalBayesianNetwork == null || categoricalBayesianNetwork.logger == null) {
            this.logger = LogIntoJavaLogger.getLogger((Class<?>) NodeCategorical.class);
        } else {
            this.logger = categoricalBayesianNetwork.logger;
        }
        this.cNetwork = categoricalBayesianNetwork;
    }

    public final double[] getContent() {
        return Arrays.copyOf(this.content, this.content.length);
    }

    public final CategoricalBayesianNetwork getNetwork() {
        return this.cNetwork;
    }

    public Integer getCountOfZeros() {
        if (this.countZeros == null) {
            computeCountOfZeros();
        }
        return this.countZeros;
    }

    private void computeCountOfZeros() {
        int i = 0;
        for (double d : this.content) {
            if (d == 0.0d) {
                i++;
            }
        }
        this.countZeros = Integer.valueOf(i);
    }

    protected int getParentsCardinality() {
        return this.parents.stream().mapToInt((v0) -> {
            return v0.getDomainSize();
        }).reduce(1, Math::multiplyExact);
    }

    @Override // ch.resear.thiriot.knime.bayesiannetworks.lib.bn.FiniteNode
    public void addDomain(String str) {
        super.addDomain(str);
        this.cachedParentsCardinality = -1;
    }

    public int getCardinality() {
        if (this.cachedParentsCardinality < 0) {
            this.cachedParentsCardinality = this.domain.size() * getParentsCardinality();
        }
        return this.cachedParentsCardinality;
    }

    @Override // ch.resear.thiriot.knime.bayesiannetworks.lib.bn.AbstractNode
    public void addParent(NodeCategorical nodeCategorical) {
        super.addParent(nodeCategorical);
        this.parentsArray = (NodeCategorical[]) Arrays.copyOf(this.parentsArray, this.parentsArray.length + 1);
        this.parentsArray[this.parentsArray.length - 1] = nodeCategorical;
        adaptContentSize();
        this.parent2index.put(nodeCategorical, Integer.valueOf(this.parent2index.size()));
        this.cachedParentsCardinality = -1;
    }

    protected final int _getIndex(String str, Object... objArr) {
        return _getIndex(getDomainIndex(str), _getParentIndices(objArr));
    }

    protected final int _getIndex(String str, Map<NodeCategorical, String> map) {
        return _getIndex(getDomainIndex(str), _getParentIndices(map));
    }

    @Override // ch.resear.thiriot.knime.bayesiannetworks.lib.bn.FiniteNode
    protected void adaptContentSize() {
        this.content = new double[getCardinality()];
        this.multipliers = new int[this.parents.size()];
        for (int i = 0; i < this.parentsArray.length; i++) {
            int domainSize = getDomainSize();
            for (int i2 = i + 1; i2 < this.parentsArray.length; i2++) {
                domainSize *= this.parentsArray[i2].getDomainSize();
            }
            this.multipliers[i] = domainSize;
        }
    }

    protected int _getIndex(int i, int... iArr) {
        int i2 = 0 + i;
        for (int i3 = 0; i3 < iArr.length; i3++) {
            i2 += iArr[i3] * this.multipliers[i3];
        }
        return i2;
    }

    public void setProbabilities(double d, String str, Object... objArr) {
        this.countZeros = null;
        this.content[_getIndex(str, objArr)] = d;
    }

    public void setProbabilities(double d, String str, Map<NodeCategorical, String> map) {
        this.countZeros = null;
        this.content[_getIndex(str, map)] = d;
    }

    public void setProbabilities(double[] dArr) {
        if (dArr.length != getParentsCardinality() * getDomainSize()) {
            throw new IllegalArgumentException("wrong size for the content");
        }
        this.countZeros = null;
        this.content = dArr;
    }

    public double getProbability(String str, Object... objArr) {
        return this.content[_getIndex(str, objArr)];
    }

    public double getProbability(String str, Map<NodeCategorical, String> map) {
        return this.content[_getIndex(str, map)];
    }

    protected double getProbability(int i, int[] iArr) {
        return this.content[_getIndex(i, iArr)];
    }

    protected int[] _getParentIndices(Map<NodeCategorical, String> map) {
        if (map.size() != this.parents.size() || !map.keySet().containsAll(this.parents)) {
            throw new IllegalArgumentException("expecting all the parents values to be defined");
        }
        int[] iArr = new int[this.parents.size()];
        for (int i = 0; i < iArr.length; i++) {
            iArr[i] = this.parentsArray[i].getDomainIndex(map.get(this.parentsArray[i]));
        }
        return iArr;
    }

    protected final int[] _getParentIndices(Object... objArr) {
        Object parent;
        if (objArr.length % 2 != 0) {
            throw new IllegalArgumentException("expecting a list of parameters such as gender, male, age, 0-15");
        }
        if (objArr.length / 2 != this.parents.size()) {
            throw new IllegalArgumentException("not enough parameters");
        }
        int[] iArr = new int[objArr.length / 2];
        for (int i = 0; i < objArr.length; i += 2) {
            Object obj = objArr[i];
            if (obj instanceof NodeCategorical) {
                parent = obj;
            } else {
                if (!(obj instanceof String)) {
                    throw new IllegalArgumentException("unable to find parent " + obj);
                }
                parent = getParent((String) obj);
            }
            NodeCategorical nodeCategorical = (NodeCategorical) parent;
            iArr[this.parent2index.get(nodeCategorical).intValue()] = nodeCategorical.getDomainIndex((String) objArr[i + 1]);
        }
        return iArr;
    }

    public final double getSum() {
        double d = 0.0d;
        for (double d2 : this.content) {
            d += d2;
        }
        return d;
    }

    public int getParentsDimensionality() {
        return this.parents.stream().mapToInt((v0) -> {
            return v0.getDomainSize();
        }).reduce(1, Math::multiplyExact);
    }

    public void normalize() {
        IteratorCategoricalVariables iterateDomains = this.cNetwork.iterateDomains(getParents());
        while (iterateDomains.hasNext()) {
            Map<NodeCategorical, String> next = iterateDomains.next();
            double d = 0.0d;
            Iterator<String> it = this.domain.iterator();
            while (it.hasNext()) {
                d += getProbability(it.next(), next);
            }
            if (d == 0.0d) {
                System.out.println("equiprobability for p(" + this.name + "|" + ((String) next.entrySet().stream().map(entry -> {
                    return String.valueOf(((NodeCategorical) entry.getKey()).getName()) + "=" + ((String) entry.getValue());
                }).collect(Collectors.joining(","))) + ")");
                double size = 1.0d / this.domain.size();
                Iterator<String> it2 = this.domain.iterator();
                while (it2.hasNext()) {
                    setProbabilities(size, it2.next(), next);
                }
            } else if (Math.abs(d - 1.0d) > 1.0E-7d) {
                System.out.println("normalizing p(" + this.name + "|" + ((String) next.entrySet().stream().map(entry2 -> {
                    return String.valueOf(((NodeCategorical) entry2.getKey()).getName()) + "=" + ((String) entry2.getValue());
                }).collect(Collectors.joining(","))) + ")");
                for (String str : this.domain) {
                    setProbabilities(getProbability(str, next) / d, str, next);
                }
            }
        }
        if (!isValid()) {
            throw new RuntimeException("the node is not valid after normalization: " + collectInvalidityReasons());
        }
    }

    public double getConditionalProbability(String str) {
        int domainIndex = getDomainIndex(str);
        double d = 0.0d;
        int parentsCardinality = getParentsCardinality();
        for (int i = 0; i < parentsCardinality; i++) {
            int[] iArr = new int[this.parents.size()];
            int size = this.parents.size() - 1;
            while (size > -1 && iArr[size] >= this.parentsArray[size].getDomainSize()) {
                iArr[size] = 0;
                size--;
            }
            if (size > -1) {
                int i2 = size;
                iArr[i2] = iArr[i2] + 1;
            }
            d += getProbability(domainIndex, iArr);
        }
        InferencePerformanceUtils.singleton.incAdditions(parentsCardinality);
        return d;
    }

    public double getConditionalProbabilityPosterior(String str, Map<NodeCategorical, String> map, Map<NodeCategorical, Map<String, Double>> map2) {
        return getConditionalProbabilityPosterior(str, map, map2, Collections.emptyMap());
    }

    public double getConditionalProbabilityPosterior(String str, Map<NodeCategorical, String> map, Map<NodeCategorical, Map<String, Double>> map2, Map<NodeCategorical, String> map3) {
        Map<String, Double> map4;
        Double d;
        if (this.logger.isDebugEnabled()) {
            this.logger.debug("computing posteriors for p(" + this.name + "=" + str + "|" + map + ")");
            this.logger.debug("alreadyComputed: " + map2);
        }
        if (map.containsKey(this)) {
            if (map.get(this).equals(str)) {
                if (!this.logger.isDebugEnabled()) {
                    return 1.0d;
                }
                this.logger.debug("from evidence, posteriors p(" + this.name + "=" + str + ")=1.0");
                return 1.0d;
            }
            if (!this.logger.isDebugEnabled()) {
                return 0.0d;
            }
            this.logger.debug("from evidence, posteriors p(" + this.name + "=" + str + ")=0.0");
            return 0.0d;
        }
        if (map2 != null && (map4 = map2.get(this)) != null && (d = map4.get(str)) != null) {
            return d.doubleValue();
        }
        if (!hasParents()) {
            if (this.logger.isDebugEnabled()) {
                this.logger.debug("no parent, returning internal probability");
            }
            return getConditionalProbability(str);
        }
        int domainIndex = getDomainIndex(str);
        double d2 = 0.0d;
        int[] iArr = new int[this.parents.size()];
        int i = 1;
        for (int i2 = 0; i2 < this.parents.size(); i2++) {
            if (map3.containsKey(this.parentsArray[i2])) {
                iArr[i2] = this.parentsArray[i2].getDomainIndex(map3.get(this.parentsArray[i2]));
            } else {
                iArr[i2] = this.parentsArray[i2].getDomainSize() - 1;
                i *= this.parentsArray[i2].getDomainSize();
            }
        }
        int size = this.parents.size() - 1;
        for (int i3 = 0; i3 < i; i3++) {
            if (this.logger.isDebugEnabled()) {
                this.logger.debug("now cursor parents " + size + " idxParents " + iArr);
                this.logger.debug("adding to probability p(" + this.name + "=" + str + "|*) from parents " + this.parents.stream().collect(Collectors.toMap((v0) -> {
                    return v0.getName();
                }, nodeCategorical -> {
                    return nodeCategorical.getValueIndexed(iArr[Arrays.asList(this.parentsArray).indexOf(nodeCategorical)]);
                })));
            }
            double probability = getProbability(domainIndex, iArr);
            double d3 = 1.0d;
            Iterator it = this.parents.iterator();
            while (true) {
                if (!it.hasNext()) {
                    break;
                }
                NodeCategorical nodeCategorical2 = (NodeCategorical) it.next();
                String valueIndexed = nodeCategorical2.getValueIndexed(iArr[Arrays.asList(this.parentsArray).indexOf(nodeCategorical2)]);
                if (this.logger.isDebugEnabled()) {
                    this.logger.debug("computing posteriors for parent p(" + nodeCategorical2.name + "=" + valueIndexed + ")");
                }
                double conditionalProbabilityPosterior = nodeCategorical2.getConditionalProbabilityPosterior(valueIndexed, map, map2);
                d3 *= conditionalProbabilityPosterior;
                if (this.logger.isDebugEnabled()) {
                    this.logger.debug("cumulated * " + conditionalProbabilityPosterior + " = " + d3);
                }
                if (d3 == 0.0d) {
                    if (this.logger.isDebugEnabled()) {
                        this.logger.debug("reached p=0, stopping there");
                    }
                }
            }
            d2 += probability * d3;
            if (this.logger.isDebugEnabled()) {
                this.logger.debug("the probability p(" + this.name + "=" + str + "|*) is now after addition " + d2);
            }
            if (this.logger.isDebugEnabled()) {
                this.logger.debug("initial cursor parents " + size + " idxParents " + iArr);
            }
            for (int i4 = 0; i4 < iArr.length; i4++) {
                if (!map3.containsKey(this.parentsArray[i4])) {
                    int i5 = i4;
                    iArr[i5] = iArr[i5] - 1;
                    if (iArr[i4] >= 0) {
                        break;
                    }
                    iArr[i4] = this.parentsArray[i4].getDomainSize() - 1;
                }
            }
            if (iArr[0] < 0) {
                break;
            }
        }
        if (this.logger.isDebugEnabled()) {
            this.logger.debug("computed posteriors for p(" + this.name + "=" + str + "|" + map + ")=" + d2);
        }
        Map<String, Double> map5 = map2.get(this);
        if (map5 == null) {
            map5 = new HashMap();
            try {
                map2.put(this, map5);
            } catch (UnsupportedOperationException unused) {
            }
        }
        map5.put(str, Double.valueOf(d2));
        return d2;
    }

    public double getConditionalProbabilityPosterior(String str) {
        return getConditionalProbabilityPosterior(str, Collections.emptyMap(), Collections.emptyMap(), Collections.emptyMap());
    }

    public double getPosterior(String str, Object... objArr) {
        return this.content[_getIndex(str, objArr)];
    }

    @Override // ch.resear.thiriot.knime.bayesiannetworks.lib.bn.AbstractNode
    public boolean isValid() {
        for (int i = 0; i < this.content.length; i++) {
            if (this.content[i] < 0.0d || this.content[i] > 1.0d) {
                return false;
            }
        }
        return ((int) Math.round(getSum())) == getParentsDimensionality();
    }

    @Override // ch.resear.thiriot.knime.bayesiannetworks.lib.bn.AbstractNode
    public List<String> collectInvalidityReasons() {
        TreeSet treeSet = new TreeSet();
        for (int i = 0; i < this.content.length; i++) {
            if (this.content[i] < 0.0d || this.content[i] > 1.0d) {
                treeSet.add("there is a value not in [0:1]: " + this.content[i]);
            }
        }
        double sum = getSum();
        if (((int) Math.round(sum)) != getParentsCardinality()) {
            treeSet.add("the sum is " + sum + " instead of " + getParentsDimensionality());
        }
        if (treeSet.isEmpty()) {
            return null;
        }
        return new LinkedList(treeSet);
    }

    @Override // ch.resear.thiriot.knime.bayesiannetworks.lib.bn.AbstractNode
    public void toXMLBIF(StringBuffer stringBuffer) {
        stringBuffer.append("<VARIABLE TYPE=\"").append("nature").append("\">\n");
        stringBuffer.append("\t<NAME>").append(StringEscapeUtils.escapeXml10(getName())).append("</NAME>\n");
        Iterator<String> it = this.domain.iterator();
        while (it.hasNext()) {
            stringBuffer.append("\t<OUTCOME>").append(StringEscapeUtils.escapeXml10(it.next())).append("</OUTCOME>\n");
        }
        stringBuffer.append("</VARIABLE>\n");
        stringBuffer.append("\n");
        stringBuffer.append("<DEFINITION>\n");
        stringBuffer.append("\t<FOR>").append(StringEscapeUtils.escapeXml10(getName())).append("</FOR>\n");
        for (NodeCategorical nodeCategorical : this.parentsArray) {
            stringBuffer.append("\t<GIVEN>").append(StringEscapeUtils.escapeXml10(nodeCategorical.getName())).append("</GIVEN>\n");
        }
        stringBuffer.append("\t<TABLE>");
        for (double d : this.content) {
            stringBuffer.append(Double.valueOf(d).toString()).append(" ");
        }
        stringBuffer.append("</TABLE>\n");
        stringBuffer.append("</DEFINITION>\n");
    }

    public static String convertDomainValueToNormalizedIdentifier(String str) {
        String replaceAll = patternReplaceNonNumeric.matcher(str).replaceAll("_");
        if (!patternFirstCharNumeric.matcher(replaceAll).matches()) {
            replaceAll = "z" + replaceAll;
        }
        return replaceAll;
    }

    @Override // ch.resear.thiriot.knime.bayesiannetworks.lib.bn.AbstractNode
    public void toBIF(StringBuffer stringBuffer) {
        stringBuffer.append("variable ").append(convertDomainValueToNormalizedIdentifier(this.name)).append(" {\n");
        stringBuffer.append("   type discrete [ ").append(this.domain.size()).append(" ] { ").append((String) this.domain.stream().map(str -> {
            return convertDomainValueToNormalizedIdentifier(str);
        }).collect(Collectors.joining(", "))).append(" };\n");
        stringBuffer.append("}\n");
        stringBuffer.append("probability ( ").append(convertDomainValueToNormalizedIdentifier(this.name));
        if (!this.parents.isEmpty()) {
            stringBuffer.append(" | ").append((String) this.parents.stream().map(nodeCategorical -> {
                return convertDomainValueToNormalizedIdentifier(nodeCategorical.getName());
            }).collect(Collectors.joining(", ")));
        }
        stringBuffer.append(" ) {\n");
        if (this.parents.isEmpty()) {
            stringBuffer.append("   table ").append((String) Arrays.stream(this.content).mapToObj(d -> {
                return Double.toString(d);
            }).collect(Collectors.joining(", "))).append(";\n");
        } else {
            IteratorCategoricalVariables iterateDomains = this.cNetwork.iterateDomains(getParents());
            while (iterateDomains.hasNext()) {
                Map<NodeCategorical, String> next = iterateDomains.next();
                stringBuffer.append(" (").append((String) next.values().stream().map(str2 -> {
                    return convertDomainValueToNormalizedIdentifier(str2);
                }).collect(Collectors.joining(", "))).append(") ");
                stringBuffer.append((String) this.domain.stream().map(str3 -> {
                    return Double.toString(getProbability(str3, (Map<NodeCategorical, String>) next));
                }).collect(Collectors.joining(", ")));
                stringBuffer.append(";\n");
            }
        }
        stringBuffer.append("}\n");
    }

    @Override // ch.resear.thiriot.knime.bayesiannetworks.lib.bn.AbstractNode
    public void toNet(StringBuffer stringBuffer) {
        stringBuffer.append("node ").append(convertDomainValueToNormalizedIdentifier(this.name)).append("\n{\n");
        stringBuffer.append("  states = ( ").append((String) this.domain.stream().map(str -> {
            return String.valueOf('\"') + convertDomainValueToNormalizedIdentifier(str) + '\"';
        }).collect(Collectors.joining(" "))).append(" );\n");
        stringBuffer.append("}\n");
        ArrayList<NodeCategorical> arrayList = new ArrayList(getNetwork().enumerateNodes());
        arrayList.retainAll(this.parents);
        stringBuffer.append("potential ( ").append(convertDomainValueToNormalizedIdentifier(this.name));
        if (!this.parents.isEmpty()) {
            stringBuffer.append(" | ").append((String) arrayList.stream().map(nodeCategorical -> {
                return convertDomainValueToNormalizedIdentifier(nodeCategorical.getName());
            }).collect(Collectors.joining(" ")));
        }
        stringBuffer.append(" )\n{\n");
        stringBuffer.append("  data = ( ");
        if (this.parents.isEmpty()) {
            stringBuffer.append((String) Arrays.stream(this.content).mapToObj(d -> {
                return Double.toString(d);
            }).collect(Collectors.joining(" ")));
        } else {
            IteratorCategoricalVariables iterateDomains = this.cNetwork.iterateDomains(new ArrayList(arrayList));
            Map<NodeCategorical, String> map = null;
            for (NodeCategorical nodeCategorical2 : arrayList) {
                stringBuffer.append("(");
            }
            while (iterateDomains.hasNext()) {
                Map<NodeCategorical, String> next = iterateDomains.next();
                if (map != null) {
                    for (NodeCategorical nodeCategorical3 : arrayList) {
                        if (!map.get(nodeCategorical3).equals(next.get(nodeCategorical3))) {
                            stringBuffer.append(")");
                        }
                    }
                    for (NodeCategorical nodeCategorical4 : arrayList) {
                        if (!map.get(nodeCategorical4).equals(next.get(nodeCategorical4))) {
                            stringBuffer.append("(");
                        }
                    }
                }
                map = next;
                stringBuffer.append((String) this.domain.stream().map(str2 -> {
                    return Double.toString(getProbability(str2, (Map<NodeCategorical, String>) next));
                }).collect(Collectors.joining(" ")));
            }
            for (NodeCategorical nodeCategorical5 : arrayList) {
                stringBuffer.append(")");
            }
        }
        stringBuffer.append(" );\n");
        stringBuffer.append("}\n");
    }

    public Factor asFactor() {
        HashSet hashSet = new HashSet(this.parents);
        hashSet.add(this);
        Factor factor = new Factor(this.cNetwork, hashSet);
        IteratorCategoricalVariables iterateDomains = this.cNetwork.iterateDomains(this.parents);
        while (iterateDomains.hasNext()) {
            Map<NodeCategorical, String> next = iterateDomains.next();
            for (String str : this.domain) {
                double probability = getProbability(str, next);
                HashMap hashMap = new HashMap(next);
                hashMap.put(this, str);
                factor.setFactor(hashMap, probability);
            }
        }
        return factor;
    }

    public Set<NodeCategorical> family() {
        HashSet hashSet = new HashSet(getParents());
        hashSet.add(this);
        return hashSet;
    }

    @Override // ch.resear.thiriot.knime.bayesiannetworks.lib.bn.AbstractNode
    public String toString() {
        StringBuffer stringBuffer = new StringBuffer();
        stringBuffer.append("p( ");
        stringBuffer.append(getName());
        if (hasParents()) {
            stringBuffer.append(" | ");
        }
        boolean z = true;
        for (N n : getParents()) {
            if (z) {
                z = false;
            } else {
                stringBuffer.append(", ");
            }
            stringBuffer.append(n.getName());
        }
        stringBuffer.append(" )");
        return stringBuffer.toString();
    }

    public String toStringComplete() {
        StringBuffer stringBuffer = new StringBuffer();
        toStringComplete(stringBuffer);
        return stringBuffer.toString();
    }

    public void toStringComplete(StringBuffer stringBuffer) {
        for (String str : getDomain()) {
            IteratorCategoricalVariables iterateDomains = this.cNetwork.iterateDomains(getParents());
            while (iterateDomains.hasNext()) {
                Map<NodeCategorical, String> next = iterateDomains.next();
                stringBuffer.append("p( ").append(getName()).append("=").append(str);
                if (hasParents()) {
                    stringBuffer.append(" | ");
                    stringBuffer.append((String) next.entrySet().stream().map(entry -> {
                        return String.valueOf(((NodeCategorical) entry.getKey()).name) + "=" + ((String) entry.getValue());
                    }).collect(Collectors.joining(", ")));
                }
                stringBuffer.append(" ) = ").append(getProbability(str, next)).append("\n");
            }
        }
    }
}
