drugex.training package

Subpackages

Submodules

drugex.training.environment module

environment

Created by: Martin Sicho On: 06.06.22, 16:51

class drugex.training.environment.DrugExEnvironment(scorers, thresholds=None, reward_scheme=None)[source]

Bases: Environment

Original implementation of the environment scoring strategy for DrugEx v3.

getScores(smiles, frags=None, no_multifrag_smiles=True)[source]

This method is used to get the scores from the scorers and to check molecule validity and desireability.

Parameters:
  • smiles (list of str) – List of SMILES strings to score.

  • frags (list of str, optional) – List of SMILES strings of fragments to check for.

  • no_multifrag_smiles (bool, optional) – Whether to check for SMILES strings that contain more than one fragment.

Returns:

preds – Dataframe with the scores from the scorers and the validity and desireability of the molecules.

Return type:

pd.DataFrame

getUnmodifiedScores(smiles)[source]

This method is used to get the scores from the scorers without any modifications.

Parameters:

smiles (list of str) – List of SMILES strings to score.

Returns:

preds – Dataframe with the scores from the scorers.

Return type:

pd.DataFrame

drugex.training.interfaces module

interfaces

Created by: Martin Sicho On: 01.06.22, 11:29

class drugex.training.interfaces.Environment(scorers, thresholds=None, reward_scheme=None)[source]

Bases: ModelEvaluator

Definition of the generic environment class for DrugEx. Reference implementation is DrugExEnvironment.

getRewards(smiles, frags=None)[source]

Calculate the single value as the reward for each molecule used for reinforcement learning.

Parameters:
  • smiles (list) – List of SMILES strings of the molecules to score.

  • frags (list, optional) – List of SMILES strings of the fragments used to generate the molecules.

Returns:

Array of rewards for the molecules.

Return type:

np.ndarray

getScorerKeys()[source]

Get the keys of the scorers.

Returns:

List of keys of the scorers.

Return type:

list

abstract getScores(smiles, frags=None)[source]

Calculate the scores of all objectives per molecule and qualify generated molecules (valid, accurate, desired).

Parameters:
  • smiles (list) – List of SMILES strings of the molecules to score.

  • frags (list, optional) – List of SMILES strings of the fragments used to generate the molecules.

Returns:

A DataFrame with the scores and qualifications for the molecules.

Return type:

pd.DataFrame

abstract getUnmodifiedScores(smiles)[source]

Calculate the scores without applying modifiers of all objectives per molecule.

Parameters:

smiles (list) – List of SMILES strings of the molecules to score.

Returns:

A DataFrame with the scores for the molecules.

Return type:

pd.DataFrame

class drugex.training.interfaces.Model(device=device(type='cuda'), use_gpus=(0,))[source]

Bases: Module, ModelProvider, ABC

Generic base class for all PyTorch models in DrugEx. Manages the GPU or CPU gpus available to the model.

abstract attachToGPUs(gpus)[source]

Use this method to handle a request to change the used GPUs. This method is automatically called when the class is instantiated, but may need to be called again in subclasses to move all data to the required devices.

Subclasses should also make sure to set “self.device” to the currently used device and “self.gpus” to GPU ids of the currently used GPUs

Parameters:

gpus (tuple) – Tuple of GPU ids to use.

abstract fit(train_loader, valid_loader, epochs=1000, monitor=None, **kwargs)[source]

Train and validate the model with a given training and validation loader (see DataSet and its implementations docs to learn how to generate them).

Parameters:
  • train_loader (torch.utils.data.DataLoader) – The training data loader.

  • valid_loader (torch.utils.data.DataLoader) – The validation data loader.

  • epochs (int, optional) – The number of epochs to train the model for.

  • monitor (TrainingMonitor, optional) – A TrainingMonitor instance to monitor the training process.

  • **kwargs – Additional keyword arguments to pass to the training loop.

loadStates(state_dict, strict=True)[source]

Load the model states from a dictionary.

Parameters:
  • state_dict (dict) – The dictionary containing the model states.

  • strict (bool, optional) – Whether to raise an error if the dictionary contains keys that do not match the model.

loadStatesFromFile(path)[source]

Load the model states from a file.

Parameters:

path (str) – The path to the file containing the model states.

updateDevices(device, gpus)[source]

Update the device and GPUs used by the model.

Parameters:
  • device (torch.device) – The device to use for the model.

  • gpus (list) – List of GPUs to use for the model.

