package ch.resear.thiriot.knime.bayesiannetworks.learn;

import ch.resear.thiriot.knime.bayesiannetworks.DataTableToBNMapper;
import ch.resear.thiriot.knime.bayesiannetworks.LogIntoNodeLogger;
import ch.resear.thiriot.knime.bayesiannetworks.lib.ILogger;
import ch.resear.thiriot.knime.bayesiannetworks.lib.bn.CategoricalBayesianNetwork;
import ch.resear.thiriot.knime.bayesiannetworks.lib.bn.IteratorCategoricalVariables;
import ch.resear.thiriot.knime.bayesiannetworks.lib.bn.NodeCategorical;
import ch.resear.thiriot.knime.bayesiannetworks.port.BayesianNetworkPortObject;
import ch.resear.thiriot.knime.bayesiannetworks.port.BayesianNetworkPortSpec;
import java.io.File;
import java.io.IOException;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Iterator;
import java.util.Map;
import java.util.function.Function;
import java.util.stream.Collectors;
import org.knime.core.data.DataRow;
import org.knime.core.data.container.CloseableRowIterator;
import org.knime.core.node.BufferedDataTable;
import org.knime.core.node.CanceledExecutionException;
import org.knime.core.node.ExecutionContext;
import org.knime.core.node.ExecutionMonitor;
import org.knime.core.node.InvalidSettingsException;
import org.knime.core.node.NodeLogger;
import org.knime.core.node.NodeModel;
import org.knime.core.node.NodeSettingsRO;
import org.knime.core.node.NodeSettingsWO;
import org.knime.core.node.defaultnodesettings.SettingsModelBoolean;
import org.knime.core.node.defaultnodesettings.SettingsModelColumnName;
import org.knime.core.node.defaultnodesettings.SettingsModelIntegerBounded;
import org.knime.core.node.defaultnodesettings.SettingsModelString;
import org.knime.core.node.port.PortObject;
import org.knime.core.node.port.PortObjectSpec;
import org.knime.core.node.port.PortType;

/* loaded from: input_file:readbnfromxmlbif.jar:ch/resear/thiriot/knime/bayesiannetworks/learn/LearnBNFromSampleNodeModel.class */
public class LearnBNFromSampleNodeModel extends NodeModel {
    private static final NodeLogger logger = NodeLogger.getLogger(LearnBNFromSampleNodeModel.class);
    private static final ILogger ilogger = new LogIntoNodeLogger(logger);
    public static final String METHOD_NOCASE_EQUIPROBABILITY = "assume equiprobability";
    public static final String METHOD_NOCASE_PREVIOUS = "keep previous probabilities";
    private static final int MAX_WARNINGS = 30;
    private SettingsModelIntegerBounded m_constant;
    private SettingsModelBoolean m_useWeightColumn;
    private SettingsModelColumnName m_colnameWeight;
    private SettingsModelString m_methodNoCase;

    /* JADX INFO: Access modifiers changed from: protected */
    public LearnBNFromSampleNodeModel() {
        super(new PortType[]{BayesianNetworkPortObject.TYPE, BufferedDataTable.TYPE}, new PortType[]{BayesianNetworkPortObject.TYPE});
        this.m_constant = new SettingsModelIntegerBounded("m_addconstant", 0, 0, 1000);
        this.m_useWeightColumn = new SettingsModelBoolean("m_use_weight_colum", false);
        this.m_colnameWeight = new SettingsModelColumnName("m_colname", (String) null);
        this.m_methodNoCase = new SettingsModelString("m_method_no_vase", METHOD_NOCASE_PREVIOUS);
    }

