Skip to content

Commit 6972801

Browse files
A couple optimizations of the diversity picker (#82)
* some optimization of the diversity picker * Performance tuning for fingerprint calculation or conversion Implemented multi-threading for time-consuming code * Central bugfix to unleash performance power of RDKit nodes in general The cleanup tracker for RDKit objects got a performance boost by using now a set instead of a list to track RDKit objects marked for future cleanup - this saves us from checking if an object was already contained in the list, and apparently the set implementation is much better and faster, internally using hashes (for RDKit objects probably direct pointers) to compare objects. This resulted in the diversity picker in a performance boost of more than 75%! Co-authored-by: Manuel Schwarze <manuel.schwarze@novartis.com>
1 parent 3afe684 commit 6972801

File tree

3 files changed

+167
-104
lines changed

3 files changed

+167
-104
lines changed

org.rdkit.knime.nodes/src/org/rdkit/knime/nodes/AbstractRDKitNodeModel.java

Lines changed: 13 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -2575,7 +2575,7 @@ public void processFinished(final ComputationTask task) {
25752575
*
25762576
* @author Manuel Schwarze
25772577
*/
2578-
static class RDKitCleanupTracker extends HashMap<Long, List<Object>> {
2578+
static class RDKitCleanupTracker extends HashMap<Long, HashSet<Object>> {
25792579

25802580
//
25812581
// Constants
@@ -2661,28 +2661,26 @@ public synchronized <T extends Object> T markForCleanup(final T rdkitObject, fin
26612661
if (bRemoveFromOtherWave) {
26622662

26632663
// Loop through all waves to find the rdkitObject - we create a copy here, because
2664-
// we may remove empty wave lists which may blow up out iterator
2664+
// we may remove empty wave lists which may blow up our iterator
26652665
for (final long waveExisting : new HashSet<Long>(keySet())) {
2666-
final List<Object> list = get(waveExisting);
2666+
final HashSet<Object> list = get(waveExisting);
26672667
if (list.remove(rdkitObject) && list.isEmpty()) {
26682668
remove(waveExisting);
26692669
}
26702670
}
26712671
}
26722672

26732673
// Get the list of the target wave
2674-
List<Object> list = get(wave);
2674+
HashSet<Object> set = get(wave);
26752675

26762676
// Create a wave list, if not found yet
2677-
if (list == null) {
2678-
list = new ArrayList<Object>();
2679-
put(wave, list);
2677+
if (set == null) {
2678+
set = new HashSet<Object>();
2679+
put(wave, set);
26802680
}
26812681

2682-
// Add the object only once
2683-
if (!list.contains(rdkitObject)) {
2684-
list.add(rdkitObject);
2685-
}
2682+
// Add the object (only once, since it is a set)
2683+
set.add(rdkitObject);
26862684
}
26872685

26882686
return rdkitObject;
@@ -2708,11 +2706,11 @@ public synchronized void cleanupMarkedObjects() {
27082706
*/
27092707
public synchronized void cleanupMarkedObjects(final long wave) {
27102708
// Find the right wave list, if not found yet
2711-
final List<Object> list = get(wave);
2709+
final HashSet<Object> set = get(wave);
27122710

27132711
// If wave list was found, free all objects in it
2714-
if (list != null) {
2715-
for (final Object objForCleanup : list) {
2712+
if (set != null) {
2713+
for (final Object objForCleanup : set) {
27162714
Class<?> clazz = null;
27172715

27182716
try {
@@ -2739,7 +2737,7 @@ public synchronized void cleanupMarkedObjects(final long wave) {
27392737
}
27402738
}
27412739

2742-
list.clear();
2740+
set.clear();
27432741
remove(wave);
27442742
}
27452743
}

org.rdkit.knime.nodes/src/org/rdkit/knime/nodes/diversitypicker/RDKitDiversityPickerNodeModel.java

Lines changed: 151 additions & 83 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,9 @@
5151
import java.util.ArrayList;
5252
import java.util.List;
5353
import java.util.Map;
54+
import java.util.concurrent.CancellationException;
55+
import java.util.concurrent.ExecutionException;
56+
import java.util.concurrent.atomic.AtomicLong;
5457

5558
import org.RDKit.EBV_Vect;
5659
import org.RDKit.ExplicitBitVect;
@@ -65,13 +68,15 @@
6568
import org.knime.core.data.RowIterator;
6669
import org.knime.core.data.vector.bitvector.BitVectorValue;
6770
import org.knime.core.node.BufferedDataTable;
71+
import org.knime.core.node.CanceledExecutionException;
6872
import org.knime.core.node.ExecutionContext;
6973
import org.knime.core.node.InvalidSettingsException;
7074
import org.knime.core.node.defaultnodesettings.SettingsModelInteger;
7175
import org.knime.core.node.defaultnodesettings.SettingsModelString;
7276
import org.knime.core.node.port.PortObjectSpec;
7377
import org.knime.core.node.port.PortType;
7478
import org.knime.core.node.port.PortTypeRegistry;
79+
import org.knime.core.util.MultiThreadWorker;
7580
import org.rdkit.knime.nodes.AbstractRDKitNodeModel;
7681
import org.rdkit.knime.nodes.AbstractRDKitSplitterNodeModel;
7782
import org.rdkit.knime.nodes.rdkfingerprint.DefaultFingerprintSettings;
@@ -363,84 +368,29 @@ protected void preProcessing(final BufferedDataTable[] inData,
363368
m_ebvRowsToKeep = null;
364369

365370
// Create sub execution contexts for pre-processing steps
366-
final ExecutionContext subExecReadingFingerprints = exec.createSubExecutionContext(0.25d);
367-
final ExecutionContext subExecReadingAdditionalFingerprints = exec.createSubExecutionContext(0.25d);
368-
final ExecutionContext subExecCheckDiversity = exec.createSubExecutionContext(0.25d);
369-
final ExecutionContext subExecProcessingDiversity = exec.createSubExecutionContext(0.25d);
371+
final ExecutionContext subExecReadingFingerprints = exec.createSubExecutionContext(0.75d);
372+
final ExecutionContext subExecReadingAdditionalFingerprints = exec.createSubExecutionContext(0.05d);
373+
final ExecutionContext subExecCheckDiversity = exec.createSubExecutionContext(0.10d);
374+
final ExecutionContext subExecProcessingDiversity = exec.createSubExecutionContext(0.10d);
370375

371376
final long lInputRowCount = inData[0].size();
372-
final List<Integer> listIndicesUsed = new ArrayList<Integer>();
377+
final List<Long> listIndicesUsed = new ArrayList<Long>();
373378
final EBV_Vect vFingerprints = markForCleanup(new EBV_Vect());
374379
Int_Vect firstPicks = null;
375-
long lFpLength = -1;
380+
AtomicLong alFpLength = new AtomicLong(-1);
376381

377382
// 1. Get all fingerprints in a form that we can process further
378-
final InputDataInfo inputDataInfo1 = arrInputDataInfo[0][INPUT_COLUMN_MAIN];
379-
final boolean bNeedsCalculation1 = inputDataInfo1.isCompatibleOrAdaptable(RDKitMolValue.class);
380-
String strInfoForProgress = (bNeedsCalculation1 ? " - Calculating fingerprints" : " - Reading fingerprints");
383+
final boolean bNeedsCalculation1 = arrInputDataInfo[0][INPUT_COLUMN_MAIN].isCompatibleOrAdaptable(RDKitMolValue.class);
381384
final FingerprintSettingsHeaderProperty fpSpec1 = (bNeedsCalculation1 ?
382385
new FingerprintSettingsHeaderProperty(DEFAULT_FINGERPRINT_SETTINGS) :
383386
new FingerprintSettingsHeaderProperty(arrInputDataInfo[0][INPUT_COLUMN_MAIN].getColumnSpec()));
384387
FingerprintSettingsHeaderProperty fpSpec2 = null;
385-
final FingerprintType fpTypeDefault = DEFAULT_FINGERPRINT_SETTINGS.getRdkitFingerprintType();
386-
387-
int iRowIndex = 0;
388-
final RowIterator it1 = inData[0].iterator();
389-
while (it1.hasNext()) {
390-
final DataRow row = it1.next();
391-
ExplicitBitVect expBitVector = null;
392-
393-
if (bNeedsCalculation1) {
394-
// Calculate the fingerprint for the molecule on the fly
395-
ROMol mol = null;
396-
397-
try {
398-
mol = arrInputDataInfo[0][INPUT_COLUMN_MAIN].getROMol(row);
399-
if (mol != null) {
400-
expBitVector = markForCleanup(fpTypeDefault.calculate(mol, DEFAULT_FINGERPRINT_SETTINGS));
401-
}
402-
else {
403-
warnings.saveWarning(WarningConsolidator.ROW_CONTEXT.getId(), "Encountered empty molecule cell in table 1 - ignored it.");
404-
}
405-
}
406-
finally {
407-
// Delete the molecule manually to free memory quickly
408-
if (mol != null) {
409-
mol.delete();
410-
}
411-
}
412-
}
413-
else {
414-
expBitVector = markForCleanup(arrInputDataInfo[0][INPUT_COLUMN_MAIN].getExplicitBitVector(row));
415-
if (expBitVector == null) {
416-
warnings.saveWarning(WarningConsolidator.ROW_CONTEXT.getId(), "Encountered empty fingerprint cell in table 1 - ignored it.");
417-
}
418-
}
419388

420-
if (expBitVector != null) {
421-
final long lNumBits = expBitVector.getNumBits();
422-
if (lFpLength == -1) {
423-
lFpLength = lNumBits;
424-
}
425-
426-
if (lFpLength == lNumBits){
427-
listIndicesUsed.add(iRowIndex);
428-
vFingerprints.add(expBitVector);
429-
}
430-
else {
431-
warnings.saveWarning(WarningConsolidator.ROW_CONTEXT.getId(),
432-
"Encountered fingerprint with invalid length (" +
433-
lNumBits + " instead of " + lFpLength + " bits) in table 1 - ignoring it.");
434-
}
435-
}
436-
437-
// Every 20 iterations report progress and check for cancel
438-
if (iRowIndex % 20 == 0) {
439-
AbstractRDKitNodeModel.reportProgress(subExecReadingFingerprints, iRowIndex, lInputRowCount, row, " - Reading fingerprints");
440-
}
441-
442-
iRowIndex++;
443-
}
389+
// Parallel processing to prepare fingerprints from main table (first table)
390+
prepareFingerprints(1, inData[0], arrInputDataInfo[0][INPUT_COLUMN_MAIN],
391+
bNeedsCalculation1, DEFAULT_FINGERPRINT_SETTINGS.getRdkitFingerprintType(),
392+
listIndicesUsed, vFingerprints, alFpLength,
393+
warnings, subExecReadingFingerprints);
444394

445395
// Check, if parameters of user make sense based on the found fingerprints in table 1 (NOT combined yet with table 2)
446396
final int iNumberToPick = m_modelNumberToPick.getIntValue();
@@ -464,7 +414,7 @@ protected void preProcessing(final BufferedDataTable[] inData,
464414
int iAdditionalRowIndex = 0;
465415
int iBiasAwayIndex = (int)vFingerprints.size();
466416
final long lAdditionalRowCount = inData[1].size();
467-
strInfoForProgress = (bNeedsCalculation2 ? " - Calculating additional fingerprints" : " - Reading additional fingerprints");
417+
final String strInfoForProgress = (bNeedsCalculation2 ? " - Calculating additional fingerprints" : " - Reading additional fingerprints");
468418

469419
if (!bNeedsCalculation2) {
470420
fpSpec2 = new FingerprintSettingsHeaderProperty(
@@ -513,24 +463,24 @@ else if (!FingerprintType.isCompatible(fpSpec1, fpSpec2)) {
513463
// Just add the fingerprint to bias away from and mark it as part of first picks
514464
if (expBitVector != null) {
515465
final long lNumBits = expBitVector.getNumBits();
516-
if (lFpLength == -1) {
517-
lFpLength = lNumBits;
466+
if (alFpLength.get() == -1) {
467+
alFpLength.set(lNumBits);
518468
}
519469

520-
if (lFpLength == lNumBits){
470+
if (alFpLength.get() == lNumBits){
521471
vFingerprints.add(expBitVector);
522472
firstPicks.add(iBiasAwayIndex);
523473
iBiasAwayIndex++;
524474
}
525475
else {
526476
warnings.saveWarning(ROW_CONTEXT_TABLE_2.getId(),
527477
"Encountered fingerprint with invalid length (" +
528-
lNumBits + " instead of " + lFpLength + " bits) in table 2 - ignoring it.");
478+
lNumBits + " instead of " + alFpLength.get() + " bits) in table 2 - ignoring it.");
529479
}
530480
}
531481

532-
// Every 20 iterations report progress and check for cancel
533-
if (iAdditionalRowIndex % 20 == 0) {
482+
// Every 1000 iterations report progress and check for cancel
483+
if (iAdditionalRowIndex % 1000 == 0) {
534484
AbstractRDKitNodeModel.reportProgress(subExecReadingAdditionalFingerprints, iAdditionalRowIndex, lAdditionalRowCount, row, strInfoForProgress);
535485
}
536486

@@ -555,13 +505,12 @@ else if (!FingerprintType.isCompatible(fpSpec1, fpSpec2)) {
555505
}
556506
}
557507
else if (firstPicks == null || firstPicks.isEmpty()) {
558-
intVector = markForCleanup(RDKFuncs.pickUsingFingerprints(vFingerprints,
559-
iNumberToPick, m_randomSeed.getIntValue()));
560-
}
561-
else {
562-
intVector = markForCleanup(RDKFuncs.pickUsingFingerprints(vFingerprints,
563-
iNumberToPick + firstPicks.size(), m_randomSeed.getIntValue(), firstPicks));
508+
firstPicks = new Int_Vect();
564509
}
510+
// the distance cache just slows things down with the new diversity picker implementation
511+
Boolean useDistanceCache = false;
512+
intVector = markForCleanup(RDKFuncs.pickUsingFingerprints(vFingerprints,
513+
iNumberToPick + firstPicks.size(), m_randomSeed.getIntValue(), firstPicks, useDistanceCache));
565514

566515
subExecCheckDiversity.setProgress(1.0d);
567516
subExecCheckDiversity.checkCanceled();
@@ -572,14 +521,14 @@ else if (firstPicks == null || firstPicks.isEmpty()) {
572521
for(int i = 0; i < iDiversityCount; i++) {
573522
final int pickedFingerprintIndex = intVector.get(i);
574523
if (pickedFingerprintIndex < listIndicesUsed.size()) {
575-
final int pickedRowIndex = listIndicesUsed.get(pickedFingerprintIndex);
524+
final long pickedRowIndex = listIndicesUsed.get(pickedFingerprintIndex);
576525
if (pickedRowIndex < lInputRowCount) {
577526
m_ebvRowsToKeep.setBit(pickedRowIndex);
578527
}
579528
}
580529

581-
// Every 20 iterations report progress and check for cancel
582-
if (i % 20 == 0) {
530+
// Every 1000 iterations report progress and check for cancel
531+
if (i % 1000 == 0) {
583532
AbstractRDKitNodeModel.reportProgress(subExecProcessingDiversity, i, iDiversityCount,
584533
null, " - Processing diversity results");
585534
}
@@ -588,6 +537,125 @@ else if (firstPicks == null || firstPicks.isEmpty()) {
588537
subExecProcessingDiversity.setProgress(1.0d);
589538
}
590539

540+
/**
541+
* Prepares fingerprints for diversity picking from an input table, either with a molecule column
542+
* or with a fingerprint column.
543+
*
544+
* @param iTableNumber Table index. Only used for warning generations.
545+
* @param inData Table data. Must not be null.
546+
* @param inputDataInfo Input data definition for the column to process. Must not be null.
547+
* @param bNeedsCalculation True to calculate fingerprints from molecules. False otherwise.
548+
* @param fpTypeDefault Fingerprint type used when we need to calculate fingerprints from molecules.
549+
* Must not be null.
550+
* @param listIndicesUsed IN/OUT: List of indices that will be filled with row indexes.
551+
* @param vFingerprints IN/OUT: List of fingerprints that will be filled with fingerprints.
552+
* @param alFpLength IN/OUT: Length of processed fingerprints. Must not be null.
553+
* @param warnings Warning consolidator. Must not be null.
554+
* @param subExecReadingFingerprints Execution context. Must not be null.
555+
*
556+
* @throws Exception Thrown, if something goes wrong.
557+
*/
558+
protected void prepareFingerprints(final int iTableNumber, final BufferedDataTable inData, final InputDataInfo inputDataInfo,
559+
final boolean bNeedsCalculation, final FingerprintType fpTypeDefault,
560+
final List<Long> listIndicesUsed, final EBV_Vect vFingerprints, final AtomicLong alFpLength,
561+
final WarningConsolidator warnings, final ExecutionContext subExecReadingFingerprints) throws Exception {
562+
563+
// Get settings and define data specific behavior
564+
final int iMaxParallelWorkers = (int)Math.ceil(1.5 * Runtime.getRuntime().availableProcessors());
565+
final int iQueueSize = 1000 * iMaxParallelWorkers;
566+
final long lTotalRowCount = inData.size();
567+
568+
// Calculate RDKit Fingerprints from molecule, or convert them from KNIME Fingerprints
569+
new MultiThreadWorker<DataRow, ExplicitBitVect>(iQueueSize, iMaxParallelWorkers) {
570+
571+
/**
572+
* Prepares a fingerprint from first table.
573+
*
574+
* @param row Input row.
575+
* @param index Index of row.
576+
*
577+
* @return Null, if fingerprint could not be determined.
578+
* Result fingerprint, if we have a valid fingerprint
579+
* to be used for diversity picking.
580+
*/
581+
@Override
582+
protected ExplicitBitVect compute(final DataRow row, final long index) throws Exception {
583+
ExplicitBitVect expBitVector = null;
584+
585+
if (bNeedsCalculation) {
586+
// Calculate the fingerprint for the molecule on the fly
587+
ROMol mol = null;
588+
589+
try {
590+
mol = inputDataInfo.getROMol(row);
591+
if (mol != null) {
592+
expBitVector = markForCleanup(fpTypeDefault.calculate(mol, DEFAULT_FINGERPRINT_SETTINGS));
593+
}
594+
else {
595+
warnings.saveWarning(WarningConsolidator.ROW_CONTEXT.getId(),
596+
"Encountered empty molecule cell in table " + iTableNumber + " - ignored it.");
597+
}
598+
}
599+
finally {
600+
// Delete the molecule manually to free memory quickly
601+
if (mol != null) {
602+
mol.delete();
603+
}
604+
}
605+
}
606+
else {
607+
expBitVector = markForCleanup(inputDataInfo.getExplicitBitVector(row));
608+
if (expBitVector == null) {
609+
warnings.saveWarning(WarningConsolidator.ROW_CONTEXT.getId(),
610+
"Encountered empty fingerprint cell in table " + iTableNumber + " - ignored it.");
611+
}
612+
}
613+
614+
return expBitVector;
615+
}
616+
617+
/**
618+
* Adds the fingerprint results to the fingerprint list for further processing.
619+
*
620+
* @param task Processing result for a row.
621+
*/
622+
@Override
623+
protected void processFinished(final ComputationTask task)
624+
throws ExecutionException, CancellationException, InterruptedException {
625+
final ExplicitBitVect expBitVector = task.get();
626+
final long lRowIndex = task.getIndex();
627+
628+
if (expBitVector != null) {
629+
final long lNumBits = expBitVector.getNumBits();
630+
if (alFpLength.get() == -1) {
631+
alFpLength.set(lNumBits);
632+
}
633+
634+
if (alFpLength.get() == lNumBits) {
635+
listIndicesUsed.add(lRowIndex);
636+
vFingerprints.add(expBitVector);
637+
} else {
638+
warnings.saveWarning(WarningConsolidator.ROW_CONTEXT.getId(),
639+
"Encountered fingerprint with invalid length (" + lNumBits + " instead of " + alFpLength.get()
640+
+ " bits) in table " + iTableNumber + " - ignoring it.");
641+
}
642+
}
643+
644+
// Check, if user pressed cancel (however, we will finish the method
645+
// nevertheless)
646+
// Update the progress only every 1000 rows
647+
if (task.getIndex() % 1000 == 0) {
648+
try {
649+
AbstractRDKitNodeModel.reportProgress(subExecReadingFingerprints, lRowIndex, lTotalRowCount, null,
650+
" - " + (bNeedsCalculation ? "Calculating" : "Reading") + " fingerprints");
651+
} catch (final CanceledExecutionException e) {
652+
cancel(true);
653+
}
654+
}
655+
};
656+
}.run(inData);
657+
}
658+
591659
/**
592660
* {@inheritDoc}
593661
*/

0 commit comments

Comments
 (0)