class drugex.training.interfaces.ModelEvaluator[source]

Bases: ABC

A simple function to score a model based on the generated molecules and input fragments if applicable.

class drugex.training.interfaces.ModelProvider[source]

Bases: ABC

Any instance that contains a DrugEx Model or its serialized form (i.e a state dictionary).

abstract getModel()[source]

Return the current model as a Model instance or in serialized form.

Returns:

The current model or its serialized form.

Return type:

Model or dict

class drugex.training.interfaces.RewardScheme[source]

Bases: ABC

Reward scheme that enables ranking of molecules based on the calculated objectives and other criteria.

exception RewardException[source]

Bases: Exception

Exception to catch errors in the calculation of rewards.

class drugex.training.interfaces.TrainingMonitor[source]

Bases: ModelProvider, ABC

Interface used to monitor model training.

abstract close()[source]

Close this monitor. Training has finished.

abstract endStep(step, epoch)[source]

Notify the monitor that a step of the training has finished.

Parameters:
  • step (int) – The current training step (i.e. batch).

  • epoch (int) – The current epoch.

abstract saveModel(model)[source]

Save the state dictionary of the Model instance currently being trained or serialize the model any other way.

Parameters:

model (Model) – The model to save.

abstract savePerformanceInfo(performance_dict, df_smiles=None)[source]

Save the performance data for the current epoch.

Parameters:
  • performance_dict (dict) – A dictionary with the performance data.

  • df_smiles (pd.DataFrame) – A DataFrame with the SMILES of the molecules generated in the current epoch.

abstract saveProgress(current_step=None, current_epoch=None, total_steps=None, total_epochs=None, *args, **kwargs)[source]

Notifies the monitor of the current progress of the training.

Parameters:
  • current_step (int, optional) – The current training step (i.e. batch).

  • current_epoch (int, optional) – The current epoch.

  • total_steps (int, optional) – The total number of training steps.

  • total_epochs (int, optional) – The total number of epochs.

  • *args – Additional arguments depending on the model type.

  • **kwargs – Additional keyword arguments depending on the model type.

drugex.training.monitors module

monitors

Created by: Martin Sicho On: 02.06.22, 13:59

class drugex.training.monitors.FileMonitor(path, save_smiles=False, reset_directory=False)[source]

Bases: TrainingMonitor

A simple TrainingMonitor implementation with file outputs.

static appendTableToFile(df, outfile)[source]
close()[source]

Close this monitor. Training has finished.

endStep(step, epoch)[source]

Notify the monitor that a step of the training has finished.

Parameters:
  • step (int) – The current training step (i.e. batch).

  • epoch (int) – The current epoch.

getModel()[source]

Return the current model as a Model instance or in serialized form.

Returns:

The current model or its serialized form.

Return type:

Model or dict

saveEpochData(df)[source]
saveModel(model)[source]

Save the model state.

saveMolecules(df)[source]
savePerformanceInfo(performance_dict, df_smiles=None)[source]

Save the performance data for the current epoch.

Parameters:
  • performance_dict (dict) – A dictionary with the performance data.

  • df_smiles (pd.DataFrame) – A DataFrame with the SMILES of the molecules generated in the current epoch.

saveProgress(current_step=None, current_epoch=None, total_steps=None, total_epochs=None, loss=None, *args, **kwargs)[source]

Save the current training progress: epoch, step, loss.

Parameters:
  • current_step (int) – The current step.

  • current_epoch (int) – The current epoch.

  • total_steps (int) – The total number of steps.

  • total_epochs (int) – The total number of epochs.

  • loss (float) – The current training loss.

class drugex.training.monitors.NullMonitor[source]

Bases: TrainingMonitor

close()[source]

Close this monitor. Training has finished.

endStep(step, epoch)[source]

Notify the monitor that a step of the training has finished.

Parameters:
  • step (int) – The current training step (i.e. batch).

  • epoch (int) – The current epoch.

getModel()[source]

Return the current model as a Model instance or in serialized form.

Returns:

The current model or its serialized form.

Return type:

Model or dict

saveModel(model)[source]

Save the state dictionary of the Model instance currently being trained or serialize the model any other way.

Parameters:

model (Model) – The model to save.

savePerformanceInfo(performance_dict, df_smiles=None)[source]

Save the performance data for the current epoch.

