"""
tests
Created by: Martin Sicho
On: 31.05.22, 10:20
"""
import json
import logging
import os.path
import tempfile
from collections import OrderedDict
from unittest import TestCase
import numpy as np
import pandas as pd
from drugex.data.corpus.corpus import SequenceCorpus
from drugex.data.corpus.vocabulary import VocGraph, VocSmiles
from drugex.data.datasets import GraphFragDataSet, SmilesDataSet, SmilesFragDataSet
from drugex.data.fragments import (
FragmentCorpusEncoder,
FragmentPairsSplitter,
GraphFragmentEncoder,
SequenceFragmentEncoder,
)
from drugex.data.processing import (
CorpusEncoder,
RandomTrainTestSplitter,
Standardization,
)
from drugex.molecules.converters.dummy_molecules import dummyMolsFromFragments
from drugex.molecules.converters.fragmenters import Fragmenter
from drugex.training.environment import DrugExEnvironment
from drugex.training.explorers import (
FragGraphExplorer,
FragSequenceExplorer,
SequenceExplorer,
)
from drugex.training.generators import (
GraphTransformer,
SequenceRNN,
SequenceTransformer,
)
from drugex.training.interfaces import TrainingMonitor
from drugex.training.monitors import FileMonitor
from drugex.training.rewards import ParetoCrowdingDistance
from drugex.training.scorers.interfaces import Scorer
from drugex.training.scorers.modifiers import ClippedScore
from drugex.training.scorers.properties import Property
from rdkit import Chem
[docs]class TestModelMonitor(TrainingMonitor):
def __init__(self, submonitors=None):
self.model = None
self.execution = {
'model' : False,
'progress' : False,
'performance' : False,
'end' : False,
'close' : False,
}
self.submonitors = [
FileMonitor(tempfile.NamedTemporaryFile().name, save_smiles=True)
] if not submonitors else submonitors
[docs] def passToSubmonitors(self, method, *args, **kwargs):
for monitor in self.submonitors:
method = getattr(monitor, method)(*args, **kwargs)
[docs] def saveModel(self, model):
self.model = model.getModel()
self.execution['model'] = True
self.passToSubmonitors('saveModel', model)
[docs] def saveProgress(self, current_step=None, current_epoch=None, total_steps=None, total_epochs=None, *args, **kwargs):
print("Test Progress Monitor:")
print(json.dumps({
'current_step' : current_step,
'current_epoch' : current_epoch,
'total_steps' : total_steps,
'total_epochs' : total_epochs,
}, indent=4))
if args:
print("Args:", args)
if kwargs:
print("Kwargs:", json.dumps(kwargs, indent=4))
self.execution['progress'] = True
self.passToSubmonitors('saveProgress', current_step, current_epoch, total_steps, total_epochs, *args, **kwargs)
[docs] def endStep(self, step, epoch):
print(f"Finished step {step} of epoch {epoch}.")
self.execution['end'] = True
self.passToSubmonitors('endStep', step, epoch)
[docs] def close(self):
print("Training done.")
self.execution['close'] = True
self.passToSubmonitors('close')
[docs] def getModel(self):
return self.model
[docs] def allMethodsExecuted(self):
return all([self.execution[key] for key in self.execution])
[docs]class MockScorer(Scorer):
[docs] def getScores(self, mols, frags=None):
return list(np.random.random(len(mols)))
[docs] def getKey(self):
return "MockScorer"
[docs]def getPredictor():
try:
from drugex.training.scorers.qsprpred import QSPRPredScorer
from qsprpred.models.models import QSPRModel
model = QSPRModel.fromFile(os.path.join(os.path.dirname(__file__),
"test_data/A2AR_RandomForestClassifier/A2AR_RandomForestClassifier_meta.json"))
ret = QSPRPredScorer(model)
except ImportError:
ret = MockScorer()
return ret
[docs]class TestScorer(TestCase):
[docs] def test_getScores(self):
scorer = getPredictor()
mols = ["CCO", "CC"]
scores = scorer.getScores(mols)
self.assertEqual(len(scores), len(mols))
self.assertTrue(all([isinstance(score, float) and score > 0 for score in scores]))
mols = [Chem.MolFromSmiles("CCO"), Chem.MolFromSmiles("CC")]
scores = scorer.getScores(mols)
self.assertEqual(len(scores), len(mols))
self.assertTrue(all([isinstance(score, float) and score > 0 for score in scores]))
mols = ["CCO", "XXXX"] # test with invalid
scores = scorer.getScores(mols)
self.assertEqual(len(scores), len(mols))
[docs]class TrainingTestCase(TestCase):
# input file information
test_data_dir = os.path.join(os.path.dirname(__file__), 'test_data')
pretraining_file = os.path.join(test_data_dir, 'ZINC_raw_small.txt')
finetuning_file = os.path.join(test_data_dir, 'A2AR_raw_small.txt')
# global calculation settings
N_PROC = 2
N_EPOCHS = 2
SEED = 42
MAX_SMILES = 16
BATCH_SIZE = 8
# environment objectives (TODO: we should test more options and combinations here)
scorers = [
Property(
"MW",
modifier=ClippedScore(lower_x=1000, upper_x=500)
),
getPredictor()
]
thresholds = [0.5, 0.99]
[docs] def setUp(self):
self.monitor = TestModelMonitor()
[docs] def getTestEnvironment(self, scheme=None):
"""
Get the testing environment
Parameters
----------
scheme: RewardScheme
The reward scheme to use. If None, the default ParetoTanimotoDistance is used.
Returns
-------
DrugExEnvironment
"""
scheme = ParetoCrowdingDistance() if not scheme else scheme
return DrugExEnvironment(self.scorers, thresholds=self.thresholds, reward_scheme=scheme)
[docs] def getSmiles(self, _file):
"""
Read and standardize SMILES from a file.
Parameters
----------
_file: str
The file to read from (must be a .tsv file with a column named "CANONICAL_SMILES")
Returns
-------
list
The list of SMILES
"""
return self.standardize(pd.read_csv(_file, header=0, sep='\t')['CANONICAL_SMILES'].sample(self.MAX_SMILES, random_state=self.SEED).tolist())
[docs] def standardize(self, smiles):
"""
Standardize the input SMILES
Parameters
----------
smiles: list
The list of SMILES to standardize
Returns
-------
list
The list of standardized SMILES
"""
return Standardization(n_proc=self.N_PROC).apply(smiles)
[docs] def setUpSmilesFragData(self):
"""
Create inputs for the fragment-based SMILES models.
Returns
-------
tuple
The tuple of (pretraining training dataloader, pretraining test dataloader, finetuning training dataloader, finetuning test dataloader, vocabulary)
"""
pre_smiles = self.getSmiles(self.pretraining_file)
ft_smiles = self.getSmiles(self.finetuning_file)
# create and encode fragments
splitter = FragmentPairsSplitter(0.1, 1e4, seed=self.SEED)
encoder = FragmentCorpusEncoder(
fragmenter=Fragmenter(4, 4, 'brics'),
encoder=SequenceFragmentEncoder(
VocSmiles(True)
),
pairs_splitter=splitter,
n_proc=self.N_PROC
)
# get training data
pr_data_set_test = SmilesFragDataSet(self.getRandomFile())
pr_data_set_train = SmilesFragDataSet(self.getRandomFile())
encoder.apply(pre_smiles, encodingCollectors=[pr_data_set_test, pr_data_set_train])
ft_data_set_test = SmilesFragDataSet(self.getRandomFile())
ft_data_set_train = SmilesFragDataSet(self.getRandomFile())
encoder.apply(ft_smiles, encodingCollectors=[ft_data_set_test, ft_data_set_train])
# get vocabulary (we will join all generated vocabularies to make sure the one used to create data loaders contains all tokens)
vocabulary = pr_data_set_test.getVoc() + pr_data_set_train.getVoc() + ft_data_set_train.getVoc() + ft_data_set_test.getVoc()
pr_data_set_test.setVoc(vocabulary)
pr_data_set_train.setVoc(vocabulary)
ft_data_set_train.setVoc(vocabulary)
ft_data_set_test.setVoc(vocabulary)
pr_loader_train = pr_data_set_train.asDataLoader(self.BATCH_SIZE)
pr_loader_test = pr_data_set_test.asDataLoader(self.BATCH_SIZE)
# pr_loader_test = pr_data_set_test.asDataLoader(32, split_converter=SmilesFragDataSet.TargetCreator())
self.assertTrue(pr_loader_train)
self.assertTrue(pr_loader_test)
ft_loader_train = pr_data_set_train.asDataLoader(self.BATCH_SIZE)
ft_loader_test = pr_data_set_test.asDataLoader(self.BATCH_SIZE)
# ft_loader_test = pr_data_set_test.asDataLoader(self.BATCH_SIZE, split_converter=SmilesFragDataSet.TargetCreator())
self.assertTrue(ft_loader_train)
self.assertTrue(ft_loader_test)
return pr_loader_train, pr_loader_test, ft_loader_train, ft_loader_test, vocabulary
[docs] @staticmethod
def getRandomFile():
"""
Generate a random temporary file and return its path.
Returns
-------
str
The path to the temporary file
"""
return tempfile.NamedTemporaryFile().name
[docs] def fitTestModel(self, model, train_loader, test_loader):
"""
Fit a model and return the best model.
Parameters
----------
model: Model
The model to fit
train_loader: DataLoader
The training data loader
test_loader: DataLoader
The test data loader
Returns
-------
tuple
The tuple of (fitted model, monitor)
"""
monitor = TestModelMonitor()
model.fit(train_loader, test_loader, epochs=self.N_EPOCHS, monitor=monitor)
pr_model = monitor.getModel()
self.assertTrue(type(pr_model) == OrderedDict)
self.assertTrue(monitor.allMethodsExecuted())
model.loadStates(pr_model) # initialize from the best state
return model, monitor
[docs] def test_sequence_rnn(self):
"""
Test sequence RNN.
"""
pre_smiles = self.getSmiles(self.pretraining_file)
ft_smiles = self.getSmiles(self.finetuning_file)
# get training data
encoder = CorpusEncoder(
SequenceCorpus,
{
'vocabulary': VocSmiles(False)
},
n_proc=self.N_PROC
)
pre_data_set = SmilesDataSet(self.getRandomFile())
encoder.apply(pre_smiles, pre_data_set)
ft_data_set = SmilesDataSet(self.getRandomFile())
encoder.apply(ft_smiles, ft_data_set)
# get common vocabulary
vocabulary = pre_data_set.getVoc() + ft_data_set.getVoc()
# pretraining
splitter = RandomTrainTestSplitter(0.1)
pr_loader_train, pr_loader_test = pre_data_set.asDataLoader(self.BATCH_SIZE, splitter=splitter)
self.assertTrue(pr_loader_train)
self.assertTrue(pr_loader_test)
pretrained = SequenceRNN(vocabulary, is_lstm=True)
pretrained, monitor = self.fitTestModel(pretrained, pr_loader_train, pr_loader_test)
# fine-tuning
splitter = RandomTrainTestSplitter(0.1)
ft_loader_train, ft_loader_test = ft_data_set.asDataLoader(self.BATCH_SIZE, splitter=splitter)
self.assertTrue(ft_loader_train)
self.assertTrue(ft_loader_test)
finetuned = SequenceRNN(vocabulary, is_lstm=True)
finetuned.loadStates(pretrained.getModel())
finetuned, monitor = self.fitTestModel(finetuned, ft_loader_train, ft_loader_test)
# RL
environment = self.getTestEnvironment()
explorer = SequenceExplorer(pretrained, env=environment, mutate=finetuned, crover=pretrained, n_samples=10)
monitor = TestModelMonitor()
explorer.fit(ft_loader_train, ft_loader_test, monitor=monitor, epochs=self.N_EPOCHS)
self.assertTrue(type(monitor.getModel()) == OrderedDict)
self.assertTrue(monitor.allMethodsExecuted())
pretrained.generate(num_samples=10, evaluator=environment, drop_invalid=False)