drugex.training.generators package
Submodules
drugex.training.generators.graph_transformer module
- class drugex.training.generators.graph_transformer.AtomLayer(d_model=512, n_head=8, d_inner=1024, n_layer=12)[source]
Bases:
Module
- forward(x: Tensor, key_mask=None, attn_mask=None)[source]
Defines the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Module
instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
- class drugex.training.generators.graph_transformer.Block(d_model, n_head, d_inner)[source]
Bases:
Module
- forward(x, key_mask=None, attn_mask=None)[source]
Defines the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Module
instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
- class drugex.training.generators.graph_transformer.GraphTransformer(voc_trg, d_emb=512, d_model=512, n_head=8, d_inner=1024, n_layer=12, pad_idx=0, device=device(type='cuda'), use_gpus=(0,))[source]
Bases:
FragGenerator
Graph Transformer for molecule generation from fragments
- attachToGPUs(gpus)[source]
Attach model to GPUs
Parameters:
- gpus:
tuple
A tuple of GPU ids to use
Returns:
None
- gpus:
- forward(src, is_train=False)[source]
Forward pass
Parameters:
- src:
torch.Tensor
Input tensor of shape [batch_size, 80, 5] (transpose of the encoded graphs as drawn in the paper)
- is_train:
bool
Whether the model is in training mode
Returns:
TODO : fill outputs
- src:
- init_states()[source]
Initialize model parameters
Notes:
Xavier initialization for all parameters except for the embedding layer
- loaderFromFrags(frags, batch_size=32, n_proc=1)[source]
Encode the input fragments and create a dataloader object
Parameters:
Returns:
- loader:
torch.utils.data.DataLoader
A dataloader object to iterate over the input fragments
- loader:
- sample(loader)[source]
Sample SMILES from the network
- Parameters:
loader (torch.utils.data.DataLoader) – The data loader for the input fragments
- Returns:
smiles (list) – List of SMILES
frags (list) – List of fragments
- validateNet(loader, evaluator=None, no_multifrag_smiles=True, n_samples=None)[source]
Validate the network
- Parameters:
- Returns:
valid_metrics (dict) – Dictionary containing the validation metrics
scores (pandas.DataFrame) – DataFrame containing Smiles, frags and the scores for each SMILES
Notes
- The validation metrics are:
valid_ratio: the ratio of valid SMILES
accurate_ratio: the ratio of SMILES that are valid and have the desired fragments
loss_valid: the validation loss
drugex.training.generators.interfaces module
interfaces
Created by: Martin Sicho On: 01.06.22, 11:29
- class drugex.training.generators.interfaces.FragGenerator(device=device(type='cuda'), use_gpus=(0,))[source]
Bases:
Generator
A generator for fragment-based molecules.
- generate(input_frags: List[str] | None = None, input_dataset: DataSet | None = None, num_samples=100, batch_size=32, n_proc=1, keep_frags=True, drop_duplicates=True, drop_invalid=True, evaluator=None, no_multifrag_smiles=True, drop_undesired=False, raw_scores=True, progress=True, tqdm_kwargs={})[source]
Generate SMILES from either a list of input fragments (
input_frags
) or a dataset object directly (input_dataset
). You have to specify either one or the other. Various other options are available to filter, score and show generation progress (see below).- Parameters:
input_frags (list) – a
list
of input fragments to incorporate in the (as molecules in SMILES format)input_dataset (GraphFragDataSet) – a
GraphFragDataSet
object to use to provide the input fragmentsnum_samples – the number of SMILES to generate, default is 100
batch_size – the batch size to use for generation, default is 32
n_proc – the number of processes to use for encoding the fragments if
input_frags
is provided, default is 1keep_frags – if
True
, the fragments are kept in the generated SMILES, default isTrue
drop_duplicates – if
True
, duplicate SMILES are dropped, default isTrue
drop_invalid – if
True
, invalid SMILES are dropped, default isTrue
evaluator (Environment) – an
Environment
object to score the generated SMILES against, ifNone
, no scoring is performed, is required ifdrop_undesired
isTrue
, default isNone
no_multifrag_smiles – if
True
, only single-fragment SMILES are considered valid, default isTrue
drop_undesired – if
True
, SMILES that do not contain the desired fragments are dropped, default isFalse
raw_scores – if
True
, raw scores (without modifiers) are calculated ifevaluator
is specified, these values are also used for filtering ifdrop_undesired
isTrue
, default forraw_scores
isTrue
progress – if
True
, a progress bar is shown, default isTrue
tqdm_kwargs – keyword arguments to pass to the
tqdm
progress bar, default is an emptydict
Returns:
- init_states()[source]
Initialize model parameters
Notes:
Xavier initialization for all parameters except for the embedding layer
- class drugex.training.generators.interfaces.Generator(device=device(type='cuda'), use_gpus=(0,))[source]
-
The base generator class for fitting and evaluating a DrugEx generator.
- evaluate(smiles: List[str], frags: List[str] | None = None, evaluator=None, no_multifrag_smiles: bool = True, unmodified_scores: bool = False)[source]
Evaluate molecules by using the given evaluator or checking for validity.
Parameters:
- smiles: List
List of SMILES to evaluate
- frags: List
List of fragments used to generate the SMILES
- evaluator: Environement
An
Environement
instance used to evaluate the molecules- no_multifrag_smiles: bool
If
True
, only single-fragment SMILES are considered valid- unmodified_scores: bool
If
True
, the scores are not modified by the evaluator
- returns:
scores – A
DataFrame
with the scores for each molecule- rtype:
DataFrame
- filterNewMolecules(df_old, df_new, with_frags=True, drop_duplicates=True, drop_undesired=True, evaluator=None, no_multifrag_smiles=True)[source]
Filter the generated SMILES
Parameters:
- smiles:
list
A list of previous SMILES
- new_smiles:
list
A list of additional generated SMILES
- frags:
list
A list of additional input fragments
- drop_duplicates:
bool
If
True
, duplicate SMILES are dropped- drop_undesired:
bool
If
True
, SMILES that do not fulfill the desired objectives- evaluator:
Evaluator
An evaluator object to evaluate the generated SMILES
- no_multifrag_smiles:
bool
If
True
, only single-fragment SMILES are considered valid
Returns:
- smiles:
- fit(train_loader, valid_loader, epochs=100, patience=50, evaluator=None, monitor=None, no_multifrag_smiles=True)[source]
Fit the generator.
- Parameters:
train_loader (DataLoader) – a
DataLoader
instance to use for trainingvalid_loader (DataLoader) – a
DataLoader
instance to use for validationepochs (int) – the number of epochs to train for
patience (int) – the number of epochs to wait for improvement before early stopping
evaluator (ModelEvaluator) – a
ModelEvaluator
instance to use for validation TODO: maybe the evaluator should be hard coded to None here as during PT/FT training we don’t need itmonitor (Monitor) – a
Monitor
instance to use for saving the model and performance infono_multifrag_smiles (bool) – if
True
, only single-fragment SMILES are considered valid
- abstract generate(*args, **kwargs)[source]
Generate molecules from the generator.
- Returns:
df_smiles – a
DataFrame
with the generated molecules (and their scores)- Return type:
DataFrame
- getModel()[source]
Return a copy of this model as a state dictionary.
- Returns:
model – A serializable copy of this model as a state dictionary
- Return type:
- logPerformanceAndCompounds(epoch, metrics, scores)[source]
Log performance and compounds
Parameters:
- abstract sample(*args, **kwargs)[source]
Samples molcules from the generator.
- Returns:
smiles (List) – List of SMILES strings
frags (List, optional) – List of fragments used to generate the molecules
- abstract validateNet(loader=None, evaluator=None, no_multifrag_smiles=True, n_samples=None)[source]
Validate the performance of the generator.
- Parameters:
loader (DataLoader) – a
DataLoader
instance to use for validation.evaluator (ModelEvaluator) – a
ModelEvaluator
instance to use for validationno_multifrag_smiles (bool) – if
True
, only single-fragment SMILES are considered validn_samples (int) – the number of samples to use for validation. Not used by transformers.
- Returns:
valid_metrics (dict) – a dictionary with the validation metrics
smiles_scores (DataFrame) – a
DataFrame
with the scores for each molecule
drugex.training.generators.sequence_rnn module
- class drugex.training.generators.sequence_rnn.SequenceRNN(voc, embed_size=128, hidden_size=512, is_lstm=True, lr=0.001, device=device(type='cuda'), use_gpus=(0,))[source]
Bases:
Generator
Sequence RNN model for molecule generation.
- attachToGPUs(gpus)[source]
This model currently uses only one GPU. Therefore, only the first one from the list will be used.
Parameters:
- gpus:
tuple
A tuple of GPU indices.
Returns:
None
- gpus:
- evolve(batch_size, epsilon=0.01, crover=None, mutate=None)[source]
Evolve a SMILES from the model by sequantial addition of tokens.
Parameters:
Returns:
TODO: check if ouput smiles are still encoded
- forward(input, h)[source]
Forward pass of the model.
Parameters:
- input:
torch.Tensor
Input tensor of shape (batch_size, 1).
- h:
torch.Tensor
# TODO: Verify h shape. Hidden state tensor of shape (num_layers, batch_size, hidden_size).
Returns:
TODO: fill outputs
- input:
- generate(num_samples=100, batch_size=32, n_proc=1, drop_duplicates=True, drop_invalid=True, evaluator=None, no_multifrag_smiles=True, drop_undesired=False, raw_scores=True, progress=True, tqdm_kwargs={})[source]
Generate molecules from the generator.
- Returns:
df_smiles – a
DataFrame
with the generated molecules (and their scores)- Return type:
DataFrame
- init_h(batch_size, labels=None)[source]
Initialize hidden state of the model.
Hidden state is initialized with random values. If labels are provided, the first hidden state will be set to the labels.
Parameters:
- batch_size:
int
Batch size.
- labels:
torch.Tensor
Labels tensor of shape (batch_size, 1).
Returns:
TODO: fill outputs
- batch_size:
- likelihood(target)[source]
Calculate the likelihood of the target sequence.
Parameters:
- target:
torch.Tensor
Target tensor of shape (batch_size, seq_len).
Returns:
- scores:
torch.Tensor
Scores tensor of shape (batch_size, seq_len).
- target:
- sample(batch_size)[source]
Sample a SMILES from the model.
Parameters:
- batch_size:
int
Batch size.
Returns:
- smiles:
list
List of SMILES.
- batch_size:
- trainNet(loader, epoch, epochs)[source]
Train the RNN network for one epoch
Parameters:
- loadertorch.utils.data.DataLoader
The data loader for the training set
- epochint
The current epoch
- epochsint
The total number of epochs
- returns:
loss – The training loss of the epoch
- rtype:
float
- validateNet(loader=None, evaluator=None, no_multifrag_smiles=True, n_samples=128)[source]
Validate the network
- Parameters:
loader (torch.utils.data.DataLoader) – A dataloader object to iterate over the validation data to compute the validation loss
evaluator (Evaluator) – An evaluator object to evaluate the generated SMILES
no_multifrag_smiles (bool) – If
True
, only single-fragment SMILES are considered validn_samples (int) – The number of SMILES to sample from the model
- Returns:
valid_metrics (dict) – Dictionary containing the validation metrics
scores (pandas.DataFrame) – DataFrame containing Smiles, frags and the scores for each SMILES
Notes
- The validation metrics are:
valid_ratio: the ratio of valid SMILES
accurate_ratio: the ratio of SMILES that are valid and have the desired fragments
loss_valid: the validation loss
drugex.training.generators.sequence_transformer module
- class drugex.training.generators.sequence_transformer.Block(d_model, n_head, d_inner)[source]
Bases:
Module
- forward(x, key_mask=None, atn_mask=None)[source]
Defines the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Module
instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
- class drugex.training.generators.sequence_transformer.GPT2Layer(voc, d_emb=512, d_model=512, n_head=12, d_inner=1024, n_layer=12, pad_idx=0)[source]
Bases:
Module
- forward(input: Tensor, key_mask=None, atn_mask=None)[source]
Defines the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Module
instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
- class drugex.training.generators.sequence_transformer.SequenceTransformer(voc_trg, d_emb=512, d_model=512, n_head=8, d_inner=1024, n_layer=12, pad_idx=0, device=device(type='cuda'), use_gpus=(0,))[source]
Bases:
FragGenerator
Sequence Transformer for molecule generation from fragments
- forward(src, trg=None)[source]
Forward pass of the model
Parameters:
- src:
torch.Tensor
TODO: check that the shape is correct Source tensor of shape [batch_size, 200]
- trg:
torch.Tensor
Target tensor of shape [batch_size, 200]
Returns:
TODO: fill outputs
- src:
- loaderFromFrags(frags, batch_size=32, n_proc=1)[source]
Encode the input fragments and create a dataloader object
Parameters:
Returns:
- loader:
torch.utils.data.DataLoader
A dataloader object to iterate over the input fragments
- loader:
- sample(loader)[source]
Sample SMILES from the model
Parameters:
- loader:
torch.utils.data.DataLoader
A dataloader object to iterate over the input fragments
Returns:
- loader:
- trainNet(loader, epoch, epochs)[source]
Train the model for one epoch
Parameters:
Returns:
- loss:
float
The loss value for the current epoch
- loss:
- validateNet(loader, evaluator=None, no_multifrag_smiles=True, n_samples=None)[source]
Validate the model
Parameters:
Returns:
- valid_metrics:
dict
A dictionary containing the validation metrics
- scores:
pandas.DataFrame
DataFrame containing Smiles, frags and the scores for each SMILES
Notes:
- The validation metrics are:
valid_ratio: ratio of valid SMILES
accurate_ratio: ratio of SMILES that are valid and have the desired fragments
loss_valid: loss on the validation set
- valid_metrics:
drugex.training.generators.utils module
Define the Layers
- class drugex.training.generators.utils.PositionalEmbedding(d_model: int, max_len=100, batch_first=False)[source]
Bases:
Module
Positional embedding for sequence transformer
- forward(x)[source]
Defines the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Module
instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
- class drugex.training.generators.utils.PositionalEncoding(d_model: int, max_len=100, batch_first=False)[source]
Bases:
Module
Positional encoding for graph transformer
- forward(x)[source]
Defines the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Module
instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
- class drugex.training.generators.utils.PositionwiseFeedForward(d_in, d_hid)[source]
Bases:
Module
A two-feed-forward-layer module
- forward(x)[source]
Defines the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Module
instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
- class drugex.training.generators.utils.SublayerConnection(size, dropout=0.1)[source]
Bases:
Module
A residual connection followed by a layer norm