Parameters:
  • performance_dict (dict) – A dictionary with the performance data.

  • df_smiles (pd.DataFrame) – A DataFrame with the SMILES of the molecules generated in the current epoch.

saveProgress(current_step=None, current_epoch=None, total_steps=None, total_epochs=None, *args, **kwargs)[source]

Notifies the monitor of the current progress of the training.

Parameters:
  • current_step (int, optional) – The current training step (i.e. batch).

  • current_epoch (int, optional) – The current epoch.

  • total_steps (int, optional) – The total number of training steps.

  • total_epochs (int, optional) – The total number of epochs.

  • *args – Additional arguments depending on the model type.

  • **kwargs – Additional keyword arguments depending on the model type.

drugex.training.rewards module

rewards

Created by: Martin Sicho On: 26.06.22, 18:07

class drugex.training.rewards.ParetoCrowdingDistance[source]

Bases: ParetoRewardScheme

Reward scheme that uses the NSGA-II crowding distance ranking strategy to rank the solutions in the same Pareto frontier.

Paper: Deb, Kalyanmoy, et al. “A fast and elitist multiobjective genetic algorithm: NSGA-II. IEEE transactions on evolutionary computation 6.2 (2002): 182-197.”

getMoleculeRank(fronts, smiles=None, scores=None)[source]

Crowding distance algorithm to rank the solutions in the same pareto frontier.

Parameters:
  • fronts (list) – list of Pareto fronts. Each front is a list of indices of the molecules in the Pareto front.

  • smiles (list) – List of SMILES sequence to be ranked (not used in the calculation -> just a requirement of the interface because some ranking strategies need it)”

  • scores (np.ndarray) – matrix of scores for the multiple objectives

Returns:

rank – Indices of the SMILES sequences ranked with the NSGA-II crowding distance method from worst to best

Return type:

np.array

class drugex.training.rewards.ParetoRewardScheme[source]

Bases: RewardScheme, ABC

abstract getMoleculeRank(fronts, smiles=None, scores=None)[source]

Ranks molecules within each Pareto front and returns the indices of the molecules in the ranked order.

Parameters:
  • fronts (list) – list of Pareto fronts. Each front is a list of indices of the molecules in the Pareto front.

  • smiles (list) – List of SMILES sequence to be ranked

  • scores (np.ndarray) – matrix of scores for the multiple objectives

Returns:

rank – Indices of the ranked SMILES sequences

Return type:

np.array

getParetoFronts(scores)[source]

Returns Pareto fronts.

Parameters:

scores (np.ndarray) – Matrix of scores for the multiple objectives

Returns:

list of Pareto fronts. Each front is a list of indices of the molecules in the Pareto front. Most dominant front is the first one.

Return type:

list

class drugex.training.rewards.ParetoTanimotoDistance(distance_metric: str = 'min')[source]

Bases: ParetoRewardScheme

Reward scheme that uses the Tanimoto distance ranking strategy to rank the solutions in the same Pareto frontier.

static calc_fps(mols, fp_type='ECFP6')[source]

Calculate fingerprints for a list of molecules.

Parameters:
  • mols (list) – List of RDKit molecules

  • fp_type (str) – Type of fingerprint to use

Returns:

fps – List of RDKit fingerprints

Return type:

list

getFPs(smiles)[source]

Calculate fingerprints for a list of molecules.

Parameters:

smiles – smiles to calculate fingerprints for

Returns:

list of RDKit fingerprints

getMoleculeRank(fronts, smiles=None, scores=None)[source]

Get the rank of the molecules in the Pareto front based on the Tanimoto distance of molecules in a front.

Parameters:
  • fronts (list) – List of Pareto fronts

  • smiles (list) – List of SMILES sequence to be ranked

  • scores (np.ndarray) – Array of scores for each molecule (not used)

Returns:

rank – List of indices of molecules, ranked from worst to best

Return type:

list

class drugex.training.rewards.WeightedSum[source]

Bases: RewardScheme

Reward scheme that uses the weighted sum ranking strategy to rank the solutions.

drugex.training.tests module

tests

Created by: Martin Sicho On: 31.05.22, 10:20

class drugex.training.tests.MockScorer(modifier=None)[source]

Bases: Scorer

getKey()[source]
getScores(mols, frags=None)[source]

Returns the raw scores for the input molecules.

Parameters:
  • mols (list of rdkit molecules) – The molecules to be scored.

  • frags (list of rdkit molecules, optional) – The fragments used to generate the molecules, by default None.

