drugex.training.explorers package
Submodules
drugex.training.explorers.frag_graph_explorer module
- class drugex.training.explorers.frag_graph_explorer.FragGraphExplorer(agent, env, mutate=None, crover=None, batch_size=128, epsilon=0.1, beta=0.0, n_samples=-1, optim=None, device=device(type='cuda'), use_gpus=(0,), no_multifrag_smiles=True)[source]
Bases:
FragExplorer
Explorer
to optimize a graph-based fragment-using agent with the givenEnvironment
.- forward(src)[source]
Forward pass of the agent.
- Parameters:
src (torch.Tensor) – Input tensor of shape (batch_size, 80, 5).
- Returns:
Output tensor of shape (batch_size, 80, 5).
- Return type:
torch.Tensor
- getBatchOutputs(net, src)[source]
Outputs (frags, smiles) and loss of the agent for a batch of fragments-molecule pairs.
- Parameters:
net (torch.nn.Module) – Agent
src (torch.Tensor) – Fragments-molecule pairs
- Returns:
frags (list) – List of fragments
smiles (list) – List of SMILES
loss (torch.Tensor) – Loss of the agent
- sampleEncodedPairsToLoader(net, loader)[source]
Sample new fragments-molecule pairs from a data loader.
- Parameters:
net (torch.nn.Module) – Agent
loader (torch.utils.data.DataLoader) – Data loader for original fragments-molecule pairs
- Returns:
Data loader for sampled fragments-molecule pairs
- Return type:
torch.utils.data.DataLoader
- sample_input(loader, is_test=False)[source]
Sample a batch of fragments-molecule pairs from the dataset.
- Parameters:
loader (torch.utils.data.DataLoader) – Data loader for original fragments-molecule pairs
is_test (bool) – Whether to sample from the validation set or not
- Returns:
Data loader for sampled fragments-molecule pairs
- Return type:
torch.utils.data.DataLoader
drugex.training.explorers.frag_sequence_explorer module
- class drugex.training.explorers.frag_sequence_explorer.FragSequenceExplorer(agent, env=None, crover=None, mutate=None, batch_size=128, epsilon=0.1, beta=0.0, n_samples=-1, optim=None, device=device(type='cuda'), use_gpus=(0,), no_multifrag_smiles=True)[source]
Bases:
FragExplorer
Explorer
to optimize a sequence-based fragment-using agent with the givenEnvironment
.- forward(src)[source]
Forward pass of the agent.
- Parameters:
src (torch.Tensor) – TODO: check the shape of the input tensor.
- Returns:
TODO: check the shape of the input tensor.
- Return type:
torch.Tensor
- getBatchOutputs(net, src)[source]
Outputs (frags, smiles) and loss of the agent for a batch of fragments-molecule pairs.
- Parameters:
net (torch.nn.Module) – Agent
src (torch.Tensor) – Fragments-molecule pairs
- Returns:
frags (list) – List of fragments
smiles (list) – List of SMILES
loss (torch.Tensor) – Loss of the agent
- sampleEncodedPairsToLoader(net, loader)[source]
Sample new fragments-molecule pairs from a data loader.
- Parameters:
net (torch.nn.Module) – Agent
loader (torch.utils.data.DataLoader) – Data loader for original fragments-molecule pairs
- Returns:
Data loader for sampled fragments-molecule pairs
- Return type:
torch.utils.data.DataLoader
- sample_input(loader, is_test=False)[source]
Sample a batch of fragments-molecule pairs from the dataset.
- Parameters:
loader (torch.utils.data.DataLoader) – Data loader for original fragments-molecule pairs
is_test (bool) – Whether to sample from the validation set or not
- Returns:
Data loader for sampled fragments-molecule pairs
- Return type:
torch.utils.data.DataLoader
drugex.training.explorers.interfaces module
interfaces
Created by: Martin Sicho On: 01.06.22, 11:29
- class drugex.training.explorers.interfaces.Explorer(agent, env, mutate=None, crover=None, no_multifrag_smiles=True, batch_size=128, epsilon=0.1, beta=0.0, n_samples=-1, device=device(type='cuda'), use_gpus=(0,))[source]
-
Implements the DrugEx exploration strategy for DrugEx models under the reinforcement learning framework.
- attachToGPUs(gpus)[source]
Attach the model to GPUs
- Parameters:
gpus (tuple) – The GPUs to use for training.
- abstract fit(train_loader, valid_loader=None, epochs=1000, monitor=None)[source]
Train and validate the model with a given training and validation loader (see
DataSet
and its implementations docs to learn how to generate them).- Parameters:
train_loader (torch.utils.data.DataLoader) – The training data loader.
valid_loader (torch.utils.data.DataLoader) – The validation data loader.
epochs (int, optional) – The number of epochs to train the model for.
monitor (TrainingMonitor, optional) – A
TrainingMonitor
instance to monitor the training process.**kwargs – Additional keyword arguments to pass to the training loop.
- getModel()[source]
Returns the current state of the agent
- Returns:
The current state of the agent
- Return type:
torch.nn.Module
- getNovelMoleculeMetrics(scores)[source]
Get metrics for novel molecules
- Parameters:
scores (pd.DataFrame) – The scores for each molecule.
- Returns:
- The metrics:
valid_ratio (float): ratio of valid molecules
unique_ratio (float): ratio of valid and unique molecules
desired_ratio (float): ratio of valid, unique and desired molecules
avg_amean (float): average arithmetic mean score of valid and unique molecules
avg_gmean (float): average geometric mean score of valid and unique molecules
- Return type:
- class drugex.training.explorers.interfaces.FragExplorer(agent, env, mutate=None, crover=None, no_multifrag_smiles=True, batch_size=128, epsilon=0.1, beta=0.0, n_samples=-1, device=device(type='cuda'), use_gpus=(0,))[source]
Bases:
Explorer
Implements the DrugEx exploration strategy for DrugEx models under the reinforcement learning framework for fragment-based generators
- fit(train_loader, valid_loader=None, epochs=1000, patience=50, criteria='desired_ratio', min_epochs=100, monitor=None)[source]
Fit the graph explorer to the training data.
- Parameters:
train_loader (torch.utils.data.DataLoader) – Data loader for training data
valid_loader (torch.utils.data.DataLoader) – Data loader for validation data
epochs (int) – Number of epochs to train for
patience (int) – Number of epochs to wait for improvement before early stopping
criteria (str) – Criteria to use for early stopping
min_epochs (int) – Minimum number of epochs to train for
monitor (Monitor) – Monitor to use for logging and saving model
- abstract getBatchOutputs(src, net)[source]
Outputs (frags, smiles) and loss of the agent for a batch of fragments-molecule pairs.
- policy_gradient(loader)[source]
Policy gradient training.
Novel molecules are generated by the agent and are scored by the environment. The policy gradient is calculated using the REINFORCE algorithm and the agent is updated.
- Parameters:
loader (torch.utils.data.DataLoader) – Data loader for training data
- Returns:
The average loss of the agent
- Return type:
drugex.training.explorers.sequence_explorer module
- class drugex.training.explorers.sequence_explorer.SequenceExplorer(agent, env, mutate=None, crover=None, no_multifrag_smiles=True, batch_size=128, epsilon=0.1, beta=0.0, n_samples=1000, optim=None, device=device(type='cuda'), use_gpus=(0,))[source]
Bases:
Explorer
Explorer
to optimize a sequence-based agent (RNN) with the givenEnvironment
.- Reference: Liu, X., Ye, K., van Vlijmen, H.W.T. et al. DrugEx v2: De Novo Design of Drug Molecule by
Pareto-based Multi-Objective Reinforcement Learning in Polypharmacology. J Cheminform (2021). https://doi.org/10.1186/s13321-019-0355-6
- fit(train_loader=None, valid_loader=None, monitor=None, epochs=1000, patience=50, reload_interval=50, criteria='desired_ratio', min_epochs=100)[source]
Fit the graph explorer to the training data.
- Parameters:
train_loader (torch.utils.data.DataLoader) – ignored, for compatibility with
FragExplorer
valid_loader (torch.utils.data.DataLoader) – ignored, for compatibility with
FragExplorer
epochs (int) – Number of epochs to train for
patience (int) – Number of epochs to wait for improvement before early stopping
reload_interval (int) – Every nth epoch reset the agent (and the crover) network to the best state
criteria (str) – Criteria to use for early stopping: ‘desired_ratio’, ‘avg_amean’ or ‘avg_gmean’
min_epochs (int) – Minimum number of epochs to train for
monitor (Monitor) – Monitor to use for logging and saving model
- forward()[source]
Generate molecules with the given
agent
network- Returns:
smiles (list) – The generated SMILES.
seqs (torch.Tensor) – The generated encoded sequences.