package ch.resear.thiriot.knime.bayesiannetworks.enumerate;

import ch.resear.thiriot.knime.bayesiannetworks.DataTableToBNMapper;
import ch.resear.thiriot.knime.bayesiannetworks.LogIntoNodeLogger;
import ch.resear.thiriot.knime.bayesiannetworks.lib.ILogger;
import ch.resear.thiriot.knime.bayesiannetworks.lib.bn.CategoricalBayesianNetwork;
import ch.resear.thiriot.knime.bayesiannetworks.lib.bn.IteratorCategoricalVariables;
import ch.resear.thiriot.knime.bayesiannetworks.lib.bn.NodeCategorical;
import ch.resear.thiriot.knime.bayesiannetworks.port.BayesianNetworkPortObject;
import java.io.File;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Comparator;
import java.util.HashMap;
import java.util.Iterator;
import java.util.LinkedList;
import java.util.List;
import java.util.Map;
import java.util.stream.Collectors;
import org.apache.commons.math3.stat.descriptive.rank.Median;
import org.knime.core.data.DataCell;
import org.knime.core.data.DataColumnSpec;
import org.knime.core.data.DataColumnSpecCreator;
import org.knime.core.data.DataTableSpec;
import org.knime.core.data.RowKey;
import org.knime.core.data.def.DefaultRow;
import org.knime.core.data.def.DoubleCell;
import org.knime.core.node.BufferedDataContainer;
import org.knime.core.node.BufferedDataTable;
import org.knime.core.node.CanceledExecutionException;
import org.knime.core.node.ExecutionContext;
import org.knime.core.node.ExecutionMonitor;
import org.knime.core.node.InvalidSettingsException;
import org.knime.core.node.NodeLogger;
import org.knime.core.node.NodeModel;
import org.knime.core.node.NodeSettingsRO;
import org.knime.core.node.NodeSettingsWO;
import org.knime.core.node.defaultnodesettings.SettingsModelBoolean;
import org.knime.core.node.defaultnodesettings.SettingsModelDoubleBounded;
import org.knime.core.node.port.PortObject;
import org.knime.core.node.port.PortObjectSpec;
import org.knime.core.node.port.PortType;

/* loaded from: input_file:readbnfromxmlbif.jar:ch/resear/thiriot/knime/bayesiannetworks/enumerate/EnumerateBNNodeModel.class */
public class EnumerateBNNodeModel extends NodeModel {
    private static final NodeLogger logger = NodeLogger.getLogger(EnumerateBNNodeModel.class);
    private static final ILogger ilogger = new LogIntoNodeLogger(logger);
    private final SettingsModelBoolean m_skipNull;
    private final SettingsModelBoolean m_skipOnEpsilon;
    private final SettingsModelDoubleBounded m_skipEpsilon;
    private Map<NodeCategorical, DataTableToBNMapper> node2mapper;

    /* JADX INFO: Access modifiers changed from: protected */
    public EnumerateBNNodeModel() {
        super(new PortType[]{BayesianNetworkPortObject.TYPE}, new PortType[]{BufferedDataTable.TYPE});
        this.m_skipNull = new SettingsModelBoolean("skip_null", true);
        this.m_skipOnEpsilon = new SettingsModelBoolean("skip_on_epsilon", true);
        this.m_skipEpsilon = new SettingsModelDoubleBounded("skip_epsilon", 1.0E-6d, 0.0d, 1.0d);
        this.node2mapper = new HashMap();
    }

    protected DataColumnSpec[] createSpecsForBN(CategoricalBayesianNetwork categoricalBayesianNetwork) {
        this.node2mapper.clear();
        this.node2mapper.putAll(DataTableToBNMapper.createMapper(categoricalBayesianNetwork, ilogger));
        LinkedList linkedList = new LinkedList();
        Iterator<NodeCategorical> it = categoricalBayesianNetwork.enumerateNodes().iterator();
        while (it.hasNext()) {
            linkedList.add(this.node2mapper.get(it.next()).getSpecForNode());
        }
        linkedList.add(new DataColumnSpecCreator("probability", DoubleCell.TYPE).createSpec());
        return (DataColumnSpec[]) linkedList.toArray(new DataColumnSpec[linkedList.size()]);
    }