    protected PortObject[] execute(PortObject[] portObjectArr, ExecutionContext executionContext) throws Exception {
        double doubleValue;
        try {
            CategoricalBayesianNetwork bn = ((BayesianNetworkPortObject) portObjectArr[0]).getBN();
            try {
                BufferedDataTable bufferedDataTable = (BufferedDataTable) portObjectArr[1];
                int intValue = this.m_constant.getIntValue();
                boolean booleanValue = this.m_useWeightColumn.getBooleanValue();
                int findColumnIndex = bufferedDataTable.getDataTableSpec().findColumnIndex(this.m_colnameWeight.getColumnName());
                boolean equals = METHOD_NOCASE_EQUIPROBABILITY.equals(this.m_methodNoCase.getStringValue());
                CategoricalBayesianNetwork m77clone = bn.m77clone();
                HashSet hashSet = new HashSet();
                for (NodeCategorical nodeCategorical : m77clone.getNodes()) {
                    if (bufferedDataTable.getDataTableSpec().containsName(nodeCategorical.getName())) {
                        hashSet.add(nodeCategorical.getName());
                    }
                }
                HashSet<NodeCategorical> hashSet2 = new HashSet();
                for (NodeCategorical nodeCategorical2 : m77clone.getNodes()) {
                    if (!bufferedDataTable.getDataTableSpec().containsName(nodeCategorical2.getName()) || nodeCategorical2.getAllAncestors().stream().map(nodeCategorical3 -> {
                        return nodeCategorical3.getName();
                    }).anyMatch(str -> {
                        return !bufferedDataTable.getDataTableSpec().containsName(str);
                    })) {
                        logger.warn("will not learn the node " + nodeCategorical2.getName() + " for which columns are not available");
                    } else {
                        logger.info("will learn the node " + nodeCategorical2.getName());
                        hashSet2.add(nodeCategorical2);
                    }
                }
                executionContext.setMessage("initializing counters");
                HashMap hashMap = new HashMap();
                for (NodeCategorical nodeCategorical4 : hashSet2) {
                    HashMap hashMap2 = new HashMap();
                    for (String str2 : nodeCategorical4.getDomain()) {
                        HashMap hashMap3 = new HashMap();
                        IteratorCategoricalVariables iterateDomains = m77clone.iterateDomains(nodeCategorical4.getParents());
                        while (iterateDomains.hasNext()) {
                            hashMap3.put((Map) iterateDomains.next().entrySet().stream().collect(Collectors.toMap(entry -> {
                                return (NodeCategorical) entry.getKey();
                            }, entry2 -> {
                                return (String) entry2.getValue();
                            })), Double.valueOf(0.0d));
                        }
                        hashMap2.put(str2, hashMap3);
                        executionContext.checkCanceled();
                    }
                    hashMap.put(nodeCategorical4, hashMap2);
                }
                Map<NodeCategorical, DataTableToBNMapper> createMapper = DataTableToBNMapper.createMapper(m77clone, ilogger);
                Map map = (Map) hashSet2.stream().collect(Collectors.toMap(Function.identity(), nodeCategorical5 -> {
                    return new Integer(bufferedDataTable.getDataTableSpec().findColumnIndex(nodeCategorical5.getName()));
                }));
                CloseableRowIterator it = bufferedDataTable.iterator();
                HashSet hashSet3 = new HashSet();
                int i = 0;
                while (it.hasNext()) {
                    DataRow dataRow = (DataRow) it.next();
                    for (NodeCategorical nodeCategorical6 : hashSet2) {
                        String stringValueForCell = createMapper.get(nodeCategorical6).getStringValueForCell(dataRow.getCell(((Integer) map.get(nodeCategorical6)).intValue()));
                        Map map2 = (Map) nodeCategorical6.getParents().stream().collect(Collectors.toMap(Function.identity(), nodeCategorical7 -> {
                            return ((DataTableToBNMapper) createMapper.get(nodeCategorical7)).getStringValueForCell(dataRow.getCell(((Integer) map.get(nodeCategorical7)).intValue()));
                        }));
                        try {
                            Double d = (Double) ((Map) ((Map) hashMap.get(nodeCategorical6)).get(stringValueForCell)).get(map2);
                            ((Map) ((Map) hashMap.get(nodeCategorical6)).get(stringValueForCell)).put(map2, Double.valueOf(d.doubleValue() + (booleanValue ? dataRow.getCell(findColumnIndex).getDoubleValue() : 1.0d)));
                        } catch (NullPointerException unused) {
                            logger.error("unknown value " + nodeCategorical6.getName() + "=" + stringValueForCell + ", domain is " + nodeCategorical6.getDomain());
                            hashSet3.add(nodeCategorical6);
                        }
                    }
                    if (i % 10 == 0) {
                        executionContext.checkCanceled();
                        executionContext.setProgress((0.7d * i) / bufferedDataTable.size(), "reading sample " + i);
                    }
                    i++;
                }
                if (!hashSet3.isEmpty()) {
                    logger.error("the following nodes were not measured: " + hashSet3);
                    hashSet2.removeAll(hashSet3);
                }
                executionContext.setProgress("aggregating statistics");
                HashSet hashSet4 = new HashSet();
                int i2 = 0;
                int i3 = 0;
                for (NodeCategorical nodeCategorical8 : hashSet2) {
                    executionContext.checkCanceled();
                    int i4 = i3;
                    i3++;
                    executionContext.setProgress(0.7d + ((0.3d * i4) / hashSet2.size()), "aggregating statistics for " + nodeCategorical8.name);
                    for (String str3 : nodeCategorical8.getDomain()) {
                        IteratorCategoricalVariables iterateDomains2 = m77clone.iterateDomains(nodeCategorical8.getParents());
                        while (iterateDomains2.hasNext()) {
                            Map<NodeCategorical, String> next = iterateDomains2.next();
                            Map<NodeCategorical, String> map3 = (Map) next.entrySet().stream().collect(Collectors.toMap(entry3 -> {
                                return (NodeCategorical) entry3.getKey();
                            }, entry4 -> {
                                return (String) entry4.getValue();
                            }));
                            double sum = nodeCategorical8.getDomain().stream().mapToDouble(str4 -> {
                                return ((Double) ((Map) ((Map) hashMap.get(nodeCategorical8)).get(str4)).get(next)).doubleValue() + intValue;
                            }).sum();
                            if (sum != intValue * nodeCategorical8.getDomainSize()) {
                                doubleValue = (((Double) ((Map) ((Map) hashMap.get(nodeCategorical8)).get(str3)).get(next)).doubleValue() + intValue) / sum;
                            } else if (equals) {
                                if (hashSet4.size() < 30) {
                                    hashSet4.add("no observation for the case " + nodeCategorical8.name + "=" + str3 + " given " + ((String) next.entrySet().stream().map(entry5 -> {
                                        return String.valueOf(((NodeCategorical) entry5.getKey()).name) + "=" + ((String) entry5.getValue());
                                    }).collect(Collectors.joining(", "))) + "; will assume equiprobability");
                                } else {
                                    i2++;
                                }
                                doubleValue = 1.0d / nodeCategorical8.getDomainSize();
                            } else {
                                if (hashSet4.size() < 30) {
                                    hashSet4.add("no observation for the case " + nodeCategorical8.name + "=" + str3 + " given " + ((String) next.entrySet().stream().map(entry6 -> {
                                        return String.valueOf(((NodeCategorical) entry6.getKey()).name) + "=" + ((String) entry6.getValue());
                                    }).collect(Collectors.joining(", "))) + "; will keep former probabilities");
                                } else {
                                    i2++;
                                }
                                doubleValue = -1.0d;
                            }
                            if (doubleValue >= 0.0d) {
                                try {
                                    nodeCategorical8.setProbabilities(doubleValue, str3, map3);
                                } catch (IllegalArgumentException e) {
                                    e.printStackTrace();
                                    throw e;
                                }
                            }
                        }
                    }
                    nodeCategorical8.normalize();
                    System.out.println(nodeCategorical8.asFactor().toStringLong());
                }
                executionContext.checkCanceled();
                executionContext.setMessage("processing warnings");
                Iterator it2 = hashSet4.iterator();
                while (it2.hasNext()) {
                    logger.warn((String) it2.next());
                }
                if (i2 > 0) {
                    logger.warn("(... " + i2 + " additional warnings were ignored)");
                }
                return new BayesianNetworkPortObject[]{new BayesianNetworkPortObject(m77clone)};
            } catch (ClassCastException e2) {
                throw new IllegalArgumentException("The second input should be a data table", e2);
            }
        } catch (ClassCastException e3) {
            throw new IllegalArgumentException("The first input should be a Bayesian network", e3);
        }
    }

