drugex.training.generators package
Submodules
drugex.training.generators.graph_transformer module
- class drugex.training.generators.graph_transformer.AtomLayer(d_model=512, n_head=8, d_inner=1024, n_layer=12)[source]
Bases:
Module- forward(x: Tensor, key_mask=None, attn_mask=None)[source]
Defines the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Moduleinstance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
- class drugex.training.generators.graph_transformer.Block(d_model, n_head, d_inner)[source]
Bases:
Module- forward(x, key_mask=None, attn_mask=None)[source]
Defines the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Moduleinstance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
- class drugex.training.generators.graph_transformer.GraphTransformer(voc_trg, d_emb=512, d_model=512, n_head=8, d_inner=1024, n_layer=12, pad_idx=0, device=device(type='cuda'), use_gpus=(0,))[source]
Bases:
FragGeneratorGraph Transformer for molecule generation from fragments
- attachToGPUs(gpus)[source]
Attach model to GPUs
Parameters:
- gpus:
tuple A tuple of GPU ids to use
Returns:
None
- gpus:
- forward(src, is_train=False)[source]
Forward pass
Parameters:
- src:
torch.Tensor Input tensor of shape [batch_size, 80, 5] (transpose of the encoded graphs as drawn in the paper)
- is_train:
bool Whether the model is in training mode
Returns:
TODO : fill outputs
- src:
- init_states()[source]
Initialize model parameters
Notes:
Xavier initialization for all parameters except for the embedding layer
- loaderFromFrags(frags, batch_size=32, n_proc=1)[source]
Encode the input fragments and create a dataloader object
Parameters:
Returns:
- loader:
torch.utils.data.DataLoader A dataloader object to iterate over the input fragments
- loader:
- sample(loader)[source]
Sample SMILES from the network
- Parameters:
loader (torch.utils.data.DataLoader) – The data loader for the input fragments
- Returns:
smiles (list) – List of SMILES
frags (list) – List of fragments
- validateNet(loader, evaluator=None, no_multifrag_smiles=True, n_samples=None)[source]
Validate the network
- Parameters:
- Returns:
valid_metrics (dict) – Dictionary containing the validation metrics
scores (pandas.DataFrame) – DataFrame containing Smiles, frags and the scores for each SMILES
Notes
- The validation metrics are:
valid_ratio: the ratio of valid SMILES
accurate_ratio: the ratio of SMILES that are valid and have the desired fragments
loss_valid: the validation loss
drugex.training.generators.interfaces module
interfaces
Created by: Martin Sicho On: 01.06.22, 11:29
- class drugex.training.generators.interfaces.FragGenerator(device=device(type='cuda'), use_gpus=(0,))[source]
Bases:
GeneratorA generator for fragment-based molecules.
- generate(input_frags: List[str] | None = None, input_dataset: DataSet | None = None, num_samples=100, batch_size=32, n_proc=1, keep_frags=True, drop_duplicates=True, drop_invalid=True, evaluator=None, no_multifrag_smiles=True, drop_undesired=False, raw_scores=True, progress=True, tqdm_kwargs={})[source]
Generate SMILES from either a list of input fragments (
input_frags) or a dataset object directly (input_dataset). You have to specify either one or the other. Various other options are available to filter, score and show generation progress (see below).- Parameters:
input_frags (list) – a
listof input fragments to incorporate in the (as molecules in SMILES format)input_dataset (GraphFragDataSet) – a
GraphFragDataSetobject to use to provide the input fragmentsnum_samples – the number of SMILES to generate, default is 100
batch_size – the batch size to use for generation, default is 32
n_proc – the number of processes to use for encoding the fragments if
input_fragsis provided, default is 1keep_frags – if
True, the fragments are kept in the generated SMILES, default isTruedrop_duplicates – if
True, duplicate SMILES are dropped, default isTruedrop_invalid – if
True, invalid SMILES are dropped, default isTrueevaluator (Environment) – an
Environmentobject to score the generated SMILES against, ifNone, no scoring is performed, is required ifdrop_undesiredisTrue, default isNoneno_multifrag_smiles – if
True, only single-fragment SMILES are considered valid, default isTruedrop_undesired – if
True, SMILES that do not contain the desired fragments are dropped, default isFalseraw_scores – if
True, raw scores (without modifiers) are calculated ifevaluatoris specified, these values are also used for filtering ifdrop_undesiredisTrue, default forraw_scoresisTrueprogress – if
True, a progress bar is shown, default isTruetqdm_kwargs – keyword arguments to pass to the
tqdmprogress bar, default is an emptydict
Returns:
- init_states()[source]
Initialize model parameters
Notes:
Xavier initialization for all parameters except for the embedding layer
- class drugex.training.generators.interfaces.Generator(device=device(type='cuda'), use_gpus=(0,))[source]
-
The base generator class for fitting and evaluating a DrugEx generator.
- evaluate(smiles: List[str], frags: List[str] | None = None, evaluator=None, no_multifrag_smiles: bool = True, unmodified_scores: bool = False)[source]
Evaluate molecules by using the given evaluator or checking for validity.
Parameters:
- smiles: List
List of SMILES to evaluate
- frags: List
List of fragments used to generate the SMILES
- evaluator: Environement
An
Environementinstance used to evaluate the molecules- no_multifrag_smiles: bool
If
True, only single-fragment SMILES are considered valid- unmodified_scores: bool
If
True, the scores are not modified by the evaluator
- returns:
scores – A
DataFramewith the scores for each molecule- rtype:
DataFrame
- filterNewMolecules(df_old, df_new, with_frags=True, drop_duplicates=True, drop_undesired=True, evaluator=None, no_multifrag_smiles=True)[source]
Filter the generated SMILES
Parameters:
- smiles:
list A list of previous SMILES
- new_smiles:
list A list of additional generated SMILES
- frags:
list A list of additional input fragments
- drop_duplicates:
bool If
True, duplicate SMILES are dropped- drop_undesired:
bool If
True, SMILES that do not fulfill the desired objectives- evaluator:
Evaluator An evaluator object to evaluate the generated SMILES
- no_multifrag_smiles:
bool If
True, only single-fragment SMILES are considered valid
Returns:
- smiles:
- fit(train_loader, valid_loader, epochs=100, patience=50, evaluator=None, monitor=None, no_multifrag_smiles=True)[source]
Fit the generator.
- Parameters:
train_loader (DataLoader) – a
DataLoaderinstance to use for trainingvalid_loader (DataLoader) – a
DataLoaderinstance to use for validationepochs (int) – the number of epochs to train for
patience (int) – the number of epochs to wait for improvement before early stopping
evaluator (ModelEvaluator) – a
ModelEvaluatorinstance to use for validation TODO: maybe the evaluator should be hard coded to None here as during PT/FT training we don’t need itmonitor (Monitor) – a
Monitorinstance to use for saving the model and performance infono_multifrag_smiles (bool) – if
True, only single-fragment SMILES are considered valid
- abstract generate(*args, **kwargs)[source]
Generate molecules from the generator.
- Returns:
df_smiles – a
DataFramewith the generated molecules (and their scores)- Return type:
DataFrame
- getModel()[source]
Return a copy of this model as a state dictionary.
- Returns:
model – A serializable copy of this model as a state dictionary
- Return type:
- logPerformanceAndCompounds(epoch, metrics, scores)[source]
Log performance and compounds
Parameters:
- abstract sample(*args, **kwargs)[source]
Samples molcules from the generator.
- Returns:
smiles (List) – List of SMILES strings
frags (List, optional) – List of fragments used to generate the molecules
- abstract validateNet(loader=None, evaluator=None, no_multifrag_smiles=True, n_samples=None)[source]
Validate the performance of the generator.
- Parameters:
loader (DataLoader) – a
DataLoaderinstance to use for validation.evaluator (ModelEvaluator) – a
ModelEvaluatorinstance to use for validationno_multifrag_smiles (bool) – if
True, only single-fragment SMILES are considered validn_samples (int) – the number of samples to use for validation. Not used by transformers.
- Returns:
valid_metrics (dict) – a dictionary with the validation metrics
smiles_scores (DataFrame) – a
DataFramewith the scores for each molecule
drugex.training.generators.sequence_rnn module
- class drugex.training.generators.sequence_rnn.SequenceRNN(voc, embed_size=128, hidden_size=512, is_lstm=True, lr=0.001, device=device(type='cuda'), use_gpus=(0,))[source]
Bases:
GeneratorSequence RNN model for molecule generation.
- attachToGPUs(gpus)[source]
This model currently uses only one GPU. Therefore, only the first one from the list will be used.
Parameters:
- gpus:
tuple A tuple of GPU indices.
Returns:
None
- gpus:
- evolve(batch_size, epsilon=0.01, crover=None, mutate=None)[source]
Evolve a SMILES from the model by sequantial addition of tokens.
Parameters:
Returns:
TODO: check if ouput smiles are still encoded
- forward(input, h)[source]
Forward pass of the model.
Parameters:
- input:
torch.Tensor Input tensor of shape (batch_size, 1).
- h:
torch.Tensor # TODO: Verify h shape. Hidden state tensor of shape (num_layers, batch_size, hidden_size).
Returns:
TODO: fill outputs
- input:
- generate(num_samples=100, batch_size=32, n_proc=1, drop_duplicates=True, drop_invalid=True, evaluator=None, no_multifrag_smiles=True, drop_undesired=False, raw_scores=True, progress=True, tqdm_kwargs={})[source]
Generate molecules from the generator.
- Returns:
df_smiles – a
DataFramewith the generated molecules (and their scores)- Return type:
DataFrame
- init_h(batch_size, labels=None)[source]
Initialize hidden state of the model.
Hidden state is initialized with random values. If labels are provided, the first hidden state will be set to the labels.
Parameters:
- batch_size:
int Batch size.
- labels:
torch.Tensor Labels tensor of shape (batch_size, 1).
Returns:
TODO: fill outputs
- batch_size:
- likelihood(target)[source]
Calculate the likelihood of the target sequence.
Parameters:
- target:
torch.Tensor Target tensor of shape (batch_size, seq_len).
Returns:
- scores:
torch.Tensor Scores tensor of shape (batch_size, seq_len).
- target:
- sample(batch_size)[source]
Sample a SMILES from the model.
Parameters:
- batch_size:
int Batch size.
Returns:
- smiles:
list List of SMILES.
- batch_size:
- trainNet(loader, epoch, epochs)[source]
Train the RNN network for one epoch
Parameters:
- loadertorch.utils.data.DataLoader
The data loader for the training set
- epochint
The current epoch
- epochsint
The total number of epochs
- returns:
loss – The training loss of the epoch
- rtype:
float
- validateNet(loader=None, evaluator=None, no_multifrag_smiles=True, n_samples=128)[source]
Validate the network
- Parameters:
loader (torch.utils.data.DataLoader) – A dataloader object to iterate over the validation data to compute the validation loss
evaluator (Evaluator) – An evaluator object to evaluate the generated SMILES
no_multifrag_smiles (bool) – If
True, only single-fragment SMILES are considered validn_samples (int) – The number of SMILES to sample from the model
- Returns:
valid_metrics (dict) – Dictionary containing the validation metrics
scores (pandas.DataFrame) – DataFrame containing Smiles, frags and the scores for each SMILES
Notes
- The validation metrics are:
valid_ratio: the ratio of valid SMILES
accurate_ratio: the ratio of SMILES that are valid and have the desired fragments
loss_valid: the validation loss
drugex.training.generators.sequence_transformer module
- class drugex.training.generators.sequence_transformer.Block(d_model, n_head, d_inner)[source]
Bases:
Module- forward(x, key_mask=None, atn_mask=None)[source]
Defines the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Moduleinstance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
- class drugex.training.generators.sequence_transformer.GPT2Layer(voc, d_emb=512, d_model=512, n_head=12, d_inner=1024, n_layer=12, pad_idx=0)[source]
Bases:
Module- forward(input: Tensor, key_mask=None, atn_mask=None)[source]
Defines the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Moduleinstance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
- class drugex.training.generators.sequence_transformer.SequenceTransformer(voc_trg, d_emb=512, d_model=512, n_head=8, d_inner=1024, n_layer=12, pad_idx=0, device=device(type='cuda'), use_gpus=(0,))[source]
Bases:
FragGeneratorSequence Transformer for molecule generation from fragments
- forward(src, trg=None)[source]
Forward pass of the model
Parameters:
- src:
torch.Tensor TODO: check that the shape is correct Source tensor of shape [batch_size, 200]
- trg:
torch.Tensor Target tensor of shape [batch_size, 200]
Returns:
TODO: fill outputs
- src:
- loaderFromFrags(frags, batch_size=32, n_proc=1)[source]
Encode the input fragments and create a dataloader object
Parameters:
Returns:
- loader:
torch.utils.data.DataLoader A dataloader object to iterate over the input fragments
- loader:
- sample(loader)[source]
Sample SMILES from the model
Parameters:
- loader:
torch.utils.data.DataLoader A dataloader object to iterate over the input fragments
Returns:
- loader:
- trainNet(loader, epoch, epochs)[source]
Train the model for one epoch
Parameters:
Returns:
- loss:
float The loss value for the current epoch
- loss:
- validateNet(loader, evaluator=None, no_multifrag_smiles=True, n_samples=None)[source]
Validate the model
Parameters:
Returns:
- valid_metrics:
dict A dictionary containing the validation metrics
- scores:
pandas.DataFrame DataFrame containing Smiles, frags and the scores for each SMILES
Notes:
- The validation metrics are:
valid_ratio: ratio of valid SMILES
accurate_ratio: ratio of SMILES that are valid and have the desired fragments
loss_valid: loss on the validation set
- valid_metrics:
drugex.training.generators.utils module
Define the Layers
- class drugex.training.generators.utils.PositionalEmbedding(d_model: int, max_len=100, batch_first=False)[source]
Bases:
ModulePositional embedding for sequence transformer
- forward(x)[source]
Defines the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Moduleinstance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
- class drugex.training.generators.utils.PositionalEncoding(d_model: int, max_len=100, batch_first=False)[source]
Bases:
ModulePositional encoding for graph transformer
- forward(x)[source]
Defines the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Moduleinstance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
- class drugex.training.generators.utils.PositionwiseFeedForward(d_in, d_hid)[source]
Bases:
ModuleA two-feed-forward-layer module
- forward(x)[source]
Defines the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Moduleinstance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
- class drugex.training.generators.utils.SublayerConnection(size, dropout=0.1)[source]
Bases:
ModuleA residual connection followed by a layer norm