package ch.resear.thiriot.knime.bayesiannetworks.sample;

import cern.colt.Version;
import cern.jet.random.Binomial;
import cern.jet.random.engine.MersenneTwister;
import cern.jet.random.engine.RandomEngine;
import ch.resear.thiriot.knime.bayesiannetworks.DataTableToBNMapper;
import ch.resear.thiriot.knime.bayesiannetworks.LogIntoNodeLogger;
import ch.resear.thiriot.knime.bayesiannetworks.lib.ILogger;
import ch.resear.thiriot.knime.bayesiannetworks.lib.bn.CategoricalBayesianNetwork;
import ch.resear.thiriot.knime.bayesiannetworks.lib.bn.NodeCategorical;
import ch.resear.thiriot.knime.bayesiannetworks.lib.inference.AbstractInferenceEngine;
import ch.resear.thiriot.knime.bayesiannetworks.lib.inference.SimpleConditionningInferenceEngine;
import ch.resear.thiriot.knime.bayesiannetworks.lib.sampling.EntitiesAndCount;
import ch.resear.thiriot.knime.bayesiannetworks.lib.sampling.ForwardSamplingIterator;
import ch.resear.thiriot.knime.bayesiannetworks.lib.sampling.MultinomialRecursiveSamplingIterator;
import ch.resear.thiriot.knime.bayesiannetworks.lib.sampling.RoundAndSampleRecursiveSamplingIterator;
import ch.resear.thiriot.knime.bayesiannetworks.port.BayesianNetworkPortObject;
import java.io.File;
import java.io.IOException;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.Iterator;
import java.util.LinkedList;
import java.util.List;
import java.util.Map;
import java.util.concurrent.Callable;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.Future;
import org.knime.core.data.DataCell;
import org.knime.core.data.DataColumnSpec;
import org.knime.core.data.DataColumnSpecCreator;
import org.knime.core.data.DataTableSpec;
import org.knime.core.data.RowKey;
import org.knime.core.data.def.DefaultRow;
import org.knime.core.data.def.IntCell;
import org.knime.core.node.BufferedDataContainer;
import org.knime.core.node.BufferedDataTable;
import org.knime.core.node.CanceledExecutionException;
import org.knime.core.node.ExecutionContext;
import org.knime.core.node.ExecutionMonitor;
import org.knime.core.node.InvalidSettingsException;
import org.knime.core.node.NodeLogger;
import org.knime.core.node.NodeModel;
import org.knime.core.node.NodeSettingsRO;
import org.knime.core.node.NodeSettingsWO;
import org.knime.core.node.defaultnodesettings.SettingsModelBoolean;
import org.knime.core.node.defaultnodesettings.SettingsModelIntegerBounded;
import org.knime.core.node.defaultnodesettings.SettingsModelSeed;
import org.knime.core.node.defaultnodesettings.SettingsModelString;
import org.knime.core.node.port.PortObject;
import org.knime.core.node.port.PortObjectSpec;
import org.knime.core.node.port.PortType;

/* loaded from: input_file:readbnfromxmlbif.jar:ch/resear/thiriot/knime/bayesiannetworks/sample/SampleFromBNNodeModel.class */
public class SampleFromBNNodeModel extends NodeModel {
    private static final NodeLogger logger = NodeLogger.getLogger(SampleFromBNNodeModel.class);
    private static final ILogger ilogger = new LogIntoNodeLogger(logger);
    private static final int MIN_ROWS_FOR_PARALLEL = 19;
    static final String CFGKEY_COUNT = "Count";
    static final int DEFAULT_COUNT = 100;
    private final SettingsModelIntegerBounded m_count;
    private final SettingsModelSeed m_seed;
    private final SettingsModelBoolean m_groupRows;
    private final SettingsModelString m_generationMethod;
    private final SettingsModelBoolean m_threadsAuto;
    private final SettingsModelIntegerBounded m_threads;
    private final SettingsModelBoolean m_noStorage;
    private Map<NodeCategorical, DataTableToBNMapper> node2mapper;
    private long totalRowsGenerated;
    private long timestampStart;
    private boolean groupRows;

    /* loaded from: input_file:readbnfromxmlbif.jar:ch/resear/thiriot/knime/bayesiannetworks/sample/SampleFromBNNodeModel$BNToTableSampler.class */
    private class BNToTableSampler implements Callable<BufferedDataTable> {
        private final DataTableSpec outputSpec;
        private final ExecutionContext exec;
        private final int countToSample;
        private final CategoricalBayesianNetwork bn;
        private final int firstId;
        private final RandomEngine random;
        private final String method;
        private final boolean nostorage;