    protected PortObjectSpec[] configure(PortObjectSpec[] portObjectSpecArr) throws InvalidSettingsException {
        return new DataTableSpec[1];
    }

    protected double computeProbability(List<NodeCategorical> list, Map<NodeCategorical, String> map) {
        double d = 1.0d;
        for (NodeCategorical nodeCategorical : list) {
            HashMap hashMap = new HashMap(map);
            hashMap.keySet().retainAll(nodeCategorical.getParents());
            double probability = nodeCategorical.getProbability(map.get(nodeCategorical), hashMap);
            if (probability == 0.0d) {
                return 0.0d;
            }
            d *= probability;
        }
        return d;
    }

    protected double computeProbabilityPruned(List<NodeCategorical> list, Map<NodeCategorical, String> map, double d) {
        double d2 = 1.0d;
        for (NodeCategorical nodeCategorical : list) {
            HashMap hashMap = new HashMap(map);
            hashMap.keySet().retainAll(nodeCategorical.getParents());
            double probability = nodeCategorical.getProbability(map.get(nodeCategorical), hashMap);
            if (probability == 0.0d) {
                return 0.0d;
            }
            d2 *= probability;
            if (d2 < d) {
                return -1.0d;
            }
        }
        return d2;
    }

    protected PortObject[] execute(PortObject[] portObjectArr, ExecutionContext executionContext) throws Exception {
        if (portObjectArr.length == 0) {
            throw new IllegalArgumentException("No Bayesian network found as input");
        }
        if (portObjectArr.length > 1) {
            throw new IllegalArgumentException("Only one Bayesian network expected as input");
        }
        try {
            CategoricalBayesianNetwork bn = ((BayesianNetworkPortObject) portObjectArr[0]).getBN();
            double doubleValue = this.m_skipEpsilon.getDoubleValue();
            boolean z = this.m_skipOnEpsilon.getBooleanValue() && doubleValue > 0.0d;
            boolean z2 = this.m_skipNull.getBooleanValue() || (this.m_skipOnEpsilon.getBooleanValue() && doubleValue == 0.0d);
            executionContext.setMessage("preparing the output table");
            BufferedDataContainer createDataContainer = executionContext.createDataContainer(new DataTableSpec(createSpecsForBN(bn)));
            executionContext.checkCanceled();
            executionContext.setProgress(0.0d, "generating rows");
            IteratorCategoricalVariables iterateDomains = bn.iterateDomains();
            ArrayList arrayList = new ArrayList(bn.enumerateNodes());
            final Map map = (Map) arrayList.stream().collect(Collectors.toMap(nodeCategorical -> {
                return nodeCategorical;
            }, nodeCategorical2 -> {
                return Double.valueOf(nodeCategorical2.getCountOfZeros().intValue() / nodeCategorical2.getCardinality());
            }));
            if (z) {
                Median median = new Median();
                final Map map2 = (Map) arrayList.stream().collect(Collectors.toMap(nodeCategorical3 -> {
                    return nodeCategorical3;
                }, nodeCategorical4 -> {
                    return Double.valueOf(median.evaluate(nodeCategorical4.getContent()));
                }));
                arrayList.sort(new Comparator<NodeCategorical>() { // from class: ch.resear.thiriot.knime.bayesiannetworks.enumerate.EnumerateBNNodeModel.1
                    @Override // java.util.Comparator
                    public int compare(NodeCategorical nodeCategorical5, NodeCategorical nodeCategorical6) {
                        int i = -((Double) map.get(nodeCategorical5)).compareTo((Double) map.get(nodeCategorical6));
                        if (i == 0) {
                            i = ((Double) map2.get(nodeCategorical5)).compareTo((Double) map2.get(nodeCategorical6));
                        }
                        if (i == 0) {
                            i = nodeCategorical6.getCardinality() - nodeCategorical5.getCardinality();
                        }
                        return i;
                    }
                });
                System.out.println("nodes sorted according to median:\n" + ((String) arrayList.stream().map(nodeCategorical5 -> {
                    return String.valueOf(nodeCategorical5.getName()) + ": median " + map2.get(nodeCategorical5) + ", " + nodeCategorical5.getCountOfZeros() + " zeros, " + Integer.toString(nodeCategorical5.getCardinality()) + " values";
                }).collect(Collectors.joining("\n"))));
            } else {
                arrayList.sort(new Comparator<NodeCategorical>() { // from class: ch.resear.thiriot.knime.bayesiannetworks.enumerate.EnumerateBNNodeModel.2
                    @Override // java.util.Comparator
                    public int compare(NodeCategorical nodeCategorical6, NodeCategorical nodeCategorical7) {
                        int i = -((Double) map.get(nodeCategorical6)).compareTo((Double) map.get(nodeCategorical7));
                        if (i == 0) {
                            i = nodeCategorical6.getCardinality() - nodeCategorical7.getCardinality();
                        }
                        return i;
                    }
                });
                System.out.println("nodes sorted according to 0:\n" + ((String) arrayList.stream().map(nodeCategorical6 -> {
                    return String.valueOf(nodeCategorical6.getName()) + ": " + nodeCategorical6.getCountOfZeros() + " zeros, " + Integer.toString(nodeCategorical6.getCardinality()) + " values";
                }).collect(Collectors.joining("\n"))));
            }
            long j = 1;
            while (bn.enumerateNodes().iterator().hasNext()) {
                j += j * r0.next().getDomainSize();
            }
            System.out.println("worse total expected " + j);
            long j2 = 0;
            long j3 = 0;
            while (iterateDomains.hasNext()) {
                j2++;
                executionContext.setProgress(j2 / j, "Exploring combination " + j2 + ", " + j3 + " rows created");
                executionContext.checkCanceled();
                Map<NodeCategorical, String> next = iterateDomains.next();
                double computeProbabilityPruned = z ? computeProbabilityPruned(arrayList, next, doubleValue) : computeProbability(arrayList, next);
                if (!z || computeProbabilityPruned >= doubleValue) {
                    if (!z2 || computeProbabilityPruned != 0.0d) {
                        DataCell[] dataCellArr = new DataCell[next.size() + 1];
                        int i = 0;
                        for (NodeCategorical nodeCategorical7 : bn.enumerateNodes()) {
                            int i2 = i;
                            i++;
                            dataCellArr[i2] = this.node2mapper.get(nodeCategorical7).createCellForStringValue(next.get(nodeCategorical7));
                        }
                        dataCellArr[i] = DoubleCell.DoubleCellFactory.create(computeProbabilityPruned);
                        createDataContainer.addRowToTable(new DefaultRow(new RowKey("Row " + j2), dataCellArr));
                        j3++;
                    }
                }
            }
            executionContext.setProgress(100.0d, "closing outputs");
            createDataContainer.close();
            return new BufferedDataTable[]{createDataContainer.getTable()};
        } catch (ClassCastException e) {
            throw new IllegalArgumentException("The input should be a Bayesian network", e);
        }
    }