    protected void reset() {
    }

    /* JADX INFO: Access modifiers changed from: protected */
    /* renamed from: configure, reason: merged with bridge method [inline-methods] */
    public BayesianNetworkPortSpec[] m76configure(PortObjectSpec[] portObjectSpecArr) throws InvalidSettingsException {
        return new BayesianNetworkPortSpec[]{new BayesianNetworkPortSpec()};
    }

    protected void saveSettingsTo(NodeSettingsWO nodeSettingsWO) {
        this.m_constant.saveSettingsTo(nodeSettingsWO);
        this.m_useWeightColumn.saveSettingsTo(nodeSettingsWO);
        this.m_colnameWeight.saveSettingsTo(nodeSettingsWO);
        this.m_methodNoCase.saveSettingsTo(nodeSettingsWO);
    }

    protected void loadValidatedSettingsFrom(NodeSettingsRO nodeSettingsRO) throws InvalidSettingsException {
        this.m_constant.loadSettingsFrom(nodeSettingsRO);
        this.m_useWeightColumn.loadSettingsFrom(nodeSettingsRO);
        this.m_colnameWeight.loadSettingsFrom(nodeSettingsRO);
        this.m_methodNoCase.loadSettingsFrom(nodeSettingsRO);
    }

    protected void validateSettings(NodeSettingsRO nodeSettingsRO) throws InvalidSettingsException {
        this.m_constant.validateSettings(nodeSettingsRO);
        this.m_useWeightColumn.validateSettings(nodeSettingsRO);
        this.m_colnameWeight.validateSettings(nodeSettingsRO);
        this.m_methodNoCase.validateSettings(nodeSettingsRO);
    }

    protected void loadInternals(File file, ExecutionMonitor executionMonitor) throws IOException, CanceledExecutionException {
    }

    protected void saveInternals(File file, ExecutionMonitor executionMonitor) throws IOException, CanceledExecutionException {
    }
}