Returns:

scores – The scores for the molecules.

Return type:

np.array

class drugex.training.tests.TestModelMonitor(submonitors=None)[source]

Bases: TrainingMonitor

allMethodsExecuted()[source]
close()[source]

Close this monitor. Training has finished.

endStep(step, epoch)[source]

Notify the monitor that a step of the training has finished.

Parameters:
  • step (int) – The current training step (i.e. batch).

  • epoch (int) – The current epoch.

getModel()[source]

Return the current model as a Model instance or in serialized form.

Returns:

The current model or its serialized form.

Return type:

Model or dict

passToSubmonitors(method, *args, **kwargs)[source]
saveModel(model)[source]

Save the state dictionary of the Model instance currently being trained or serialize the model any other way.

Parameters:

model (Model) – The model to save.

savePerformanceInfo(performance_dict, df_smiles=None)[source]

Save the performance data for the current epoch.

Parameters:
  • performance_dict (dict) – A dictionary with the performance data.

  • df_smiles (pd.DataFrame) – A DataFrame with the SMILES of the molecules generated in the current epoch.

saveProgress(current_step=None, current_epoch=None, total_steps=None, total_epochs=None, *args, **kwargs)[source]

Notifies the monitor of the current progress of the training.

Parameters:
  • current_step (int, optional) – The current training step (i.e. batch).

  • current_epoch (int, optional) – The current epoch.

  • total_steps (int, optional) – The total number of training steps.

  • total_epochs (int, optional) – The total number of epochs.

  • *args – Additional arguments depending on the model type.

  • **kwargs – Additional keyword arguments depending on the model type.

class drugex.training.tests.TestScorer(methodName='runTest')[source]

Bases: TestCase

test_getScores()[source]
class drugex.training.tests.TrainingTestCase(methodName='runTest')[source]

Bases: TestCase

BATCH_SIZE = 8
MAX_SMILES = 16
N_EPOCHS = 2
N_PROC = 2
SEED = 42
finetuning_file = '/home/sichom/projects/DrugEx/drugex/training/test_data/A2AR_raw_small.txt'
fitTestModel(model, train_loader, test_loader)[source]

Fit a model and return the best model.

Parameters:
  • model (Model) – The model to fit

  • train_loader (DataLoader) – The training data loader

  • test_loader (DataLoader) – The test data loader

Returns:

The tuple of (fitted model, monitor)

Return type:

tuple

static getRandomFile()[source]

Generate a random temporary file and return its path.

Returns:

The path to the temporary file

Return type:

str

getSmiles(_file)[source]

Read and standardize SMILES from a file.

Parameters:

_file (str) – The file to read from (must be a .tsv file with a column named “CANONICAL_SMILES”)

Returns:

The list of SMILES

Return type:

list

getTestEnvironment(scheme=None)[source]

Get the testing environment

Parameters:

scheme (RewardScheme) – The reward scheme to use. If None, the default ParetoTanimotoDistance is used.

Return type:

DrugExEnvironment

pretraining_file = '/home/sichom/projects/DrugEx/drugex/training/test_data/ZINC_raw_small.txt'
scorers = [<drugex.training.scorers.properties.Property object>, <drugex.training.tests.MockScorer object>]
setUp()[source]

Hook method for setting up the test fixture before exercising it.

setUpSmilesFragData()[source]

Create inputs for the fragment-based SMILES models.

Returns:

The tuple of (pretraining training dataloader, pretraining test dataloader, finetuning training dataloader, finetuning test dataloader, vocabulary)

Return type:

tuple

standardize(smiles)[source]

Standardize the input SMILES

Parameters:

smiles (list) – The list of SMILES to standardize

Returns:

The list of standardized SMILES

Return type:

list

test_data_dir = '/home/sichom/projects/DrugEx/drugex/training/test_data'
test_graph_transformer()[source]

Test fragment-based graph transformer model.

test_graph_transformer_scaffold()[source]

Test RL with fragment-based graph transformer model with scaffold input.

test_sequence_rnn()[source]

Test sequence RNN.

test_sequence_transformer()[source]

Test fragment-based sequence transformer model.

test_sequence_transformer_scaffold()[source]

Test RL with fragment-based sequence transformer model with scaffold input.

thresholds = [0.5, 0.99]
drugex.training.tests.getPredictor()[source]

Module contents

__init__.py

Created by: Martin Sicho On: 31.05.22, 10:20