        public BNToTableSampler(RandomEngine randomEngine, CategoricalBayesianNetwork categoricalBayesianNetwork, DataTableSpec dataTableSpec, ExecutionContext executionContext, int i, int i2, String str, boolean z) {
            this.outputSpec = dataTableSpec;
            this.exec = executionContext;
            this.countToSample = i;
            this.bn = categoricalBayesianNetwork;
            this.firstId = i2;
            this.random = new MersenneTwister(randomEngine.nextInt());
            this.method = str;
            this.nostorage = z;
        }

        /* JADX WARN: Can't rename method to resolve collision */
        @Override // java.util.concurrent.Callable
        public BufferedDataTable call() throws Exception {
            Iterator forwardSamplingIterator;
            BufferedDataContainer createDataContainer = this.exec.createDataContainer(this.outputSpec);
            if (this.method.equals(RoundAndSampleRecursiveSamplingIterator.GENERATION_METHOD_NAME)) {
                forwardSamplingIterator = new RoundAndSampleRecursiveSamplingIterator(this.countToSample, this.bn, this.random, (AbstractInferenceEngine) new SimpleConditionningInferenceEngine(SampleFromBNNodeModel.ilogger, null, this.bn), (ExecutionMonitor) this.exec, SampleFromBNNodeModel.ilogger);
            } else if (this.method.equals(MultinomialRecursiveSamplingIterator.GENERATION_METHOD_NAME)) {
                forwardSamplingIterator = new MultinomialRecursiveSamplingIterator(this.countToSample, this.bn, new Binomial(42, 0.1d, this.random), (AbstractInferenceEngine) new SimpleConditionningInferenceEngine(SampleFromBNNodeModel.ilogger, null, this.bn), (ExecutionMonitor) this.exec, SampleFromBNNodeModel.ilogger);
            } else {
                if (!this.method.equals(ForwardSamplingIterator.GENERATION_METHOD_NAME)) {
                    throw new RuntimeException("Unknown generation method " + this.method);
                }
                forwardSamplingIterator = new ForwardSamplingIterator(this.random, this.bn, this.countToSample, SampleFromBNNodeModel.ilogger);
            }
            int i = 0;
            int i2 = 0;
            while (forwardSamplingIterator.hasNext()) {
                double d = i / this.countToSample;
                long currentTimeMillis = System.currentTimeMillis();
                if (this.firstId == 0) {
                    String str = "entity " + i;
                    if ((currentTimeMillis - SampleFromBNNodeModel.this.timestampStart) / 1000 > 10) {
                        str = String.valueOf(str) + " (" + ((int) (SampleFromBNNodeModel.this.totalRowsGenerated / r0)) + "/s)";
                    }
                    this.exec.setProgress(d, str);
                } else {
                    this.exec.setProgress(d);
                }
                try {
                    EntitiesAndCount next = forwardSamplingIterator.next();
                    i += next.count.intValue();
                    SampleFromBNNodeModel.this.totalRowsGenerated += next.count.intValue();
                    if (!this.nostorage) {
                        if (SampleFromBNNodeModel.this.groupRows) {
                            DataCell[] dataCellArr = new DataCell[next.node2value.size() + 1];
                            int i3 = 0;
                            for (NodeCategorical nodeCategorical : this.bn.enumerateNodes()) {
                                int i4 = i3;
                                i3++;
                                dataCellArr[i4] = ((DataTableToBNMapper) SampleFromBNNodeModel.this.node2mapper.get(nodeCategorical)).createCellForStringValue(next.node2value.get(nodeCategorical));
                            }
                            dataCellArr[i3] = IntCell.IntCellFactory.create(next.count.intValue());
                            int i5 = i2;
                            i2++;
                            createDataContainer.addRowToTable(new DefaultRow(new RowKey("Row " + (this.firstId + i5)), dataCellArr));
                        } else {
                            DataCell[] dataCellArr2 = new DataCell[next.node2value.size()];
                            int i6 = 0;
                            for (NodeCategorical nodeCategorical2 : this.bn.enumerateNodes()) {
                                int i7 = i6;
                                i6++;
                                dataCellArr2[i7] = ((DataTableToBNMapper) SampleFromBNNodeModel.this.node2mapper.get(nodeCategorical2)).createCellForStringValue(next.node2value.get(nodeCategorical2));
                            }
                            for (int i8 = 0; i8 < next.count.intValue(); i8++) {
                                int i9 = i2;
                                i2++;
                                createDataContainer.addRowToTable(new DefaultRow(new RowKey("Row " + (this.firstId + i9)), dataCellArr2));
                            }
                        }
                    }
                    this.exec.checkCanceled();
                } catch (RuntimeException e) {
                    e.printStackTrace();
                    throw new RuntimeException("Error when sampling the next entity: " + e.getMessage(), e);
                }
            }
            createDataContainer.close();
            return createDataContainer.getTable();
        }
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public SampleFromBNNodeModel() {
        super(new PortType[]{BayesianNetworkPortObject.TYPE}, new PortType[]{BufferedDataTable.TYPE});
        this.m_count = new SettingsModelIntegerBounded(CFGKEY_COUNT, DEFAULT_COUNT, 0, Integer.MAX_VALUE);
        this.m_seed = new SettingsModelSeed("seed", (int) System.currentTimeMillis(), false);
        this.m_groupRows = new SettingsModelBoolean("m_grouprows", true);
        this.m_generationMethod = new SettingsModelString("m_generation_method", MultinomialRecursiveSamplingIterator.GENERATION_METHOD_NAME);
        this.m_threadsAuto = new SettingsModelBoolean("m_threads_auto", true);
        this.m_threads = new SettingsModelIntegerBounded("m_threads", Runtime.getRuntime().availableProcessors(), 1, 128);
        this.m_noStorage = new SettingsModelBoolean("m_nostorage", false);
        this.node2mapper = new HashMap();
        this.totalRowsGenerated = 0L;
        this.timestampStart = 0L;
        this.groupRows = false;
    }

