Source code for drugex.training.generators.sequence_rnn

import pandas as pd

import torch
from torch import nn
from torch import optim
from tqdm.auto import tqdm

from drugex import DEFAULT_DEVICE, DEFAULT_GPUS
from drugex.training.generators.interfaces import Generator


[docs]class SequenceRNN(Generator): """ Sequence RNN model for molecule generation. """ def __init__(self, voc, embed_size=128, hidden_size=512, is_lstm=True, lr=1e-3, device=DEFAULT_DEVICE, use_gpus=DEFAULT_GPUS): super(SequenceRNN, self).__init__(device=device, use_gpus=use_gpus) self.voc = voc self.embed_size = embed_size self.hidden_size = hidden_size self.output_size = voc.size self.embed = nn.Embedding(voc.size, embed_size) self.is_lstm = is_lstm rnn_layer = nn.LSTM if is_lstm else nn.GRU self.rnn = rnn_layer(embed_size, hidden_size, num_layers=3, batch_first=True) self.linear = nn.Linear(hidden_size, voc.size) self.optim = optim.Adam(self.parameters(), lr=lr) self.attachToGPUs(self.gpus) self.model_name = 'SequenceRNN'
[docs] def attachToGPUs(self, gpus): """ This model currently uses only one GPU. Therefore, only the first one from the list will be used. Parameters: ---------- gpus: `tuple` A tuple of GPU indices. Returns: ------- None """ self.device = torch.device(f'cuda:{gpus[0]}') self.to(self.device) self.gpus = (gpus[0],)
[docs] def forward(self, input, h): """ Forward pass of the model. Parameters: ---------- input: `torch.Tensor` Input tensor of shape (batch_size, 1). h: `torch.Tensor` # TODO: Verify h shape. Hidden state tensor of shape (num_layers, batch_size, hidden_size). Returns: ------- TODO: fill outputs """ output = self.embed(input.unsqueeze(-1)) output, h_out = self.rnn(output, h) output = self.linear(output).squeeze(1) return output, h_out
[docs] def init_h(self, batch_size, labels=None): """ Initialize hidden state of the model. Hidden state is initialized with random values. If labels are provided, the first hidden state will be set to the labels. Parameters: ---------- batch_size: `int` Batch size. labels: `torch.Tensor` Labels tensor of shape (batch_size, 1). Returns: ------- TODO: fill outputs """ h = torch.rand(3, batch_size, 512).to(self.device) if labels is not None: h[0, batch_size, 0] = labels if self.is_lstm: c = torch.rand(3, batch_size, self.hidden_size).to(self.device) return (h, c) if self.is_lstm else h
[docs] def likelihood(self, target): """ Calculate the likelihood of the target sequence. Parameters: ---------- target: `torch.Tensor` Target tensor of shape (batch_size, seq_len). Returns: ------- scores: `torch.Tensor` Scores tensor of shape (batch_size, seq_len). """ batch_size, seq_len = target.size() x = torch.LongTensor([self.voc.tk2ix['GO']] * batch_size).to(self.device) h = self.init_h(batch_size) scores = torch.zeros(batch_size, seq_len).to(self.device) for step in range(seq_len): logits, h = self(x, h) logits = logits.log_softmax(dim=-1) score = logits.gather(1, target[:, step:step+1]).squeeze() scores[:, step] = score x = target[:, step] return scores
[docs] def sample(self, batch_size): """ Sample a SMILES from the model. Parameters: ---------- batch_size: `int` Batch size. Returns: ------- smiles: `list` List of SMILES. """ x = torch.LongTensor([self.voc.tk2ix['GO']] * batch_size).to(self.device) h = self.init_h(batch_size) sequences = torch.zeros(batch_size, self.voc.max_len).long().to(self.device) isEnd = torch.zeros(batch_size).bool().to(self.device) for step in range(self.voc.max_len): logit, h = self(x, h) proba = logit.softmax(dim=-1) x = torch.multinomial(proba, 1).view(-1) x[isEnd] = self.voc.tk2ix['EOS'] sequences[:, step] = x end_token = (x == self.voc.tk2ix['EOS']) isEnd = torch.ge(isEnd + end_token, 1) if (isEnd == 1).all(): break return [self.voc.decode(s, is_tk = False) for s in sequences]
[docs] def evolve(self, batch_size, epsilon=0.01, crover=None, mutate=None): """ Evolve a SMILES from the model by sequantial addition of tokens. Parameters: ---------- batch_size: `int` Batch size. epsilon: `float` Probability using the mutate network to generate the next token. crover: `drugex.models.Crover` Crover network. mutate: `drugex.models.Mutate` Mutate network. Returns: ------- TODO: check if ouput smiles are still encoded""" # Start tokens x = torch.LongTensor([self.voc.tk2ix['GO']] * batch_size).to(self.device) # Hidden states initialization for exploitation network (agent) hA = self.init_h(batch_size) # Hidden states initialization for exploration networks (mutate and crover) hM = self.init_h(batch_size) hC = self.init_h(batch_size) # Initialization of output matrix sequences = torch.zeros(batch_size, self.voc.max_len).long().to(self.device) # labels to judge and record which sample is ended is_end = torch.zeros(batch_size).bool().to(self.device) for step in range(self.voc.max_len): # Get unscaled logits and hidden states from agent network and convert them to probabilities with softmax logitA, hA = self(x, hA) proba = logitA.softmax(dim=-1) # If crover combine probablities from agent network and crover network if crover is not None: ratio = torch.rand(batch_size, 1).to(self.device) logitC, hC = crover(x, hC) proba = proba * ratio + logitC.softmax(dim=-1) * (1 - ratio) # If mutate replace with the epsilon-rate the probabilities with the ones from the mutation network if mutate is not None: logitM, hM = mutate(x, hM) is_mutate = (torch.rand(batch_size) < epsilon).to(self.device) proba[is_mutate, :] = logitM.softmax(dim=-1)[is_mutate, :] # sampling based on output probability distribution x = torch.multinomial(proba, 1).view(-1) is_end |= x == self.voc.tk2ix['EOS'] x[is_end] = self.voc.tk2ix['EOS'] sequences[:, step] = x if is_end.all(): break return sequences
[docs] def trainNet(self, loader, epoch, epochs): """ Train the RNN network for one epoch Parameters: ----------- loader : torch.utils.data.DataLoader The data loader for the training set epoch : int The current epoch epochs : int The total number of epochs Returns ------- loss : float The training loss of the epoch """ total_steps = len(loader) for i, batch in enumerate(loader): self.optim.zero_grad() loss = self.likelihood(batch.to(self.device)) loss = -loss.mean() loss.backward() self.optim.step() self.monitor.saveProgress(i, epoch, total_steps, epochs, loss=loss.item()) return loss.item()
[docs] def validateNet(self, loader=None, evaluator=None, no_multifrag_smiles=True, n_samples=128): """ Validate the network Parameters ---------- loader : torch.utils.data.DataLoader A dataloader object to iterate over the validation data to compute the validation loss evaluator : Evaluator An evaluator object to evaluate the generated SMILES no_multifrag_smiles : bool If `True`, only single-fragment SMILES are considered valid n_samples : int The number of SMILES to sample from the model Returns ------- valid_metrics : dict Dictionary containing the validation metrics scores : pandas.DataFrame DataFrame containing Smiles, frags and the scores for each SMILES Notes ----- The validation metrics are: - valid_ratio: the ratio of valid SMILES - accurate_ratio: the ratio of SMILES that are valid and have the desired fragments - loss_valid: the validation loss """ valid_metrics = {} smiles = self.sample(n_samples) scores = self.evaluate(smiles, evaluator=evaluator) scores['SMILES'] = smiles valid_metrics['valid_ratio'] = scores.Valid.mean() # If a separate validation set is provided, use it to compute the validation loss if loader is not None: loss_valid, size = 0, 0 for j, batch in enumerate(loader): size += batch.size(0) loss_valid += -self.likelihood(batch.to(self.device)).sum().item() valid_metrics['loss_valid'] = loss_valid / size / self.voc.max_len return valid_metrics, scores
[docs] def generate(self, num_samples=100, batch_size=32, n_proc=1, drop_duplicates=True, drop_invalid=True, evaluator=None, no_multifrag_smiles=True, drop_undesired=False, raw_scores=True, progress=True, tqdm_kwargs={}): if progress: tqdm_kwargs.update({'total': num_samples, 'desc': 'Generating molecules'}) pbar = tqdm(**tqdm_kwargs) df_all = pd.DataFrame(columns=['SMILES', 'Frags']) # Frags column to be compatible with other generators while not len(df_all) >= num_samples: with torch.no_grad(): new_smiles = self.sample(batch_size) df_new = pd.DataFrame({'SMILES': new_smiles, 'Frags': None}) # If drop_invalid is True, invalid (and inaccurate) SMILES are dropped # valid molecules are canonicalized and optionally extra filtering is applied # else invalid molecules are kept and no filtering is applied if drop_invalid: df_new = self.filterNewMolecules(df_all, df_new, with_frags=False, drop_duplicates=drop_duplicates, drop_undesired=drop_undesired, evaluator=evaluator, no_multifrag_smiles=no_multifrag_smiles) # Update the dataframe with the new molecules df_all = pd.concat([df_all, df_new], axis=0, ignore_index=True) # Update progress bar if progress: pbar.update(len(df_new) if pbar.n + len(df_new) <= num_samples else num_samples - pbar.n) if progress: pbar.close() # Post-processing df = df_all.head(num_samples) if evaluator: df = pd.concat([df, self.evaluate(df.SMILES.tolist(), evaluator=evaluator, no_multifrag_smiles=no_multifrag_smiles, unmodified_scores=raw_scores)[evaluator.getScorerKeys()]], axis=1) return df.drop('Frags', axis=1).round(decimals=3)