    protected void reset() {
    }

    protected void saveSettingsTo(NodeSettingsWO nodeSettingsWO) {
        this.m_skipNull.saveSettingsTo(nodeSettingsWO);
        this.m_skipEpsilon.saveSettingsTo(nodeSettingsWO);
        this.m_skipOnEpsilon.saveSettingsTo(nodeSettingsWO);
    }

    protected void loadValidatedSettingsFrom(NodeSettingsRO nodeSettingsRO) throws InvalidSettingsException {
        this.m_skipNull.loadSettingsFrom(nodeSettingsRO);
        this.m_skipEpsilon.loadSettingsFrom(nodeSettingsRO);
        this.m_skipOnEpsilon.loadSettingsFrom(nodeSettingsRO);
    }

    protected void validateSettings(NodeSettingsRO nodeSettingsRO) throws InvalidSettingsException {
        this.m_skipNull.validateSettings(nodeSettingsRO);
        this.m_skipEpsilon.validateSettings(nodeSettingsRO);
        this.m_skipOnEpsilon.validateSettings(nodeSettingsRO);
    }

    protected void loadInternals(File file, ExecutionMonitor executionMonitor) throws IOException, CanceledExecutionException {
    }

    protected void saveInternals(File file, ExecutionMonitor executionMonitor) throws IOException, CanceledExecutionException {
    }
}