    protected DataColumnSpec[] createSpecsForBN(CategoricalBayesianNetwork categoricalBayesianNetwork) {
        this.node2mapper.clear();
        this.node2mapper.putAll(DataTableToBNMapper.createMapper(categoricalBayesianNetwork, ilogger));
        LinkedList linkedList = new LinkedList();
        Iterator<NodeCategorical> it = categoricalBayesianNetwork.enumerateNodes().iterator();
        while (it.hasNext()) {
            linkedList.add(this.node2mapper.get(it.next()).getSpecForNode());
        }
        if (this.m_groupRows.getBooleanValue() && !this.m_generationMethod.getStringValue().equals(ForwardSamplingIterator.GENERATION_METHOD_NAME)) {
            linkedList.add(new DataColumnSpecCreator("count", IntCell.TYPE).createSpec());
        }
        return (DataColumnSpec[]) linkedList.toArray(new DataColumnSpec[linkedList.size()]);
    }

    protected PortObjectSpec[] configure(PortObjectSpec[] portObjectSpecArr) throws InvalidSettingsException {
        return new DataTableSpec[1];
    }

    protected PortObject[] execute(PortObject[] portObjectArr, ExecutionContext executionContext) throws Exception {
        int currentTimeMillis;
        BufferedDataTable bufferedDataTable;
        if (portObjectArr.length == 0) {
            throw new IllegalArgumentException("No Bayesian network found as input");
        }
        if (portObjectArr.length > 1) {
            throw new IllegalArgumentException("Only one Bayesian network expected as input");
        }
        try {
            CategoricalBayesianNetwork bn = ((BayesianNetworkPortObject) portObjectArr[0]).getBN();
            int intValue = this.m_count.getIntValue();
            String stringValue = this.m_generationMethod.getStringValue();
            this.groupRows = this.m_groupRows.getBooleanValue() && !stringValue.equals(ForwardSamplingIterator.GENERATION_METHOD_NAME);
            boolean booleanValue = this.m_noStorage.getBooleanValue();
            if (booleanValue) {
                setWarningMessage("storage is disabled; no data will be produced");
            }
            if (this.m_seed.getIsActive()) {
                currentTimeMillis = (int) this.m_seed.getLongValue();
                if (currentTimeMillis != this.m_seed.getLongValue()) {
                    logger.info("the seed was converted from long " + this.m_seed.getLongValue() + " to int " + currentTimeMillis + "; this should have no impact for you");
                }
            } else {
                currentTimeMillis = (int) System.currentTimeMillis();
            }
            executionContext.setMessage("preparing the output table");
            DataTableSpec dataTableSpec = new DataTableSpec(createSpecsForBN(bn));
            executionContext.setMessage("init of the random engine");
            logger.info("generating random numbers using the MersenneTwister pseudo-random number generator with seed " + currentTimeMillis + ", as implemented in the COLT library " + Version.getMajorVersion() + "." + Version.getMinorVersion() + "." + Version.getMicroVersion());
            MersenneTwister mersenneTwister = new MersenneTwister(currentTimeMillis);
            executionContext.checkCanceled();
            int i = 1;
            int intValue2 = this.m_threads.getIntValue();
            if (this.m_threadsAuto.getBooleanValue()) {
                intValue2 = Runtime.getRuntime().availableProcessors();
            }
            while (i < intValue2 && intValue / i > MIN_ROWS_FOR_PARALLEL) {
                i++;
            }
            logger.debug("will use " + i + " threads to generate the data");
            executionContext.setProgress(0.0d, "generating rows");
            executionContext.setMessage("sampling");
            ArrayList arrayList = new ArrayList(i);
            int i2 = intValue;
            int i3 = 0;
            for (int i4 = 0; i4 < i - 1; i4++) {
                int i5 = intValue / i;
                i2 -= i5;
                arrayList.add(new BNToTableSampler(mersenneTwister, bn, dataTableSpec, executionContext.createSubExecutionContext(0.9d / i), i5, i3, stringValue, booleanValue));
                i3 += i5;
            }
            arrayList.add(new BNToTableSampler(mersenneTwister, bn, dataTableSpec, executionContext.createSubExecutionContext(0.9d / i), i2, i3, stringValue, booleanValue));
            executionContext.checkCanceled();
            ExecutorService newFixedThreadPool = Executors.newFixedThreadPool(i);
            this.totalRowsGenerated = 0L;
            this.timestampStart = System.currentTimeMillis();
            List invokeAll = newFixedThreadPool.invokeAll(arrayList);
            long currentTimeMillis2 = System.currentTimeMillis() - this.timestampStart;
            int i6 = (int) ((intValue / currentTimeMillis2) * 1000.0d);
            logger.info("generation of " + intValue + " entities on " + i + " CPUs with method " + stringValue + " took " + currentTimeMillis2 + "s, that is on average " + i6 + " entities/s");
            newFixedThreadPool.shutdown();
            if (i > 1) {
                executionContext.setProgress("merging tables");
                BufferedDataTable[] bufferedDataTableArr = new BufferedDataTable[i];
                for (int i7 = 0; i7 < bufferedDataTableArr.length; i7++) {
                    bufferedDataTableArr[i7] = (BufferedDataTable) ((Future) invokeAll.get(i7)).get();
                }
                bufferedDataTable = executionContext.createConcatenateTable(executionContext.createSubProgress(0.1d), bufferedDataTableArr);
            } else {
                bufferedDataTable = (BufferedDataTable) ((Future) invokeAll.get(0)).get();
            }
            pushFlowVariableInt("sampled_count", intValue);
            pushFlowVariableInt("sampling_performance_entities_per_second", i6);
            return new BufferedDataTable[]{bufferedDataTable};
        } catch (ClassCastException e) {
            throw new IllegalArgumentException("The input should be a Bayesian network", e);
        }
    }

    protected void reset() {
    }

    protected void saveSettingsTo(NodeSettingsWO nodeSettingsWO) {
        this.m_count.saveSettingsTo(nodeSettingsWO);
        this.m_seed.saveSettingsTo(nodeSettingsWO);
        this.m_generationMethod.saveSettingsTo(nodeSettingsWO);
        this.m_threads.saveSettingsTo(nodeSettingsWO);
        this.m_threadsAuto.saveSettingsTo(nodeSettingsWO);
        this.m_groupRows.saveSettingsTo(nodeSettingsWO);
        this.m_noStorage.saveSettingsTo(nodeSettingsWO);
    }

    protected void loadValidatedSettingsFrom(NodeSettingsRO nodeSettingsRO) throws InvalidSettingsException {
        this.m_count.loadSettingsFrom(nodeSettingsRO);
        this.m_seed.loadSettingsFrom(nodeSettingsRO);
        this.m_generationMethod.loadSettingsFrom(nodeSettingsRO);
        this.m_threads.loadSettingsFrom(nodeSettingsRO);
        this.m_threadsAuto.loadSettingsFrom(nodeSettingsRO);
        this.m_groupRows.loadSettingsFrom(nodeSettingsRO);
        this.m_noStorage.loadSettingsFrom(nodeSettingsRO);
    }

    protected void validateSettings(NodeSettingsRO nodeSettingsRO) throws InvalidSettingsException {
        this.m_count.validateSettings(nodeSettingsRO);
        this.m_seed.validateSettings(nodeSettingsRO);
        this.m_generationMethod.validateSettings(nodeSettingsRO);
        this.m_threads.validateSettings(nodeSettingsRO);
        this.m_threadsAuto.validateSettings(nodeSettingsRO);
        this.m_groupRows.validateSettings(nodeSettingsRO);
        this.m_noStorage.validateSettings(nodeSettingsRO);
    }

    protected void loadInternals(File file, ExecutionMonitor executionMonitor) throws IOException, CanceledExecutionException {
    }

    protected void saveInternals(File file, ExecutionMonitor executionMonitor) throws IOException, CanceledExecutionException {
    }
}
